;; -*- Mode: Irken -*-

(include "self/context.scm")
(include "self/transform.scm")

(datatype literal
  (:string string)
  (:int int)
  (:char char)
  (:undef)
  (:symbol symbol)
  (:cons symbol symbol (list literal))
  (:vector (list literal))
  )

;; node type holds metadata related to the node,
;;  but sub-nodes are held with the record.
(datatype node
  (:varref symbol)
  (:varset symbol)
  (:literal literal)
  (:cexp (list type) type string) ;; generic-tvars type template
  (:nvcase symbol (list symbol) (list int))  ;; datatype alts arities
  (:sequence)
  (:if)
  (:function symbol (list symbol)) ;; name formals
  (:call)
  (:let (list symbol))
  (:fix (list symbol))
  (:subst symbol symbol)
  (:primapp symbol sexp) ;; name params
  )

(define node-counter (make-counter 0))

;; given a list of nodes, add up their sizes (+1)
(define (sum-size l)
  (fold (lambda (n acc) (+ n.size acc)) 1 l))

(define no-type (pred '? '()))

(define no-type?
  (type:pred '? () _) -> #t
  _ -> #f
  )

;; a cleaner way to do this might be with an alist? (makes sense if
;;   most flags are clear most of the time?)
;; flags
(define (node-get-flag node i) 
  (bit-get node.flags i))
(define (node-set-flag! node i)
  (set! node.flags (bit-set node.flags i)))

;; defined node flags
(define NFLAG-RECURSIVE 0)
(define NFLAG-ESCAPES   1)
(define NFLAG-LEAF      2)
(define NFLAG-NFLAGS    3)

(define (make-node t subs)
  {t=t subs=subs size=(sum-size subs) id=(node-counter.inc) type=no-type flags=0}
  )

(define (node/varref name)
  (make-node (node:varref name) '()))

(define (node/varset name val)
  (make-node (node:varset name) (LIST val)))

(define varref->name
  (node:varref name) -> name
  _ -> (error "varref->name"))

(define (node/literal lit)
  (make-node (node:literal lit) '()))

(define (node/cexp gens type template args)
  (make-node (node:cexp gens type template) args))

(define (node/sequence subs)
  (make-node (node:sequence) subs))

(define (node/if test then else)
  (let ((nodes (LIST test then else)))
    (make-node (node:if) nodes)))

(define (node/function name formals type body)
  (let ((node (make-node (node:function name formals) (LIST body))))
    (match type with
      (sexp:bool #f) -> node
      type-exp -> (begin (set! node.type (parse-type type)) node))))

(define (function? node)
  (match node.t with
    (node:function _ _ ) -> #t
    _ -> #f))

(define (function->name node)
  (match node.t with
    (node:function name _) -> name
    _ -> (error "function->name")))

(define (node/call rator rands)
  (let ((subs (list:cons rator rands)))
    (make-node (node:call) subs)))

(define (node/fix names inits body)
  (let ((subs (append inits (LIST body))))
    (make-node (node:fix names) subs)))

(define (node/let names inits body)
  (let ((subs (append inits (LIST body))))
    (make-node (node:let names) subs)))

(define (node/nvcase dt tags arities value alts else)
  (let ((subs (list:cons value (list:cons else alts))))
    (make-node (node:nvcase dt tags arities) subs)))

(define (node/subst from to body)
  (make-node (node:subst from to) (LIST body)))

(define (node/primapp name params args)
  (make-node (node:primapp name params) args))

(define (node-copy node0)
  (match node0.t with
    (node:literal _) -> node0 ;; no deep copy on literals
    _ -> (let ((node1 (make-node node0.t node0.subs)))
	   (set! node1.flags node0.flags)
	   (set! node1.type node0.type)
	   node1)))

(define (unpack-fix subs)
  ;; unpack (init0 init1 ... body) for fix and let.
  (let ((rsubs (reverse subs)))
    (match rsubs with
      (body . rinits) -> (:fixsubs body (reverse rinits))
      _ -> (error "unpack-fix: no body?"))))

(define literal->string
  (literal:string s)	 -> (format (char #\") s (char #\")) ;; XXX correct string repr
  (literal:symbol s)     -> (format (sym s))
  (literal:int n)	 -> (format (int n))
  (literal:char ch)	 -> (format (char #\#) (char #\\) (char ch)) ;; printable?
  (literal:undef)	 -> (format "#u")
  (literal:cons dt v ()) -> (format "(" (sym dt) ":" (sym v) ")")
  (literal:cons dt v l)	 -> (format "(" (sym dt) ":" (sym v) " " (join literal->string " " l) ")")
  (literal:vector l)     -> (format "#(" (join literal->string " " l) ")")
  )

(define (flags-repr n)
  (let loop ((bits '())
	     (n n))
    (cond ((= n 0) (list->string bits))
	  ((= (logand n 1) 1)
	   (loop (list:cons #\1 bits) (>> n 1)))
	  (else
	   (loop (list:cons #\0 bits) (>> n 1))))))

(define indent1
  0 -> #t
  n -> (begin (print-string " ") (indent (- n 1))))

(define (pp-node n)
  (define (PPL l d)
    (for-each (lambda (n) (PP n d)) l))
  (define (PP n d)
    (define PS print-string)
    (let ((tr (type-repr (apply-subst n.type))) ;; temporary apply-subst
	  (head (format (lpad 6 (int n.id)) (lpad 5 (int n.size)) (lpad (+ 2 NFLAG-NFLAGS) (flags-repr n.flags)))))
      (newline)
      (PS head)
      (indent1 d)
      (match n.t with
	(node:varref name)	       -> (PS (format "varref " (sym name) " : " tr))
	(node:varset name)	       -> (PS (format "varset " (sym name) " : " tr))
	(node:literal lit)	       -> (PS (format "literal " (p literal->string lit) " : " tr))
	(node:cexp gens type template) -> (PS (format "cexp " (p type-repr type) " " template " : " tr))
	(node:sequence)		       -> (PS (format "sequence : " tr))
	(node:if)		       -> (PS (format "conditional : " tr))
	(node:call)		       -> (PS (format "call : " tr))
	(node:function name formals)   -> (PS (format "function " (sym name) " (" (join symbol->string " " formals) ") : " tr))
	(node:fix formals)	       -> (PS (format "fix (" (join symbol->string " " formals) ") : " tr))
	(node:nvcase dt tags arities)  -> (PS (format "nvcase " (sym dt) "(" (join symbol->string " " tags) ") (" (join int->string " " arities) ") : " tr))
	(node:subst from to)	       -> (PS (format "subst " (sym from) "->" (sym to)))
	(node:primapp name params)     -> (PS (format "primapp " (sym name) " " (p repr params) " : " tr))
	(node:let formals)	       -> (PS (format "let (" (join symbol->string " " formals) ") : " tr))
	)
      (PPL n.subs (+ 1 d))
      ))
  (PP n 0)
  (newline)
  )

(define (get-formals l)
  (define p
    (sexp:symbol formal) acc -> (list:cons formal acc)
    _			 acc -> (error1 "malformed formal" l))
  (reverse (fold p '() l)))

(define (unpack-bindings bindings)
  (let loop ((l bindings)
	     (names '())
	     (inits '()))
    (match l with
      () -> (:pair (reverse names) (reverse inits))
      ((sexp:list ((sexp:symbol name) init)) . l)
      -> (loop l (list:cons name names) (list:cons init inits))
      _ -> (error1 "unpack-bindings" l)
      )))

(define (parse-cexp-sig sig)
  (let ((generic-tvars (alist-maker))
	(result (parse-type* sig generic-tvars)))
    (:scheme (generic-tvars::values) result)))

;; sort the inits so that all function definitions come first.
(define (sort-fix-inits names inits)
  (let ((names0 '())
	(names1 '())
	(inits0 '())
	(inits1 '())
	(n (length names)))
    (for-range
	i n
	(let ((name (nth names i))
	      (init (nth inits i)))
	  (match init.t with
	    (node:function _ _) -> (begin (PUSH names0 name) (PUSH inits0 init))
	    _			-> (begin (PUSH names1 name) (PUSH inits1 init)))))
    (:sorted-fix (append (reverse names0) (reverse names1))
		 (append (reverse inits0) (reverse inits1)))))

;; XXX not stable for some reason...
;; (define (sort-fix-inits names inits)
;;   (let ((n (length names))
;; 	(l0 (map2 (lambda (a b) {name=a init=b}) names inits))
;; 	(l1 (sort
;; 	     (lambda (a b)
;; 	       (match a.init.t b.init.t with
;; 		 (node:function _ _) (node:function _ _) -> #f ;; stable?
;; 		 (node:function _ _) _			 -> #t
;; 		 _ _					 -> #f
;; 		 ))
;; 	     l0))
;; 	(names0 (map (lambda (x) x.name) l1))
;; 	(inits0 (map (lambda (x) x.init) l1)))
;;     (print-string "sort-fix-inits unsorted=") (printn names)
;;     (print-string "sort-fix-inits   sorted=") (printn names0)
;;     (:sorted-fix names0 inits0)
;;     ;;(:sorted-fix names inits)
;;     ))

(define walk
  (sexp:symbol s)  -> (node/varref s)
  (sexp:string s)  -> (node/literal (literal:string s))
  (sexp:int n)	   -> (node/literal (literal:int n))
  (sexp:char c)	   -> (node/literal (literal:char c))
  (sexp:bool b)	   -> (node/literal (literal:cons 'bool (if b 'true 'false) '()))
  (sexp:undef)	   -> (node/literal (literal:undef))
  (sexp:record fl) -> (foldr (lambda (field rest)
			       (match field with
				 (field:t name value)
				 -> (node/primapp '%rextend
						  (sexp:symbol name)
						  (LIST rest (walk value)))))
			     (node/primapp '%rmake (sexp:bool #f) '())
			     fl)
  (sexp:attr exp sym) -> (node/primapp '%raccess (sexp:symbol sym) (LIST (walk exp)))
  (sexp:vector l)     -> (node/literal (build-literal (sexp:vector l))) ;; some day work out issues with QUOTE/LITERAL etc.
  (sexp:cons dt alt)  -> (node/varref (string->symbol (format (sym dt) ":" (sym alt))))
  (sexp:list l)
  -> (match l with
       ((sexp:symbol 'begin) . exps)		    -> (node/sequence (map walk exps))
       ((sexp:symbol 'set!) (sexp:symbol name) arg) -> (node/varset name (walk arg))
       ((sexp:symbol 'quote) arg)		    -> (node/literal (build-literal arg))
       ((sexp:symbol 'literal) arg)		    -> (node/literal (build-literal arg))
       ((sexp:symbol 'if) test then else)	    -> (node/if (walk test) (walk then) (walk else))
       ((sexp:symbol '%%cexp) sig (sexp:string template) . args)
       -> (let ((scheme (parse-cexp-sig sig)))
	    (match scheme with
	      (:scheme gens type)
	      -> (node/cexp gens type template (map walk args))))
       ((sexp:symbol '%nvcase) (sexp:symbol dt) val-exp (sexp:list tags) (sexp:list arities) (sexp:list alts) ealt)
       -> (node/nvcase dt (map sexp->symbol tags) (map sexp->int arities) (walk val-exp) (map walk alts) (walk ealt))
       ((sexp:symbol 'function) (sexp:symbol name) (sexp:list formals) type . body)
       -> (node/function name (get-formals formals) type (node/sequence (map walk body)))
       ;; ----------------------------------------------------------
       ;; HUGE typing problem here, when I accidentally did this:
       ;; ((sexp:symbol 'fix) (sexp:list names) (sexp:list inits) . body)
       ;; -> (match (sort-fix-inits names inits) with
       ;;      (:sorted-fix names inits)
       ;;      -> (node/fix (get-formals names) (map walk inits) (node/sequence (map walk body))))
       ;; the typer let it fly. 
       ;; ----------------------------------------------------------
       ((sexp:symbol 'fix) (sexp:list names) (sexp:list inits) . body)
       -> (match (sort-fix-inits (get-formals names) (map walk inits)) with
	    (:sorted-fix names inits)
	    -> (node/fix names inits (node/sequence (map walk body))))
       ((sexp:symbol 'let-splat) (sexp:list bindings) . body)
       -> (match (unpack-bindings bindings) with
	    (:pair names inits)
	    -> (node/let names (map walk inits) (node/sequence (map walk body))))
       ((sexp:symbol 'letrec) (sexp:list bindings) . body)
       -> (match (unpack-bindings bindings) with
	    (:pair names inits)
	    -> (node/fix names (map walk inits) (node/sequence (map walk body))))
       ((sexp:symbol 'let_subst) (sexp:list ((sexp:symbol from) (sexp:symbol to))) body)
       -> (node/subst from to (walk body))
       (rator . rands)
       -> (match rator with
	    (sexp:symbol name)
	    -> (if (eq? (string-ref (symbol->string name) 0) #\%)
		   (match rands with
		     (params . rands)
		     -> (node/primapp name params (map walk rands))
		     _ -> (error1 "null primapp missing params?" l))
		   (node/call (walk rator) (map walk rands)))
	    (sexp:cons dt alt)
	    -> (if (eq? dt 'nil)
		   (node/primapp '%vcon (sexp (sexp:symbol alt) (sexp:int (length rands))) (map walk rands))
		   ;; automatically inline all constructors:
		   ;;(node/primapp '%dtcon rator (map walk rands))
		   ;; let the inliner do it, correctly.
		   (node/call (walk rator) (map walk rands)))
	    _ -> (node/call (walk rator) (map walk rands)))
       _ -> (error1 "syntax error: " l)
       )
  x -> (error1 "syntax error 2: " x)
  )

(define build-list-literal
  ((sexp:cons dt alt) . args) -> (literal:cons dt alt (map build-literal args))
  (hd . tl)		      -> (literal:cons 'list 'cons (LIST (build-literal hd) (build-list-literal tl)))
  ()			      -> (literal:cons 'list 'nil '())
  )

(define (build-literal exp)
  (match exp with
    (sexp:string s)  -> (literal:string s)
    (sexp:int n)     -> (literal:int n)
    (sexp:char c)    -> (literal:char c)
    (sexp:undef)     -> (literal:undef)
    (sexp:symbol s)  -> (literal:symbol s)
    (sexp:list l)    -> (build-list-literal l)
    (sexp:vector l)  -> (literal:vector (map build-literal l))
    ;; XXX the rest
    _ -> (error1 "unhandled literal type" exp)
    ))

(define (frob name num)
  (string->symbol (format (sym name) "_" (int num))))

(define (make-vardef name serial)
  (let ((frobbed (frob name serial)))
    {name=name name2=frobbed assigns='() refs='() serial=serial }
    ))

(define (make-var-map)
  (let ((map (tree/empty))
	(counter (make-counter 0)))
    (define (add sym)
      (let ((vd (make-vardef sym (counter.inc))))
	(set! map (tree/insert map symbol-index<? sym vd))
	vd))
    (define (lookup sym)
      (tree/member map symbol-index<? sym))
    (define (get) map)
    {add=add lookup=lookup get=get}
    ))

(define (rename-variables n)

  (let ((varmap (make-var-map)))

    (define (rename-all exps lenv)
      (for-each (lambda (exp) (rename exp lenv)) exps))
    
    (define (rename exp lenv)

      (define (lookup name)
	(let loop0 ((lenv lenv))
	  (match lenv with
	    ()		 -> (maybe:no)
	    (rib . next) -> (let loop1 ((l rib))
			      (match l with
				()	  -> (loop0 next)
				(vd . tl) -> (if (eq? name vd.name)
						 (maybe:yes vd)
						 (loop1 tl)))))))
      
      (match exp.t with
	(node:function name formals)
	-> (let ((rib (map varmap.add formals))
		 (name2 (match (lookup name) with
			  (maybe:no) -> (if (eq? name 'lambda)
					    (string->symbol (format "lambda_" (int exp.id)))
					    name)
			  (maybe:yes vd) -> vd.name2)))
	     (set! exp.t (node:function name2 (map (lambda (x) x.name2) rib)))
	     (rename-all exp.subs (list:cons rib lenv)))
	(node:fix names)
	-> (let ((rib (map varmap.add names)))
	     ;; in this one, the <inits> namespace is renamed, too
	     (set! exp.t (node:fix (map (lambda (x) x.name2) rib)))
	     (rename-all exp.subs (list:cons rib lenv)))
	(node:let names) ;; this one is tricky!
	-> (match (unpack-fix exp.subs) with
	     (:fixsubs body inits)
	     -> (let ((rib '())
		      (n (length inits)))
		  (for-range
		      i n
		      ;; add each name only after its init
		      (rename (nth inits i) (list:cons rib lenv))
		      (set! rib (list:cons (varmap.add (nth names i)) rib)))
		  ;; now that all the inits are done, rename the bindings and body
		  (set! exp.t (node:let (map (lambda (x) x.name2) (reverse rib))))
		  (rename body (list:cons rib lenv))))
	(node:varref name)
	-> (match (lookup name) with
	     (maybe:no) -> #u ;; can't rename it if we don't know what it is
	     (maybe:yes vd) -> (set! exp.t (node:varref vd.name2)))
	(node:varset name)
	-> (begin (match (lookup name) with
		    (maybe:no) -> #u
		    (maybe:yes vd) -> (set! exp.t (node:varset vd.name2)))
		  (rename-all exp.subs lenv))
	_ -> (rename-all exp.subs lenv)
	))

    (rename n '())
    ))

;; walk the node tree, applying subst nodes
(define (apply-substs exp)
  
  ;; could we do this more easily by flattening the environment and just using member?
  (define shadow
    names ()		-> (list:nil)
    names (pair . tail) -> (if (not (member-eq? pair.from names))
			       (list:cons pair (shadow names tail))
			       (shadow names tail)))

  (define lookup
    name ()	       -> name
    name (pair . tail) -> (if (eq? name pair.from)
			      pair.to
			      (lookup name tail)))

  (define (walk exp lenv)
    (let/cc return
	(match exp.t with
	  (node:fix formals)	    -> (set! lenv (shadow formals lenv))
	  (node:let formals)	    -> (set! lenv (shadow formals lenv))
	  (node:function _ formals) -> (set! lenv (shadow formals lenv))
	  (node:subst from to)	    -> (begin
					 (set! lenv (list:cons {from=from to=(lookup to lenv)} lenv))
					 (return (walk (car exp.subs) lenv)))
	  (node:varref name)	    -> (set! exp.t (node:varref (lookup name lenv)))
	  (node:varset name)	    -> (set! exp.t (node:varset (lookup name lenv)))
	  _ -> #u
	  )
      (set! exp.subs (map (lambda (x) (walk x lenv)) exp.subs))
      exp))

  (walk exp '())
  )