[MLton] improved overloading for SML

Stephen Weeks sweeks@sweeks.com
Sat, 22 Oct 2005 14:16:14 -0700


Here's a slightly different approach to doing typeCase.  This approach
doesn't require any fold technology.  Rather, it uses a generally
applicable trick for exposing dataflow information down the branch of
a conditional via a "witness" value.  Here's a snippet of the
signature.

      structure Equiv:
         sig
            type ('a, 'b) t
         end
      type u1
      type u2
      type ('a, 'b1, 'b2) t (* 'a is u1 or u2 *)
      val cast: ('a, 'b1, 'b2) t * ('a, 'c) Equiv.t -> ('c, 'b1, 'b2) t
      val to1: (u1, 'b1, 'b2) t -> 'b1
      val typeCase:
         ('a, 'b1, 'b2) t
         * (('a, u1) Equiv.t -> 'c)
         * (('a, u2) Equiv.t -> 'c)
         -> 'c

Here, "('a, 'b1, 'b2) t" is either a value of type 'b1 or a value of
type 'b2.  The "'a" specifies which type the value is.  If the 'a is
u1 (a type constant) then the value is of type 'b1 and if 'a is u2
then the value is of type 'b2.  In "typeCase (v, f1, f2)", if v is of
type 'b1, we proceed down the f1 branch, supplying f1 with a witness
(of type ('a, u1) Equiv.t) that 'a is the same as u1.  Similarly for
the other branch.  Within the first branch, one can use the witness w
to extract the value of type 'b1 with "to1 (cast (v, w))".  Underneath
the signature cast is just the identity function (which is easy to
prove), but outside, the witnesses force us to always do type-safe
cast.

The complete code is below.  I've also shown the changes require to
implement Num using this new approach.

One reason I like this approach is the very simple reasoning needed to
convince oneself that Fail is never raised.

BTW, all the Equiv stuff could be completely phantom, except that I
want to force MLton to duplicate code.

--------------------------------------------------------------------------------
fun const c _ = c
fun id x = x

signature TYPE_CASE2 =
   sig
      structure Equiv:
         sig
            type ('a, 'b) t

            val reflexive: ('a, 'a) t
            val symmetric: ('a, 'b) t -> ('b, 'a) t
            val transitive: ('a, 'b) t * ('b, 'c) t -> ('a, 'c) t
         end

      type u1
      type u2
      type ('a, 'b1, 'b2) t (* 'a is u1 or u2 *)

      val cast: ('a, 'b1, 'b2) t * ('a, 'c) Equiv.t -> ('c, 'b1, 'b2) t
      val from1: 'b1 -> (u1, 'b1, 'b2) t
      val from2: 'b2 -> (u2, 'b1, 'b2) t
      val to1: (u1, 'b1, 'b2) t -> 'b1
      val to2: (u2, 'b1, 'b2) t -> 'b2
      val typeCase:
         ('a, 'b1, 'b2) t
         * (('a, u1) Equiv.t -> 'c)
         * (('a, u2) Equiv.t -> 'c)
         -> 'c
    end

structure TypeCase2:> TYPE_CASE2 =
   struct
      datatype u1 = U1
      datatype u2 = U2
      datatype ('b1, 'b2) x = X1 of 'b1 | X2 of 'b2
      datatype ('a, 'b1, 'b2) t = T of 'a * ('b1, 'b2) x

      structure Equiv =
         struct
            datatype ('a, 'b) t = T of ('a -> 'b) * ('b -> 'a)

            val reflexive = T (id, id)
            fun symmetric (T (f, g)) = T (g, f)
            fun transitive (T (f, g), T (f', g')) = T (f' o f, g o g')
         end

      fun cast (T (ty, x), Equiv.T (f, _)) = T (f ty, x)
         
      fun from1 x1 = T (U1, X1 x1)
      fun from2 x2 = T (U2, X2 x2)

      fun bug () = raise Fail "bug"

      val to1 = fn T (_, X1 x1) => x1 | _ => bug ()
      val to2 = fn T (_, X2 x2) => x2 | _ => bug ()
         
      fun typeCase (T (u, x), f1, f2) =
         let
            fun one (f, t) = f (Equiv.T (const t, const u))
         in
            case x of
               X1 _ => one (f1, U1)
             | X2 _ => one (f2, U2)
         end
   end

signature NUM =
   sig
      type 'a t = ('a, int, real) TypeCase2.t
      type i = TypeCase2.u1 t
      type r = TypeCase2.u2 t

      val < : 'a t * 'a t -> bool
      val <= : 'a t * 'a t -> bool
      val > : 'a t * 'a t -> bool
      val >= : 'a t * 'a t -> bool
      val ~ : 'a t -> 'a t
      val + : 'a t * 'a t -> 'a t
      val - : 'a t * 'a t -> 'a t
      val * : 'a t * 'a t -> 'a t
      val / : 'a t * 'b t -> r
      val abs: 'a t -> 'a t
      val div: i * i -> i
      val e: r
      val fromInt: int -> i
      val fromReal: real -> r
      val max: 'a t * 'a t -> 'a t
      val min: 'a t * 'a t -> 'a t
      val mod: i * i -> i
      val pi: r
      val real: 'a t -> r
      val round: 'a t -> i
      val sqrt: 'a t -> r
      val toInt: i -> int
      val toReal: r -> real
      val toString: 'a t -> string
   end

structure Num:> NUM =
   struct
      open TypeCase2

      type 'a t = ('a, int, real) t

      type i = u1 t
      type r = u2 t

      val fromInt: int -> i = from1

      val fromReal: real -> r = from2

      val toInt: i -> int = to1

      val toReal: r -> real = to2

      val e = fromReal Real.Math.e

      val pi = fromReal Real.Math.pi

      fun unary (fi, fr) n =
         let
            fun one (f, to) e n = f (to (cast (n, e)))
         in
            typeCase (n, one (fi, toInt), one (fr, toReal)) n
         end

      val toString = fn $ => unary (Int.toString, Real.toString) $

      fun real n = fromReal (unary (Real.fromInt, fn r => r) n)

      fun sqrt n = fromReal (Real.Math.sqrt (toReal (real n)))

      val round = fn $ => unary (fromInt, fromInt o Real.round) $

      local
         fun make (fi, fr) n =
            let
               fun one (f, from, to) e =
                  cast (from (f (to (cast (n, e)))), Equiv.symmetric e)
            in
               typeCase (n,
                         one (fi, fromInt, toInt),
                         one (fr, fromReal, toReal))
            end
      in
         val abs = fn $ => make (Int.abs, Real.abs) $
         val ~ = fn $ => make (Int.~, Real.~) $
      end

      local
         fun make (fi, fr) (n1, n2) =
            let
               fun one (f, from, to) e =
                  f (to (cast (n1, e)), to (cast (n2, e)))
            in
               typeCase (n1,
                         one (fi, fromInt, toInt),
                         one (fr, fromReal, toReal))
            end
      in
         val op < = fn $ => make (Int.<, Real.<) $
         val op <= = fn $ => make (Int.<=, Real.<=) $
         val op > = fn $ => make (Int.>, Real.>) $
         val op >= = fn $ => make (Int.>=, Real.>=) $
      end

      local
         fun make (fi, fr) (n1, n2) =
            let
               fun one (f, from, to) e =
                  cast (from (f (to (cast (n1, e)), to (cast (n2, e)))),
                        Equiv.symmetric e)
            in
               typeCase (n1,
                         one (fi, fromInt, toInt),
                         one (fr, fromReal, toReal))
            end
      in
         val op + = fn $ => make (Int.+, Real.+) $
         val op - = fn $ => make (Int.-, Real.-) $
         val op * = fn $ => make (Int.*, Real.* ) $
         val max = fn $ => make (Int.max, Real.max) $
         val min = fn $ => make (Int.min, Real.min) $
      end

      fun a / b = fromReal (Real./ (toReal (real a), toReal (real b)))

      local
         fun make f (n1, n2) = fromInt (f (toInt n1, toInt n2))
      in
         val op div = fn $ => make Int.div $
         val op mod = fn $ => make Int.mod $
      end
   end

functor Test (Num: NUM) =
   struct
      open Num
      val i = fromInt
      val r = fromReal
      fun p n = print (concat [toString n, "\n"])
      val () = p (i 1 + i 2)
      val () = p (r 1.5 + r 2.5)
      val () = p (round ((i 1 + i 2) / r 3.5))
      fun double x = x + x
      val () = p (double (i 1))
      val () = p (double (r 1.5 + pi))
   end

structure Z = Test (Num)