[MLton] How to fix a family of picklers or to tie a product of knots

Vesa Karvonen vesa.karvonen@cs.helsinki.fi
Thu, 13 Oct 2005 02:37:16 +0300


I recently read Andrew Kennedy's functional pearl Pickler Combinators [1].
In the discussion section it is said that

  "[...] it is necessary to be explicit about recursion, using a fixpoint
   operator whose type is ('a PU -> 'a PU) -> 'a PU. This is somewhat
   cumbersome, especially with mutual recursion, for a family of fixpoint
   combinators fix_n are required, where n is the number of functions
   defined by mutual recursion."

Upon reading this it struck me that instead of having a family of fixpoint
combinators, we could have combinators to build a fixpoint combinator of
any desired type. Concretely, instead of

  val (v1, ..., vN) = fixN (fn (v1, ..., vN) => ...)

we would write (using the varargs Fold notation [2] and the infix
product type [3])

  val v1 & ... & vN = fix T1 ... TN $ (fn v1 & ... & vN => ...)

The varargs combinators T1, ..., TN would be "tier" combinators that
would be used for "tying the knot"s. Each tier combinator may provide a
tier for a different type; it is possible to express mutual recursion over
multiple different types.

I have implemented the idea and tried it on (in addition to a couple of
other things) Andrew Kennedy's pickler combinators as implemented in the
SML.NET compiler [4]. You can find a copy of just the signature and the
structure of the FixProduct library at the end of the message.

The fix combinators used in the SML.NET compiler are implemented using
mutually recursive functions:

  fun fix F =
  let
    fun p x = let val (p',_) = F (p,u) in p' x end
    and u x = let val (_,u') = F (p,u) in u' x end
  in
    (p,u)
  end

Such an implementation isn't supported by the FixProduct library and the
recursion must be broken:

  fun fix F =
      let
         val p = ref (fn _ => raise Fail "fix P")
         val u = ref (fn _ => raise Fail "fix U")
         val k as (p', u') = F (fn x => !p x, fn x => !u x)
      in
         p := p' ;
         u := u' ;
         k
      end

We can then break the process of tying the knot into separate steps and
use the FixProduct library to make a tier for pairs of functions:

  structure F = FixProduct

  fun PU $ =
      F.makeTier
         {new = fn () => (ref (fn _ => raise Fail "fix P"),
                          ref (fn _ => raise Fail "fix U")),
          knot = fn (p, u) => (fn x => !p x, fn x => !u x),
          tie = fn ((sp, su), (p, u)) => (sp := p ; su := u)} $

Using the tier PU, it is now possible to compute fixpoints over arbitrary
products of pairs of functions. Instead of

  fix (fn x => ...),
  fix2 (fn (x, y) => ...), or
  fix3 (fn (x, y, z) => ...)

you would write

  fix PU $ (fn x => ...),
  fix PU PU $ (fn x & y => ...), and
  fix PU PU PU $ (fn x & y & z => ...)

where fix = FixProduct.fix and PU would come from the Pickle structure.

The use of ref cells instead of mutually recursive function definitions
also seems to bring a considerable performance improvement in the case of
pickler combinators and MLton. On my Centrino laptop, a simple test which
generates a largish tree and then pickles and unpickles it, the version
using ref cells runs over twice as fast as the (original) version using
mutually recursive functions.

(I can prepare a snapshot of the simple test code and full implementation
if anyone is interested.)

-Vesa Karvonen

[1] http://mlton.org/References#Kennedy04
[2] http://mlton.org/pipermail/mlton/2005-August/027907.html
[3] http://mlton.org/ProductType
[4] http://www.cl.cam.ac.uk/Research/TSG/SMLNET/download.html

(**
 * Extensible fixpoint combinator over products.
 *
 * In a strict language you sometimes want to provide a fixpoint
 * combinator for an abstract type {t} to make it possible to write
 * recursive definitions. Unfortunately, a single combinator {fix} of the
 * type {(t -> t) -> t} does not cover mutual recursion. To allow mutual
 * recursion, you need to provide a family of fixpoint combinators having
 * types of the form {(u -> u) -> u} where {u} is a type of the form {t *
 * ... * t}. Unfortunately, even such a family of fixpoint combinators
 * does not cover mutual recursion over different abstract types. This
 * library provides an extensible general purpose fixpoint combinator that
 * allows you to write recursive definitions of products.
 *)
signature FIX_PRODUCT =
   sig
      type ('kl, 'rl, 'kr, 'rr, 'k, 'r) fix_st
      (** The {fix_st} type hides the details of how knots are tied. *)

      val fix : ((unit, unit, 'a, 'b, 'a, 'b) fix_st
                 * (('c, 'd, 'e, 'f, 'g, 'h) fix_st -> 'd fix), 'i) va
      (**
       * {fix T1 ... TN $ f} computes a fixed point of {f} over a product
       * {v1 & ... & vN} using the tiers {T1 ... TN}.
       *)

      val makeTier : {new: 'a thunk,
                      knot: 'a -> 'b,
                      tie: ('a * 'b) effect}
                     -> ('c, 'd, 'a, 'b, 'e, 'f) fix_st * 'g
                     -> (('e, 'f, 'h, 'i, ('e, 'h) product, ('f, 'i) product) fix_st
                         * 'g, 'j) va
      (**
       * {makeTier {new = ..., knot = ..., tie = ...}} makes a new tier.
       *
       * - {new ()} is supposed to create a fresh "slot" for a "knot".
       * - {knot slot} is supposed to extract the (untied) knot out of the
       *   slot created by {new ()}.
       * - {tie (slot, knot)} is supposed to tie the knot, which is
       *   returned by the function whose fixed point is being computed,
       *   into the slot.
       *
       * Due to the value restriction, you will usually need to eta-expand
       * tier definitions:
       *
       *> fun T $ = makeTier {new = ...,
       *>                     knot = ...,
       *>                     tie = ...} $
       *)
   end

structure FixProduct :> FIX_PRODUCT =
   struct
      datatype ('kl, 'rl, 'kr, 'rr, 'k, 'r) fix_st =
               T of {new: 'kl thunk,
                     knot: 'kl -> 'rl,
                     tie: ('kl * 'rl) effect} *
                    {new: ('kl thunk, 'kr thunk) product -> 'k thunk,
                     knot: ('kl -> 'rl, 'kr -> 'rr) product -> 'k -> 'r,
                     tie: (('kl * 'rl) effect, ('kr * 'rr) effect) product -> ('k * 'r) effect}

      local
         val noneMake =
             {new = Product.snd,
              knot = Product.snd,
              tie = Product.snd}

         val someMake =
             {new = fn na & nb =>
                       fn () =>
                          na () & nb (),
              knot = Product.map,
              tie = fn ta & tb =>
                       fn (ka & kb, ra & rb) =>
                          (ta (ka, ra) : unit ; tb (kb, rb) : unit)}
      in
         fun fix $ =
             Fold.fold
                (T ({new = fn () => (),
                     knot = fn () => (),
                     tie = fn ((), ()) => ()},
                    noneMake),
                 fn T ({new, knot, tie}, _) =>
                    fn functional =>
                       let
                          val slot = new ()
                          val knot = functional (knot slot)
                       in
                          tie (slot, knot) ;
                          knot
                       end) $

         fun makeTier {new, knot, tie} =
             Fold.step0
                (fn T (l, m) =>
                    T ({new = #new m (#new l & new),
                        knot = #knot m (#knot l & knot),
                        tie = #tie m (#tie l & tie)},
                       someMake))
      end
   end