[MLton] cvs commit: improved analysis for ref flattening

Stephen Weeks sweeks@mlton.org
Wed, 19 May 2004 15:34:03 -0700


sweeks      04/05/19 15:34:02

  Modified:    mlton/ssa ref-flatten.fun
  Log:
  MAIL improved analysis for ref flattening
  
  After looking at some more examples, I decided that the local analysis
  that I had implemented was too brittle and didn't get enough cases.
  The main problem was the syntactic check to see if a ref that was
  extracted from a tuple was only every assigned to or dereferenced.
  There were many situations where the ref would be passed as an
  argument, and so would not be flattened, despite the fact that the
  place it was passed would only assign or dereference the ref.
  
  So, I decided to switch to a less syntactic and more data-flow
  analysis based approach.  This checkin implements the new analysis.
  
  The idea is to compute for each occurrence of a ref type in the program
  whether or not that ref can be represented as an offset of some object
  (conapp or tuple).  As before, a unification-based whole-program with
  deep abstract values makes sure the analysis is consistent.
  
  The only syntactic part of the analysis that remains is the part that
  checks that a variable bound to a value constructed by Ref_ref only
  occurs once and that that occurrence is in the construction of an
  object.  If so, and it is consistent with the rest of the flattening
  annotations, we flatten that ref.  That's it.
  
  I am more worried about space safety with this analysis, since it lets
  a ref's containing object be kept alive quite a bit longer than the
  old analysis.  This will be worth revisiting at some point.
  
  I looked through the results of this analysis on the benchmarks and
  some other tests, and it seems to get what I want.  Here's a few nice
  examples.
  
  As before, it gets some of the refs in the IO stack.
  
  It gets singly-linked lists and should get Dan Wang's mutable cons
  example.
  
  For the barnes-hut benchmark, it flattens all four refs in the main
  Space.body record type.
  
  For boyer, there are two data structures with refs and both are
  flattened.
  
  For raytrace, it flattens the ref in the second field of the OObj
  constructor.
  
  For tsp, it flattens the prev and next fields of the ND constructor in
  the tree datatype.
  
  Feel free to try out the new analysis (via -diag-pass refFlatten) on
  your code and see what happens.

Revision  Changes    Path
1.2       +119 -148  mlton/mlton/ssa/ref-flatten.fun

Index: ref-flatten.fun
===================================================================
RCS file: /cvsroot/mlton/mlton/mlton/ssa/ref-flatten.fun,v
retrieving revision 1.1
retrieving revision 1.2
diff -u -r1.1 -r1.2
--- ref-flatten.fun	16 May 2004 18:12:28 -0000	1.1
+++ ref-flatten.fun	19 May 2004 22:34:01 -0000	1.2
@@ -16,13 +16,43 @@
 datatype z = datatype Transfer.t
 
 structure Set = DisjointSet
-   
+
+structure Flat =
+   struct
+      datatype t =
+	 ConOffset of {con: Con.t,
+		       offset: int}
+       | NotFlat
+       | TupleOffset of {offset: int,
+			 tuple: Type.t}
+       | Unknown
+
+      val layout: t -> Layout.t =
+	 fn f =>
+	 let
+	    open Layout
+	 in
+	    case f of
+	       ConOffset {con, offset} =>
+		  seq [str "ConOffset ",
+		       record [("con", Con.layout con),
+			       ("offset", Int.layout offset)]]
+	     | NotFlat => str "NotFlat"
+	     | TupleOffset {offset, tuple} =>
+		  seq [str "TupleOffset ",
+		       record [("offset", Int.layout offset),
+			       ("tuple", Type.layout tuple)]]
+	     | Unknown => str "Unknown"
+	 end
+   end
+
 structure Value =
    struct
       datatype t =
 	 Ground
-       | Ref of t
-       | Tuple of {elt: t, mayFlatten: bool Set.t} vector
+       | Ref of {arg: t,
+		 flat: Flat.t Set.t}
+       | Tuple of t vector
        | Unary of t
 
       fun layout (v: t): Layout.t =
@@ -31,13 +61,11 @@
 	 in
 	    case v of
 	       Ground => str "Ground"
-	     | Ref v => seq [str "Ref ", layout v]
-	     | Tuple v =>
-		  tuple (Vector.toListMap
-			 (v, fn {elt, mayFlatten} =>
-			  record [("elt", layout elt),
-				  ("mayFlatten",
-				   Bool.layout (Set.value mayFlatten))]))
+	     | Ref {arg, flat} =>
+		  seq [str "Ref ",
+		       record [("arg", layout arg),
+			       ("flat", Flat.layout (Set.value flat))]]
+	     | Tuple v => tuple (Vector.toListMap (v, layout))
 	     | Unary v => seq [str "Unary ", layout v]
 	 end
 
@@ -53,9 +81,11 @@
 	  | _ => Unary v
 
       val array = unary
-      val reff = Ref
       val vector = unary
       val weak = unary
+	 
+      val reff: t -> t =
+	 fn arg => Ref {arg = arg, flat = Set.singleton Flat.Unknown}
 
       val deUnary: t -> t =
 	 fn Ground => Ground
@@ -64,7 +94,7 @@
 
       val deArray = deUnary
       val deref =
-	 fn Ref v => v
+	 fn Ref {arg, ...} => arg
 	  | _ => Error.bug "deref"
       val deVector = deUnary
       val deWeak = deUnary
@@ -73,14 +103,13 @@
 	 fn vs =>
 	 if Vector.forall (vs, isGround)
 	    then ground
-	 else Tuple (Vector.map (vs, fn v =>
-				 {elt = v, mayFlatten = Set.singleton true}))
+	 else Tuple vs
 
       val select: t * int -> t =
 	 fn (v, i) =>
 	 case v of
 	    Ground => ground
-	  | Tuple v => #elt (Vector.sub (v, i))
+	  | Tuple v => Vector.sub (v, i)
 	  | _ => Error.bug "Value.select"
 	       
       fun fromType (t: Type.t) =
@@ -98,20 +127,16 @@
 
       val rec unify: t * t -> unit =
 	 fn (Ground, Ground) => ()
-	  | (Ref v, Ref v') => unify (v, v')
-	  | (Tuple v, Tuple v') =>
-	       Vector.foreach2
-	       (v, v', fn ({elt = e, mayFlatten = m},
-			   {elt = e', mayFlatten = m'}) =>
-		(Set.union (m, m'); unify (e, e')))
+	  | (Ref {arg = a, flat = f}, Ref {arg = a', flat = f'}) =>
+	       (Set.union (f, f'); unify (a, a'))
+	  | (Tuple v, Tuple v') => Vector.foreach2 (v, v', unify)
 	  | (Unary v, Unary v') => unify (v, v')
 	  | _ => Error.bug "Value.unify"
    end
 
 fun flatten (program as Program.T {datatypes, functions, globals, main}) =
    let
-      val {get = conInfo: Con.t -> {args: {mayFlatten: bool ref,
-					   value: Value.t} vector},
+      val {get = conInfo: Con.t -> {args: Value.t vector},
 	   set = setConInfo, ...} =
 	 Property.getSetOnce 
 	 (Con.plist, Property.initRaise ("conInfo", Con.layout))
@@ -121,16 +146,14 @@
 	 (datatypes, fn Datatype.T {cons, tycon} =>
 	  Vector.foreach
 	  (cons, fn {args, con} =>
-	   setConInfo (con, {args = Vector.map (args, fn t =>
-						{mayFlatten = ref true,
-						 value = Value.fromType t})})))
+	   setConInfo (con, {args = Vector.map (args, Value.fromType)})))
       fun coerce {from, to} = Value.unify (from, to)
       fun conApp {args, con} =
-	 (Vector.foreach2 (args, conArgs con, fn (v, {value = v', ...}) =>
+	 (Vector.foreach2 (args, conArgs con, fn (v, v') =>
 			   coerce {from = v, to = v'})
 	  ; Value.ground)
       fun filter (_, con, args) =
-	 Vector.foreach2 (conArgs con, args, fn ({value = v, ...}, v') =>
+	 Vector.foreach2 (conArgs con, args, fn (v, v') =>
 			  coerce {from = v, to = v'})
       fun primApp {args, prim, resultVar, resultType, targs = _} =
 	 let
@@ -180,83 +203,95 @@
        * Flag indicating if a variable is used in something other than
        * a ! or :=.
        *)
-      val {get = varInfo: Var.t -> {isDirectRef: bool ref,
-				    isUsedOnlyAsRef: bool ref,
-				    numOccurrences: int ref}, ...} =
+      val {get = varInfo: Var.t -> {flat: Flat.t ref}, ...} =
 	 Property.get (Var.plist,
-		       Property.initFun
-		       (fn _ => {isDirectRef = ref false,
-				 isUsedOnlyAsRef = ref true,
-				 numOccurrences = ref 0}))
-      fun use x =
-	 let
-	    val {isUsedOnlyAsRef, numOccurrences, ...} = varInfo x
-	 in
-	    Int.inc numOccurrences
-	    ; isUsedOnlyAsRef := false
-	 end
+		       Property.initFun (fn _ => {flat = ref Flat.Unknown}))
+      fun use x = #flat (varInfo x) := Flat.NotFlat
       fun uses xs = Vector.foreach (xs, use)
+      fun object (xs, make) =
+	 Vector.foreachi
+	 (xs, fn (i, x) =>
+	  let
+	     val {flat, ...} = varInfo x
+	  in
+	     case !flat of
+		Flat.Unknown => flat := make {offset = i}
+	      | _ => flat := Flat.NotFlat
+	  end)
       fun loopStatements ss =
 	 Vector.foreach
-	 (ss, fn Statement.T {exp, var, ...} =>
+	 (ss, fn Statement.T {exp, ty, var} =>
 	  case exp of
-	     ConApp {args, ...} => uses args
+	     ConApp {args, con, ...} =>
+		object (args, fn {offset} =>
+			Flat.ConOffset {con = con, offset = offset})
 	   | Const _ => ()
-	   | PrimApp {args, prim, ...} =>
-		let
-		   fun arg i = Vector.sub (args, i)
-		   fun asRef () =
-		      Int.inc (#numOccurrences (varInfo (arg 0)))
-		   datatype z = datatype Prim.Name.t
-		in
-		   case Prim.name prim of
-		      Ref_assign => (asRef (); use (arg 1))
-		    | Ref_deref => asRef ()
-		    | Ref_ref =>
-			 (uses args
-			  ;  Option.app (var, fn x =>
-					 #isDirectRef (varInfo x) := true))
-		    | _ => uses args
-		end
+	   | PrimApp {args, ...} => uses args
 	   | Profile _ => ()
 	   | Select {tuple, ...} => use tuple
-	   | Tuple xs => uses xs
+	   | Tuple xs =>
+		object (xs, fn {offset} =>
+			Flat.TupleOffset {offset = offset, tuple = ty})
 	   | Var x => use x)
       fun loopTransfer t = Transfer.foreachVar (t, use)
-      val () = loopStatements globals
-      fun loopFormals xts =
-	 Vector.foreach (xts, fn (x, _) =>
-			 let
-			    val {isUsedOnlyAsRef, ...} = varInfo x
-			 in
-			    isUsedOnlyAsRef := false
-			 end)
       val {get = labelInfo: Label.t -> {args: (Var.t * Type.t) vector},
 	   set = setLabelInfo, ...} =
 	 Property.getSetOnce (Label.plist,
 			      Property.initRaise ("info", Label.layout))
       val labelArgs = #args o labelInfo
+      val () = loopStatements globals
       val () =
 	 List.foreach
 	 (functions, fn f =>
 	  let
 	     val {args, blocks, ...} = Function.dest f
-	     val () = loopFormals args
 	  in
 	     Vector.foreach
 	     (blocks, fn Block.T {args, label, statements, transfer, ...} =>
 	      (setLabelInfo (label, {args = args})
-	       ; loopFormals args
 	       ; loopStatements statements
 	       ; loopTransfer transfer))
 	  end)
-      val isLoneRef: Var.t -> bool =
-	 fn x =>
-	 let
-	    val {isDirectRef, numOccurrences, ...} = varInfo x
-	 in
-	    !isDirectRef andalso 1 = !numOccurrences
-	 end
+      (* Now, walk over the whole program and try to flatten each ref. *)
+      fun loopStatement (Statement.T {exp, var, ...}) =
+	 case exp of
+	    PrimApp {prim, ...} =>
+	       (case Prim.name prim of
+		   Prim.Name.Ref_ref =>
+		      Option.app
+		      (var, fn var =>
+		       case value var of
+			  Value.Ref {flat, ...} =>
+			     let
+				datatype z = datatype Flat.t
+				val {flat = flat'} = varInfo var
+				val flat' = !flat'
+				fun notFlat () = Set.setValue (flat, NotFlat)
+			     in
+				case flat' of
+				   ConOffset {con = c, offset = i} =>
+				      (case Set.value flat of
+					  ConOffset {con = c', offset = i'} =>
+					     if Con.equals (c, c') andalso i = i'
+						then ()
+					     else notFlat ()
+					| Unknown => Set.setValue (flat, flat')
+					| _ => notFlat ())
+				 | NotFlat => notFlat ()
+				 | TupleOffset {offset = i, tuple = t} =>
+				      (case Set.value flat of
+					  TupleOffset {offset = i', tuple = t'} =>
+					     if i = i' andalso Type.equals (t, t')
+						then ()
+					     else notFlat ()
+					| Unknown => Set.setValue (flat, flat')
+					| _ => notFlat ())
+				 | Unkonwn => notFlat ()
+			     end
+			| _ => Error.bug "Ref_ref with strange value")
+		 | _ => ())
+	  | _ => ()
+      val () = Vector.foreach (globals, loopStatement)
       val () =
 	 List.foreach
 	 (functions, fn f =>
@@ -264,64 +299,8 @@
 	     val {blocks, ...} = Function.dest f
 	  in
 	     Vector.foreach
-	     (blocks, fn Block.T {statements, transfer, ...} =>
-	      let
-		 val () =
-		    Vector.foreach
-		    (statements, fn Statement.T {exp, var, ...} =>
-		     case exp of
-			ConApp {args, con, ...} =>
-			   Vector.foreach2
-			   (args, conArgs con, fn (x, {mayFlatten, ...}) =>
-			    if isLoneRef x
-			       then ()
-			    else mayFlatten := false)
-		      | Select {offset, tuple} =>
-			   Option.app
-			   (var, fn x =>
-			    let
-			       val {isUsedOnlyAsRef, ...} = varInfo x
-			    in
-			       if !isUsedOnlyAsRef
-				  then ()
-			       else (case value tuple of
-					Value.Tuple v =>
-					   Set.setValue
-					   (#mayFlatten (Vector.sub (v, offset)),
-					    false)
-				      | _ => ())
-			    end)
-		      | Tuple xs =>
-			   Option.app
-			   (var, fn x =>
-			    case value x of
-			       Value.Tuple v =>
-				  Vector.foreach2
-				  (xs, v, fn (x, {mayFlatten, ...}) =>
-				   if isLoneRef x
-				      then ()
-				   else Set.setValue (mayFlatten, false))
-			     | _ => ())
-		      | _ => ())
-		 val () =
-		    case transfer of
-		       Case {cases = Cases.Con v, ...} =>
-			  Vector.foreach
-			  (v, fn (con, l) =>
-			   Vector.foreach2
-			   (conArgs con, labelArgs l,
-			    fn ({mayFlatten, ...}, (x, _)) =>
-			    let
-			       val {isUsedOnlyAsRef, ...} = varInfo x
-			    in
-			       if !isUsedOnlyAsRef
-				  then ()
-			       else mayFlatten := false
-			    end))
-		     | _ => ()
-	      in
-		 ()
-	      end)
+	     (blocks, fn Block.T {statements, ...} =>
+	      Vector.foreach (statements, loopStatement))
 	  end)
       val () =
 	 Control.diagnostics
@@ -335,25 +314,17 @@
 		 (cons, fn {con, ...} =>
 		  display
 		  (seq [Con.layout con, str " ",
-			tuple (Vector.toListMap
-			       (conArgs con, fn {mayFlatten, value} =>
-				record
-				[("mayFlatten", Bool.layout (!mayFlatten)),
-				 ("value", Value.layout value)]))])))
+			tuple (Vector.toListMap (conArgs con, Value.layout))])))
 	     val () =
 		Program.foreachVar
 		(program, fn (x, _) =>
 		 let
-		    val {isDirectRef, isUsedOnlyAsRef, numOccurrences} =
-		       varInfo x
+		    val {flat} = varInfo x
 		 in
 		    display
 		    (seq [Var.layout x, str " ",
-			  record
-			  [("isDirectRef", Bool.layout (!isDirectRef)),
-			   ("isUsedOnlyAsRef", Bool.layout (!isUsedOnlyAsRef)),
-			   ("numOccurrences", Int.layout (!numOccurrences)),
-			   ("value", Value.layout (value x))]])
+			  record [("flat", Flat.layout (!flat)),
+				  ("value", Value.layout (value x))]])
 		 end)
 	  in
 	     ()