forwarded message from Henry Cejtin

Suresh Jagannathan suresh@research.nj.nec.com
Wed, 11 Oct 2000 12:08:00 -0400


Here's the latest version of the uncurrier.  It worked
on an old version of MLton (7-12-1999), but
I don't believe any changes to SXML have occurred, so
it might work on the latest release.  I never got around
to testing the uncurrier on the bigger programs because
it appeared to offer only marginal improvements on the
smaller ones, but your observation that our smaller
benchmarks don't use uncurried functions often might
account for this.

     -- Suresh

===============================
uncurry.sig

(* Copyright (C) 1997-1999 NEC Research Institute.
 * Please see the file LICENSE for license information.
 *)
signature UNCURRY_STRUCTS = 
   sig
      structure	Sxml : SXML
   end

signature UNCURRY = 
   sig
      include UNCURRY_STRUCTS
      val uncurry: Sxml.Program.t -> Sxml.Program.t
   end



================================

functor UnCurry(S: UNCURRY_STRUCTS): UNCURRY = 
struct

open S
open Sxml
open Dec PrimExp

fun uncurry(program as Program.T{datatypes, body}) =
   let
       datatype D = T of {var: Var.t, lambda : Lambda.t}

       val {get = getArity: Var.t -> int, set = setArity} =
	       Property.new(Var.plist, Property.initConst 0)

       val {get = curriedRep: Var.t -> {unCurriedFun: D, curriedFun: D} option,
	    set = setCurriedRep} = Property.new(Var.plist, Property.initConst NONE)

       val {get = getType: Var.t -> {args: Type.t list, result: Type.t}, set = setType} =
	       Property.new(Var.plist, Property.initConst {args = [Type.unit],
							   result = Type.unit})

       fun getResultType(exp) =							
	  let val {decs,result} = Exp.dest(exp)
	  in List.fold(decs,
		       Type.unit,
		       fn (d, i) => (case d of
					MonoVal {var, ty, exp} =>
					   if Var.equals(var,VarExp.var(result))
					      then ty
					   else i
				      | Fun {tyvars, decs} =>
					   List.fold(decs,
						     Type.unit,
						     fn ({var,ty,lambda}, i) =>
						        if Var.equals(var,VarExp.var(result))
							   then ty
							else i)
				      | _ => i))
	  end

       fun buildLambda(f,args,types,resultType) =
	  let val newArg' = Var.newString("c")
	      val newArg'' = Var.newString("c")
	  in
	     Lambda.new
	     {arg = List.head(List.reverse(args)),
	      argType = List.head(List.reverse(types)),
	      body = List.fold2
	      (List.tail(List.allButLast(args)),
	       List.tail(List.allButLast(types)),
	       let val newVar = Var.newString("c")
		   val argType = List.head(types)
	       in Exp.new
		  {decs = [ MonoVal
			   {var = newVar,
			    ty = Type.arrow(argType,resultType),
			    exp =
			    Lambda (Lambda.new
				    {arg = List.head(args),
				     argType = argType,
				     body =
				     Exp.new { decs =
					      [ MonoVal{var = newArg',
							ty = Type.tuple(List.reverse(types)),
							exp = Tuple(List.map
								    (List.reverse(args),
								     fn a =>
								     VarExp.mono(a)))},

					        MonoVal{var = newArg'',
							ty = resultType,
							exp = App {func = f,
								   arg =
								   VarExp.mono(newArg')}} ],

					       
					      result = VarExp.mono(newArg'')}})}],
		   result = VarExp.mono(newVar)}
	       end,
	       fn (a, atype, i) => let val newVar = Var.newString("c")
				   in Exp.new
				      { decs = [ MonoVal
						{ var = newVar,
						  ty = Type.arrow(atype, getResultType(i)),
						  exp = Lambda(Lambda.new {arg = a,
									  argType = atype,
									  body = i})}],
				       result = VarExp.mono(newVar)}
				   end)}
	  end

       fun uncurryFun(dec) = 
	  let fun lamExp(decs,result,args,types,newDecs,e) =
	     case decs of
		[] => (args,types,e)
	      | d::rest =>
		   case d of
		      Dec.MonoVal{var, ty, exp = Const c} =>
			 lamExp(rest, result, args,types,d::newDecs,e)
		    | Dec.MonoVal{var, ty, exp = Var v} =>
			 lamExp(rest, result, args,types,d::newDecs,e)
		    | Dec.MonoVal{var, ty, exp = Select tuple} =>
			 lamExp(rest, result, args,types,d::newDecs,e)
		    | Dec.MonoVal{var, ty, exp = Lambda l} =>
			 let val body = Lambda.body(l)
			    val r = result
			    val {decs,result} = Exp.dest(body)
			    val newDecs = List.append(newDecs,decs)
			    val new = Exp.new{decs = newDecs,result = result}
			 in if Var.equals(var, VarExp.var(r))
			       andalso List.isEmpty(rest)
			       then lamExp(newDecs,
					   result,
					   Lambda.arg(l)::args,
					   Lambda.argType(l)::types,
					   [],
					   new)
			    else (args,types,e)
			 end
		    | _ => (args,types,e)

              val T{var,lambda} = dec
	      val (f, r) = let val arg = Lambda.arg(lambda)
			       val argType = Lambda.argType(lambda)
			       val body = Lambda.body(lambda)
			       val {decs,result} = Exp.dest(body)
			   in (var,  lamExp(decs, result, [arg], [argType], [],body))
			   end
		     

	      fun buildCurried (f,args,types,e) =
		 let val newVar = Var.newString("c")
		     val newArg = Var.newString("c")
		     val (newDecs,n) = List.fold2(List.reverse(args),
						  List.reverse(types),
						  ([],0),
						  fn (a, mtype, (l, i)) =>
						  (MonoVal { var = a,
							    ty = mtype,
							    exp =
							    PrimExp.Select {tuple =
									    VarExp.mono(newArg),
									    offset = i }}::l,
						   i+1))
		    val newExp = Exp.new {decs = List.append(newDecs, Exp.decs(e)),
					  result = Exp.result(e)}
		    val resultType = getResultType(newExp)
		    val unCurriedFun =
		       T{var = newVar,
			 lambda = Lambda.new { arg = newArg,
					       argType = Type.tuple(List.reverse(types)), 
					       body = newExp }}
		    val newArgs = List.map(args, fn z => Var.newString("c"))
                    val newFun = buildLambda(VarExp.mono(newVar),newArgs,types,resultType)
	      
 		    val newFunBinding = T{ var = f,  lambda = newFun }
		 in setCurriedRep(f, SOME {unCurriedFun = unCurriedFun,
					   curriedFun =  newFunBinding})
		 end

	      
	  in case r of
	     (args,types,e) =>
		    (setArity(f, length(args));
		     setType(f, {args = types, result = getResultType(e)});
		     if getArity(f) > 1 
			then buildCurried(f,args,types,e)
		     else ())
	  end

       fun replaceVar(decs,old,new) =
	  let fun compare(v1) = if Var.equals(VarExp.var(v1),old)
				      then new
				   else v1
	      fun replaceExp(e) = let val {decs,result} = Exp.dest(e)
				      val newDecs = replaceVar(decs,old,new)
				      val newResult = compare(result)
				  in Exp.new {decs = newDecs,
					      result = newResult}
				  end
	  in  List.map(decs,
		   fn d =>
		       (case d of
			   MonoVal {var, ty, exp} =>
			      MonoVal {var=var,
				       ty = ty,
				       exp =  (case exp of
						  Var v => PrimExp.Var(compare(v))
						| Tuple vs =>
						     Tuple(List.map(vs,
								    fn v => compare(v)))
						| Select {tuple,offset} =>
						     Select {tuple=compare(tuple),
							     offset=offset}
						| Lambda l =>
						     let val {arg,argType,body} =
							Lambda.dest(l)
							 val {decs,result} = Exp.dest(body)
							 val newDecs =
							    replaceVar(decs,old,new)
						     in Lambda (Lambda.new
							{arg=arg,
							 argType=argType,
							 body=Exp.new {decs = newDecs,
									result = result}})
						     end
					        | ConApp {con,targs,arg} =>
						     (case arg of
							NONE => exp
						      | SOME v => ConApp {con=con,
									  targs=targs,
									  arg = SOME (compare(v))
									  })
						| PrimApp {prim,targs,args} =>
						     PrimApp {prim=prim,
							      targs=targs,
							      args =
							      List.map(args,
								       fn a => compare(a))}
						| App {func,arg} =>
						     App {func = compare(func),
							  arg = compare(arg)}
						| Raise v => Raise (compare(v))
						| Case {test,cases,default} =>
						     Case {test=compare(test),
							   cases =
							   List.map(cases,
								    fn (p,e) =>
								    (p,replaceExp(e))),
							   default =
							   case default of
							      NONE => NONE
							    | SOME e => SOME (replaceExp(e))}
                                               | Handle {try,catch,handler} =>
						    Handle {try=replaceExp(try),
							    catch = catch,
							    handler = replaceExp(handler)}
					       | _ => exp)}
			      | Fun {tyvars,decs} =>
				   Fun {tyvars=tyvars,
					decs = List.map(decs,
							fn {var,ty,lambda} =>
							{var=var,
							 ty=ty,
							 lambda = let val {arg,argType,body} =
							                 Lambda.dest(lambda)
								  in Lambda.new
								      ({arg=arg,
								      argType=argType,
								      body = replaceExp(body)})
								  end})}
		               | _ => d))
	  end

       fun uncurryApp(decs,expResult) =
	  let fun makeUncurryApp(f,arguments,lastCall) =
	       let val newArg = Var.newString("c")
		   val newArg' = Var.newString("c")
		   val varF = VarExp.var(f)
		   val {args,result} = getType(varF)
		   val c = curriedRep(varF)
		   val var = (case c of
		               NONE => Error.error "in uncurryApp"
			     | SOME {unCurriedFun,curriedFun} =>
				  let val T{var,lambda} = unCurriedFun
				  in var
				  end)
		   val argDec = MonoVal{var = newArg,
					ty = Type.tuple(List.reverse(args)),
					exp = Tuple(List.reverse(arguments))}
		   val appDec = MonoVal{var = newArg',
					ty = result,
					exp = App {func = VarExp.mono(var),
						   arg = VarExp.mono(newArg)}}
		   val newR = if Var.equals(lastCall, VarExp.var(expResult))
				 then (SOME newArg')
			      else NONE
	       in (appDec::[argDec],newR,newArg')
	       end
	  in case decs of
	     [] => Error.error("in uncurryApp")
	   | d::r => (case d of
			  MonoVal {var, ty, exp = App {func,arg}} =>
			     (case curriedRep(VarExp.var(func)) of
				 NONE => Error.error("in uncurryApp")
			       | SOME _ => let val arity = getArity(VarExp.var(func))
					       fun loop(args,arity,d,f) =
						  if arity = 0
						     then SOME (args,d,f)
						  else
						     case d of
						        [] => NONE
						      |  h::r =>
							  (case h of
							     MonoVal {var,ty,
								      exp = App {func,arg}} =>
							     if Var.equals(VarExp.var(func),f)
								then loop(arg::args,
									  arity-1,
									  r,
									  var)
							     else NONE
							   | _ => NONE)
					   in case loop([arg],arity-1,r,var) of
					      NONE => ([d],r,NONE)
 					    | SOME (args,r,lastCall) =>
						 let val (newDecs,newR,newArg) =
						       makeUncurryApp(func,args,lastCall)
						     val r = (replaceVar(r,lastCall,
									 VarExp.mono(newArg)))
						 in (newDecs,r,newR)
						 end
				      end)
		 | _ => Error.error("in uncurryApp"))
	  end

       fun singleUse(var,decs) =
	  let fun compare(e) = (case e of
				   App {func,arg} => Var.equals(VarExp.var(func),var)
				 | _ => false)
	  in List.fold(decs,
		       false,
		       fn (d,r) => case d of
		                    MonoVal {var,ty,exp} => compare(exp)
				  | _ => false)
	  end
				       
		     
       fun transform(body) =
	  let val {decs,result} = Exp.dest(body)
	      val newR = ref NONE
	  in
	     Exp.new
	     {decs =
	      List.reverse
	      (let fun loop(decs,newDecs) =
		     case decs of
			[] => newDecs
		      | d::rest =>
			   (case d of
			      MonoVal {var,ty, exp = Lambda l} =>
				 (case curriedRep(var) of
				     NONE => let val lamBody = Lambda.body(l)
						 val arg = Lambda.arg(l)
						 val argType = Lambda.argType(l)
						 val newLam =
						    Lambda.new{arg=arg,
							       argType = argType,
							       body = transform(lamBody)}
						 val newDec = MonoVal{var=var,
								      ty=ty,
								      exp = Lambda newLam}
					     in loop(rest,newDec::newDecs)
					     end
				   | SOME {unCurriedFun,curriedFun} =>
					let val T{var,lambda} = unCurriedFun
					    val body = Lambda.body(lambda)
					    val newBody = transform(body)
					    val resultType = getResultType(newBody)
					    val argType = Lambda.argType(lambda)
					    val l = Lambda(Lambda.new
							   {arg =
							    Lambda.arg(lambda),
							    argType = argType,
							    body = newBody})
					    val b1 = MonoVal{var=var,
							     ty = Type.arrow(argType,resultType),
							     exp = l}
					    val T{var,lambda} = curriedFun
					    val argType = Lambda.argType(lambda)
					    val resultType = getResultType(Lambda.body(lambda))
					    val b2 = MonoVal{var=var,
							     ty =
							     Type.arrow(argType, resultType),
							     exp = Lambda lambda}
					in loop(rest,b2::(b1::newDecs))
					end)
			    | MonoVal {var,ty,exp = App {func,arg}} => 
				  (case curriedRep(VarExp.var(func)) of
				    NONE => loop(rest,d::newDecs)
				  | SOME _ => 
				       if singleUse(var,rest)
					  then let val (appDecs,r,newResult) =
					                 uncurryApp(decs,result)
					       in (newR := newResult;
						   loop(r,List.append(appDecs,newDecs)))
					       end
				       else loop(rest,d::newDecs))
 
			    | MonoVal {var,ty,exp = Case {test,cases,default}} =>
				  let val newCases =
				       List.map(cases, fn (pat,e) => (pat, transform(e)))
				      val default = (case default of
							NONE => NONE
						      | SOME e  => SOME (transform(e)))
				  in loop(rest,
					  (MonoVal{var=var,
						   ty=ty,
						   exp = Case {test=test,
							 cases=newCases,
							 default=default}}::
					   newDecs))
				  end
			    | MonoVal {var,ty, exp = Handle {try,catch,handler}} =>
				 loop(rest, (MonoVal{var=var,
						    ty=ty,
						    exp = Handle {try = transform(try),
							    catch = catch,
							    handler = transform(handler)}}::
					     newDecs))
			    | Fun {tyvars,decs} => 
				 loop(rest,
				      Fun {tyvars = [],
				      decs =
				      List.fold
				      (decs,
				       []:{var:Var.t,
					   ty:Type.t,
					   lambda:Lambda.t} list,
				       fn (d as {var,
						 ty,
						 lambda:Lambda.t},
					   acc) =>
				       
				       (case curriedRep(var) of
					   NONE =>
					      let val body = Lambda.body(lambda)
						  val arg = Lambda.arg(lambda)
						  val argType = Lambda.argType(lambda)
						  val newBody = transform(body)
						  val newLam = Lambda.new{arg = arg,
									  argType = argType,
									  body = newBody}
					      in {var=var,
						  ty=ty,
						  lambda=newLam}::acc
					      end
					 | SOME {unCurriedFun,curriedFun} =>
					      let val T{var,lambda} = unCurriedFun
						 val body = Lambda.body(lambda)
						 val newBody = transform(body)
						 val argType = Lambda.argType(lambda)
						 val resultType = getResultType(newBody)
						 val b1 = {var=var,
							   ty = Type.arrow(argType,resultType),
							   lambda =
							   Lambda.new{arg =  Lambda.arg(lambda),
								      argType = argType,
								      body = newBody}}
						 val T{var,lambda} = curriedFun
						 val argType = Lambda.argType(lambda)
						 val newBody = transform(Lambda.body(lambda))
						 val resultType = getResultType(newBody)
						 val b2 = {var=var,
							   ty = Type.arrow(argType,resultType),
							   lambda = lambda}
					      in b1::(b2::acc)
					      end))}::newDecs)
			    | _ => loop(rest,d::newDecs))
	      in loop(decs,[])
	      end),
	   result = (case !newR of
			NONE => result
		      | SOME r => VarExp.mono(r))}
	  end
   in
      Exp.foreachExp(body,
		     fn e =>
		     let val {decs,result} = Exp.dest(e)
		     in	List.foreach(decs,
				     fn d =>
				     case d of
					MonoVal {var,ty,exp = Lambda l} =>
					   uncurryFun(T{var=var,lambda=l})
				      | Fun {tyvars,decs} =>
					   List.foreach(decs,
							fn {var,ty,lambda} =>
							uncurryFun(T{var=var,lambda=lambda}))
				      | _ => ())
		     end);


      
      let val newBody = transform(body)
      in Program.T{datatypes = datatypes, body = newBody}
      end
   end
end