[MLton-user] Unified Argument Fold (OptionalArg + Literal + Num)

Vesa Karvonen vesa.a.j.k at gmail.com
Wed Jan 30 04:07:32 PST 2008


Stephen was right and I was wrong :-).  But since you have no idea what
I'm talking about, let me start with a historical note.  Before ICFP'06,
at the urging of Matthew Fluet that Stephen Weeks and I should write a
paper about the Fold technique (http://mlton.org/Fold), I spent about a
week with the goal to write such a paper (based on the Fold pages that
Stephen had already written), but ultimately that didn't happen.

Instead of writing a paper, I came up with many small gradual improvements
to the techniques and examples.  One of the improvements was realizing
that the array literal (http://mlton.org/ArrayLiteral) and numeric literal
(http://mlton.org/NumericLiteral) examples could be written in terms of a
VarArg module for writing variable argument functions.  The advantage, of
course, is that instead of having two incompatible specifiers Literal.`
and Num.`, there is just one specifier VarArg.`, that could be exposed at
the top-level for convenience.

Another improvement was writing modules FoldAlt and FoldPair for combining
folds (sum of folds (fold either this or that) and product of folds (fold
both at the same time), respectively) and writing a combined format fold
(trivially) using FoldAlt in terms of previously defined scanf and printf
folds.  It would be also nice if we could combine VarArg with the optional
argument fold (http://mlton.org/OptionalArguments).  I tried to write such
a thing using FoldAlt, but it didn't quite work.  The problem was that
VarArg requires all arguments to have the same type, while OptArg allows
arguments of different type.  Written in terms of an ordinary sum
datatype, FoldAlt was simply not flexible enough.  The combined fold
required all arguments to be of the same type like in VarArg.  To me it
seemed like a fundamental restriction, but Stephen wasn't convinced.

Yet another improvement was an arguably simpler implementation and design
(smaller type expressions, fewer tokens to implement, and perhaps easier
to use) for the Fold01N fold (http://mlton.org/Fold01N).  The difficulty
with Fold01N is that it needs to distinguish between having given no steps
and given one or more steps.  The idea was to use a selector, either #1 or
#2, as a part of folded state (see the NSZ substructure in
http://mlton.org/cgi-bin/viewsvn.cgi/mltonlib/trunk/com/ssh/extended-basis/unstable/detail/fold/fold.sml?rev=6320&view=auto).
This allows the next step or the finisher to flexibly choose what to do.
Stephen then realized that this could be generalized to what ultimately
became StaticSum and was used by Stephen in his Basic library
(http://mlton.org/cgi-bin/viewsvn.cgi/mltonlib/trunk/com/sweeks/basic/unstable/).

Fast forward to present time.

Recently, I added StaticSum to extended basis, because I wanted to use it
to implement the monad pick notation described on the
http://mlton.org/StaticSum page.  A few days ago it occured to me that
static sums might just be what is needed to combine folds more flexibly
than with ordinary sums.  So, I dug up my old fold sandbox and rewrote
FoldAlt using static sums.  Then I wrote an Arg module that combines both
VarArg and OptArg folds.  And... it worked!

Full code is included below.  (Sorry, I don't have time to strip it down
to minimum.)  Here is a simple example.  First a simple variable argument
sum function:

fun sum ? = Arg.foldl op + 0 ?

Then a toy function with optional arguments:

fun foo ? = Arg.opt (D #"1") (D 2) (D 3) (D "four") $
  (fn a & b & c & d => (a, b, c, d)) ?

Now, both can be called by supplying arguments with the same specifier, `:

val 3 = sum `1 `2 $

val (#"1", 2, 3, "four") = foo $
val (#"-", 2, 3, "four") = foo ` #"-" $

Also, partial argument "lists" can be shared

fun four_five ? = pass ? `4 `5

val 9 = sum four_five $

val (#"3", 4, 5, "six") = foo ` #"3" four_five `"six" $

Finally, as promised, numeric and array literals also work with the same
specifiers:

val anArray = ArrayLiteral.array `1`2`3 four_five`6 $
val 0x12345 = Num.I 16 `1`2`3 four_five $

Well, that's about it.

-Vesa Karvonen

type 'a uop = 'a -> 'a
type 'a predicate = 'a -> bool
type 'a thunk = unit -> 'a
type 'a effect = 'a -> unit

infixr 2 |<
fun f |< x = f x

infix   4  <\ \>
infixr  4  </ />
fun op<\ (x, f) y = f (x, y)
fun op\> (f, y) = f y

fun op/> (f, y) x = f (x, y)
fun op</ (x, f) = f x

fun id x = x
fun const x _ = x
fun pass x f = f x
fun fst (x, _) = x
fun snd (_, y) = y
fun swap (x, y) = (y, x)
fun fail m = raise Fail m
fun failing m _ = raise Fail m
fun cross (f, g) (l, r) = (f l, g r)
fun curry f x y = f (x, y)
fun uncurry f (x, y) = f x y

fun opt (f, g) = fn NONE => f () | SOME x => g x
fun isNone ? = not o isSome |< ?

datatype ('a, 'b) alt = INL of 'a | INR of 'b
val outL = fn INL x => x | _ => raise Match
val outR = fn INR x => x | _ => raise Match
fun alt (f, g) = fn INL x => f x | INR y => g y
fun pipe (f, g) = alt (INL o f, INR o g)

infix &
datatype ('a, 'b) product = & of 'a * 'b

structure StaticSum :> sig
   type ('dL, 'cL, 'dR, 'cR, 'c) t
   val inL : 'dL -> ('dL, 'cL, 'dR, 'cR, 'cL) t
   val inR : 'dR -> ('dL, 'cL, 'dR, 'cR, 'cR) t
   val match : ('dL, 'cL, 'dR, 'cR, 'c) t -> ('dL -> 'cL) * ('dR -> 'cR) -> 'c
   val sum : ('dL -> 'cL) * ('dR -> 'cR) -> ('dL, 'cL, 'dR, 'cR, 'c) t -> 'c
   val map : ('a -> 'b) * ('c -> 'd) ->
             ('a, ('b, 'e, 'f, 'g, 'e) t,
              'c, ('h, 'i, 'd, 'j, 'j) t, 'k) t -> 'k
   val out : ('a, 'a, 'b, 'b, 'c) t -> 'c
end = struct
   type ('dL, 'cL, 'dR, 'cR, 'c) t = ('dL -> 'cL) * ('dR -> 'cR) -> 'c
   fun inL x (f, _) = f x
   fun inR x (_, g) = g x
   fun map (f, g) = pass (inL o f, inR o g)
   val match = id
   val sum = pass
   fun out s = s (id, id)
end

structure Fold = struct
   type ('a, 'b, 'c, 'd) step =
        'a * ('b -> 'c) -> 'd
   type ('a, 'b, 'c, 'd) t =
        ('a, 'b, 'c, 'd) step -> 'd
   type ('a, 'b, 'c, 'd, 'e) step0 =
        ('a, 'c, 'd, ('b, 'c, 'd, 'e) t) step
   type ('a, 'b, 'c, 'd, 'e, 'f) step1 =
        ('b, 'd, 'e, 'a -> ('c, 'd, 'e, 'f) t) step
end

signature FOLD = sig
   type ('a, 'b, 'c, 'd) step =
        ('a, 'b, 'c, 'd) Fold.step
   type ('a, 'b, 'c, 'd) t =
        ('a, 'b, 'c, 'd) Fold.t
   type ('a, 'b, 'c, 'd, 'e) step0 =
        ('a, 'b, 'c, 'd, 'e) Fold.step0
   type ('a, 'b, 'c, 'd, 'e, 'f) step1 =
        ('a, 'b, 'c, 'd, 'e, 'f) Fold.step1

   val fold : 'a * ('b -> 'c) -> ('a, 'b, 'c, 'd) t
   val unfold : ('a, 'b, 'c, 'a * ('b -> 'c)) t
                -> 'a * ('b -> 'c)
   val lift : ('a, 'b, 'c, 'a * ('b -> 'c)) t
              -> ('a, 'b, 'c, 'd) t

   val post : ('a -> 'd)
              -> ('b, 'c, 'a, 'b * ('c -> 'a)) t
              -> ('b, 'c, 'd, 'e) t

   val step0 : ('a -> 'b)
               -> ('a, 'b, 'c, 'd, 'e) step0
   val step1 : ('a * 'b -> 'c)
               -> ('a, 'b, 'c, 'd, 'e, 'f) step1

   val unstep0 : ('a, 'b, 'b, 'b, 'b) step0
                 -> 'a -> 'b
   val unstep1 : ('a, 'b, 'c, 'c, 'c, 'c) step1
                 -> 'a * 'b -> 'c

   val lift0 : ('a, 'b, 'b, 'b, 'b) step0
               -> ('a, 'b, 'c, 'd, 'e) step0
   val lift1 : ('a, 'b, 'c, 'c, 'c, 'c) step1
               -> ('a, 'b, 'c, 'd, 'e, 'f) step1
   val lift0to1 : ('b, 'c, 'c, 'c, 'c) step0
                  -> ('a, 'b, 'c, 'd, 'e, 'f) step1
end

fun $ (x, f) = f x

structure Fold :> FOLD = struct
   open Fold

   val fold = pass
   fun unfold f = f id
   fun lift ? = (fold o unfold) ?

   fun post g =
       fold o cross (id, fn f => g o f) o unfold

   fun step0 h (a1, f) = fold (h a1, f)
   fun step1 h (a2, f) a1 = fold (h (a1, a2), f)

   fun unstep0 s a1 = fold (a1, id) s $
   fun unstep1 s (a1, a2) = fold (a2, id) s a1 $

   fun lift0 ? = (step0 o unstep0) ?
   fun lift1 ? = (step1 o unstep1) ?

   fun lift0to1 s = step1 (unstep0 s o snd)
end

structure FoldAlt = struct
   fun fold ? =
       (Fold.fold o
        StaticSum.sum
           (cross (StaticSum.inL, fn f => f o StaticSum.out) o Fold.unfold,
            cross (StaticSum.inR, fn f => f o StaticSum.out) o Fold.unfold)) ?

   fun step0 (l, r) =
       Fold.step0 (StaticSum.map (Fold.unstep0 l,
                                  Fold.unstep0 r))

   fun step1 (l, r) =
       Fold.step1 (StaticSum.map (Fold.unstep1 l,
                                  Fold.unstep1 r)
                   o (fn (a11, a12) =>
                         StaticSum.match a12
                                         (fn a12 => StaticSum.inL (a11, a12),
                                          fn a12 => StaticSum.inR (a11, a12))))
end

structure FoldPair = struct
   type ('a, 'b, 'c, 'd, 'e, 'f) t =
        ('a * 'b, 'c * 'd, 'e, 'f) Fold.t
   type ('a, 'b, 'c, 'd, 'e, 'f, 'g) step0 =
        ('a * 'c, 'b * 'd, 'e, 'f, 'g) Fold.step0
   type ('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h) step1 =
        ('a, 'b * 'd, 'c * 'e, 'f, 'g, 'h) Fold.step1
end

signature FOLD_PAIR = sig
   type ('a, 'b, 'c, 'd, 'e, 'f) t =
        ('a, 'b, 'c, 'd, 'e, 'f) FoldPair.t
   type ('a, 'b, 'c, 'd, 'e, 'f, 'g) step0 =
        ('a, 'b, 'c, 'd, 'e, 'f, 'g) FoldPair.step0
   type ('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h) step1 =
        ('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h) FoldPair.step1

   val fold :
       ('a, 'b, 'c, 'a * ('b -> 'c)) Fold.t
       * ('d, 'e, 'f, 'd * ('e -> 'f)) Fold.t
       -> ('c * 'f -> 'g)
       -> ('a, 'd, 'b, 'e, 'g, 'h) t
   val step0 :
       ('a, 'b, 'b, 'b, 'b) Fold.step0
       * ('c, 'd, 'd, 'd, 'd) Fold.step0
       -> ('a, 'b, 'c, 'd, 'e, 'f, 'g) step0
   val step1 :
       ('a, 'b, 'c, 'c, 'c, 'c) Fold.step1
       * ('a, 'd, 'e, 'e, 'e, 'e) Fold.step1
       -> ('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h) step1
end

structure FoldPair :> FOLD_PAIR = struct
   open FoldPair

   fun fold (l, r) f = let
      val (la, lf) = Fold.unfold l
      val (ra, rf) = Fold.unfold r
   in
      Fold.fold ((la, ra), f o cross (lf, rf))
   end

   fun step0 (l, r) =
       Fold.step0 (cross (Fold.unstep0 l,
                          Fold.unstep0 r))

   fun step1 (l, r) =
       Fold.step1 (cross (Fold.unstep1 l,
                          Fold.unstep1 r)
                   o (fn (a11, (a12l, a12r)) =>
                         ((a11, a12l),
                          (a11, a12r))))
end

structure OptArg = struct
  local
     fun colDefFold ? = Fold.fold (id, pass ()) ?
     fun colDefStep d =
         Fold.step0 (fn f => fn ds => f (d & ds))

     fun mkRevFold ? = Fold.fold (id, pass id) ?
     fun mkRevStep ? =
         Fold.step0
            (fn r =>
                fn f =>
                   fn a & b =>
                      r (fn x => f a & x) b) ?

     fun givenFold ? = Fold.fold (id, id) ?
     fun givenStep ? =
         Fold.step1
            (fn (x, f) => fn d => f (x & d)) ?

     fun restFold d = Fold.fold (d, id)
     fun restStep ? = Fold.step0 (fn _ & d => d) ?
  in
    fun make ? =
        FoldPair.fold
           (colDefFold, mkRevFold)
           (fn (d, r) =>
               fn func =>
                  FoldPair.fold
                     (givenFold, restFold d)
                     (fn (f, d) =>
                         case r (f d) of
                            d & () => func d)) ?

    fun D d = FoldPair.step0
                 (colDefStep d, mkRevStep)
    fun ` ? = FoldPair.step1
                 (givenStep, Fold.lift0to1 restStep) ?
  end
end

val D = OptArg.D

signature VAR_ARG = sig
   type ('p, 's) ac
   type ('p, 's, 'a) t =
        (('p, 's) ac, ('p, 's) ac, 's, 'a) Fold.t

   val foldl : ('p * 's -> 's) -> 's
               -> ('p, 's, 'a) t
   val foldr : ('p * 's -> 's) -> 's
               -> ('p, 's, 'a) t

   val ` : ('p, ('p, 's) ac, ('p, 's) ac,
            'a, 'b, 'c) Fold.step1
end

structure VarArg :> VAR_ARG = struct
   datatype ('p, 's) ac =
            IN of 's uop * ('p * 's uop -> 's uop)
   type ('p, 's, 'a) t =
        (('p, 's) ac, ('p, 's) ac, 's, 'a) Fold.t

   local
      fun out (IN x) = x
      fun make d plus zero =
          Fold.fold
             (IN (id, fn (p, f) =>
                         op o (d (p <\plus, f))),
              pass zero o fst o out)
   in
      fun foldl ? = make id ?
      fun foldr ? = make swap ?
   end

   fun ` ? =
       Fold.step1 (fn (p, IN (s, plus)) =>
                      IN (plus (p, s), plus)) ?
end

structure Arg = struct
   fun opt ? =
       Fold.post (fn f => FoldAlt.fold o StaticSum.inR o f) OptArg.make ?
   fun foldl plus zero = FoldAlt.fold (StaticSum.inL (VarArg.foldl plus zero))
   fun foldr plus zero = FoldAlt.fold (StaticSum.inL (VarArg.foldr plus zero))
   fun ` ? = FoldAlt.step1 (VarArg.`, OptArg.`) ?
end

val ` = Arg.`

structure Num = struct
   fun make op * op + i2x iBase = let
      val xBase = i2x iBase
      val i2x =
          fn i =>
             if 0 <= i andalso i < iBase then
                i2x i
             else
                fail ("Num: " ^ Int.toString i ^
                      " is not a valid digit in \
                      \base " ^ Int.toString iBase)
   in
      Arg.foldl (fn (i, x) => i2x i + x * xBase)
                (i2x 0)
   end

   fun I  ? = make op * op + id ?
   fun LI ? = make op * op + LargeInt.fromInt ?
   fun W  ? = make op * op + Word.fromInt ?

   val a = 10
   val b = 11
   val c = 12
   val d = 13
   val e = 14
   val f = 15
end

structure ArrayLiteral = struct
   datatype 'a ac =
            IN of int * 'a option * 'a array effect

   fun array ? =
       Fold.post
          (fn IN (_, NONE, _) =>
              Array.tabulate (0, failing "array0")
            | IN (n, SOME x, fill) =>
              let val a = Array.array (n, x)
              in fill a ; a
              end)
          (Arg.foldl
              (fn (x, IN (i, _, fill)) =>
                  IN (i+1, SOME x,
                      fn a => (Array.update (a, i, x)
                             ; fill a)))
              (IN (0, NONE, ignore))) ?
end



More information about the MLton-user mailing list