;; -*- Mode: Irken -*-
(include "lib/counter.scm")
(include "lib/stack.scm")
;; See "The Implementation of Functional Programming Languages",
;; Chapter 5: "Efficient Compilation of Pattern-Matching".
;; http://research.microsoft.com/en-us/um/people/simonpj/papers/slpj-book-1987/
;;
;; Thanks for the hint, OCaml people! (Xavier Leroy?) They were kind
;; enough to put this reference in their source code
;; (ocaml/bytecomp/matching.ml), otherwise I may have never found
;; out about this book. And thanks to Simon Peyton-Jones for
;; putting his book online.
(datatype fieldpair
(:t symbol pattern)
)
(datatype pattern
(:literal sexp)
(:variable symbol)
(:constructor symbol symbol (list pattern))
(:record (list fieldpair))
)
(datatype rule
(:t (list pattern) sexp))
(define rule->code (rule:t _ code) -> code)
(define rule->pats (rule:t pats _) -> pats)
(define match-error (sexp (sexp:symbol '%match-error) (sexp:bool #f)))
(define match-fail (sexp (sexp:symbol '%fail) (sexp:bool #f)))
(define match-counter (make-counter 0))
(define (new-match-var)
(string->symbol (format "m" (int (match-counter.inc)))))
(define notdotdotdot
(field:t '... _) -> #f
_ -> #t
)
(define (compile-pattern context expander vars exp)
(define (parse-pattern exp)
(define parse-fieldpair
(field:t name pat) -> (fieldpair:t name (kind pat)))
(define kind
(sexp:symbol s) -> (pattern:variable s)
;; for now, ignore '...' in record patterns
(sexp:record fields) -> (pattern:record (map parse-fieldpair (filter notdotdotdot fields)))
(sexp:bool b) -> (pattern:constructor 'bool (if b 'true 'false) '())
(sexp:list l)
-> (match l with
() -> (pattern:constructor 'list 'nil '())
((sexp:symbol 'quote) (sexp:symbol s)) -> (pattern:literal (sexp:symbol s))
((sexp:cons dt alt) . args) -> (pattern:constructor dt alt (map kind args))
((sexp:symbol '.) last) -> (kind last)
(hd . tl) -> (pattern:constructor 'list 'cons (LIST (kind hd) (kind (sexp:list tl))))
_ -> (error1 "malformed pattern" l))
x -> (pattern:literal x))
(kind exp))
;; (p0 p1 p2 -> r0 ...)
(define (parse-match expander body)
(let loop ((patterns '())
(rules '())
(l body))
(match l with
() -> (reverse rules)
((sexp:symbol '->) code . tl)
-> (loop '() (list:cons (rule:t (reverse patterns) (expander code)) rules) tl)
(pat . tl)
-> (loop (list:cons (parse-pattern pat) patterns) rules tl))))
;; XXX redo with format after writing <sexp-repr>
(define (dump-pat p)
(define ps print-string)
(define dump-field
(fieldpair:t name vpat)
-> (begin (print name)
(ps "=")
(dump-pat vpat)))
(match p with
(pattern:literal exp)
-> (begin (ps "L") (ps (repr exp)))
(pattern:variable var)
-> (print var)
(pattern:constructor dt alt args)
-> (begin (ps "(") (print dt) (ps ":") (print alt) (ps " ")
(for-each (lambda (x) (dump-pat x) (ps " ")) args)
(ps ")"))
(pattern:record fpats)
-> (begin (ps "{")
(for-each (lambda (fp) (dump-field fp) (ps " ")) fpats)
(ps "}"))
_ -> (error1 "NYI" p)))
(define pattern->kind
(pattern:literal _) -> 'literal
(pattern:variable _) -> 'variable
(pattern:constructor _ _ _ ) -> 'constructor
(pattern:record _) -> 'record
)
;; pull the first pattern out of each rule
(define remove-first-pat
(rule:t (pat . pats) code)
-> (rule:t pats code)
_ -> (error "remove-first-pat: empty pats?"))
(define first-pattern-kind
(rule:t (pat0 . pats) _) -> (pattern->kind pat0)
_ -> (error "empty pattern list?"))
(define (compare-first-patterns a b)
(eq? (first-pattern-kind a)
(first-pattern-kind b)))
;; (define thingum-counter (make-counter 0))
;; (define (compile-match vars rules default)
;; (let ((n (thingum-counter.inc)))
;; (print-string "compile-match: ") (printn n)
;; (for-each dump-rule rules)
;; (let ((r (compile-match* vars rules default)))
;; (print-string "\n -- result=") (printn n)
;; (pp 0 r) (newline)
;; r)))
(define (compile-match vars rules default)
(match vars rules with
;; the 'empty rule'
() () -> default
() (rule . _) -> (rule->code rule)
_ _ ->
;; group the rules by kind of first pattern
(let ((groups (pack rules compare-first-patterns)))
(if (= (length groups) 1)
;; one of the standard rules
(compile-group vars (car groups) default)
;; mixture rule
(begin
;;(print-string " -- kind = mixture\n")
(for-each
(lambda (group)
(set! default (compile-group vars group default)))
;; the python code iteratively calls pop(), which
;; is equivalent to iterating over it in reverse.
(reverse groups))
default)
))))
;; we know the rules are of identical kind
(define (compile-group vars rules default)
(let ((kind (first-pattern-kind (car rules))))
;;(print-string " -- kind = ") (printn kind)
(match kind with
'literal -> (constant-rule vars rules default)
'variable -> (variable-rule vars rules default)
'constructor -> (constructor-rule vars rules default)
'record -> (record-rule vars rules default)
_ -> (impossible))))
(define (fatbar e1 e2)
(cond ((eq? e1 match-fail) e2)
((eq? e2 match-fail) e1)
(else
(sexp1 '%fatbar (LIST (sexp:bool #f) e1 e2)))))
(define (subst var0 pat code)
(match pat with
(pattern:variable var1)
;; record a subst to be applied during node building (unless it's a wildcard pattern)
-> (if (not (eq? var1 '_))
(sexp (sexp:symbol 'let_subst)
(sexp (sexp:symbol var1)
(sexp:symbol var0)) code)
code)
_ -> (impossible)
))
;; if every rule begins with a variable, we can remove that column
;; from the set of patterns and substitute the var within each body
(define (variable-rule vars rules default)
(let ((var0 (car vars))
(rules0 (map (lambda (rule)
(match rule with
(rule:t pats code)
-> (rule:t (cdr pats) (subst var0 (car pats) code))))
rules)))
(compile-match (cdr vars) rules0 default)))
(define fieldpair->label
(fieldpair:t label _) -> label)
(define fieldpair->pattern
(fieldpair:t _ pattern) -> pattern)
(define pattern->fieldpairs
(pattern:record fields) -> fields
_ -> (error "not a record pattern"))
(define (pattern->record-sig p)
(map fieldpair->label (pattern->fieldpairs p)))
(define (equal-sigs? a b)
(every2? eq? a b))
(define (record-rule vars rules default)
;; first - sanity check, make sure each sig matches.
(let ((sig0 (pattern->record-sig (car (rule->pats (car rules))))))
(for-each
(lambda (rule)
(if (not (equal-sigs? sig0 (pattern->record-sig (car (rule->pats rule)))))
(error1 "record pattern with different label sigs" rules)))
(cdr rules))
;; translate
;; (print-string "record-rule, vars=") (printn vars)
;; (print-string "record-rule, sig0=") (printn sig0)
(let ((var0 (nth vars 0))
(vars0 (map (lambda (field) (string->symbol (format (sym var0) "_" (sym field)))) sig0))
(rules0
(map (lambda (rule)
(let ((pats (rule->pats rule))
(pats0 (map fieldpair->pattern (pattern->fieldpairs (car pats)))))
(rule:t (append pats0 (cdr pats)) (rule->code rule))))
rules))
(bindings
(map-range
i (length vars0)
(sexp (sexp:symbol (nth vars0 i))
(sexp:attr (sexp:symbol var0) (nth sig0 i))))))
(sexp (sexp:symbol 'let)
(sexp:list bindings)
(compile-match (append vars0 (cdr vars))
rules0
default)))))
(define pattern->literal
(pattern:literal exp) -> exp
_ -> (error "not a literal pattern"))
(define (first-literal=? r0 r1)
(match r0 r1 with
(rule:t pats0 _) (rule:t pats1 _)
-> (sexp=? (pattern->literal (car pats0))
(pattern->literal (car pats1)))))
(define (constant-rule vars rules default0)
;; group runs of the same literal together
(let loop ((groups (pack rules first-literal=?))
(default default0))
(match groups with
() -> default
(rules0 . groups)
-> (let ((lit (pattern->literal (car (rule->pats (car rules0)))))
(comp-fun
(match lit with
(sexp:string _) -> (sexp:symbol 'string=?)
_ -> (sexp:symbol 'eq?))))
(loop groups
(fatbar (sexp (sexp:symbol 'if)
(sexp comp-fun (sexp:symbol (car vars)) (sexp1 'quote (LIST lit)))
(compile-match (cdr vars) (map remove-first-pat rules0) match-fail)
match-fail)
default))))))
;; sort a collection <l> into lists with matching <p>
;; <p> must return an eq?-compatible object. returns an alist of stacks.
(define (collect p l)
(let loop ((acc (alist/make))
(l l))
(match l with
() -> acc
(hd . tl)
-> (let ((key (p hd)))
(match (alist/lookup acc key) with
(maybe:no) -> (let ((stack (make-stack)))
(stack.push hd)
(loop (alist:entry key stack acc) tl))
(maybe:yes stack) -> (begin (stack.push hd) (loop acc tl)))))))
(define pattern->dt
(pattern:constructor dt _ _) -> dt
_ -> (error "not a constructor pattern"))
(define pattern->alt
(pattern:constructor _ alt _) -> alt
_ -> (error "not a constructor pattern"))
(define pattern->subs
(pattern:constructor _ _ subs) -> subs
_ -> (error "not a constructor pattern"))
(define rule->constructor-dt
(rule:t pats _)
-> (pattern->dt (car pats)))
(define rule->constructor-alt
(rule:t pats _)
-> (pattern->alt (car pats)))
(define (sort-constructor-rules rules)
;; first, make sure we're all on the same datatype
(let ((by-dt (collect rule->constructor-dt rules))
(keys (alist->keys by-dt)))
(if (not (= (length keys) 1))
(error1 "more than one datatype in pattern match" keys)
(collect rule->constructor-alt rules))))
;; this handles normal constructors *and* polymorphic variants.
(define (constructor-rule vars rules default)
(let ((dtname (rule->constructor-dt (car rules)))
(alts (sort-constructor-rules rules))
(nalts 0)
(mdt (alist/lookup context.datatypes
(rule->constructor-dt (car rules))))
(default0 (if (sexp=? default match-error) default match-fail))
(cases '())
)
(alist/iterate
(lambda (tag rules-stack)
(let ((arity (match mdt with
(maybe:no) -> (length (pattern->subs (car (rule->pats (rules-stack.top)))))
(maybe:yes dt) -> (let ((alt (dt.get tag)))
alt.arity)))
(vars0 (nthunk arity new-match-var))
(wild (make-vector arity #t))
(rules1 '()))
(set! nalts (+ nalts 1))
(define frob-rule
(rule:t pats code)
-> (let ((subs (pattern->subs (car pats))))
(if (not (= (length subs) arity))
(error1 "arity mismatch in variant pattern" rules))
(PUSH rules1 (rule:t (append (pattern->subs (car pats)) (cdr pats)) code))
(for-range
i arity
(match (nth subs i) with
(pattern:variable '_) -> #u
_ -> (set! wild[i] #f))
)))
(for-each frob-rule (rules-stack.get))
;; if every pattern has a wildcard for this arg of the constructor,
;; then use '_' rather than the symbol we generated.
(let ((vars1 (map-range i arity (if wild[i] '_ (nth vars0 i)))))
(PUSH cases
;; ((:tag var0 var1 ...) (match ...))
(sexp
(sexp:list
(list:cons (sexp:cons 'nil tag) (map sexp:symbol vars1)))
;; we don't reverse rules1 because we popped it off a reversed stack already
(compile-match (append vars0 (cdr vars)) rules1 default0))))))
alts)
(let ((result
(match mdt with
(maybe:yes dt)
-> (begin (if (< nalts (dt.get-nalts))
(PUSH cases (sexp (sexp:symbol 'else) default0)))
(sexp:list (append (LIST (sexp:symbol 'vcase) (sexp:symbol dt.name) (sexp:symbol (car vars)))
(reverse cases))))
(maybe:no)
-> (begin (if (not (eq? default match-error))
(PUSH cases (sexp (sexp:symbol 'else) match-fail)))
(sexp:list (append (LIST (sexp:symbol 'vcase) (sexp:symbol (car vars)))
(reverse cases)))))
))
(if (not (eq? default match-error))
(fatbar result default)
result))))
(define dump-rule
(rule:t pats code)
-> (begin (for-each (lambda (p)
(dump-pat p)
(print-string " ")) pats)
(print-string "-> ")
(pp 0 code)
(newline)
))
(define nthunk
0 p -> '()
n p -> (list:cons (p) (nthunk (- n 1) p)))
(let ((rules (parse-match expander exp)))
;; (print-string "compiling match:\n")
;; (for-each dump-rule rules) (newline)
(let ((npats (length (rule->pats (car rules))))
(vars (if (null? vars)
(nthunk npats new-match-var)
vars))
(result (compile-match vars rules match-error)))
;; (print-string "match compiler result:\n")
;; (pp 0 result) (newline)
;; (print-string " ---\n")
(:pair vars result)))
)