;; -*- Mode: Irken -*- (include "lib/counter.scm") (include "lib/stack.scm") ;; See "The Implementation of Functional Programming Languages", ;; Chapter 5: "Efficient Compilation of Pattern-Matching". ;; http://research.microsoft.com/en-us/um/people/simonpj/papers/slpj-book-1987/ ;; ;; Thanks for the hint, OCaml people! (Xavier Leroy?) They were kind ;; enough to put this reference in their source code ;; (ocaml/bytecomp/matching.ml), otherwise I may have never found ;; out about this book. And thanks to Simon Peyton-Jones for ;; putting his book online. (datatype fieldpair (:t symbol pattern) ) (datatype pattern (:literal sexp) (:variable symbol) (:constructor symbol symbol (list pattern)) (:record (list fieldpair)) ) (datatype rule (:t (list pattern) sexp)) (define rule->code (rule:t _ code) -> code) (define rule->pats (rule:t pats _) -> pats) (define match-error (sexp (sexp:symbol '%match-error) (sexp:bool #f))) (define match-fail (sexp (sexp:symbol '%fail) (sexp:bool #f))) (define match-counter (make-counter 0)) (define (new-match-var) (string->symbol (format "m" (int (match-counter.inc))))) (define notdotdotdot (field:t '... _) -> #f _ -> #t ) (define (compile-pattern context expander vars exp) (define (parse-pattern exp) (define parse-fieldpair (field:t name pat) -> (fieldpair:t name (kind pat))) (define kind (sexp:symbol s) -> (pattern:variable s) ;; for now, ignore '...' in record patterns (sexp:record fields) -> (pattern:record (map parse-fieldpair (filter notdotdotdot fields))) (sexp:bool b) -> (pattern:constructor 'bool (if b 'true 'false) '()) (sexp:list l) -> (match l with () -> (pattern:constructor 'list 'nil '()) ((sexp:symbol 'quote) (sexp:symbol s)) -> (pattern:literal (sexp:symbol s)) ((sexp:cons dt alt) . args) -> (pattern:constructor dt alt (map kind args)) ((sexp:symbol '.) last) -> (kind last) (hd . tl) -> (pattern:constructor 'list 'cons (LIST (kind hd) (kind (sexp:list tl)))) _ -> (error1 "malformed pattern" l)) x -> (pattern:literal x)) (kind exp)) ;; (p0 p1 p2 -> r0 ...) (define (parse-match expander body) (let loop ((patterns '()) (rules '()) (l body)) (match l with () -> (reverse rules) ((sexp:symbol '->) code . tl) -> (loop '() (list:cons (rule:t (reverse patterns) (expander code)) rules) tl) (pat . tl) -> (loop (list:cons (parse-pattern pat) patterns) rules tl)))) ;; XXX redo with format after writing (define (dump-pat p) (define ps print-string) (define dump-field (fieldpair:t name vpat) -> (begin (print name) (ps "=") (dump-pat vpat))) (match p with (pattern:literal exp) -> (begin (ps "L") (ps (repr exp))) (pattern:variable var) -> (print var) (pattern:constructor dt alt args) -> (begin (ps "(") (print dt) (ps ":") (print alt) (ps " ") (for-each (lambda (x) (dump-pat x) (ps " ")) args) (ps ")")) (pattern:record fpats) -> (begin (ps "{") (for-each (lambda (fp) (dump-field fp) (ps " ")) fpats) (ps "}")) _ -> (error1 "NYI" p))) (define pattern->kind (pattern:literal _) -> 'literal (pattern:variable _) -> 'variable (pattern:constructor _ _ _ ) -> 'constructor (pattern:record _) -> 'record ) ;; pull the first pattern out of each rule (define remove-first-pat (rule:t (pat . pats) code) -> (rule:t pats code) _ -> (error "remove-first-pat: empty pats?")) (define first-pattern-kind (rule:t (pat0 . pats) _) -> (pattern->kind pat0) _ -> (error "empty pattern list?")) (define (compare-first-patterns a b) (eq? (first-pattern-kind a) (first-pattern-kind b))) ;; (define thingum-counter (make-counter 0)) ;; (define (compile-match vars rules default) ;; (let ((n (thingum-counter.inc))) ;; (print-string "compile-match: ") (printn n) ;; (for-each dump-rule rules) ;; (let ((r (compile-match* vars rules default))) ;; (print-string "\n -- result=") (printn n) ;; (pp 0 r) (newline) ;; r))) (define (compile-match vars rules default) (match vars rules with ;; the 'empty rule' () () -> default () (rule . _) -> (rule->code rule) _ _ -> ;; group the rules by kind of first pattern (let ((groups (pack rules compare-first-patterns))) (if (= (length groups) 1) ;; one of the standard rules (compile-group vars (car groups) default) ;; mixture rule (begin ;;(print-string " -- kind = mixture\n") (for-each (lambda (group) (set! default (compile-group vars group default))) ;; the python code iteratively calls pop(), which ;; is equivalent to iterating over it in reverse. (reverse groups)) default) )))) ;; we know the rules are of identical kind (define (compile-group vars rules default) (let ((kind (first-pattern-kind (car rules)))) ;;(print-string " -- kind = ") (printn kind) (match kind with 'literal -> (constant-rule vars rules default) 'variable -> (variable-rule vars rules default) 'constructor -> (constructor-rule vars rules default) 'record -> (record-rule vars rules default) _ -> (impossible)))) (define (fatbar e1 e2) (cond ((eq? e1 match-fail) e2) ((eq? e2 match-fail) e1) (else (sexp1 '%fatbar (LIST (sexp:bool #f) e1 e2))))) (define (subst var0 pat code) (match pat with (pattern:variable var1) ;; record a subst to be applied during node building (unless it's a wildcard pattern) -> (if (not (eq? var1 '_)) (sexp (sexp:symbol 'let_subst) (sexp (sexp:symbol var1) (sexp:symbol var0)) code) code) _ -> (impossible) )) ;; if every rule begins with a variable, we can remove that column ;; from the set of patterns and substitute the var within each body (define (variable-rule vars rules default) (let ((var0 (car vars)) (rules0 (map (lambda (rule) (match rule with (rule:t pats code) -> (rule:t (cdr pats) (subst var0 (car pats) code)))) rules))) (compile-match (cdr vars) rules0 default))) (define fieldpair->label (fieldpair:t label _) -> label) (define fieldpair->pattern (fieldpair:t _ pattern) -> pattern) (define pattern->fieldpairs (pattern:record fields) -> fields _ -> (error "not a record pattern")) (define (pattern->record-sig p) (map fieldpair->label (pattern->fieldpairs p))) (define (equal-sigs? a b) (every2? eq? a b)) (define (record-rule vars rules default) ;; first - sanity check, make sure each sig matches. (let ((sig0 (pattern->record-sig (car (rule->pats (car rules)))))) (for-each (lambda (rule) (if (not (equal-sigs? sig0 (pattern->record-sig (car (rule->pats rule))))) (error1 "record pattern with different label sigs" rules))) (cdr rules)) ;; translate ;; (print-string "record-rule, vars=") (printn vars) ;; (print-string "record-rule, sig0=") (printn sig0) (let ((var0 (nth vars 0)) (vars0 (map (lambda (field) (string->symbol (format (sym var0) "_" (sym field)))) sig0)) (rules0 (map (lambda (rule) (let ((pats (rule->pats rule)) (pats0 (map fieldpair->pattern (pattern->fieldpairs (car pats))))) (rule:t (append pats0 (cdr pats)) (rule->code rule)))) rules)) (bindings (map-range i (length vars0) (sexp (sexp:symbol (nth vars0 i)) (sexp:attr (sexp:symbol var0) (nth sig0 i)))))) (sexp (sexp:symbol 'let) (sexp:list bindings) (compile-match (append vars0 (cdr vars)) rules0 default))))) (define pattern->literal (pattern:literal exp) -> exp _ -> (error "not a literal pattern")) (define (first-literal=? r0 r1) (match r0 r1 with (rule:t pats0 _) (rule:t pats1 _) -> (sexp=? (pattern->literal (car pats0)) (pattern->literal (car pats1))))) (define (constant-rule vars rules default0) ;; group runs of the same literal together (let loop ((groups (pack rules first-literal=?)) (default default0)) (match groups with () -> default (rules0 . groups) -> (let ((lit (pattern->literal (car (rule->pats (car rules0))))) (comp-fun (match lit with (sexp:string _) -> (sexp:symbol 'string=?) _ -> (sexp:symbol 'eq?)))) (loop groups (fatbar (sexp (sexp:symbol 'if) (sexp comp-fun (sexp:symbol (car vars)) (sexp1 'quote (LIST lit))) (compile-match (cdr vars) (map remove-first-pat rules0) match-fail) match-fail) default)))))) ;; sort a collection into lists with matching

;;

must return an eq?-compatible object. returns an alist of stacks. (define (collect p l) (let loop ((acc (alist/make)) (l l)) (match l with () -> acc (hd . tl) -> (let ((key (p hd))) (match (alist/lookup acc key) with (maybe:no) -> (let ((stack (make-stack))) (stack.push hd) (loop (alist:entry key stack acc) tl)) (maybe:yes stack) -> (begin (stack.push hd) (loop acc tl))))))) (define pattern->dt (pattern:constructor dt _ _) -> dt _ -> (error "not a constructor pattern")) (define pattern->alt (pattern:constructor _ alt _) -> alt _ -> (error "not a constructor pattern")) (define pattern->subs (pattern:constructor _ _ subs) -> subs _ -> (error "not a constructor pattern")) (define rule->constructor-dt (rule:t pats _) -> (pattern->dt (car pats))) (define rule->constructor-alt (rule:t pats _) -> (pattern->alt (car pats))) (define (sort-constructor-rules rules) ;; first, make sure we're all on the same datatype (let ((by-dt (collect rule->constructor-dt rules)) (keys (alist->keys by-dt))) (if (not (= (length keys) 1)) (error1 "more than one datatype in pattern match" keys) (collect rule->constructor-alt rules)))) ;; this handles normal constructors *and* polymorphic variants. (define (constructor-rule vars rules default) (let ((dtname (rule->constructor-dt (car rules))) (alts (sort-constructor-rules rules)) (nalts 0) (mdt (alist/lookup context.datatypes (rule->constructor-dt (car rules)))) (default0 (if (sexp=? default match-error) default match-fail)) (cases '()) ) (alist/iterate (lambda (tag rules-stack) (let ((arity (match mdt with (maybe:no) -> (length (pattern->subs (car (rule->pats (rules-stack.top))))) (maybe:yes dt) -> (let ((alt (dt.get tag))) alt.arity))) (vars0 (nthunk arity new-match-var)) (wild (make-vector arity #t)) (rules1 '())) (set! nalts (+ nalts 1)) (define frob-rule (rule:t pats code) -> (let ((subs (pattern->subs (car pats)))) (if (not (= (length subs) arity)) (error1 "arity mismatch in variant pattern" rules)) (PUSH rules1 (rule:t (append (pattern->subs (car pats)) (cdr pats)) code)) (for-range i arity (match (nth subs i) with (pattern:variable '_) -> #u _ -> (set! wild[i] #f)) ))) (for-each frob-rule (rules-stack.get)) ;; if every pattern has a wildcard for this arg of the constructor, ;; then use '_' rather than the symbol we generated. (let ((vars1 (map-range i arity (if wild[i] '_ (nth vars0 i))))) (PUSH cases ;; ((:tag var0 var1 ...) (match ...)) (sexp (sexp:list (list:cons (sexp:cons 'nil tag) (map sexp:symbol vars1))) ;; we don't reverse rules1 because we popped it off a reversed stack already (compile-match (append vars0 (cdr vars)) rules1 default0)))))) alts) (let ((result (match mdt with (maybe:yes dt) -> (begin (if (< nalts (dt.get-nalts)) (PUSH cases (sexp (sexp:symbol 'else) default0))) (sexp:list (append (LIST (sexp:symbol 'vcase) (sexp:symbol dt.name) (sexp:symbol (car vars))) (reverse cases)))) (maybe:no) -> (begin (if (not (eq? default match-error)) (PUSH cases (sexp (sexp:symbol 'else) match-fail))) (sexp:list (append (LIST (sexp:symbol 'vcase) (sexp:symbol (car vars))) (reverse cases))))) )) (if (not (eq? default match-error)) (fatbar result default) result)))) (define dump-rule (rule:t pats code) -> (begin (for-each (lambda (p) (dump-pat p) (print-string " ")) pats) (print-string "-> ") (pp 0 code) (newline) )) (define nthunk 0 p -> '() n p -> (list:cons (p) (nthunk (- n 1) p))) (let ((rules (parse-match expander exp))) ;; (print-string "compiling match:\n") ;; (for-each dump-rule rules) (newline) (let ((npats (length (rule->pats (car rules)))) (vars (if (null? vars) (nthunk npats new-match-var) vars)) (result (compile-match vars rules match-error))) ;; (print-string "match compiler result:\n") ;; (pp 0 result) (newline) ;; (print-string " ---\n") (:pair vars result))) )