[MLton] recursive generics in SML

Stephen Weeks MLton@mlton.org
Sun, 23 Oct 2005 23:36:02 -0700


I've sent earlier mails showing how to do overloading in SML via
typecase, where the overloaded operators worked on a flat sum type,
like int + real.  This note shows how to extend that approach to
inductively defined type families via a safe (I think) typerec
function.  For example, suppose one wanted to represent the type
family

  'b + 'b list + 'b list list + 'b list list list + ...

Then, this approach represents the infinite sum via the type

  ('a, 'b) t

For a particular value of type "(u, b) t", u is an index indicating
which summand the value is in and b indicates the base type.  The
summand index is represented by a family of types

  type u1
  type 'a u2

For example, a "string list list" is represented by the type

  (u1 u2 u2, string) t

As with the earlier non-recursive sum types, there are injections and
projections from the sum.

  val from1: 'b -> (u1, 'b) t
  val from2: ('a, 'b) t list -> ('a u2, 'b) t
  val to1: (u1, 'b) t -> 'b
  val to2: ('a u2, 'b) t -> ('a, 'b) t list

The fun part is the typeRec function, which allows one to define an
overloaded function that works over the entire sum type, via recursion
over the types.  For example, one can define a generic print
function as follows.

  val print: ('a, string) t -> unit =
     fn x => 
     typeRec`x $ (TextIO.print, fn (l, print) => List.foreach (l, print))

There is no direct recursion here.  After being supplied as many
values as desired, typeRec takes two arguments.  The first, in this
case TextIO.print, gives the behavior at the base type.  The second
gives the behavior at higher type, and can use the function provided by
typeRec for recursive calls at the next lower type.  The definition of
typeRec uses similar Fold01N technology as one of the earlier typeCase
examples.

Although this looks like polymorphic recursion on the outside, SML
doesn't have polymorphic recursion, so underneath there is a universal
type (hence we don't get the nice code duplication as with the flat
sum type).  But, that is completely hidden under the interface.  What
one sees from the outside is the infinite family of types, along with
a way of defining certain (apparently) polymorphically recursive
functions.

The code below also includes a couple of other examples from the
extensional polymorphism/GCaml papers: generic flatten and generic
equals.  I think the other examples that recur on arrow types are
pretty easy to emulate as well.  

There is one weakness of typeRec as given here; the same result type
must be returned for each member of the family.  I think this
restriction is present in GCaml as well.  It rules out some functions
that could be done with a more powerful typeRec (such as in TIL), for
example, reversing all the lists at every level within a value of type
('a, 'b) t.  The essential missing ingredient is the ability to
construct a value whose type depends (in the same way at every level)
on the level of the value currently being processed.  I thought that
perhaps the ideas about witnesses to type equivalences that I sent
earlier might help, but I couldn't figure out how to make that work
with the recursive helper provided by typeRec.

Here's the code.

----------------------------------------------------------------------

structure List =
   struct
      fun fold (l, b, f) = List.foldl f b l
      fun foreach (l, f) = List.app f l
      fun map (l, f) = List.map f l
   end

datatype ('a, 'b) product = & of 'a * 'b
infix 4 &
fun $ (a, f) = f a
fun const c _ = c
fun curry h x y = h (x, y)
fun id x = x
fun ignore _ = ()
fun pass x f = f x

structure Fold =
   struct
      type ('a, 'b, 'c, 'd) step = 'a * ('b -> 'c) -> 'd
      type ('a, 'b, 'c, 'd) t = ('a, 'b, 'c, 'd) step -> 'd
      type ('a1, 'a2, 'b, 'c, 'd) step0 =
         ('a1, 'b, 'c, ('a2, 'b, 'c, 'd) t) step
      type ('a1, 'a2, 'a3, 'b, 'c, 'd) step1 =
         ('a2, 'b, 'c, 'a1 -> ('a3, 'b, 'c, 'd) t) step

      val fold = pass
      fun step0 h (a1, f) = fold (h a1, f)
      fun step1 h $ x = step0 (curry h x) $
   end

structure Fold01N =
   struct
      type ('a, 'b, 'c, 'd, 'e, 'f) t =
         ((unit -> unit) * ('a -> 'b), (unit -> 'c) * 'e, 'd, 'f) Fold.t
      type ('a, 'b, 'c, 'z1, 'z2, 'z3, 'z4, 'z5) step1 =
         ('z1,
          'z2 * ('z1 -> 'a),
          (unit -> 'a) * ('b -> 'c),
          'z3, 'z4, 'z5) Fold.step1
          
      val fold: ('a -> 'b) * ('c -> 'd) -> ('a, 'b, 'c, 'd, 'e, 'f) t =
         fn (one, finish) =>
         Fold.fold ((ignore, one), fn (p, _) => finish (p ()))
         
      val step1
         : ('a * 'b -> 'c) -> ('a, 'b, 'c, 'z1, 'z2, 'z3, 'z4, 'z5) step1 =
         fn combine =>
         Fold.step1 (fn (x, (_, f)) =>
                     (fn () => f x, fn x' => combine (f x, x')))
   end

signature TYPE_REC =
   sig
      type u1
      type 'a u2
      type 'a tr
      type z
      type ('a, 'b) t

      val from1: 'b -> (u1, 'b) t
      val from2: ('a, 'b) t tr -> ('a u2, 'b) t
      val to1: (u1, 'b) t -> 'b
      val to2: ('a u2, 'b) t -> ('a, 'b) t tr

      type ('a, 'b, 'c, 'd, 'e1, 'e2) split
      val typeRec:
         (('a, 'b) t,
          ('a, 'b) t
          * ('a, 'b, (z, 'b) t, ('a, 'b) t, 'b, (z, 'b) t tr) split,
          'd * ('a, 'b, 'c, 'd, 'e1, 'e2) split,
          (('e1 -> 'f) * ('e2 * ('c -> 'f) -> 'f)) -> 'f,
          'z1, 'z2) Fold01N.t
      val ` :
         ('d * ('a, 'b, 'c, 'd, 'e1, 'e2) split,
          ('a, 'b) t,
          ('d, ('a, 'b) t) product
          * ('a, 'b,
             ('c, (z, 'b) t) product,
             ('d, ('a, 'b) t) product,
             ('e1, 'b) product,
             ('e2, (z, 'b) t tr) product) split,
          'z1, 'z2, 'z3, 'z4, 'z5) Fold01N.step1
   end

functor TypeRec (type 'a t):> TYPE_REC where type 'a tr = 'a t =
   struct
      type 'a tr = 'a t

      datatype ('b1, 'b2) s =
          X1 of 'b1
        | X2 of 'b2

      datatype 'b t = T of ('b, 'b t tr) s
         
      fun from1 b = T (X1 b)
      fun from2 b = T (X2 b)

      fun bug () = raise Fail "bug"

      val to1 = fn T (X1 b) => b | _ => bug ()
      val to2 = fn T (X2 b) => b | _ => bug ()
         
      type ('a, 'b, 'c, 'd, 'e1, 'e2) split =
         ('c -> 'd) * ('d -> ('e1, 'e2) s)

      fun typeRec $ =
         Fold01N.fold
         (fn x => (x, (id, fn T x => x)),
          fn (p, (cast, split)) => fn (f1, f2) =>
          let
             fun loop p =
                case split p of
                   X1 p => f1 p
                 | X2 p => f2 (p, loop o cast)
          in
             loop p
          end)
         $

      fun ` $ =
         Fold01N.step1
         (fn ((p, (cast, split)), x) =>
          (p & x,
           (fn p & x => cast p & x,
            fn p & T x =>
            case (split p, x) of
               (X1 p, X1 x) => X1 (p & x)
             | (X2 p, X2 x) => X2 (p & x)
             | _ => bug ())))
         $

      type ('a, 'b) t = 'b t
      type u1 = unit
      type 'a u2 = unit
      type z = unit
   end

structure Test =
   struct
      structure ListRec = TypeRec (type 'a t = 'a list)
      open ListRec

      val flatOnto: ('a, 'b) t -> 'b list -> 'b list =
         fn x =>
         typeRec`x $
         (fn b => fn ac => b :: ac,
          fn (l, flatOnto) => fn ac =>
          List.fold (rev l, ac, fn (x, ac) => flatOnto x ac))

      val flat: ('a, 'b) t -> 'b list = fn x => flatOnto x []

      val equals: ('a, 'b) t * ('a, 'b) t * ('b * 'b -> bool) -> bool =
         fn (x1, x2, equals) =>
         typeRec`x1`x2 $
         (fn y1 & y2 => equals (y1, y2),
          fn (l1 & l2, equals) =>
          let
             val rec loop =
                fn ([], []) => true
                 | (x1 :: l1, x2 :: l2) =>
                      equals (x1 & x2) andalso loop (l1, l2)
                 | _ => false
          in
             loop (l1, l2)
          end)

      val print: ('a, string) t -> unit =
         fn x =>
         typeRec`x $
         (TextIO.print, fn (l, print) => List.foreach (l, print))

      val S = from1
      val L = from2
      val x1 = S "s\n"
      val x2 = L [S "hello, ", S "world\n"]
      val x3 = L [L [], L [S "hello, "], L [S "world\n"]]
      val x4 = L [L [], L [S "hello, "], L [S "world"]]

      val () = (print x1; print x2; print x3)
      val () = List.foreach (flat x3, TextIO.print)

      fun test (x1, x2) =
         TextIO.print (concat [Bool.toString (equals (x1, x2, op =)), "\n"])

      val () = test (x1, x1)
      val () = test (x2, x2)
      val () = test (x3, x3)
      val () = test (x3, x4)
   end