[MLton] improved overloading for SML

Stephen Weeks MLton@mlton.org
Tue, 18 Oct 2005 17:27:50 -0700


There is an aspect of the latest Num implementation that I did not
like.  There were calls to the "bug" function, which raises an
exception, inside the implementation of Num.  These corresponded to
code that should be unreachable, where the unreachability was
guaranteed by the invariants enforced by the phantom types and
signature constraints.

This mail has an implementation of Num that fixes the problem.  All of
the potentially exception-raising code has been moved to the generic
TypeCase2 module, where it can be proved correct (i.e. that it doesn't
raise any exceptions) once and for all.  Furthermore, TypeCase2 has
been improved to be more general, by using varargs fold to allow
typecase over any number of arguments (of the same type).  This is not
just a syntactic convenience -- it is essential to implement operators
like + that require both arguments to be the same type.  With typecase
only operating on a single argument, the code for + had to account for
the possibility of getting, say, a real and an int.  Now, with the
code below, the implementation of plus looks like

  fun n1 + n2 =
     typeCase`n1`n2 $
     (fn (i1 & i2, I) => I (Int.+ (i1, i2)),
      fn (r1 & r2, R) => R (Real.+ (r1, r2)))

The type system guarantees that + gets either two reals or two
integers.

The complete code is below.  It makes use of the Fold01N module that I
just explained in an earlier mail.

----------------------------------------------------------------------

datatype ('a, 'b) product = & of 'a * 'b
infix 4 &
fun $ (a, f) = f a
fun const c _ = c
fun curry h x y = h (x, y)
fun id x = x
fun ignore _ = ()
fun pass x f = f x

structure Fold =
   struct
      type ('a, 'b, 'c, 'd) step = 'a * ('b -> 'c) -> 'd
      type ('a, 'b, 'c, 'd) t = ('a, 'b, 'c, 'd) step -> 'd
      type ('a1, 'a2, 'b, 'c, 'd) step0 =
         ('a1, 'b, 'c, ('a2, 'b, 'c, 'd) t) step
      type ('a1, 'a2, 'a3, 'b, 'c, 'd) step1 =
         ('a2, 'b, 'c, 'a1 -> ('a3, 'b, 'c, 'd) t) step

      val fold = pass
      fun step0 h (a1, f) = fold (h a1, f)
      fun step1 h $ x = step0 (curry h x) $
   end

structure Fold01N =
   struct
      type ('a, 'b, 'c, 'd, 'e) t =
         ((unit -> unit) * ('c -> 'c), (unit -> 'a) * 'd, 'b, 'e) Fold.t
      type ('a, 'b, 'c, 'z1, 'z2, 'z3, 'z4, 'z5) step1 =
         ('z1,
          'z2 * ('z1 -> 'a),
          (unit -> 'a) * ('b -> 'c),
          'z3, 'z4, 'z5) Fold.step1
          
      val fold: ('a -> 'b) -> ('a, 'b, 'c, 'd, 'e) t =
         fn finish =>
         Fold.fold ((ignore, id), fn (p, _) => finish (p ()))
         
      val step1
         : ('a * 'b -> 'c) -> ('a, 'b, 'c, 'z1, 'z2, 'z3, 'z4, 'z5) step1 =
         fn combine =>
         Fold.step1 (fn (x, (_, f)) =>
                     (fn () => f x, fn x' => combine (f x, x')))
   end

signature TYPE_CASE2 =
   sig
      type u1
      type u2
      type ('a, 'b1, 'b2) t

      val from1: 'b1 -> (u1, 'b1, 'b2) t
      val from2: 'b2 -> (u2, 'b1, 'b2) t
      val to1: (u1, 'b1, 'b2) t -> 'b1
      val to2: (u2, 'b1, 'b2) t -> 'b2
      val typeCase:
         (('a, 'c1, 'c2) t,
          (  ('c1 * ('b1 -> ('a, 'b1, 'b2) t) -> 'e)
           * ('c2 * ('b2 -> ('a, 'b1, 'b2) t) -> 'e)) -> 'e,
          'z1, 'z2, 'z3) Fold01N.t
      val ` :
         (('a, 'c1, 'c2) t,
          ('a, 'b1, 'b2) t,
          ('a, ('c1, 'b1) product, ('c2, 'b2) product) t,
          'z1, 'z2, 'z3, 'z4, 'z5) Fold01N.step1
   end

structure TypeCase2:> TYPE_CASE2 =
   struct
      (* Invariant: Values are always of the form X1 (U1, _) or X2 (U2, _) *)
      datatype u1 = U1
      datatype u2 = U2
      datatype ('a, 'b1, 'b2) t =
          X1 of 'a * 'b1
        | X2 of 'a * 'b2

      fun from1 b = X1 (U1, b)
      fun from2 b = X2 (U2, b)

      fun bug () = raise Fail "bug"

      val to1 = fn X1 (_, b) => b | _ => bug ()
      val to2 = fn X2 (_, b) => b | _ => bug ()
         
      fun typeCase $ =
         Fold01N.fold
         (fn p => fn (f1, f2) =>
          let
             fun call (f, P, (a, p)) = f (p, fn x => P (a, x))
          in
             case p of
                X1 p => call (f1, X1, p)
              | X2 p => call (f2, X2, p)
          end) $

      fun ` $ =
         Fold01N.step1
         (fn (X1 (a, p), X1 (_, x)) => X1 (a, p & x)
           | (X2 (a, p), X2 (_, x)) => X2 (a, p & x)
           | _ => bug ()) $
   end
   
signature NUM =
   sig
      type 'a t = ('a, int, real) TypeCase2.t
      type i = TypeCase2.u1 t
      type r = TypeCase2.u2 t

      val < : 'a t * 'a t -> bool
      val <= : 'a t * 'a t -> bool
      val > : 'a t * 'a t -> bool
      val >= : 'a t * 'a t -> bool
      val ~ : 'a t -> 'a t
      val + : 'a t * 'a t -> 'a t
      val - : 'a t * 'a t -> 'a t
      val * : 'a t * 'a t -> 'a t
      val / : 'a t * 'b t -> r
      val abs: 'a t -> 'a t
      val div: i * i -> i
      val e: r
      val fromInt: int -> i
      val fromReal: real -> r
      val max: 'a t * 'a t -> 'a t
      val min: 'a t * 'a t -> 'a t
      val mod: i * i -> i
      val pi: r
      val real: 'a t -> r
      val round: 'a t -> i
      val sqrt: 'a t -> r
      val toInt: i -> int
      val toReal: r -> real
      val toString: 'a t -> string
   end

structure Num:> NUM =
   struct
      open TypeCase2

      type 'a t = ('a, int, real) t

      type i = u1 t
      type r = u2 t

      val fromInt: int -> i = from1

      val fromReal: real -> r = from2

      val toInt: i -> int = to1

      val toReal: r -> real = to2

      val e = fromReal Real.Math.e

      val pi = fromReal Real.Math.pi

      fun unary (fi, fr) n = typeCase`n $ (fi o #1, fr o #1)

      val toString = fn $ => unary (Int.toString, Real.toString) $

      fun real n = fromReal (unary (Real.fromInt, fn r => r) n)

      fun sqrt n = fromReal (Real.Math.sqrt (toReal (real n)))

      val round = fn $ => unary (fromInt, fromInt o Real.round) $

      local
         fun make (fi, fr) n =
            let
               fun wrap f (x, X) = X (f x)
            in
               typeCase`n $ (wrap fi, wrap fr)
            end
      in
         val abs = fn $ => make (Int.abs, Real.abs) $
         val ~ = fn $ => make (Int.~, Real.~) $
      end

      local
         fun make (fi, fr) (n1, n2) =
            typeCase`n1`n2 $
            (fn (i1 & i2, _) => fi (i1, i2),
             fn (r1 & r2, _) => fr (r1, r2))
      in
         val op < = fn $ => make (Int.<, Real.<) $
         val op <= = fn $ => make (Int.<=, Real.<=) $
         val op > = fn $ => make (Int.>, Real.>) $
         val op >= = fn $ => make (Int.>=, Real.>=) $
      end

      local
         fun make (fi, fr) (n1, n2) =
            typeCase`n1`n2 $
            (fn (i1 & i2, I) => I (fi (i1, i2)),
             fn (r1 & r2, R) => R (fr (r1, r2)))
      in
         val op + = fn $ => make (Int.+, Real.+) $
         val op - = fn $ => make (Int.-, Real.-) $
         val op * = fn $ => make (Int.*, Real.* ) $
         val max = fn $ => make (Int.max, Real.max) $
         val min = fn $ => make (Int.min, Real.min) $
      end

      fun a / b = fromReal (Real./ (toReal (real a), toReal (real b)))

      local
         fun make f (n1, n2) = fromInt (f (toInt n1, toInt n2))
      in
         val op div = fn $ => make Int.div $
         val op mod = fn $ => make Int.mod $
      end
   end

functor Test (Num: NUM) =
   struct
      open Num
      val i = fromInt
      val r = fromReal
      fun p n = print (concat [toString n, "\n"])
      val () = p (i 1 + i 2)
      val () = p (r 1.5 + r 2.5)
      val () = p (round ((i 1 + i 2) / r 3.5))
      fun double x = x + x
      val () = p (double (i 1))
      val () = p (double (r 1.5 + pi))
   end

structure Z = Test (Num)