;; -*- Mode: Irken -*-

(define (unify exp t0 t1)

  (define (type-error t0 t1)
    (newline)
    (pp-node exp)
    (let ((ut0 (apply-subst t0))
	  (ut1 (apply-subst t1)))
      (print-string
       (format "\nType Error:\n\t" (type-repr ut0) "\n\t" (type-repr ut1) "\n"))
      (error "type error")))

  (define (U t0 t1)
    ;;(print-string (format "    ----U " (type-repr t0) " -- " (type-repr t1) "\n"))
    (let/cc return
	(let ((u (type-find t0))
	      (v (type-find t1)))
	  ;;(print-string (format "    ----: " (type-repr u) " -- " (type-repr v) "\n"))
	  (if (not (eq? u v))
	      (begin
		(match u v with
		  (type:tvar _ _) _ -> #u
		  ;; original line
		  ;;_ (type:tvar _ _) -> #u
		  ;; missing optimize-nvcase version:
		  (type:pred _ _ _) (type:tvar _ _) -> #u
		  (type:pred pu su _) (type:pred pv sv _)
		  -> (match pu pv with
		       'moo 'moo   -> #u
		       ;; row and moo vars - early exit to avoid union
		       'moo _	   -> (return (U (car su) v))
		       _ 'moo	   -> (return (U (car sv) u))
		       'rlabel _   -> (return (U-row u v))
		       _ 'rlabel   -> (return (U-row v u))
		       'rdefault _ -> (return (U-row u v))
		       _ 'rdefault -> (return (U-row v u))
		       _ _ -> (if (or (not (eq? pu pv))
				      (not (= (length su) (length sv))))
				  (type-error t0 t1)
				  #u)))
		(type-union u v)
		(match u v with
		  (type:pred _ su _) (type:pred _ sv _) -> (for-each2 U su sv)
		  _ _ -> #u))
	      ))))

  (define label=?
    ;; labels are represented with nullary predicates, compare via symbol
    (type:pred na _ _) (type:pred nb _ _) -> (eq? na nb)
    _ _ -> #f)

  (define (U-row u v)
    (match u v with
      ;; u and v are both rlabel
      (type:pred 'rlabel (l0 t0 d0) _) (type:pred 'rlabel (l1 t1 d1) _)
      -> (cond ((label=? l0 l1)
		;; identical head labels, normal unify
		(U t0 t1)
		(U d0 d1))
	       (else
		;; distinct head labels, C-MUTATE-LL
		(let ((x (new-tvar)))
		  (U d0 (rlabel l1 t1 x))
		  (U d1 (rlabel l0 t0 x)))))
      ;; u is rlabel, v is not
      (type:pred 'rlabel (l0 t0 d0) _) (type:pred p1 s1 _)
      -> (cond ((eq? p1 'rdefault)
		;; C-MUTATE-DL
		(U (car s1) t0)
		(U v d0))
	       (else
		;; some other predicate
		;; C-MUTATE-GL
		(let ((n (length s1))
		      (tvars0 (map-range i n (new-tvar)))
		      (tvars1 (map-range i n (new-tvar))))
		  (U (pred p1 tvars0) t0)
		  (U (pred p1 tvars1) d0)
		  (for-range i n
		     (U (nth s1 i) (rlabel l0 (nth tvars0 i) (nth tvars1 i)))
		     ))))
      ;; both are rdefault
      (type:pred 'rdefault (t0) _) (type:pred 'rdefault (t1) _)
      -> (U t0 t1)
      ;; u is rdefault, v is some other predicate
      (type:pred 'rdefault (t0) _) (type:pred p1 s1 _)
      -> (let ((n (length s1))
	       (tvars (map-range i n (new-tvar))))
	   (U t0 (pred p1 tvars))
	   (for-range i n
	      (U (nth s1 i) (rdefault (nth tvars i)))))
      ;; anything else is an error
      _ _ -> (type-error u v)
      ))

  (U t0 t1)
  )

(define (apply-subst t)
  (define (p t)
    (let ((t (type-find t))
	  (trec (type->trec t)))
      (if trec.pending
	  (let ((tv (new-tvar)))
	    (set! trec.moo (maybe:yes tv))
	    tv)
	  (match t with
	    (type:tvar _ _)	 -> t
	    (type:pred 'moo _ _) -> t
	    (type:pred name subs _)
	    -> (begin
		 (set! trec.pending #t)
		 (let ((r (pred name (map p subs))))
		   (set! trec.pending #f)
		   (match trec.moo with
		     (maybe:yes mv) -> (pred 'moo (LIST mv r))
		     (maybe:no) -> r)))))))
  (p t))

(define scheme-repr
  (:scheme gens type)
  -> (format "forall(" (join type-repr "," gens) ")." (type-repr type)))

;; (define (apply-subst t)
;;   (let ((t0 (apply-subst* t)))
;;     (print-string (format "apply-subst: " (p type-repr t) "\n"))
;;     (print-string (format " = " (p type-repr t0) "\n"))
;;     t0))

(define (type-program node context)

  (define (instantiate-type-scheme gens type)
    ;; map from gens->fresh
    ;;(print-string "gens= ") (printn gens)
    ;;(print-string "type= ") (print-string (type-repr type)) (newline)
    (let ((fresh
	   (foldr
	    (lambda (gen al)
	      (match gen with
		(type:tvar id _) -> (alist:entry id (new-tvar) al)
		_ -> (error "instantiate-type-scheme")))
	    (alist:nil)
	    gens
	    )))
      ;; walk the type, replacing each member of <gen> with its fresh tvar.
      (let walk ((t type))
	(match t with
	  (type:pred name args _) -> (pred name (map walk args))
	  (type:tvar id _)        -> (match (alist/lookup fresh id) with
				       (maybe:yes tv) -> tv
				       (maybe:no) -> t)))))

  (define (occurs-in-type tvar type)
    (let/cc return
	(if (eq? tvar type)
	    #t
	    (match type with
	      (type:tvar _ _) -> #f
	      (type:pred _ args _)
	      -> (begin
		   (for-each
		    (lambda (arg)
		      (if (occurs-in-type tvar arg) (return #t) #u))
		    args)
		   #f)))))

  ;; occurs-free and build-type-scheme could obviously
  ;;   be made more efficient - build-type-scheme walks
  ;;   the entire environment repeatedly.
  (define (occurs-free-in-tenv tvar tenv)
    (let/cc return
	(alist/iterate
	 (lambda (name scheme)
	   (match scheme with
	     (:scheme gens type)
	     -> (if (not (member-eq? tvar gens))
		    (if (occurs-in-type tvar type)
			(return #t)))))
	 tenv)
      #f))

  (define (build-type-scheme type tenv)
    (let ((gens (set-maker '())))
      (define (find-generic-tvars t)
	(match t with
	  (type:tvar _ _)      -> (if (not (occurs-free-in-tenv t tenv)) (gens::add t))
	  (type:pred _ args _) -> (for-each find-generic-tvars args)))
      (let ((type (apply-subst type)))
	(find-generic-tvars type)
	;;(print-string (format "build-type-scheme type=" (type-repr type) " gens = " (join type-repr "," (gens::get)) "\n"))
	(:scheme (gens::get) type))))

  (define (type-of* exp tenv)
    (match exp.t with
      (node:literal lit)	    -> (type-of-literal lit exp tenv)
      (node:cexp gens sig _)	    -> (type-of-cexp gens sig exp tenv)
      (node:if)			    -> (type-of-conditional exp tenv)
      (node:sequence)		    -> (type-of-sequence exp.subs tenv)
      (node:function _ formals)	    -> (type-of-function formals (car exp.subs) exp.type tenv)
      (node:varref name)	    -> (type-of-varref name tenv)
      (node:varset name)	    -> (type-of-varset name exp tenv)
      (node:call)		    -> (type-of-call exp tenv)
      (node:fix names)		    -> (type-of-fix names exp tenv)
      (node:let names)		    -> (type-of-let names exp tenv)
      (node:primapp name parms)	    -> (type-of-primapp name parms exp.subs tenv)
      (node:nvcase dt tags arities) -> (type-of-vcase dt tags arities exp tenv)
      _ -> (begin
	     (pp-node exp)
	     (error1 "typing NYI" exp))))

  (define (type-of exp tenv)
    (let ((t (type-of* exp tenv)))
;;       (print-string "type-of ") (pp-node exp) (newline)
;;       (print-string "  == ") (print-string (type-repr t)) (newline)
      (set! exp.type t)
      t))

  (define (type-of-literal lit exp tenv)
    (match lit with
      (literal:string _)    -> string-type
      (literal:int _)	    -> int-type
      (literal:char _)	    -> char-type
      (literal:undef)	    -> undefined-type
      (literal:symbol _)    -> symbol-type
      (literal:cons dt v l) -> (let ((dto (alist/get context.datatypes dt "no such datatype")))
				 (match (dto.get-alt-scheme v) with
				   (:scheme gens type)
				   -> (match (instantiate-type-scheme gens type) with
					(type:pred 'arrow (result-type . arg-types) _)
					-> (begin
					     (for-range
						 i (length arg-types)
						 (let ((tx (type-of-literal (nth l i) exp tenv)))
						   (unify exp tx (nth arg-types i))))
					     result-type)
					x -> (error1 "strange constructor scheme" x))))
      (literal:vector l)    -> (let ((tv (new-tvar)))
				 (for-each
				  (lambda (x)
				    (let ((tx (type-of-literal x exp tenv)))
				      (unify exp tv tx)))
				  l)
				 (pred 'vector (LIST tv))
				 )
      ))

  ;; HACK: remove a raw predicate if present
  (define unraw
    (type:pred 'raw (arg) _) -> arg
    t -> t)

  (define (type-of-cexp gens sig exp tenv)
    (let ((type (instantiate-type-scheme gens sig)))
      (match type with
	(type:pred 'arrow pargs _)
	-> (if (not (= (- (length pargs) 1) (length exp.subs)))
	       (error1 "wrong number of args to cexp" exp)
	       (match pargs with
		 () -> (error1 "malformed arrow type" sig)
		 (result-type . parg-types)
		 -> (let ((arg-types (map (lambda (x) (type-of x tenv)) exp.subs)))
		      (for-each2 (lambda (a b)
				   (unify exp (unraw a) b))
				 parg-types arg-types)
		      result-type
		      )))
	_ -> type)))

  (define (type-of-conditional exp tenv)
    (match (map (lambda (x) (type-of x tenv)) exp.subs) with
      (tif tthen telse)
      -> (begin
	   (unify exp tif bool-type)
	   (unify exp tthen telse)
	   telse)
      _ -> (error1 "malformed conditional" exp)
      ))

  (define (type-of-sequence exps tenv)
    (let loop ((l exps))
      (match l with
	()	  -> (error "empty sequence?")
	(one)	  -> (type-of one tenv)
	(hd . tl) -> (begin
		       ;; ignore all but the last
		       (type-of hd tenv)
		       (loop tl)))))

  ;; XXX TODO - handle user-supplied types (on formals?)
  (define (optional-type formal tenv)
    (new-tvar))

  (define (type-of-function formals body sig tenv)
    (if (no-type? sig)
	(let ((arg-types '()))
	  (for-each
	   (lambda (formal)
	     (let ((type (optional-type formal tenv)))
	       (PUSH arg-types type)
	       (alist/push tenv formal (:scheme '() type))))
	   formals)
	  (arrow (type-of body tenv) (reverse arg-types)))
	;; user-supplied type (do we need to instantiate?)
	sig))

  (define (apply-tenv name tenv)
;;     (print-string "apply-tenv: ") (printn name)
;;     (print-string " tenv= {\n")
;;     (alist/iterate
;;      (lambda (k v)
;;        (print-string (format " " (sym k) " " (scheme-repr v) "\n")))
;;      tenv)
;;     (print-string "}\n")
    (match (alist/lookup tenv name) with
      (maybe:no) -> (error1 "apply-tenv: unbound variable" name)
      (maybe:yes (:scheme gens type))
      -> (instantiate-type-scheme gens type)))

  (define (type-of-varref name tenv)
    (let ((t (apply-tenv name tenv)))
      ;;(print-string (format "varref: " (sym name) " type = " (type-repr t) "\n"))
      t))

  (define (type-of-varset name exp tenv)
    (let ((val (car exp.subs))
	  (t0 (apply-tenv name tenv))
	  (t1 (type-of val tenv)))
      (unify exp t0 t1)
      undefined-type
      ))

  (define (type-of-call exp tenv)
    (match (map (lambda (x) (type-of x tenv)) exp.subs) with
      (rator-type . rand-types)
      -> (let ((result-type (new-tvar)))
	   (unify exp rator-type (arrow result-type rand-types))
	   result-type)
      () -> (error "empty call?")
      ))

  (define (type-of-fix names exp tenv)
    ;; reorder fix into dependency order
    (match (reorder-fix names exp.subs context.scc-graph) with
      (:reordered names0 inits0 body partition)
      -> (let ((n (length names0))
	       (names (list->vector names0))
	       (inits (list->vector inits0))
	       (init-tvars (list->vector (map-range i n (new-tvar))))
	       (init-types (make-vector n no-type))
	       )
	   ;;(print-string "reordered: ") (printn names0)
	   (for-each
	    (lambda (part)
	      ;; build temp tenv for typing the inits
	      (let ((temp-tenv
		     (foldr
		      (lambda (i al)
			(alist:entry names[i] (:scheme '() init-tvars[i]) al))
		      tenv (reverse part))))
		;; type each init in temp-tenv
		(for-each
		 (lambda (i)
		   (let ((ti (type-of inits[i] temp-tenv))
			 (_ (unify inits[i] ti init-tvars[i]))
			 (ti (apply-subst ti)))
		     (set! init-types[i] ti)))
		 part)
		;; now extend the environment with type schemes instead
		(for-each
		 (lambda (i)
		   (let ((scheme (build-type-scheme init-types[i] tenv)))
		     (alist/push tenv names[i] scheme)))
		 part)))
	    partition)
	   ;; type the body in the new polymorphic environment
	   (type-of body tenv))))

  (define (type-of-let names exp tenv)
    (let ((n (length names))
	  (inits exp.subs))
      (for-range
	  i n
	  (let ((init (nth inits i))
		(name (nth names i))
		(ta (type-of init tenv)))
	    ;; XXX user-supplied type
	    ;; extend environment
	    (alist/push tenv name (:scheme '() ta))))
      ;; type body in the new env
      (type-of (nth inits n) tenv)))

  (define (type-of-vcase dt tags arities exp tenv)
    (if (eq? dt 'nil)
	(type-of-pvcase tags arities exp tenv)
	(type-of-nvcase dt tags exp tenv)))

  (define (type-of-nvcase dt tags exp tenv)
    (let ((dt (alist/get context.datatypes dt "no such datatype"))
	  (subs exp.subs)
	  ;; use match for these!?
	  (value (nth subs 0))
	  (else-exp (nth subs 1))
	  (alts (cdr (cdr subs)))
	  (tval (type-of value tenv))
	  (dt-scheme (dt.get-scheme))
	  (tv (new-tvar)))
      (match dt-scheme with
	(:scheme tvars type)
	-> (if (null? tvars)
	       (unify exp tval type)
	       (let ((type0 (instantiate-type-scheme tvars type)))
		 (unify exp tval type0))))
      ;; each alt has the same type
      (for-each (lambda (alt) (unify alt tv (type-of alt tenv))) alts)
      ;; this will work even when else-exp is a dummy %%match-error
      (unify else-exp tv (type-of else-exp tenv))
      tv))

  (define (type-of-pvcase tags arities exp tenv)
    (match exp.subs with
      (value else-exp . alts)
      -> (let ((tv-exp (new-tvar))
	       (else? (match else-exp.t with
			(node:primapp '%fail _) -> #f
			(node:primapp '%match-error _) -> #f
			_ -> #t))
	       (row (if else? (rdefault (rabs)) (new-tvar))))
	   (for-range
	       i (length tags)
	       (let ((alt (nth alts i))
		     (tag (nth tags i))
		     (arity (nth arities i))
		     (argvars (n-tvars arity)))
		 (set! row (rlabel (make-label tag)
				   (rpre (if (= arity 1)
					     (car argvars)
					     (pred 'product argvars)))
				   row))
		 ;; each alt must have the same type
		 (unify alt tv-exp (type-of alt tenv))))
	   (if else?
	       (unify else-exp tv-exp (type-of else-exp tenv)))
	   ;; the value must have the row type determined
	   ;;   by the set of polyvariant alternatives
	   (unify value (rsum row) (type-of value tenv))
	   ;; this is the type of the entire expression
	   tv-exp
	   )
      _ -> (error1 "malformed pvcase" exp)
      ))

  (define (remember-variant-label label)
    (match (alist/lookup context.variant-labels label) with
      (maybe:yes _) -> #u
      (maybe:no) -> (let ((index (alist/length context.variant-labels)))
		      (alist/push context.variant-labels label index))))

  (define T0 (new-tvar))
  (define T1 (new-tvar))
  (define T2 (new-tvar))

  (define n-tvars
    0 -> '()
    n -> (list:cons (new-tvar) (n-tvars (- n 1))))

  (define (prim-error name)
    (error1 "bad parameters to primop" name))

  ;; this function acts like a lookup table for the type signatures of primitives.
  ;; it is a function because some of the prims have parameters that affect their
  ;;   type signatures in a way that requires them to be generated on the fly -
  ;;   for example accessing a record field requires a row type containing a label
  ;;   for the field.

  (define (lookup-primapp name params)
    (match name with
      '%fatbar	    -> (:scheme (LIST T0) (arrow T0 (LIST T0 T0)))
      '%fail	    -> (:scheme (LIST T0) (arrow T0 '()))
      '%match-error -> (:scheme (LIST T0) (arrow T0 '()))
      '%make-vector -> (:scheme (LIST T0) (arrow (pred 'vector (LIST T0)) (LIST int-type T0)))
      '%array-ref   -> (:scheme (LIST T0) (arrow T0 (LIST (pred 'vector (LIST T0)) int-type)))
      '%array-set   -> (:scheme (LIST T0) (arrow undefined-type (LIST (pred 'vector (LIST T0)) int-type T0)))
      '%rmake       -> (:scheme '() (arrow (rproduct (rdefault (rabs))) '()))
      '%ensure-heap -> (:scheme '() (arrow undefined-type (LIST int-type)))
      '%rextend     -> (match params with
			 (sexp:symbol label)
			 -> (let ((plabel (make-label label)))
			      (:scheme (LIST T0 T1 T2)
				       (arrow (rproduct (rlabel plabel (rpre T2) T1))
					      (LIST
					       (rproduct (rlabel plabel T0 T1))
					       T2))))
			 _ -> (prim-error name))
      '%raccess     -> (match params with
			 (sexp:symbol label)
			 -> (:scheme (LIST T0 T1)
				     (arrow T0
					    (LIST (rproduct (rlabel (make-label label)
								    (rpre T0)
								    T1)))))
			 _ -> (prim-error name))
      '%rset        -> (match params with
			 (sexp:symbol label)
			 -> (:scheme (LIST T0 T1)
				     (arrow undefined-type
					    (LIST (rproduct (rlabel (make-label label)
								    (rpre T0)
								    T1))
						  T0)))
			 _ -> (prim-error name))
      '%dtcon       -> (match params with
			 (sexp:cons dtname altname)
			 -> (match (alist/lookup context.datatypes dtname) with
			      (maybe:no) -> (error1 "lookup-primapp: no such datatype" dtname)
			      (maybe:yes dt) ->
			      (dt.get-alt-scheme altname))
			 _ -> (prim-error name))
      '%vcon        -> (match params with
			 (sexp:list ((sexp:symbol label) (sexp:int arity)))
			 -> (let ((plabel (make-label label)))
			      (remember-variant-label label)
			      (match arity with
				;; ∀X.() → Σ(l:pre (Π());X)
				0 -> (:scheme (LIST T0) (arrow (rsum (rlabel plabel (rpre (pred 'product '())) T0)) '()))
				;; ∀XY.X → Σ(l:pre X;Y)
				1 -> (:scheme (LIST T0 T1) (arrow (rsum (rlabel plabel (rpre T0) T1)) (LIST T0)))
				;; ∀ABCD.Π(A,B,C) → Σ(l:pre (Π(A,B,C));D)
				_ -> (let ((tdflt (new-tvar))
					   (targs (n-tvars arity)))
				       (:scheme (list:cons tdflt targs)
						(arrow (rsum (rlabel plabel (rpre (pred 'product targs)) tdflt))
						       targs)))))
			 _ -> (prim-error name))
      '%nvget       -> (match params with
			 (sexp:list ((sexp:cons dtname altname) (sexp:int index) (sexp:int arity)))
			 -> (if (eq? dtname 'nil)
				;; polymorphic variant
				(let ((argvars (n-tvars arity))
				      (tdflt (new-tvar))
				      (plabel (make-label altname))
				      (vtype (rsum (rlabel plabel
							   (rpre (if (> arity 1)
								     (pred 'product argvars)
								     (car argvars)))
							   tdflt))))
				  (:scheme (list:cons tdflt argvars)
					   ;; e.g., to pick the second arg:
					   ;; ∀0123. Σ(l:pre (0,1,2);3) → 1
					   (arrow (nth argvars index) (LIST vtype))))
				;; normal variant
				(match (alist/lookup context.datatypes dtname) with
				  (maybe:no) -> (error1 "lookup-primapp: no such datatype" dtname)
				  (maybe:yes dt)
				  -> (let ((alt (dt.get altname))
					   (tvars (dt.get-tvars))
					   (dtscheme (pred dtname tvars)))
				       (:scheme tvars (arrow (nth alt.types index) (LIST dtscheme))))))
			 _ -> (prim-error name))
      '%callocate  -> (let ((type (parse-type params)))
			;; int -> (buffer <type>)
			(:scheme '() (arrow (pred 'buffer (LIST type)) (LIST int-type))))
      '%exit       -> (:scheme (LIST T0 T1) (arrow T0 (LIST T1)))
      '%cget       -> (:scheme (LIST T0) (arrow T0 (LIST (pred 'buffer (LIST T0)) int-type)))
      '%cset       -> (:scheme (LIST T0 T1) (arrow undefined-type (LIST (pred 'buffer (LIST T0)) int-type T1)))
      ;; these both can be done with %%cexp, but we need to be able to detect their usage in order to
      ;;   disable inlining of functions that use them.
      '%getcc      -> (:scheme (LIST T0) (arrow (pred 'continuation (LIST T0)) (LIST)))
      '%putcc      -> (:scheme (LIST T0 T1) (arrow T1 (LIST (pred 'continuation (LIST T0)) T0)))
      _ -> (error1 "lookup-primapp" name)))

  ;; each exception is stored in a global table along with a tvar
  ;;  that will unify with each use.
  (define (get-exn-type name)
    (match (alist/lookup context.exceptions name) with
      (maybe:yes tvar) -> tvar
      (maybe:no)
      -> (let ((tvar (new-tvar)))
	   (alist/push context.exceptions name tvar)
	   tvar)))

  ;; given an exception row type for <exp> look up its name in the global
  ;;   table and unify each element of the sum.
  (define (unify-exception-types exp row tenv)
    (let loop ((row row))
      (match row with
	(type:pred 'rlabel ((type:pred exn-name _ _) exn-type rest) _)
	-> (let ((global-type (get-exn-type exn-name)))
	     (print-string (format "global type of " (sym exn-name) " is " (type-repr global-type) "\n"))
	     (print-string (format "  unifying with " (type-repr exn-type) "\n"))
	     (unify exp global-type exn-type)
	     (loop rest))
	(type:tvar _ _) -> #u
	_ -> (error1 "unify-exception-types: bad type" (type-repr row))
	)))

  ;; unify the label from this row with the global table
  (define (type-of-raise val tenv)
    (let ((val-type (apply-subst (type-of val tenv))))
      (match val-type with
	(type:pred 'rsum (row) _)
	-> (begin
	     (unify-exception-types val row tenv)
	     val-type)
	_ -> (begin
	       (print-string "bad exception type:\n")
	       (pp-node val) (newline)
	       (error1 "bad exception type:" val)))))

  ;; unify each label from this row with the global table
  (define (type-of-handle exn-val exn-match tenv)
    (let ((match-type (type-of exn-match tenv))
	  (val-type (apply-subst (type-of exn-val tenv))))
      (print-string (format "type-of-handle: " (type-repr (apply-subst val-type)) "\n"))
      (match val-type with
	(type:pred 'rsum (row) _)
	-> (begin
	     (unify-exception-types exn-match row tenv)
	     match-type)
	_ -> (error1 "unify-handlers: expected row sum type" (type-repr val-type)))))

  (define (type-of-primapp name params subs tenv)
;;     (print-string        (format "type-of-primapp, name = " (sym name) " params= " (repr params) "\n"))
;;     (print-string        (format "type-of-primapp, scheme = " (scheme-repr (lookup-primapp name params)) "\n"))

    ;; special primapps
    (match name with
      ;; we need a map of exception => tvar, then cross-verify with each handle expression.
      '%exn-raise
      -> (match subs with
	   (exn-val)
	   -> (type-of-raise exn-val tenv)
	   _ -> (error1 "%exn-raise: bad arity" subs))
      '%exn-handle
      ;; %exn-handle is wrapped around the match expression of the exception handler
      -> (match subs with
	   (exn-val exn-match)
	   -> (type-of-handle exn-val exn-match tenv)
	   _ -> (error1 "%exn-handle: bad arity" subs))

      ;; normal primapps
      _ -> (match (lookup-primapp name params) with
	     (:scheme gens type)
	     -> (let ((itype (instantiate-type-scheme gens type)))
		  ;; 	   (print-string (format "           instantiated = " (type-repr itype) "\n"))
		  (match itype with
		    ;; very similar to type-of-cexp
		    (type:pred 'arrow (result-type . arg-types) _)
		    -> (begin
			 (if (not (= (length arg-types) (length subs)))
			     (error1 "wrong number of args to primapp" subs))
			 (for-range
			     i (length arg-types)
			     (let ((arg (nth subs i))
				   (ta (type-of arg tenv))
				   (arg-type (nth arg-types i)))
			       (unify arg ta arg-type)))
			 result-type)
		    _ -> (error1 "type-of-primapp" name)
		    )))))

  (define (apply-subst-to-program n)
    (set! n.type (apply-subst n.type))
    (for-each apply-subst-to-program n.subs))

  (let ((t (type-of node (alist/make))))
    (apply-subst-to-program node)
    t))

;; (define (test-typing)
;;   (let ((context (make-context))
;; 	(transform (transformer context))
;; 	;;(exp0 (sexp:list (read-string "(%%cexp (int int -> int) \"%0+%1\" 3 #\\a)")))
;; 	;;(exp0 (sexp:list (read-string "(begin #\\A (if #t 3 4))")))
;; 	(exp0 (sexp:list (read-string "((lambda (a b) (%%cexp (int int -> int) \"%0+%1\" a b)) 3 4)")))
;; 	(exp1 (transform exp0))
;; 	(node0 (walk exp1))
;; 	(graph0 (build-dependency-graph node0))
;; 	(ignore (print-graph graph0))
;; 	(strong (strongly graph0))
;; 	(_ (set! context.scc-graph strong))
;; 	(type0 (type-program node0 context))
;; 	)
;;     (pp-node node0)
;;     (newline)
;;     ))

;; uncomment to test
;(include "self/nodes.scm")
;(test-typing)