[MLton] improved overloading for SML

Stephen Weeks MLton@mlton.org
Sun, 16 Oct 2005 15:56:43 -0700


This note describes an approach to overloading in SML that has the
following properties.

  * It can be expressed in SML'97 without any special front-end
    support (_overload declarations, overloading resolution). 
  * Type checking ensures as much as with current SML overloading.
  * Functions can be defined using overloaded operators and retain 
    the ability to be used at all overloaded types.  There is no
    "default" type for operators.
  * Overloading can be extended in different ways in different client
    code to support additional types.
  * With appropriate compiler optimizations (already present in
    MLton), there is no runtime space or time cost.  That is, all
    overloading is resolved at compile time and there is no runtime
    tagging or tag dispatch.

The only drawback of this approach w.r.t. the current ad-hoc approach
is that there is no automatic constant overloading, and so constants
require some additional syntax.  I personally don't find that very
debilitating, since most constants are named and not sprinkled
throughout code.

The idea of the approach is to use phantom subtyping, where the
supertype is a sum type of all of the overloaded types, and each
phantom subtype is one of the overloaded types.  I'll demonstrate the
approach below and build overloaded functions on integers and reals.
First, here's the signature, where "'a t" means integer or real, "int
t" means integer, and "real t" means real.

----------------------------------------------------------------------
signature NUM =
   sig
      type 'a t (* 'a is int or real *)

      val < : 'a t * 'a t -> bool
      val <= : 'a t * 'a t -> bool
      val > : 'a t * 'a t -> bool
      val >= : 'a t * 'a t -> bool
      val ~ : 'a t -> 'a t
      val + : 'a t * 'a t -> 'a t
      val - : 'a t * 'a t -> 'a t
      val * : 'a t * 'a t -> 'a t
      val / : 'a t * 'b t -> real t
      val abs: 'a t -> 'a t
      val div: int t * int t -> int t
      val e: real t
      val fromInt: int -> int t
      val fromReal: real -> real t
      val max: 'a t * 'a t -> 'a t
      val min: 'a t * 'a t -> 'a t
      val mod: int t * int t -> int t
      val pi: real t
      val real: 'a t -> real t
      val round: 'a t -> int t
      val sqrt: 'a t -> real t
   end
----------------------------------------------------------------------

This expresses all the operators that are overloaded in SML, and with
as much precision.  It also has more, for example, the / operator
deals with ints in addition to reals.  Even better, one can define new
functions that use the overloaded operators while retaining the
ability to work on both types.  And the SML type checker will even
infer the correct overloaded type for new functions.  Here's a simple
Test functor to show what I mean.

----------------------------------------------------------------------
functor Test (Num: NUM) =
   struct
      open Num
      val i = fromInt
      val r = fromReal
      fun p n = print (concat [toString n, "\n"])
      val () = p (i 1 + i 2)
      val () = p (r 1.5 + r 2.5)
      val () = p (round ((i 1 + i 2) / r 3.5))
      fun double x = x + x
      val () = p (double (i 1))
      val () = p (double (r 1.5 + pi))
   end
----------------------------------------------------------------------

The implementation of NUM is straightforward, with a sum type for the
two kinds of numbers and operators dispatching on the variants of the
sum.  There is a little extra complexity to trick MLton into doing
some code duplication, which I will explain later.

----------------------------------------------------------------------
structure Num:> NUM =
   struct
      datatype 'a t =
         I of 'a * int
       | R of 'a * real
         
      fun fromInt i = I (0,  i)

      fun fromReal r = R (0.0, r)

      val e = fromReal Real.Math.e

      val pi = fromReal Real.Math.pi

      fun unary (fi, fr) =
         fn I (_, i) => fi i
          | R (_, r) => fr r

      val toString = fn $ => unary (Int.toString, Real.toString) $

      val toReal = fn $ => unary (Real.fromInt, fn r => r) $

      fun sqrt n = fromReal (Real.Math.sqrt (toReal n))

      val real = fn $ => (fromReal o toReal) $

      val round = fn $ => unary (fromInt, fromInt o Real.round) $

      local
         fun make (fi, fr) =
            fn I (z, i) => I (z, fi i)
             | R (z, r) => R (z, fr r)
      in
         val abs = fn $ => make (Int.abs, Real.abs) $
         val ~ = fn $ => make (Int.~, Real.~) $
      end

      fun bug _ = raise Fail "bug"

      local
         fun make (fi, fr) =
            fn (I (_, i1), I (_, i2)) => fi (i1, i2)
             | (R (_, r1), R (_, r2)) => fr (r1, r2)
             | _ => bug ()
      in
         val op < = fn $ => make (Int.<, Real.<) $
         val op <= = fn $ => make (Int.<=, Real.<=) $
         val op > = fn $ => make (Int.>, Real.>) $
         val op >= = fn $ => make (Int.>=, Real.>=) $
      end

      local
         fun make (fi, fr) =
            fn (I (z, i1), I (_, i2)) => I (z, fi (i1, i2))
             | (R (z, r1), R (_, r2)) => R (z, fr (r1, r2))
             | _ => bug ()
      in
         val op + = fn $ => make (Int.+, Real.+) $
         val op - = fn $ => make (Int.-, Real.-) $
         val op * = fn $ => make (Int.*, Real.* ) $
         val max = fn $ => make (Int.max, Real.max) $
         val min = fn $ => make (Int.min, Real.min) $
      end

      fun a / b = fromReal (Real./ (toReal a, toReal b))

      local
         fun make f =
            fn (I (_, i1), I (_, i2)) => fromInt (f (i1, i2))
             | _ => bug ()
      in
         val op div = fn $ => make Int.div $
         val op mod = fn $ => make Int.mod $
      end
   end
----------------------------------------------------------------------

Testing Num with "structure Z = Test (Num)" gives the expected
results.

----------------------------------------------------------------------
3
4
1
2
9.28318530718
----------------------------------------------------------------------

To see how to add support for another type, suppose that you have a
program that deals with complex numbers as well.  You can extend the
operators to work there too by reusing the trick.  Define a new sum
type

  datatype 'a t = Num of 'a num | Complex of 'a * Complex.t

and dispatching operators, like:

  val op +: 'a t * 'a t -> 'a t =
    fn (Num n, Num n') => Num (Num.+ (n, n'))
     | (Complex (z, c), Complex (_, c')) => Complex (z, Complex.+ (c, c'))
     | _ => bug ()

Then, within the scope of these definitions, + works on int's, real's,
and complex's.  Of course, one needs manually coerce between this
world and the Int+Real world, but I can imagine this extension
approach still being quite useful.


To close, I'll explain how the optimizations already in place in MLton
will completely eliminate the Num.t sum type, as well as the Int and
Real variants, leaving only the raw int and real types.  The Num code
is written carefully to make this happen.

First off, MLton's monomorphisation pass will create different copies
of the Num types and functions, one for ints and one for reals
(actually there will be potentially four copies for functions like /
which have more than one type variable).  To make the monomorphisation
do the right thing, I've made sure that the type variable is used
within the Num.t type.  If it weren't, MLton's pass to eliminate
unused type variables would eliminate it, making the type monomorphic,
and then monomorphisation wouldn't duplicate anything.  So, although I
could have written the Num.t datatype as

  datatype 'a t = I of int | R of real

this would have, oddly enough, inhibited the optimizer.  An
alternative is to disable the pass in MLton that eliminates unused
type variables, by compiling with "-drop-pass xmlSimplifyTypes".  I
chose to include the 'a and the bogus zero values so that the program
will compile efficiently out of the box.

In any case, once monomorphisation duplicates the datatype and
functions at both the int and real types, subsequent MLton ILs
will see two datatypes:

  datatype tI = IntI of int * int | RealI of int * real
  datatype tR = IntR of real * int | RealR of real * real

For each function, e.g. +, there will be two functions:

  val +I : tI * tI -> tI
  val +R : tR * tR -> tR

Then, MLton's useless-variant elimination will kick in.  This is an
optimization that eliminates a variant of a datatype if it can prove
that the variant is never constructed.  In this case, useless-variant
elimination will notice that the RealI and IntR variants are useless.
As humans, we can convince ourselves of that fact by studying the Num
implementation and observing that the "I" constructor is always used
with "I (0, ...)" and that the "R" constructor is always used with "R
(0.0, ...)".  Furthermore, the signature constraint ensures that
clients maintain this invariant.

MLton convinces itself that RealI is useless by observing that the
only uses of RealI as a constructor are on the right-hand side of
cases that require a RealI value.  For example, look at the definition
of +I

  val +I = fn (IntI (z, i1), IntI (_, i2)) => IntI (z, Int.+ (i, j))
            | (RealI (z, r1), RealI (_, r2)) => RealI (z, Real.+ (x, y))
            | _ => bug ()

Useless-variant elimination is smart enough to understand the
circularity, and, since there is no "base case" that creates a RealI
value, is smart enough to eliminate RealI entirely.

Once the RealI variant is eliminated, tI becomes:

  datatype tI = IntI of int * int

MLton's useless component analysis will then eliminate the first int,
which is always zero, leaving it free to represent tI as a raw int.

Whew.

Provided that my analysis of MLton's optimizer is correct, the Num
implementation will *never* use the full sum type, and will always use
the raw int or raw real.  The phantom types and the type system
guarantee that clients of Num.t never mix "int Num.t" and "real
Num.t".  The only thing that can be done is write overloaded functions
that deal with type 'a Num.t, and those will be specialized
appropriately.  I took a look at the ILs for a simple program and it
seems to bear my analysis out.