;; -*- Mode: Irken -*-
;;; where to store properties?
;;; Some properties belong to the node itself. For example, a RECURSIVE
;;; flag can apply to a particular call, but not all of them.
;;; Other properties belong to a variable or a function - for example a
;;; function can also be RECURSIVE. But the property of escaping belongs
;;; to the function/variable, not to a particular node - although you could
;;; fake it by attaching it to the definition of the function/variable.
;;;
;; XXX consider combining these passes
(define (find-recursion exp context)
(define (walk exp fenv)
(match exp.t with
(node:call)
-> (let ((rator (nth exp.subs 0)))
(match rator.t with
(node:varref name)
-> (let ((var (vars-get-var context name)))
(cond ((member-eq? name fenv)
(set! var.flags (bit-set var.flags VFLAG-RECURSIVE))
(node-set-flag! exp NFLAG-RECURSIVE)))
(set! var.calls (+ 1 var.calls)))
_ -> #u
))
(node:function name _)
-> (begin
(PUSH fenv name)
(set! context.funs (tree/insert context.funs symbol-index<? name exp))
(vars-set-flag! context name VFLAG-FUNCTION))
;; a convenient place to detect this.
(node:primapp name _)
-> (if (or (eq? name '%getcc) (eq? name '%putcc))
(vars-set-flag! context (car fenv) VFLAG-GETCC))
_ -> #u)
(for-each (lambda (x) (walk x fenv)) exp.subs))
(walk exp '()))
(define (find-leaves node)
(define (search exp)
(if (every? search exp.subs)
(let ((leaf?
(match exp.t with
(node:call) -> (node-get-flag exp NFLAG-RECURSIVE)
_ -> #t)))
(if leaf?
(node-set-flag! exp NFLAG-LEAF))
leaf?)
#f))
(search node))
(define (find-refs node context)
(define (add-ref name)
(let ((var (vars-get-var context name)))
(set! var.refs (+ 1 var.refs))))
(define (add-set name)
(let ((var (vars-get-var context name)))
(set! var.sets (+ 1 var.sets))))
(define (walk node)
(match node.t with
(node:varref name) -> (add-ref name)
(node:varset name) -> (add-set name)
_ -> #u)
(for-each walk node.subs))
(walk node))
(define (symbol-add-suffix sym suffix)
(string->symbol (format (symbol->string sym) suffix)))
(define inline-threshold 13)
(define (do-inlining root context)
(let ((inline-counter (make-counter 0))
(rename-counter (make-counter 0))
(multiplier (alist:nil)))
(define (set-multiplier name calls)
;; when we inline <name>, each function that it calls must have its call-count
;; raised by a factor of <calls>.
(let ((g context.dep-graph))
(match (g::get name) with
(maybe:yes deps)
-> (deps::iterate
(lambda (dep)
(match (alist/lookup multiplier dep) with
(maybe:no) -> (alist/push multiplier dep calls)
(maybe:yes _) -> #u)))
(maybe:no) -> #u)))
;; XXX no protection against infinite aliases
(define (follow-aliases fenv name)
(match (alist/lookup fenv name) with
(maybe:no) -> (maybe:no)
(maybe:yes fun)
-> (match fun.t with
(node:varref name0)
-> (follow-aliases fenv name0)
_ -> (maybe:yes (:pair name fun)))))
(define (get-fun-calls name calls)
(match (alist/lookup multiplier name) with
(maybe:yes num) -> (* num calls)
(maybe:no) -> calls))
(define (inline node fenv)
(let/cc return
(match node.t with
(node:fix names)
-> (for-range
i (length names)
(set! fenv (alist:entry (nth names i)
(nth node.subs i)
fenv)))
(node:call)
-> (match node.subs with
() -> (impossible)
({t=(node:varref name) ...} . rands)
-> (match (follow-aliases fenv name) with
(maybe:no) -> #u
(maybe:yes (:pair name fun))
-> (let ((var (vars-get-var context name))
(escapes (bit-get var.flags VFLAG-ESCAPES))
(recursive (bit-get var.flags VFLAG-RECURSIVE))
(getputcc (bit-get var.flags VFLAG-GETCC))
;; this will spin unwanted extra copies
;;(recursive (node-get-flag node NFLAG-RECURSIVE))
(calls (get-fun-calls name var.calls)))
;; (print-string (format "testing " (sym name) " calls " (int calls)
;; " escapes " (bool escapes) " recursive " (bool recursive) "\n"))
(cond ((and (function? fun)
(not (eq? (string-ref (symbol->string name) 0) #\^))
(not getputcc) ;; don't inline functions that use getcc/putcc
(> calls 0)
(and (or (<= fun.size inline-threshold)
(and (= calls 1) (not escapes)))
(not recursive)))
(if (> calls 1)
(set-multiplier name calls))
;; (print-string (format "inline: " (sym name) " calls " (int calls)" escapes " (bool escapes) " recursive " (bool recursive) "\n"))
(let ((r (inline-application fun rands)))
;; record the new variables...
(add-vars r context)
(return (inline r fenv)))))))
;; always inline ((lambda (x) ...) ...)
({t=(node:function name formals) ...} . rands)
-> (let ((r (inline-application (car node.subs) rands)))
;; (print-string (format "inlined lambda: final size = " (int r.size) "\n"))
(add-vars r context)
(return (inline r fenv)))
_ -> #u)
_ -> #u)
(set! node.subs (map (lambda (x) (inline x fenv)) node.subs))
node
))
(define (instantiate fun)
(let ((new-vars '())
(suffix (format "_i" (int (inline-counter.inc))))
(body (nth fun.subs 0)))
(define (append-suffix sym)
(string->symbol (format (symbol->string sym) suffix)))
(define (rename node lenv)
(define (get-new-name name)
(let ((name0 (append-suffix name)))
(set! new-vars (list:cons name0 new-vars))
(set! lenv (list:cons name lenv))
name0))
(define (get-new-names names)
(map get-new-name names))
;; start with a copy of this node.
(set! node (node-copy node))
(match node.t with
(node:let names) -> (set! node.t (node:let (get-new-names names)))
(node:fix names) -> (set! node.t (node:fix (get-new-names names)))
(node:function name formals) -> (set! node.t (node:function (get-new-name name) (get-new-names formals)))
(node:varref name) -> (if (member-eq? name lenv)
(set! node.t (node:varref (append-suffix name))))
(node:varset name) -> (if (member-eq? name lenv)
(set! node.t (node:varset (append-suffix name))))
_ -> #u)
(set! node.subs (map (lambda (x) (rename x lenv)) node.subs))
node)
(rename body '())
))
(define (safe-nvget-inline rands)
(match rands with
({t=(node:primapp '%nvget _) subs=({t=(node:varref name) ...} . _) ...} . _)
-> (let ((var (vars-get-var context name)))
(= 0 var.sets))
_ -> #f))
(define (inline-application fun rands)
(let ((simple '())
(complex '())
(n (length rands)))
(match fun.t with
(node:function name formals)
-> (cond ((not (= n (length formals))) (error1 "inline: bad arity" name))
(else
;;(print-string (format "inlining function " (sym name) " has " (int n) " formals\n"))
(for-range
i n
(let ((formal (nth formals i))
(fvar (vars-get-var context formal))
(rand (nth rands i)))
(if (> fvar.sets 0)
(PUSH complex i) ;; if a formal is assigned to, it must go into a let.
(match rand.t with
(node:literal _) -> (PUSH simple i)
(node:varref arg)
-> (let ((avar (vars-get-var context arg)))
;;(print-string (format "formal: " (sym formal) " avar.sets=" (int avar.sets) " fvar.sets=" (int fvar.sets) "\n"))
(if (> avar.sets 0)
(PUSH complex i)
(PUSH simple i)))
_ -> (if (and (= 1 fvar.refs) (safe-nvget-inline rands))
(PUSH simple i)
(PUSH complex i))))))
;;(print-string " simple, complex=") (print simple) (printn complex) (newline)
(let ((body (instantiate fun)) ;; alpha converted copy of the function
(substs
(if (not (null? simple))
(map (lambda (i) (:pair (nth formals i) (nth rands i))) simple)
'())))
(if (eq? complex (list:nil))
;; simple - substitute arguments directly
(substitute body substs)
;; complex - bind args into (let ...), then inline body
;; generate new names for complex args
(let ((names '())
(inits '())
(nc (length complex)))
(for-each
(lambda (i)
(let ((name (symbol-add-suffix
(nth formals i)
(format "_i" (int (rename-counter.inc))))))
(PUSH names name)
(PUSH inits (nth rands i))
(PUSH substs (:pair (nth formals i) (node/varref name)))
))
(reverse complex))
;;(print-string "substs = ") (printn substs)
(let ((body (substitute body substs)))
(node/let (reverse names) (reverse inits) body)))
))))
_ -> (error1 "inline-application - inlining non-function?" fun)
)))
(define (substitute body substs)
(define (lookup name)
(let loop ((substs substs))
(match substs with
() -> (maybe:no)
((:pair from to) . tl)
-> (if (eq? name from)
(maybe:yes to)
(loop tl)))))
(define (walk node)
(let ((node0
(match node.t with
(node:varref name)
-> (match (lookup name) with
(maybe:yes val) -> val
(maybe:no) -> node)
(node:varset name)
-> (match (lookup name) with
(maybe:yes val) -> (node/varset (varref->name val.t) (car node.subs))
(maybe:no) -> node)
_ -> node)))
(set! node0.subs (map walk node0.subs))
node0))
(walk body))
;; body of do-inlining
(inline root (alist:nil))
))
(define (escape-analysis root context)
(let ((escaping-funs '()))
;; for each variable, we need to know if it might potentially
;; escape. a variable 'escapes' when it is referenced while free
;; inside a function that escapes (i.e., any function that is
;; varref'd outside of the operator position).
(define (fun-escapes name)
(vars-set-flag! context name VFLAG-ESCAPES)
(PUSH escaping-funs name))
(define (find-escaping-functions node parent)
(match node.t with
(node:function name _)
-> (match parent.t with
(node:fix _) -> #u
;; any function defined outside a fix (i.e., a lambda) is by
;; definition an escaping function - because we always reduce
;; ((lambda ...) ...) to (let ...)
_ -> (fun-escapes name))
(node:varref name)
-> (if (vars-get-flag context name VFLAG-FUNCTION)
(match parent.t with
(node:call)
;; any function referenced in a non-rator position
-> (if (not (eq? (first parent.subs) node))
(fun-escapes name))
_ -> (fun-escapes name)))
_ -> #u)
(for-each (lambda (x) (find-escaping-functions x node)) node.subs))
(define (maybe-var-escapes name lenv)
;;(print-string (format "maybe-var-escapes: " (sym name) "\n"))
(if (not (member-eq? name lenv))
;; reference to a free variable. flag it as escaping.
(vars-set-flag! context name VFLAG-ESCAPES)))
;; XXX make sure we still need to know if particular variables escape.
;; within each escaping function, we search for escaping variables.
(define (find-escaping-variables node lenv)
(match node.t with
;; the three binding constructs extend the environment...
(node:function _ formals) -> (set! lenv (append formals lenv))
(node:fix names) -> (set! lenv (append names lenv))
(node:let names) -> (set! lenv (append names lenv))
;; ... and here we search the environment.
(node:varref name) -> (maybe-var-escapes name lenv)
(node:varset name) -> (maybe-var-escapes name lenv)
_ -> #u)
(for-each (lambda (x) (find-escaping-variables x lenv)) node.subs))
;; first we identify escaping functions
(find-escaping-functions root (node/literal (literal:int 0)))
(for-each
(lambda (name)
;;(print-string (format "searching escaping fun " (sym name) "\n"))
(let ((fun (match (tree/member context.funs symbol-index<? name) with
(maybe:yes fun) -> fun
(maybe:no) -> (error1 "find-escaping-funs: failed lookup" name))))
(find-escaping-variables fun '())))
escaping-funs)
))
;; simple cascading optimizations - these only work from the
;; outside-in, not the inside-out, so we make repeated passes
;; until we don't get any more.
(define simpleopt-hits 0)
(define (do-simple-optimizations node)
(set! simpleopt-hits 0)
(let loop ((r (simpleopt node)))
(print-string (format "simpleopt: " (int simpleopt-hits) "\n"))
(cond ((> simpleopt-hits 0)
(set! simpleopt-hits 0)
(loop (simpleopt r)))
(else r))))
(define (simpleopt node)
;; early exit means we've rewritten the node, and will recur without
;; the default processing at the end.
(let/cc return
;; assume we'll hit an optimization, undo it down below if not.
(set! simpleopt-hits (+ 1 simpleopt-hits))
(match node.t with
(node:fix ())
;; empty fix
-> (return (simpleopt (first node.subs)))
(node:fix names0)
-> (match (unpack-fix node.subs) with
;; body of fix is another fix...
(:fixsubs {t=(node:fix names1) subs=subs1 ...} inits0)
-> (match (unpack-fix subs1) with
(:fixsubs body1 inits1)
-> (return
(simpleopt
(node/fix (append names0 names1)
(append inits0 inits1)
body1))))
_ -> #u)
(node:let ())
;; empty let
-> (return (simpleopt (first node.subs)))
(node:let names0)
-> (match (unpack-fix node.subs) with
;; body of let is another let...
(:fixsubs {t=(node:let names1) subs=subs1 ...} inits0)
-> (match (unpack-fix subs1) with
(:fixsubs body1 inits1)
-> (return
(simpleopt
(node/let (append names0 names1)
(append inits0 inits1)
body1))))
(:fixsubs body0 inits0)
;; search for any let in the inits
-> (let ((n (length inits0)))
(for-range
i n
(match (nth inits0 i) with
{t=(node:let names1) subs=subs1 ...}
-> (match (unpack-fix subs1) with
(:fixsubs body1 inits1)
-> (return
(simpleopt
;; (let (a b (c (let (d e) <body1>)) f g) <body0>)
;; => (let (a b d e (c <body1>) f g) <body0>)
;; Note: this looks like it makes <d> and <e> visible where they shouldn't be,
;; but access to those variables in <body0> will have already been caught
;; during earlier phases of the compiler.
(node/let (append (append (slice names0 0 i) names1) (slice names0 i n))
(append (append (slice inits0 0 i) inits1) (list:cons body1 (slice inits0 (+ i 1) n)))
body0))))
_ -> #u))
#u))
(node:if)
-> (match node.subs with
;; (if #t a b) => a
({t=(node:literal (literal:cons 'bool b _)) ...} then else)
-> (if (eq? b 'true) (return (simpleopt then)) (return (simpleopt else)))
_ -> #u)
(node:sequence)
-> (if (= 1 (length node.subs))
;; single-item sequence
(return (simpleopt (nth node.subs 0)))
;; sequence within sequence
(let ((subs0 node.subs)
(n (length subs0)))
(for-range
i n
(match (nth subs0 i) with
{t=(node:sequence) subs=subs1 ...}
-> (return
(simpleopt
(node/sequence
(append (append (slice subs0 0 i) subs1) (slice subs0 (+ 1 i) n)))))
_ -> #u))
#u))
_ -> #u)
(set! simpleopt-hits (- simpleopt-hits 1))
;; if we get here, there was no early exit, so just recurse
;; onto the sub-expressions
(let ((new-subs (map simpleopt node.subs)))
(set! node.subs new-subs)
node
)))
(define removed-count 0)
(define (do-trim top context)
(let ((g context.dep-graph)
(seen (symbol-set-maker '())))
;;(print-string "do-trim:\n")
;;(print-graph g)
(define (walk name)
(match (g::get name) with
(maybe:no) -> #u ;; not a function / no deps
(maybe:yes deps)
-> (begin
(seen::add name)
(for-each
(lambda (dep)
(if (not (seen::in dep))
(walk dep)))
(deps::get)))))
(define (trim node)
(let ((node0
(match node.t with
(node:fix names)
-> (let ((n (length names))
(inits node.subs)
(remove '()))
(for-range
i n
(if (not (seen::in (nth names i)))
(PUSH remove i)))
(if (null? remove)
node
(let ((new-names '())
(new-inits '())
;; XXX remove when happy, this var only for the print
(trimmed (map (lambda (i) (nth names i)) remove)))
;;(print-string (format "trimming: " (join symbol->string ", " trimmed) "\n"))
(for-range
i n
(cond ((not (member-eq? i remove))
(PUSH new-names (nth names i))
(PUSH new-inits (nth inits i)))))
(node/fix (reverse new-names)
(reverse new-inits)
(last inits))
)))
_ -> node)))
(set! node0.subs (map trim node0.subs))
node0))
(walk 'top)
(trim top)
))
(define (analyze exp context)
;; clear the variable table
(set! context.vars (tree/empty))
(set! context.funs (tree/empty))
;; rebuild it
(build-vars exp context)
(find-recursion exp context)
(find-refs exp context)
(escape-analysis exp context)
)
(define (do-one-round node context)
(analyze node context)
;;(print-vars context)
(build-dependency-graph node context)
;;(print-graph context.dep-graph)
;; trim, simple, inline, simple
(do-simple-optimizations
(do-inlining
(do-simple-optimizations
(do-trim node context))
context)))