[MLton] cvs commit: refFlatten fix and improvement

Stephen Weeks sweeks@mlton.org
Fri, 24 Sep 2004 15:22:00 -0700


sweeks      04/09/24 15:21:59

  Modified:    mlton/ssa ref-flatten.fun
               regression ref-flatten.2.sml
  Added:       regression ref-flatten.3.ok ref-flatten.3.sml
                        ref-flatten.4.ok ref-flatten.4.sml ref-flatten.5.ok
                        ref-flatten.5.sml
  Log:
  MAIL refFlatten fix and improvement
  
  Fixed a bug in refFlatten.  The bug is now tested for by the
  regression ref-flatten.3.sml.  The bug was that refFlatten mistakenly
  flattened a ref cell allocated prior to a loop into an object
  allocation inside the loop, thus losing sharing of the ref.  The fix
  was to require the object allocation to be in the same block as the
  ref allocation.  This is pretty draconian, and it would be nice to
  generalize it some day to allow flattening as long as the ref
  allocation and object allocation "line up one-to-one" in the same
  loop-free chunk of code.
  
  I also put in Matthew's suggested improvement, which allows updates to
  the ref cell provided that they occur in the same basic block in which
  it is allocated (and hence it is safe-for-space because the containing
  object is still alive).  It would be nice to relax this to allow
  updates as long as it can be proved that the container is live.
  
  Added several refFlatten regression tests.

Revision  Changes    Path
1.30      +177 -105  mlton/mlton/ssa/ref-flatten.fun

Index: ref-flatten.fun
===================================================================
RCS file: /cvsroot/mlton/mlton/mlton/ssa/ref-flatten.fun,v
retrieving revision 1.29
retrieving revision 1.30
diff -u -r1.29 -r1.30
--- ref-flatten.fun	17 Sep 2004 01:37:02 -0000	1.29
+++ ref-flatten.fun	24 Sep 2004 22:21:58 -0000	1.30
@@ -252,6 +252,45 @@
 structure Size = TwoPointLattice (val bottom = "small"
 				  val top = "large")
 
+structure VarInfo =
+   struct
+      datatype useStatus =
+	 InTuple of {object: Object.t,
+		     objectVar: Var.t,
+		     offset: int}
+       | Unused
+	 
+      datatype t =
+	 Flattenable of {components: Var.t vector,
+			 defBlock: Label.t,
+			 useStatus: useStatus ref}
+       | Unflattenable
+
+      fun layout (i: t): Layout.t =
+	 let
+	    open Layout
+	 in
+	    case i of
+	       Flattenable {components, defBlock, useStatus} =>
+		  seq [str "Flattenable ",
+		       record [("components",
+				Vector.layout Var.layout components),
+			       ("defBlock", Label.layout defBlock),
+			       ("useStatus",
+				(case !useStatus of
+				    InTuple {object, objectVar, offset} =>
+				       seq [str "InTuple ",
+					    record [("object",
+						     Object.layout object),
+						    ("objectVar",
+						     Var.layout objectVar),
+						    ("offset",
+						     Int.layout offset)]]
+				  | Unused => str "Unused"))]]
+	     | Unflattenable => str "Unflattenable"
+	 end
+   end
+
 fun flatten (program as Program.T {datatypes, functions, globals, main}) =
    let
       val {get = conValue: Con.t -> Value.t option ref, ...} =
@@ -461,37 +500,30 @@
 		  update = update,
 		  useFromTypeOnBinds = false}
       val varObject = Value.deObject o varValue
-      (* Mark a variable as flat if it is used only once and that use is in an
-       * object allocation.
+      (* Mark a variable as Flattenable if all its uses are contained in a single
+       * basic block, there is a single use in an object construction, and
+       * all other uses follow the object construction.
+       *
+       * ...
+       * r: (t ref) = (t)
+       * ... <no uses of r> ...
+       * x: (... * (t ref) * ...) = (..., r, ...)
+       * ... <othere assignments to r> ...
+       *
        *)
-      datatype varInfo =
-	 NonObject
-	| Object of {components: Var.t vector,
-		     flat: Flat.t ref}
-      val layoutVarInfo =
-	 let
-	    open Layout
-	 in
-	    fn NonObject => str "NonObject"
-	     | Object {components, flat} =>
-		  seq [str "Object ",
-		       record [("components",
-				Vector.layout Var.layout components),
-			       ("flat", Flat.layout (!flat))]]
-	 end
-      val {get = varInfo: Var.t -> varInfo ref, ...} =
-	 Property.get (Var.plist, Property.initFun (fn _ => ref NonObject))
+      datatype z = datatype VarInfo.t
+      datatype z = datatype VarInfo.useStatus
+      val {get = varInfo: Var.t -> VarInfo.t ref, ...} =
+	 Property.get (Var.plist,
+		       Property.initFun (fn _ => ref VarInfo.Unflattenable))
       val varInfo =
 	 Trace.trace ("RefFlatten.varInfo",
-		      Var.layout, Ref.layout layoutVarInfo)
+		      Var.layout, Ref.layout VarInfo.layout)
 	 varInfo
-      fun use x =
-	 case ! (varInfo x) of
-	    Object {flat, ...} => flat := Flat.NotFlat
-	  | _ => ()
+      fun use x = varInfo x := Unflattenable
       val use = Trace.trace ("RefFlatten.use", Var.layout, Unit.layout) use
       fun uses xs = Vector.foreach (xs, use)
-      fun loopStatement (s: Statement.t): unit =
+      fun loopStatement (s: Statement.t, current: Label.t): unit =
 	 case s of
 	    Bind {exp = Exp.Object {args, ...}, var, ...} =>
 	       (case var of
@@ -503,38 +535,65 @@
 			    let
 			       val () =
 				  varInfo var
-				  := Object {components = args,
-					     flat = ref Flat.Unknown}
+				  := Flattenable {components = args,
+						  defBlock = current,
+						  useStatus = ref Unused}
 			    in
 			       Vector.foreachi
 			       (args, fn (offset, x) =>
-				case ! (varInfo x) of
-				   NonObject => ()
-				 | Object {flat, ...} => 
-				      let
-					 datatype z = datatype Flat.t
-				      in
-					 case !flat of
-					    Unknown =>
-					       flat :=
-					       Offset {object = object,
-						       offset = offset}
-					  | _ => flat := NotFlat
-				      end)
+				let
+				   val r = varInfo x
+				in
+				   case !r of
+				      Flattenable {defBlock, useStatus, ...} =>
+					 (if Label.equals (current, defBlock)
+					     andalso (case !useStatus of
+							 InTuple _ => false
+						       | Unused => true)
+					     then (useStatus
+						   := (InTuple
+						       {object = object,
+							objectVar = var,
+							offset = offset}))
+					  else r := Unflattenable)
+				    | Unflattenable => ()
+				end)
 			    end)
+	  | Statement.Updates (base, us) =>
+	       (Vector.foreach (us, use o #value)
+		; (case base of
+		      Base.Object r =>
+			 let
+			    val i = varInfo r
+			 in
+			    case ! i of
+			       Flattenable {defBlock, useStatus, ...} =>
+				  if Label.equals (current, defBlock)
+				     andalso (case !useStatus of
+						 InTuple _ => true
+					       | Unused => false)
+				     then ()
+				  else i := Unflattenable
+			     | Unflattenable => ()
+			 end
+		    | Base.VectorSub _ => ()))
 	  | _ => Statement.foreachUse (s, use)
       val loopStatement =
-	 Trace.trace ("RefFlatten.loopStatement", Statement.layout, Unit.layout)
+	 Trace.trace2
+	 ("RefFlatten.loopStatement", Statement.layout, Label.layout,
+	  Unit.layout)
 	 loopStatement
-      fun loopStatements ss = Vector.foreach (ss, loopStatement)
+      fun loopStatements (ss, label) =
+	 Vector.foreach (ss, fn s => loopStatement (s, label))
       fun loopTransfer t = Transfer.foreachVar (t, use)
-      val () = loopStatements globals
+      val globalLabel = Label.newNoname ()
+      val () = loopStatements (globals, globalLabel)
       val () =
 	 List.foreach
 	 (functions, fn f =>
 	  Function.dfs
-	  (f, fn Block.T {statements, transfer, ...} =>
-	   (loopStatements statements
+	  (f, fn Block.T {label, statements, transfer, ...} =>
+	   (loopStatements (statements, label)
 	    ; loopTransfer transfer
 	    ; fn () => ())))
       fun foreachObject (f): unit =
@@ -566,46 +625,39 @@
       (* Try to flatten each ref. *)
       val () =
 	 foreachObject
-	 (fn (var, _, Obj {flat, ...}) =>
+	 (fn (var, args, obj as Obj {flat, ...}) =>
 	  let
 	     datatype z = datatype Flat.t
-	     val flat'Ref as ref flat' =
-		case ! (varInfo var) of
-		   NonObject => Error.bug "Object with NonObject"
-		 | Object {flat, ...} => flat
-	     fun notFlat () =
-		(flat := NotFlat
-		 ; flat'Ref := NotFlat)
+	     (* Check that all arguments that are represented by flattening them
+	      * into the object are available as an explicit allocation.
+	      *)
+	     val () =
+		Vector.foreach
+		(args, fn a =>
+		 case Value.deFlat {inner = varValue a, outer = obj} of
+		    NONE => ()
+		  | SOME (Obj {flat, ...}) =>
+		       case ! (varInfo a) of
+			  Flattenable _ => ()
+			| Unflattenable =>
+			     flat := NotFlat)
+	     fun notFlat () = flat := NotFlat
 	  in
-	     case flat' of
-		Offset {object = obj, offset = i} =>
-		   (case ! flat of
-		       NotFlat => notFlat ()
-		     | Offset {object = obj', offset = i'} =>
-			  if i = i' andalso Object.equals (obj, obj')
-			     then ()
-			  else notFlat ()
-		     | Unknown => flat := flat')
-	      | _ => notFlat ()
+	     case ! (varInfo var) of
+		Flattenable {useStatus, ...} =>
+		   (case !useStatus of
+		       InTuple {object = obj, offset = i, ...} =>
+			  (case ! flat of
+			      NotFlat => ()
+			    | Offset {object = obj', offset = i'} =>
+				 if i = i' andalso Object.equals (obj, obj')
+				    then ()
+				 else notFlat ()
+			    | Unknown => flat := Offset {object = obj,
+							 offset = i})
+		     | Unused => notFlat ())
+	      | Unflattenable => notFlat ()
 	  end)
-      (* Disallow flattening into object components that aren't explicitly
-       * constructed.
-       *)
-      val () =
-	 foreachObject
-	 (fn (_, args, obj) =>
-	  Vector.foreach
-	  (args, fn arg =>
-	   case ! (varInfo arg) of
-	      NonObject =>
-		 let
-		    val v = varValue arg
-		 in
-		    if isSome (Value.deFlat {inner = v, outer = obj})
-		       then Value.dontFlatten v
-		    else ()
-		 end
-	    | Object _ => ()))
       (*
        * The following code disables flattening of some refs to ensure
        * space safety.  Flattening a ref into an object that has
@@ -686,10 +738,7 @@
 			      Flat.Offset {object, offset} =>
 				 if objectHasAnotherLarge (object,
 							   {offset = offset})
-				    andalso
-				    (case ! (varInfo x) of
-					Object _ => false
-				      | _ => not (containerIsLive x))
+				    andalso not (containerIsLive x)
 				    then flat := Flat.NotFlat
 				 else ()
 			    | _ => ())
@@ -701,6 +750,26 @@
 		 ()
 	      end)
 	  end)
+      (* Mark varInfo as Unflattenable if varValue is.  This done after all the
+       * other parts of the analysis so that varInfo is consistent with the
+       * varValue.
+       *)
+      val () =
+	 Program.foreachVar
+	 (program, fn (x, _) =>
+	  let
+	     val r = varInfo x
+	  in
+	     case !r of
+		Flattenable _ =>
+		   (case Value.deObject (varValue x) of
+		       NONE => ()
+		     | SOME (Obj {flat, ...}) =>
+			  (case !flat of
+			      Flat.NotFlat => r := Unflattenable
+			    | _ => ()))
+	      | Unflattenable => ()
+	  end)
       val () =
 	 Control.diagnostics
 	 (fn display =>
@@ -715,18 +784,10 @@
 	     val () =
 		Program.foreachVar
 		(program, fn (x, _) =>
-		 let
-		    val vi =
-		       case ! (varInfo x) of
-			  NonObject => str "NonObject"
-			| Object {flat, ...} =>
-			     seq [str "Object ", Flat.layout (!flat)]
-		 in
-		    display
-		    (seq [Var.layout x, str " ",
-			  record [("value", Value.layout (varValue x)),
-				  ("varInfo", vi)]])
-		 end)
+		 display
+		 (seq [Var.layout x, str " ",
+		       record [("value", Value.layout (varValue x)),
+			       ("varInfo", VarInfo.layout (! (varInfo x)))]]))
 	  in
 	     ()
 	  end)
@@ -826,9 +887,9 @@
 		NONE => x :: ac
 	      | SOME obj =>
 		   (case ! (varInfo x) of
-		       NonObject => flattenValues (x, obj, ac)
-		     | Object {components, ...} =>
-			  flattenArgs (components, obj, ac))
+		       Flattenable {components, ...} =>
+			  flattenArgs (components, obj, ac)
+		     | Unflattenable => flattenValues (x, obj, ac))
 	  end)
       val flattenArgs =
 	 Trace.trace3 ("flattenArgs",
@@ -905,11 +966,22 @@
 		      (case varObject object of
 			  NONE => s
 			| SOME obj =>
-			     Updates
-			     (base,
-			      Vector.map (us, fn {offset, value} =>
-					  {offset = objectOffset (obj, offset),
-					   value = value})))
+			     let
+				val base =
+				   case ! (varInfo object) of
+				      Flattenable {useStatus, ...} =>
+					 (case ! useStatus of
+					     InTuple {objectVar, ...} =>
+						Base.Object objectVar
+					   | _ => base)
+				    | Unflattenable => base
+			     in
+				Updates
+				(base,
+				 Vector.map (us, fn {offset, value} =>
+					     {offset = objectOffset (obj, offset),
+					      value = value}))
+			     end)
 		 | Base.VectorSub _ => Vector.new1 s)
       val transformStatement =
 	 Trace.trace ("transformStatement",



1.2       +7 -3      mlton/regression/ref-flatten.2.sml

Index: ref-flatten.2.sml
===================================================================
RCS file: /cvsroot/mlton/mlton/regression/ref-flatten.2.sml,v
retrieving revision 1.1
retrieving revision 1.2
diff -u -r1.1 -r1.2
--- ref-flatten.2.sml	23 Jul 2004 22:50:42 -0000	1.1
+++ ref-flatten.2.sml	24 Sep 2004 22:21:58 -0000	1.2
@@ -3,9 +3,13 @@
  | B
 
 val a = Array.tabulate (100, fn i =>
-			case i mod 2 of
-			   0 => A (ref 0w13, ref 0w123, [100 + i, 2, 3])
-			 | 1 => B)
+			let
+			   val l = [100 + i, 2, 3]
+			in
+			   case i mod 2 of
+			      0 => A (ref 0w13, ref 0w123, l)
+			    | 1 => B
+			end)
 
 val _ =
    Array.app



1.1                  mlton/regression/ref-flatten.3.ok

Index: ref-flatten.3.ok
===================================================================
0
0



1.1                  mlton/regression/ref-flatten.3.sml

Index: ref-flatten.3.sml
===================================================================
(*
 * This example tests for a bug that was in refFlatten at one point.  The idea
 * is to allocate a ref cell outside a loop, and then allocate a tuple containing
 * the ref cell at each iteration of the loop.  At one point, refFlatten
 * mistakenly flattened the ref cell, which meant that it wasn't shared across
 * all the tuples allocated in the loop, as it should have been.
 *)
fun loop i =
   if i = 0
      then ()
   else
      let
	 val r = ref 13
	 val l = List.tabulate (10, fn i => (r, ref i))
	 val (r1, r2) = List.nth (l, 0)
	 val () = r1 := !r2
	 val (r1, _) = List.nth (l, 1)
	 val () = print (concat [Int.toString (!r1), "\n"])
      in
	 loop (i - 1)
      end

val () = loop 2



1.1                  mlton/regression/ref-flatten.4.ok

Index: ref-flatten.4.ok
===================================================================
NONE
SOME 9



1.1                  mlton/regression/ref-flatten.4.sml

Index: ref-flatten.4.sml
===================================================================
structure CList =
   struct
      datatype 'a clist' = Cons of 'a * 'a clist ref
      withtype 'a clist = 'a clist' option

      fun cnil () = NONE
      fun ccons (h, t) = SOME (Cons (h, ref t))

      fun match cl nilCase consCase =
	 case cl of
	    NONE => nilCase ()
	  | SOME (Cons (h, t)) => consCase (h, !t)

      fun fromList l =
	 case l of
	    [] => cnil ()
	  | h::t => ccons (h, fromList t)

      fun repeat x =
	 let
	    val r = ref NONE
	    val cl = SOME (Cons (x, r))
	    val () = r := cl
	 in
	    cl
	 end

      local
	 val max = 1000
	 fun length' (cl, n) =
	    if n >= max
	       then NONE
	       else match cl
		          (fn () => SOME n)
			  (fn (_,t) => length' (t, n + 1))
      in
	 fun length cl = length' (cl, 0)
      end
   end

val cl = CList.repeat #"x"
val n = CList.length cl
val () =
   case n of
      NONE => print "NONE\n"
    | SOME n => print (concat ["SOME ", Int.toString n, "\n"])

val cl = CList.fromList [1,2,3,4,5,6,7,8,9]
val n = CList.length cl
val () =
   case n of
      NONE => print "NONE\n"
    | SOME n => print (concat ["SOME ", Int.toString n, "\n"])



1.1                  mlton/regression/ref-flatten.5.ok

Index: ref-flatten.5.ok
===================================================================
59



1.1                  mlton/regression/ref-flatten.5.sml

Index: ref-flatten.5.sml
===================================================================
datatype t =
   A of int ref * int
 | B

val n = 100

val a = Array.tabulate (n, fn i =>
			case i mod 3 of
			   0 => B
			 | 1 => A (ref 13, 14)
			 | 2 => A (ref 15, 16))

datatype t =
   A' of int ref * int
 | B'

val a' =
   Array.tabulate (n, fn i =>
		   case Array.sub (a, i) of
		      B => B'
		    | A (r, n) => A' (r, n + 1))

val _ = Array.app (fn A (r, n) => r := 17 + n + !r  | B => ()) a

val _ =
   case Array.sub (a', 1) of
      A' (r, n) => print (concat [Int.toString (!r + n), "\n"])
    | B' => ()