[MLton] improved overloading for SML

Stephen Weeks sweeks@sweeks.com
Mon, 17 Oct 2005 17:16:24 -0700


One problem with the implementation of Num that I sent earlier is that
one must look at all of the code inside the Num structure to convince
oneself of the invariant that a value of type "x t" is of type x
underneath (where x is either real or int).  Below is an improved
implementation that uses a functor to isolate the necessary reasoning
to a few lines of reusable code.  The trick is to define a typecase
construct that can be used to implement all of the overloaded
functions.  One need only convince oneself that typecase maintains the
invariant in order to believe the code works.

Using a generic typecase functor also isolates the trick needed to
make MLton do the right thing, making it easier to create new
overloaded types and to extend existing ones.  It also makes clearer
the expressive power of this technique.

Too bad we don't have fold at the functor level, or there would be a
way to work around having to define a family of TypeCase<N> functors
for creating overloadable N-element type families.

----------------------------------------------------------------------

functor TypeCase2 (type x1
                   type x2):
   sig
      type 'a t

      val from1: x1 -> x1 t
      val from2: x2 -> x2 t
      val typeCase:
         'a t
         * (x1 * (x1 -> 'a t) -> 'b)
         * (x2 * (x2 -> 'a t) -> 'b)
         -> 'b
   end =
   struct
      datatype x = X1 of x1 | X2 of x2
      datatype 'a t = T of 'a option * x
         
      local
         fun make X x = T (NONE, X x)
      in
         val from1: x1 -> x1 t = make X1
         val from2: x2 -> x2 t = make X2
      end

      fun typeCase (T (a, x), f1, f2) =
         let
            fun call (f, X, x) = f (x, fn x => T (a, X x))
         in
            case x of
               X1 x => call (f1, X1, x)
             | X2 x => call (f2, X2, x)
         end
   end
   
signature NUM =
   sig
      type 'a t

      val < : 'a t * 'a t -> bool
      val <= : 'a t * 'a t -> bool
      val > : 'a t * 'a t -> bool
      val >= : 'a t * 'a t -> bool
      val ~ : 'a t -> 'a t
      val + : 'a t * 'a t -> 'a t
      val - : 'a t * 'a t -> 'a t
      val * : 'a t * 'a t -> 'a t
      val / : 'a t * 'b t -> real t
      val abs: 'a t -> 'a t
      val div: int t * int t -> int t
      val e: real t
      val fromInt: int -> int t
      val fromReal: real -> real t
      val max: 'a t * 'a t -> 'a t
      val min: 'a t * 'a t -> 'a t
      val mod: int t * int t -> int t
      val pi: real t
      val real: 'a t -> real t
      val round: 'a t -> int t
      val sqrt: 'a t -> real t
      val toString: 'a t -> string
      val typeCase:
         'a t
         * (int * (int -> 'a t) -> 'b)
         * (real * (real -> 'a t) -> 'b)
         -> 'b
   end

structure Num:> NUM =
   struct
      structure Z = TypeCase2 (type x1 = int
                               type x2 = real)
      open Z

      val fromInt = from1

      val fromReal = from2

      val e = fromReal Real.Math.e

      val pi = fromReal Real.Math.pi

      fun unary (fi, fr) n = typeCase (n, fi o #1, fr o #1)

      val toString = fn $ => unary (Int.toString, Real.toString) $

      val toReal = fn $ => unary (Real.fromInt, fn r => r) $

      fun sqrt n = fromReal (Real.Math.sqrt (toReal n))

      val real = fn $ => (fromReal o toReal) $

      val round = fn $ => unary (fromInt, fromInt o Real.round) $

      local
         fun make (fi, fr) n =
            let
               fun wrap f (x, X) = X (f x)
            in
               typeCase (n, wrap fi, wrap fr)
            end
      in
         val abs = fn $ => make (Int.abs, Real.abs) $
         val ~ = fn $ => make (Int.~, Real.~) $
      end

      fun bug _ = raise Fail "bug"

      local
         fun make (fi, fr) (n1, n2) =
            typeCase
            (n1,
             fn (i1, _) => typeCase (n2, fn (i2, _) => fi (i1, i2), bug),
             fn (r1, _) => typeCase (n2, bug, fn (r2, _) => fr (r1, r2)))
      in
         val op < = fn $ => make (Int.<, Real.<) $
         val op <= = fn $ => make (Int.<=, Real.<=) $
         val op > = fn $ => make (Int.>, Real.>) $
         val op >= = fn $ => make (Int.>=, Real.>=) $
      end

      local
         fun make (fi, fr) (n1, n2) =
            typeCase
            (n1,
             fn (i1, I) => typeCase (n2, fn (i2, _) => I (fi (i1, i2)), bug),
             fn (r1, R) => typeCase (n2, bug, fn (r2, _) => R (fr (r1, r2))))
      in
         val op + = fn $ => make (Int.+, Real.+) $
         val op - = fn $ => make (Int.-, Real.-) $
         val op * = fn $ => make (Int.*, Real.* ) $
         val max = fn $ => make (Int.max, Real.max) $
         val min = fn $ => make (Int.min, Real.min) $
      end

      fun a / b = fromReal (Real./ (toReal a, toReal b))

      local
         fun make f (n1, n2) =
            typeCase
            (n1,
             fn (i1, I) => typeCase (n2, fn (i2, _) => I (f (i1, i2)), bug),
             bug)
      in
         val op div = fn $ => make Int.div $
         val op mod = fn $ => make Int.mod $
      end
   end

functor Test (Num: NUM) =
   struct
      open Num
      val i = fromInt
      val r = fromReal
      fun p n = print (concat [toString n, "\n"])
      val () = p (i 1 + i 2)
      val () = p (r 1.5 + r 2.5)
      val () = p (round ((i 1 + i 2) / r 3.5))
      fun double x = x + x
      val () = p (double (i 1))
      val () = p (double (r 1.5 + pi))
   end

structure Z = Test (Num)