possible bug

Stephen Weeks sweeks@research.nj.nec.com
Tue, 16 Feb 1999 03:12:09 -0500


> In running count-graphs, I came across the following
> error.  At size threshold 50, the compile runs
> smoothly.  Run immediately again (with no intervening
> computation), an error is raised in the contifier
> claiming a duplicate function label.  Type checking
> the inlined output succeeds.  
> 
> The reason why I believe this is a MLton bug (as opposed 
> to an inliner problem) is the error manifests only 
> after the first compile leading me to suspect some 
> global property not being cleared.

There is indeed a bug in the Cps shrinker, which fails on code like
the following:

  let fun f x = ... f ...
  in ... <f not called > ...
  end

The shrinker gets confused and tries to inline f inside itself,
because it sees there is only one occurrence.  The body of f is then
duplicated, leading to the bug.  The fix was to change the shrinker so
that it keeps track of the number of occurrences of f within itself.
Then the condition for deleting the definition of f is that the number
of occurrences of f is the same as the number of occurrences of f
within itself.

I would still like to hear an explanation of why the cps inliner
behaves differently on multiple runs.

Here is the new version of cps/shrink.fun.

--------------------------------------------------------------------------------

functor Shrink(S: SHRINK_STRUCTS): SHRINK = 
struct

open S
open Dec PrimExp Transfer

val traceSimplifyExp = Trace.trace("simplifyExp", Exp.layout, Exp.layout)
val traceSimplifyTransfer =
   Trace.trace("simplifyTransfer", Transfer.layout, Exp.layout)

fun shrinkExp globals =
   let
      val {get = globalBind: Var.t -> PrimExp.t option, set} =
	 Property.new(Var.plist, Property.initConst NONE)
   in List.foreach(globals, fn {var, exp, ty} =>
		   set(var,
		       case exp of
			  Var y => globalBind y
			| _ => SOME exp))
      ;		   
      fn (exp: Exp.t, mayDelete: bool) =>
      let
	 (* ---------------------------------- *)
	 (*             variables              *)
	 (* ---------------------------------- *)

	 val {get = varInfo: Var.t -> {
				       numOccurrences: int ref,
				       replacement: Var.t ref,
				       bind: PrimExp.t option ref
				       } option,
	      set = setVarInfo} =
	    Property.new(Var.plist, Property.initConst NONE)

	 fun simplifyVar x =
	    case varInfo x of
	       NONE => x
	     | SOME{replacement, ...} => !replacement

	 (*       val simplifyVar =
	  * 	 Trace.trace("simplifyVar", Var.layout, Var.layout) simplifyVar
	  *)

	 fun simplifyVars xs = List.map(xs, simplifyVar)

	 fun bind(x: Var.t): PrimExp.t option =
	    case varInfo x of
	       NONE => globalBind x
	     | SOME{bind, ...} => !bind

	 (*       val bind =
	  * 	 Trace.trace("bind", Var.layout, Option.layout PrimExp.layout) bind
	  *)

	 fun incNumOccurrences(x: Var.t, n: int): unit =
			case varInfo x of
			   NONE => ()
			 | SOME{numOccurrences, ...} =>
			      let val new = n + !numOccurrences
			      in if new < 0
				    then (Control.message
					  let open Layout
					  in tuple[Var.layout x, Int.layout new]
					  end
					  ; Error.bug "incNumOccurrences")
				 else ()
				 ; numOccurrences := new
			      end

	 (*       val incNumOccurrences =
	  * 	 Trace.trace2("incNumOccurrences", Var.layout, Int.layout, Unit.layout)
	  * 	 incNumOccurrences
	  *)

	 fun deleteVar x = incNumOccurrences(x, ~1)

	 (*       val deleteVar = Trace.trace("deleteVar", Var.layout, Unit.layout) deleteVar
	  *)
	 fun deleteVars xs = List.foreach(xs, deleteVar)

	 local
	    fun doit(x: Var.t, f) =
			case varInfo x of
			   NONE => Error.bug "attempt to use varInfo of free variable"
			 | SOME r => f r
	 in
	    fun setReplacement(x: Var.t, y: Var.t): unit =
	       doit(x, fn {replacement, bind = b, numOccurrences} =>
		    (replacement := y
		     ; b := bind y
		     ; incNumOccurrences(y, !numOccurrences)))

	    fun setBind(x: Var.t, e: PrimExp.t): unit =
	       doit(x, fn {bind, ...} => bind := SOME e)
	       
	    fun isUseless(x: Var.t): bool =
	       doit(x, fn {numOccurrences, ...} =>
		    mayDelete andalso !numOccurrences = 0)
	 end

	 (*       val isUseless = Trace.trace("isUseless", Var.layout, Bool.layout) isUseless
	  *)

	 (* ---------------------------------- *)
	 (*               jumps                *)
	 (* ---------------------------------- *)

	 datatype jumpInfo =
	    Unknown
	   | Eta of Jump.t
	   | Useful of  {
			 numOccurrences: int ref,
			 numBodyOccurrences: int ref,
			 inlined: bool ref,
			 params: Var.t list,
			 body: Exp.t,
			 isTail: bool
			 }
	     
	 val {get = jumpInfo: Jump.t -> jumpInfo, set = setJumpInfo} =
	    Property.new(Jump.plist, Property.initConst Unknown)

	 fun replaceJump j =
	    case jumpInfo j of
	       Eta j' => replaceJump j'
	     | _ => j

	 fun deleteJump j =
	    case jumpInfo j of
	       Eta j' => deleteJump j'
	     | Unknown => ()
	     | Useful{numOccurrences, ...} => IntRef.dec numOccurrences

	 (* ---------------------------------- *)
	 (*          occurence counts          *)
	 (* ---------------------------------- *)

	 fun walkExp{
		     exp: Exp.t,
		     delta: int,
		     bind: Var.t -> unit,
		     walkFun: {name: Jump.t,
			       args: (Var.t * Type.t) list,
			       body: Exp.t} -> unit -> unit
		     } =
	    let
	       fun var x =
		  case varInfo x of
		     NONE => ()
		   | SOME{replacement, ...} => incNumOccurrences(!replacement, delta)
	       fun vars xs = List.foreach(xs, var)
	       fun jump j =
		  case jumpInfo j of
		     Unknown => ()
		   | Eta j' => jump j'
		   | Useful{numOccurrences, ...} =>
			numOccurrences := delta + !numOccurrences
	       val loopPrimExp =
		  fn Const _ => ()
		   | Var x => var x
		   | Tuple xs => vars xs
		   | Select{tuple, ...} => var tuple
		   | ConSelect{variant, ...} => var variant
		   | ConApp{args, ...} => vars args
		   | PrimApp{args, ...} => vars args
	       val loopTransfer =
		  fn Call{args, cont, ...} => (vars args; Option.map' jump cont)
		   | Jump{dst, args, ...} => (jump dst; vars args)
		   | Return xs => vars xs
		   | Case{test, cases, default} =>
			(var test
			 ; (case default of
			       NONE => ()
			     | SOME j => jump j)
			 ; List.foreach(cases, jump o #2))
		   | Raise xs => vars xs
		   | Halt => ()
		   | Bug => ()
	       fun loopDec d =
		  case d of
		     Bind{var, exp, ...} => (bind var; loopPrimExp exp)
		   | Fun(f as {args, body, ...}) =>
			let val walkFun = walkFun f
			in List.foreach(args, bind o #1)
			   ; loopExp body
			   ; walkFun()
			end
		   | HandlerPush h => jump h
		   | HandlerPop => ()
	       and loopExp e =
		  let val {decs, transfer} = Exp.dest e
		  in List.foreach(decs, loopDec)
		     ; loopTransfer transfer
		  end
	    in loopExp exp
	    end

	 (* Compute occurrence counts *)
	 val _ =
	    let
	       fun bind x = setVarInfo(x, SOME{replacement = ref x,
					       numOccurrences = ref 0,
					       bind = ref NONE})

	       fun walkFun{name, args, body} =
		  let
		     fun set info = setJumpInfo(name, info)
		     fun normal() =
			let val params = List.map(args, fn (x, _) => x)
			   val numOccurrences = ref 0
			   val numBodyOccurrences = ref 0
			in set(Useful
			       {numOccurrences = numOccurrences,
				numBodyOccurrences = numBodyOccurrences,
				inlined = ref false,
				params = params,
				body = body,
				isTail = (case Exp.dest body of
					     {decs = [], transfer = Return xs} =>
						List.equals(params, xs, Var.equals)
					   | _ => false)})
			   ; fn () => numBodyOccurrences := !numOccurrences
			end
		  in case Exp.dest body of
		     {decs = [], transfer = Jump{dst, args = args'}} =>
			if List.equals(args, args', fn ((x, _), x') =>
				       Var.equals(x, x'))
			   andalso not(Jump.equals(dst, name))
			   then (set(Eta(replaceJump dst))
				 ; fn () => ())
			else normal()
		      | _ => normal()
		  end
		  
	    in walkExp{exp = exp,
		       delta = 1,
		       bind = bind,
		       walkFun = walkFun}
	    end

	 fun deleteExp(e: Exp.t): unit =
	    let fun ignore _ = ()
	    in walkExp{exp = e, delta = ~1, bind = ignore,
		       walkFun = fn _ => ignore}
	    end

	 (* ---------------------------------- *)
	 (*               shrink               *)
	 (* ---------------------------------- *)

	 fun jump{dst, args}: Exp.t =
	    let fun normal args = Exp.make{decs = [],
					   transfer = Jump{dst = dst, args = args}}
	    in case jumpInfo dst of
	       Unknown => normal args
	     | Eta j => jump{dst = j, args = args}
	     | Useful{inlined, numOccurrences, params, body, ...} =>
		  if !numOccurrences = 1
		     then (numOccurrences := 0
			   ; inlined := true
			   ;
			   (* The special case for [] is here because a case may
			    * get turned into a jump if the test is known, and
			    * jumps for case branches are allowed to ignore their
			    * argument.
			    *)
			   case params of
			      [] => ()
			    | _ => List.foreach2(params, args, setReplacement)
			   ; deleteVars args
			   ; simplifyExp body)
		  else (case params of
			   [] => normal []
			 | _ => normal args)
	    end
	 
	 and simplifyTransfer arg : Exp.t =
	    traceSimplifyTransfer
	    (fn (t: Transfer.t) =>
	     let fun trans t = Exp.make{decs = [], transfer = t}
	     in case t of
		Call{func, args, cont} =>
		   trans(Call{func = func, args = simplifyVars args,
			      cont = (case cont of
					 NONE => NONE
				       | SOME c =>
					    (case jumpInfo c of
						Useful{isTail = true, ...} => NONE
					      | Eta c => SOME c
					      | _ => SOME c))})
	      | Jump{dst, args} => jump{dst = dst, args = simplifyVars args}
	      | Return xs => trans(Return(simplifyVars xs))
	      | Raise xs => trans(Raise(simplifyVars xs))
	      | Halt => trans Halt
	      | Bug => trans Bug
	      | Case{test, cases, default} =>
		   let val test = simplifyVar test
		      val cases = List.map(cases, fn (c, j) => (c, replaceJump j))
		      val default =
			 case default of
			    NONE => NONE
			  | SOME j => SOME(replaceJump j)
		   in case (cases, default) of
		      ([], NONE) => (deleteVar test; trans Bug)
		    | ([], SOME j) => (deleteVar test; jump{dst = j, args = []})
		    | _ => 
			 case bind test of
			    NONE => trans(Case{test = test,
					       cases = cases,
					       default = default})
			  | SOME(ConApp{con, ...}) =>
			       let
				  val rec loop =
				     fn [] =>
				          (case default of
					      NONE => trans Bug
					    | SOME dst =>
						 (deleteVar test
						  ; jump{dst = dst, args = []}))
				      | (c, j) :: cases => 
					   if Con.equals(con, c)
					      then (List.foreach
						    (cases, deleteJump o #2)
						    ; (case default of
							  NONE => ()
							| SOME j => deleteJump j)
						    ; jump{dst = j, args = [test]})
					   else (deleteJump j; loop cases)
			       in loop cases
			       end
			  | _ => Error.bug "strange bind for case test"
		   end
	     end) arg

	 and simplifyExp arg : Exp.t =
	    traceSimplifyExp
	    (fn (e: Exp.t) => 
	     let val {decs, transfer} = Exp.dest e
		fun simplifyDecs(decs: Dec.t list): Exp.t =
		   case decs of
		      [] => simplifyTransfer transfer
		    | dec :: decs =>
			 let fun keep d = Exp.prefix(simplifyDecs decs, d)
			 in case dec of
			    Fun{name, args, body} =>
			       let
				  fun filter() =
				     case jumpInfo name of
					Unknown => Error.bug "missing jumpInfo"
				      | Eta j => (deleteJump j; true)
				      | Useful{numBodyOccurrences,
					       numOccurrences, inlined, ...} =>
					   if !numOccurrences = !numBodyOccurrences
					      then (if !inlined
						       then ()
						    else deleteExp body
						    ; true)
					   else false
				  val isGone = filter()
				  val rest = simplifyDecs decs
			       in if isGone orelse filter()
				     then rest
				  else
				     Exp.prefix(rest, Fun{name = name, args = args,
							  body = simplifyExp body})
			       end
			  | Bind{var, ty, exp} =>
			       let
				  fun finish(exp, rest): Exp.t =
				     Exp.prefix(rest,
						Bind{var = var, ty = ty, exp = exp})
				  fun nonExpansive
				     (delete: unit -> unit,
				      set: unit -> (unit -> PrimExp.t) option): Exp.t =
				     if isUseless var
					then (delete(); simplifyDecs decs)
				     else let val s = set()
					      val rest = simplifyDecs decs
					  in if isUseless var
						then (delete(); rest)
					     else (case s of
						      NONE => rest
						    | SOME e => finish(e(), rest))
					  end
				  fun nonExpansiveCon(delete, exp: PrimExp.t) =
				     nonExpansive(delete,
						  fn () => (setBind(var, exp)
							    ; SOME(fn () => exp)))
			       in case exp of
				  Const _ => nonExpansive(fn () => (),
							  fn () => SOME(fn () => exp))
				| Var x =>
				     let val x = simplifyVar x
				     in setReplacement(var, x)
					; deleteVar x
					; simplifyDecs decs
				     end
				| Tuple xs =>
				     let val xs = simplifyVars xs
				     in nonExpansiveCon(fn () => deleteVars xs,
							Tuple xs)
				     end
				| ConApp{con, args} =>
				     let val args = simplifyVars args
				     in nonExpansiveCon
					(fn () => deleteVars args,
					 ConApp{con = con, args = args})
				     end
				| Select{tuple, offset} =>
				     let val tuple = simplifyVar tuple
				     in nonExpansive
					(fn () => deleteVar tuple,
					 fn () =>
					 case bind tuple of
					    NONE => SOME(fn () =>
							 Select{tuple = tuple,
								offset = offset})
					  | SOME(Tuple xs) =>
					       (deleteVar tuple
						; (setReplacement
						   (var, List.nth(xs, offset)))
						; NONE)
					  | _ => Error.bug "select of non-Tuple")
				     end
				| ConSelect{variant, con, offset} =>
				     let val variant = simplifyVar variant
				     in nonExpansive
					(fn () => deleteVar variant,
					 fn () =>
					 case bind variant of
					    NONE => SOME(fn () =>
							 ConSelect{variant = variant,
								   con = con,
								   offset = offset})
					  | SOME(ConApp{con, args}) =>
					       (deleteVar variant
						; (setReplacement
						   (var, List.nth(args, offset)))
						; NONE)
					  | _ => Error.bug "conSelect of non-ConApp")
				     end
				| PrimApp{prim, targs, args} =>
				     let val args = simplifyVars args
					val e =
					   PrimApp{prim = prim, targs = targs, args = args}
				     in if Prim.maySideEffect prim
					   then finish(e, simplifyDecs decs)
					else nonExpansive(fn () => deleteVars args,
							  fn () => SOME(fn () => e))
				     end
			       end
			  | _ => keep dec
			 end
	     in simplifyDecs decs
	     end) arg

	 val exp = simplifyExp exp

      in Exp.clear exp
	 ; exp
      end
   end

val shrinkExpNoDelete = fn e => shrinkExp [] (e, false)

val traceShrinkExp = Trace.trace("shrinkExp", Exp.layout, Exp.layout)

val shrinkExp = fn globals => let val shrinkExp = shrinkExp globals
			      in traceShrinkExp(fn e => shrinkExp(e, true))
			      end
   
fun simplifyProgram simplifyExp (Program.T{datatypes, globals, functions, main}) =
   let
      val shrinkExp = shrinkExp globals
      val functions =
	 List.map
	 (functions, fn {name, args, body, returns} =>
	  {name = name, args = args,
	   body = shrinkExp(simplifyExp body),
	   returns = returns})

   in Program.T{datatypes = datatypes,
		globals = globals,
		functions = functions,
		main = main}
   end

fun shrink p = simplifyProgram (fn x => x) p

end