the extension to continuation based contification

Stephen Weeks MLton@sourcelight.com
Mon, 27 Nov 2000 14:39:47 -0800 (PST)


> Is it possible to only introduce the "return to X" continuations if the
> analysis hits top for the function X?  Once it is determined that some
> top-level function (say F) has more than one continuation, there's no hope
> for contifying it; so, rather than propagating that non-contification
> through to functions whose continuations are F, establish a "return to F"
> continuation for the analysis of the body of F, maybe allowing other
> functions to be contified into F.

This makes sense.  I don't quite see how to get a least-fixed-point computation
out of the idea though.  I'm thinking about it.

> In any event, I'd be interested in looking at the implementation of the
> continuation contification pass for MLton, at least to see if my quick
> sketch of translating Reppy's analysis to the CPS IL is close.

Here's the current incarnation.  It's a bit messy because it handles all the
possibilities so I could do that benchmarking, and because I started adding the
"return to F" stuff, but it's not operational yet.

--------------------------------------------------------------------------------

(* Copyright (C) 1997-1999 NEC Research Institute.
 * Please see the file LICENSE for license information.
 *)
(*
 * This pass combines two analyses and transformations.
 *
 * The first analysis, "call based" notices if a toplevel function is called
 *       1. from one place outside itself
 *   and 2. all calls to itself within itself are tail
 * If (1) and (2) are true, then that function can be made local at the point
 * of its outer call.
 *
 * The second analysis is based on the following paper.
 *   Local CPS conversion in a direct-style compiler.  John Reppy.
 * It determines the set of continuations that a function is called with.  If
 * that set is a singleton, then the function can be prefixed onto the
 * continuation it is called with.
 *)

functor Contify (S: CONTIFY_STRUCTS): CONTIFY = 
struct

open S
open Dec Transfer

structure ContSet =
   struct
      datatype set =
	 Empty
       | MainCont
       | One of Jump.t
       | Return of Func.t
       | All

      fun layoutSet s =
	 let open Layout
	 in case s of
	    Empty => str "{}"
	  | MainCont => str "MainCont"
	  | One j => seq [str "{ ", Jump.layout j, str " }"]
	  | Return f => seq [str "{ ", Func.layout f, str " }"]
	  | All => str "All"
	 end

      datatype t = T of {set: set ref,
			 lessThan: t list ref}

      fun set (T {set, ...}) = !set

      fun layout (T {set, ...}) = layoutSet (!set)

      fun new () = T {set = ref Empty,
		      lessThan = ref []}

      fun up (T {set, lessThan}, s) =
	 let
	    fun doit s = (set := s; List.foreach (!lessThan, fn c => up (c, s)))
	 in case (!set, s) of
	    (_, Empty) => ()
	  | (All, _) => ()
	  | (Empty, k) => doit k
	  | (MainCont, MainCont) => ()
	  | (One j, One j') => if Jump.equals (j, j') then () else doit All
	  | (Return f, Return f') => if Func.equals (f, f') then () else doit All
	  | _ => doit All
	 end

      val up = Trace.trace2 ("up", layout, layoutSet, Unit.layout) up

      val op <= =
	 fn (T {set, lessThan}, c) =>
	      (List.push (lessThan, c)
	       ; up (c, !set))
	      
      fun addReturn (c, f) = up (c, Return f)
      fun addJump (c, j) = up (c, One j)
      fun addMain c = up (c, MainCont)
      fun makeTop c = up (c, All)
   end
	 
structure CallInfo =
   struct
      datatype t =
	 NoOuterCalls
       | OneOuterCall
       | NotCont

      val toString =
	 fn NoOuterCalls => "NoOuterCalls"
	  | OneOuterCall => "OneOuterCall"
	  | NotCont => "NotCont"

      val layout = Layout.str o toString
   end
   
structure Graph = DirectedGraph
structure Node = Graph.Node
  
fun contify (program as Program.T {datatypes, globals, functions, main}) =
   let
      val strategy = !Control.contifyStrategy
   in case strategy of
      Control.None => program
    | _ => 
	 let
	    datatype 'a replace =
	       None
	     | OneCall of 'a
	     | OneCont of 'a
	    val {get = funcInfo:
		 Func.t -> {
			    callers: Func.t list ref,
			    callInfo: CallInfo.t ref,
			    canPrefix: bool ref,
			    contSet: ContSet.t,
			    isLocal: bool ref,
			    nested: Func.t list ref,
			    node: Graph.Node.t option ref,
			    possiblePrefixes: Func.t list ref,
			    prefixes: Func.t list option ref,
			    replace: {args: (Var.t * Type.t) list,
				      body: Exp.t,
				      jump: Jump.t} option ref
			    }} =
	       Property.get (Func.plist,
			     fn _ => {callers = ref [],
				      callInfo = ref CallInfo.NoOuterCalls,
				      canPrefix = ref false,
				      contSet = ContSet.new (),
				      isLocal = ref false,
				      nested = ref [],
				      node = ref NONE,
				      possiblePrefixes = ref [],
				      prefixes = ref NONE,
				      replace = ref NONE})
	    (* Compute the contSet and callInfo for each function.
	     * The contSet is an over-approximation to the set of continuations
	     * with which the function is called.
	     * The callInfo tells at how many places outside of itself each
	     * function is called.
	     *)
	    val _ = ContSet.addMain (#contSet (funcInfo main))
	    val _ =
	       List.foreach
	       (functions, fn {name, body, ...} =>
		let val {callInfo, contSet = c, ...} = funcInfo name
		in Exp.foreachCall
		   (body, fn {func, cont, ...} =>
		    if Func.equals(name, func)
		       then
			  case cont of
			     NONE => ()
			   | SOME _ => (callInfo := CallInfo.NotCont
					; ContSet.makeTop c)
		    else
		       let
			  val {callers, callInfo = callInfo', contSet = c', ...} =
			     funcInfo func
			  val _ =
			     let
				datatype z = datatype CallInfo.t
			     in case !callInfo' of
				NoOuterCalls => callInfo' := OneOuterCall
			      | OneOuterCall => callInfo' := NotCont
			      | NotCont => ()
			     end
			  val _ = List.push (callers, name)
		       in case cont of
			  NONE => ContSet.<= (c, c')
			(* ContSet.addReturn (c, name) *)
			| SOME j => ContSet.addJump (c', j)
		       end)
		end)
	    (* Record for each jump the functions that might be turned into
	     * continuations as its prefixes.
	     * Record for each function the functions that might prefix its
	     * return.
	     *)
	    val {get = jumpInfo: Jump.t ->
		 {possiblePrefixes: Func.t list ref,
		  prefixes: Func.t list option ref}} =
	       Property.get (Jump.plist, fn _ => {possiblePrefixes = ref [],
						  prefixes = ref NONE})
	    val todo = ref []
	    val _ =
	       List.foreach
	       (functions, fn {name, ...} =>
		let
		   val {contSet, ...} = funcInfo name
		   fun doit (possiblePrefixes, prefixes) =
		      let
			 val _ =
			    case !possiblePrefixes of
			       [] => List.push (todo,
						(possiblePrefixes, prefixes))
			     | _ => ()
		      in List.push (possiblePrefixes, name)
		      end

		in case ContSet.set contSet of
		   ContSet.One j => 
		      let val {possiblePrefixes, prefixes, ...} = jumpInfo j
		      in doit (possiblePrefixes, prefixes)
		      end
		 | ContSet.Return f =>
		      let val {possiblePrefixes, prefixes, ...} = funcInfo f
		      in doit (possiblePrefixes, prefixes)
		      end
		 | _ => ()
		end)
	    (* Strongly connected components of a group of functions. *)
	    fun sccs (fs: Func.t list) =
	       let
		  val g = Graph.new ()
		  val {get = nodeFunc, set = setNodeFunc} =
		     Property.getSetOnce
		     (Node.plist, Property.initRaise ("func", Node.layout))
		   val _ =
		      List.foreach (fs, fn f =>
				    let val {node, ...} = funcInfo f
				       val n = Graph.newNode g
				       val _ = setNodeFunc (n, f)
				       val _ = node := SOME n
				    in ()
				    end)
		   (* Build the call graph.
		    * Edges go from nodes to the callers.
		    *)
		   val _ =
		      List.foreach
		      (fs, fn f =>
		       let val {node, callers, ...} = funcInfo f
			  val from = valOf (!node)
		       in List.foreach
			  (!callers, fn f' =>
			   let val {node, ...} = funcInfo f'
			   in case !node of
			      NONE => ()
			    | SOME to =>
				 (Graph.addEdge (g, {from = from, to = to})
				  ; ())
			   end)
		       end)
		   val _ = List.foreach (fs, fn f => #node (funcInfo f) := NONE)
		   val nss = Graph.stronglyConnectedComponents g
	       in List.map (nss, fn ns => List.revMap (ns, nodeFunc))
	       end
	    
	    (* For each collection of functions that are going to prefix a cont,
	     * do a strongly connected components computation.  In order for the
	     * functions to be contified, there can be at most one function in
	     * each component that is called from outside.  This is due to the
	     * fact that mutually recursive continuations cannot be directly
	     * declared in CPS -- the only way to do so is to nest one within the
	     * other.  Thus only one can be "exported".
	     *)
	    exception Nope
	    (* sccHeads returns the head function for each scc in the list of
	     * functions.  It raises Nope if there is an scc with more than one
	     * function called from outside the scc.
	     *)
	    fun sccHeads (fs: Func.t list): Func.t list =
	       List.map
	       (sccs fs, fn fs =>
		let
		   fun setLocal b =
		      List.foreach (fs, fn f => #isLocal (funcInfo f) := b)
		   val _ = setLocal true
		   val outsideCaller = ref NONE
		   val _ =
		      List.foreach
		      (fs, fn f =>
		       let val {callers, ...} = funcInfo f
		       in if List.exists (!callers, fn f =>
					  not (!(#isLocal (funcInfo f))))
			     then (case !outsideCaller of
				      SOME _ => raise Nope
				    | NONE => outsideCaller := SOME f)
			  else ()
		       end)
		   val _ = setLocal false
		in case !outsideCaller of
		   NONE => Error.bug "no outside caller"
		 | SOME f =>
		      let
			 val {nested, ...} = funcInfo f
			 val rest = List.removeFirst (fs, fn f' =>
						      Func.equals (f, f'))
			 val _ = nested := sccHeads rest
		      in f
		      end
		end)
	    val _ =
	       List.foreach
	       (!todo, fn (possiblePrefixes, prefixes) =>
		let
		   val fs = !possiblePrefixes
		   val _ = prefixes := SOME (sccHeads fs)
		   val _ = List.foreach (fs, fn f =>
					 #canPrefix (funcInfo f) := true)
		in ()
		end handle Nope => ())
	    (* Diagnostics. *)
	    val _ =
	       if false
		  then 
		     let
			val _ =
			   Program.layouts
			   (program, fn l => 
			    (Layout.output (l, Out.error)
			     ; Out.newline Out.error))
			val old = ref 0
			val new = ref 0
			val newNo = ref 0
			val same = ref 0
			val sameNo = ref 0
		     in List.foreach
			(functions, fn {name, ...} =>
			 let
			    val {callInfo, canPrefix, contSet, ...} = funcInfo name
			    fun doit (r, s) =
			       (Int.inc r
				; if false
				     then print (concat [s, " ",
							 Func.toString name, "\n"])
				  else ())
			    datatype z = datatype CallInfo.t
			    datatype z = datatype ContSet.set
			 in case (!callInfo, ContSet.set contSet) of
			    (OneOuterCall, One _) =>
			       if !canPrefix
				  then doit (same, "same")
			       else doit (sameNo, "sameNo")
			  | (OneOuterCall, _) => doit (old, "old")
			  | (_, One _) =>
			       if !canPrefix
				  then doit (new, "new")
			       else doit (newNo, "newNo")
			  | _ => ()
			 end)
			; print (concat
				 ["num functions ",
				  Int.toString (List.length functions),
				  "  same ", Int.toString (!same),
				  "  sameNo ", Int.toString (!sameNo),
				  "  num new ", Int.toString (!new),
				  "  num newNo ", Int.toString (!newNo),
				  "  num old ", Int.toString (!old),
				  "\n"])
		     end
	       else ()
	    (* For functions turned into continuations, record their
	     * args, body, and new name.
	     *)
	    val _ =
	       List.foreach
	       (functions, fn {name, args, body, ...} =>
		let val {callInfo, canPrefix, replace, ...} = funcInfo name
		   val oneCall =
		      case !callInfo of
			 CallInfo.OneOuterCall => true
		       | _ => false
		   val oneCont = !canPrefix
		in if (case strategy of
			  Control.Both => oneCall orelse oneCont
			| Control.Call => oneCall
			| Control.Cont => oneCont
			| Control.None => false)
		      then
			 replace :=
			 SOME {args = args,
			       body = body,
			       jump = Jump.newString (Func.originalName name)}
		   else ()
		end)
	    (* Walk over all functions, removing those that aren't toplevel, and
	     * descending those that are, inserting local functions
	     * where necessary.
	     *  - turn tail calls into nontail calls
	     *  - turn returns into jumps
	     *)
	    fun walkExp (f: Func.t, e: Exp.t, c: Jump.t option): Exp.t =
	       let
		  val {decs, transfer} = Exp.dest e
		  val decs = 
		     List.fold
		     (rev decs, [], fn (d, ds) =>
		      case d of
			 Bind _ => d :: ds
		       | Fun {name, args, body} =>
			    Fun {name = name,
				 args = args,
				 body = walkExp (f, body, c)}
			    :: (if (case strategy of
				       Control.Both => true
				     | Control.Cont => true
				     | _ => false)
				   then
				      let val {prefixes, ...} = jumpInfo name
				      in case !prefixes of
					 NONE => ds
				       | SOME fs => nest (fs, SOME name, ds)
				      end
				else ds)
		       | HandlerPush h => HandlerPush h :: ds
		       | HandlerPop => HandlerPop :: ds)
		  fun make transfer = Exp.make {decs = decs,
						transfer = transfer}
	       in
		  case transfer of
		     Call {func, args, cont} =>
			let
			   val newCont: Jump.t option =
			      case cont of
				 NONE => c
			       | SOME _ => cont
			   val {callInfo, canPrefix, replace, ...} =
			      funcInfo func
			in
			   case !replace of
			      NONE =>
				 make (Call {func = func,
					     args = args,
					     cont = newCont})
			    | SOME {jump, args = formals, body} =>
				 let
				    val decs =
				       if !callInfo = CallInfo.OneOuterCall
					  andalso not (Func.equals (f, func))
					  andalso (case strategy of
						      Control.Both =>
							 not (!canPrefix)
						    | Control.Call => true
						    | _ => false)
					  then
					     decs
					     @ [Fun {name = jump,
						     args = formals,
						     body = walkExp (func, body,
								     newCont)}]
					      
				       else decs
				 in
				    Exp.make
				    {decs = decs,
				     transfer = Jump {dst = jump, args = args}}
				 end
			end
		   | Return xs =>
			make (case c of
				 NONE => transfer
			       | SOME c => Jump {dst = c, args = xs})
		   | _ => make transfer
	       end
	    and nest (fs: Func.t list,
		      cont: Jump.t option,
		      ds: Dec.t list): Dec.t list =
	       List.fold
	       (rev fs, ds, fn (f, ds) =>
		let
		   val {replace, nested, prefixes, ...} = funcInfo f
		   val fs =
		      case !prefixes of
			 NONE => !nested
		       | SOME fs => List.appendRev (fs, !nested)
		   val {jump, args, body} = valOf (!replace)
		   val {decs, transfer} = Exp.dest (walkExp (f, body, cont))
		   val body = Exp.make {decs = nest (!nested, cont, decs),
					transfer = transfer}
		in Fun {name = jump, args = args, body = body} :: ds
		end)
	    val shrinkExp = shrinkExp globals
	    val functions =
	       List.fold
	       (functions, [], fn ({name, args, body, returns}, functions) =>
		let val {replace, prefixes, ...} = funcInfo name
		in case !replace of
		   NONE =>
		      let
			 val body = shrinkExp (walkExp (name, body, NONE))
			 val body =
			    case !prefixes of
			       NONE => body
			     | SOME fs =>
				  let val {decs, transfer} = Exp.dest body
				  in Exp.make {decs = nest (fs, NONE, decs),
					       transfer = transfer}
				  end
		      in {name = name, args = args, returns = returns,
			  body = body}
		      end :: functions
		 | _ => functions
		end)
	    val program =
	       Program.T {datatypes = datatypes,
			  globals = globals,
			  functions = functions,
			  main = main}
	    val _ = Program.clear program
	 in
	    program
	 end
   end

end