[MLton] Mzton an MzScheme<->MLton FFI

Jens Axel Søgaard jensaxel at soegaard.net
Thu Aug 10 03:58:04 PDT 2006


Vesa Karvonen skrev:
> Quoting Jens Axel Søgaard <jensaxel at soegaard.net>:
> [http://mlton.org/pipermail/mlton/2005-September/028042.html]
>> I have attached something to test.
>>
>>    sh build-test.sh
>>
>> will build the shared library, and
>>
>>    mzscheme -M errortrace -f test.ss
>>
>> will import and test it.
> [...]
> 
> Is there a place where one could download the latest version of
> Mzton or does the attached example (which I haven't yet tried to
> compile) contain the (then current) whole of Mzton?  I tried
> googling for Mzton, but found only a couple of messages on the
> MLton list.

The rest of the messages can be found here:

<http://www.google.com/search?hl=en&lr=&c2coff=1&safe=off&q=+site:mlton.org+jens+axel+mlton>

The main issue of letting Mzton produce shared libraries were
registering roots in Mlton from the outside.

It has been a while since I worked on Mzton, and can't really
remember the state of the code. I tried it briefly today, and
at least the test program works. It source weren't prettified
to be in a releaseable state though. It is also unclear how far
I got in terms of complete set of bindings for the exportable types.

At some point there were discussions on the list on changing the
Mlton FFI, so it might be relevant to know that I used an Mlton
version built august 21st 2005.

Note: At some point I lost the original source, so this is
a rewritten version (I think, I got further in the old version -
at least I had a Mandelbrot program working, the Scheme side
did the drawing, and the ML side did the calculating).


Here is how to run the test program on Linux.

   1. Install Fluet's patch
   2. sh build-test.sh
   3. mzscheme -f test.ss

The output of the test program is:

Welcome to MzScheme version 299.200, Copyright (c) 2004-2005 PLT Scheme, 
Inc.
* Opening library
Before atomicBegin
Before returnToC
* Opening library functions
* Done opening library functions
42
43
A
#(1 #<primitive:ffi:getIntArray> #<primitive:ffi:unregIntArray> #<ctype>)
5
0
1
2
3
4
#(0 1 2 3 4)

42
43
double test
#(0.0 2.0 4.0 6.0)
 >


/Jens Axel

-------------- next part --------------
/home/js/mzton/mlton-shared/mlton/build/bin/mlton\
 -export-header exported.h \
 -shared-library true \
 -cc-opt "-symbolic -shared -fPIC"\
 -link-opt "-shared -Bsymbolic"\
 -codegen c -default-ann 'allowFFI true' -keep g -verbose 1\
 -link-opt --verbose -link-opt -Wl,-Bsymbolic test.cm

# -link-opt -Wl,-export-dynamic 
# ln -s test test.so
-------------- next part --------------
gcc -export-dynamic -ldl -o test-from-c test-from-c.c
/home/js/mzton/mlton-shared/mlton/build/bin/mlton \
  -shared-library true\
  -cc-opt "-symbolic -shared -fPIC"\
  -link-opt "-shared -Bsymbolic"\
  -codegen c -default-ann 'allowFFI true' -keep g -verbose 1\
  -link-opt --verbose -link-opt -Wl,-Bsymbolic test.sml

-------------- next part --------------
int int42 () 
{
  return 42;
}
-------------- next part --------------

Mzton - An MzScheme to MLton FFI  --  Jens Axel Søgaard
-------------------------------------------------------


OVERVIEW
========

The Mzton FFI provides the Scheme programmer a chance to use SML in
stead of C to implement foreign libraries. The FFI is designed in the
same philosophy "stay in the fun world" as the PLT C FFI by Eli
Barzilay. Thus as many low-level technicalities as possible is
concealed from the user of the FFI.

Mzton allows Scheme code to call functions written in ML. Conversions
between Scheme and ML values are transparent. ML arrays, vectors and
references returned from ML to Scheme are automatically garbage
collected (on both sides). It is possible to pass a callback (a Scheme
function) to an ML function.


EXAMPLE
=======

To give a first impression of the usage of this FFI, let's for a
moment consider the implementation of a small library for calculations
with n-dimensional vectors. Since we want to allow mutation, we will
represent an n-dimensional vector as an array of reals in ML. 

The library consists of the following three functions:

  dotVec       : (real array) * (real array) -> real
  makeUnitVec  : int * int -> real array
  modifyVec    : (real array) * (real -> real) -> real array

The dot product is calculated by the following ML function:

    fun dotVec (A,B) = Array.foldli (fn (i, a, x) => (Array.sub(B,i)*a+x)) 0.0 A;

The type of dotVec is  (real array) * (real array) -> real  so, since
the return value has type real, which is a base type, one exports it as:

   val _ = _export "dotVec" : ( (real array)*(real array) -> real ) -> unit; dotVec;

In the Scheme program it is imported by get-ffi-obj.

    (define dot-vec 
      (get-ffi-obj "dotVec" lib (_fun $RealArray $RealArray -> $real) signal-error))

Here "dotVec" is the name that occurs in the above _export. The
variable lib holds information about the shared library in question. 


If an ML function allocates an array, vector or reference which is to
be returned to the Scheme side, it is important to register it with
the MLton garbage collector. The FFI provides return-functions for
that purpose. 

The function makeUnitVec returns a newly allocated  real array
representing the i'th basis vector of R^n.

    fun makeUnitVec (n,i) = Array.tabulate (n, fn j => if i=j then 1.0 else 0.0);  

The type of this function is  int*int -> real array  so to export it,
one writes

    val _ = _export "makeUnitVec" : ( int*int -> real array ) -> unit; (RealArrayRoot.return o makeUnitVec);

The important thing to note here is that we are not exporting
makeUnitVec, but (RealArrayRoot.return o makeUnitVec). The function
RealArrayRoot.return is nothing but the identity function, but it has
the side effect of registering the returned array with the MLton
garbage collector. When the Scheme object representing the returned ML
array is garbage collected, the Scheme garbage collector will
automatically tell the MLton garbage collector that the array has
become garbage.

To import the function in the Scheme, one writes:

    (define make-unit-vec 
      (get-ffi-obj "makeUnitVec" lib (_fun $int $int -> $RealArray) signal-error))

Given these functions we can now experiment in the REPL.

  > (make-unit-vec 2 1)
  #4(4 #<primitive:ffi:getRealArray> #<primitive:ffi:unregRealArray> #<ctype>)
  > (define v (make-unit-vec 2 1))
  > (define w (make-unit-vec 2 2))
  > (dot-vec v v)
  1.0
  > (dot-vec v w)
  0.0
  > (dot-vec w w)
  1.0

The Scheme side of Mzton provides primitives to work with ML
arrays. We can reference entries in an ML array: 

  > (ml-array-ref v 0)
  1.0
  > (ml-array-ref v 1)
  0.0

We can convert an ML array to a Scheme vector:

  > (ml-array->vector v)
  #2(1.0 0.0)

We can alter entries in an ML array:

  > (ml-array-set! v 0 3.0)
  > (dot-vec v v)
  9.0

This is even safe, since Scheme checks the types!

  > (ml-array-set! v 0 "foo")
  Scheme->C: expects argument of type <double>; given "foo"


The third function allows to demonstrate how to pass a Scheme function
to the ML side. The function modifyVec (A,f) will replace each entry x
in A with f(x) and return nothing.

    fun modifyVec (A,f) = Array.modify f A;

One would expect that this could be exported as 

    val _ = _export "modifyVec" : ( (real array)*(real -> real) -> unit ) -> unit; makeUnitVec;

but, alas, the current MLton C FFI does not allow real->real as the
type of an argument. However MLton can make indirect function calls. 

    fun exportedModifyVec (A,pointerToF)
      = let val f = (_import * : MLton.Pointer.t -> real -> real;) pointerToF
        in 
          modifyVec (A,f) 
        end;

    val _ = _export "modifyVec" : ( (real array)*MLton.Pointer.t ) -> unit; exportedModiVec;

The import in Scheme uses the type  (real array)*(real -> real) -> unit 

  (define modify-vec!   
    (get-ffi-obj "modifyVec" lib (_fun $RealArray (_fun $real -> $real) -> $unit) signal-error))

Let us try it in the REPL:


  > (define v (make-unit-vec 2 1))
  > (ml-array->vector v)
  #2(1.0 0.0)
  > (modify-vec! v (lambda (x) (+ (* 2.0 x) 3.0)))
  > (ml-array->vector v)
  #2(5.0 3.0)


WRITING AN MLTON SHARED LIBRARY
===============================


EXPORTABLE TYPES
----------------

The current MLton FFI allows values of certain types to be
exported. The first exportable types are the base types:


   SIMPLE BASE TYPES
 --------------------------
   SML             Scheme

   bool            $bool 
   char            $char
   int             $int
   Int8.int        $int8
   Int16.int       $int16
   Int32.int       $int32
   Int64.int       $int64
   MLton.Pointer.t $pointer (for c-pointers) 
   real            $real
   Real32.real     $real32
   Real64.real     $real64
   string          $string  (read only)
   word            $word
   Word8.word      $word8
   Word16.word     $word16
   Word32.word     $word32
   Word64.word     $word64

   array           $array   (not recommended)
   ref             $ref     (not recommended)
   vector          $vector  (not recommended, read only)


The exportable compound types are arrays, vectors and references to
base types. This means that the types "int array" and "real ref" are
exportable. The type "(int array) array" is not exportable, but "array
array" is. 

         COMPOUND TYPES
 -------------------------------
  'base array      $BaseArray
  'base vector     $BaseVector
  'base ref        $BaseRef


Furthermore functions between from a direct product of exportable,
non-function types to a non-function exportable type are also
exportable.


                             FUNCTIONS 
 ------------------------------------------------------------------
  'from1 * ... * 'fromn  -> 'to    ($fun $from1 ... $fromn -> $to)


-------------- next part --------------
/* Copyright (C) 2004-2005 Henry Cejtin, Matthew Fluet, Suresh
 *    Jagannathan, and Stephen Weeks.
 *
 * MLton is released under a BSD-style license.
 * See the file MLton-LICENSE for details.
 */

/* Can't use _TYPES_H_ because MSVCRT uses it.  So, we use _MLTON_TYPES_H_. */

#ifndef _MLTON_TYPES_H_
#define _MLTON_TYPES_H_

/* We need these because in header files for exported SML functions, types.h is
 * included without platform.h.
 */
#ifndef _ISOC99_SOURCE
#define _ISOC99_SOURCE
#endif
#if (defined (__OpenBSD__))
#include <inttypes.h>
#elif (defined (__sun__))
#include <sys/int_types.h>
#else
#include <stdint.h>
#endif

typedef int8_t Int8;
typedef int16_t Int16;
typedef int32_t Int32;
typedef int64_t Int64;
typedef char *Pointer;
typedef Pointer pointer;
typedef float Real32;
typedef double Real64;
typedef uint8_t Word8;
typedef uint16_t Word16;
typedef uint32_t Word32;
typedef uint64_t Word64;

typedef Int8 WordS8;
typedef Int16 WordS16;
typedef Int32 WordS32;
typedef Int64 WordS64;

typedef Word8 WordU8;
typedef Word16 WordU16;
typedef Word32 WordU32;
typedef Word64 WordU64;

/* !!! this stuff is all wrong: */
typedef Int32 Int;
typedef Real64 Real;
typedef Word8 Char;
typedef Word32 Word;
typedef Int64 Position;

typedef Int Bool;
typedef Word Cpointer;
typedef Word Cstring;
typedef Word CstringArray;
typedef Word Dirstream;
typedef Int Fd;
typedef Word Flag;
typedef Word Gid;
typedef Word Mode;
typedef Word NullString;
typedef Int Pid;
typedef Int Resource;
typedef Word Rlimit;
typedef Int Signal;
typedef Int Size;
typedef Int Speed;
typedef Int Ssize;
typedef Int Status;
typedef Int Syserror;
typedef Pointer Thread;
typedef Word Uid;

#endif /* _MLTON_TYPES_H_ */

Int32 getCurrentRootId ();
Pointer getBoolArray (Int32 x0);
void unregBoolArray (Int32 x0);
Int32 regBoolArray (Pointer x0);
Pointer getCharArray (Int32 x0);
void unregCharArray (Int32 x0);
Int32 regCharArray (Pointer x0);
Pointer getInt8Array (Int32 x0);
void unregInt8Array (Int32 x0);
Int32 regInt8Array (Pointer x0);
Pointer getInt16Array (Int32 x0);
void unregInt16Array (Int32 x0);
Int32 regInt16Array (Pointer x0);
Pointer getInt32Array (Int32 x0);
void unregInt32Array (Int32 x0);
Int32 regInt32Array (Pointer x0);
Pointer getInt64Array (Int32 x0);
void unregInt64Array (Int32 x0);
Int32 regInt64Array (Pointer x0);
Pointer getIntArray (Int32 x0);
void unregIntArray (Int32 x0);
Int32 regIntArray (Pointer x0);
Pointer getPointerArray (Int32 x0);
void unregPointerArray (Int32 x0);
Int32 regPointerArray (Pointer x0);
Pointer getReal32Array (Int32 x0);
void unregReal32Array (Int32 x0);
Int32 regReal32Array (Pointer x0);
Pointer getReal64Array (Int32 x0);
void unregReal64Array (Int32 x0);
Int32 regReal64Array (Pointer x0);
Pointer getRealArray (Int32 x0);
void unregRealArray (Int32 x0);
Int32 regRealArray (Pointer x0);
Pointer getWord8Array (Int32 x0);
void unregWord8Array (Int32 x0);
Int32 regWord8Array (Pointer x0);
Pointer getWord16Array (Int32 x0);
void unregWord16Array (Int32 x0);
Int32 regWord16Array (Pointer x0);
Pointer getWord32Array (Int32 x0);
void unregWord32Array (Int32 x0);
Int32 regWord32Array (Pointer x0);
Pointer getWord64Array (Int32 x0);
void unregWord64Array (Int32 x0);
Int32 regWord64Array (Pointer x0);
Pointer getWordArray (Int32 x0);
void unregWordArray (Int32 x0);
Int32 regWordArray (Pointer x0);
Pointer getStringArray (Int32 x0);
void unregStringArray (Int32 x0);
Int32 regStringArray (Pointer x0);
Pointer getBoolRef (Int32 x0);
void unregBoolRef (Int32 x0);
Int32 regBoolRef (Pointer x0);
Pointer getCharRef (Int32 x0);
void unregCharRef (Int32 x0);
Int32 regCharRef (Pointer x0);
Pointer getInt8Ref (Int32 x0);
void unregInt8Ref (Int32 x0);
Int32 regInt8Ref (Pointer x0);
Pointer getInt16Ref (Int32 x0);
void unregInt16Ref (Int32 x0);
Int32 regInt16Ref (Pointer x0);
Pointer getInt32Ref (Int32 x0);
void unregInt32Ref (Int32 x0);
Int32 regInt32Ref (Pointer x0);
Pointer getInt64Ref (Int32 x0);
void unregInt64Ref (Int32 x0);
Int32 regInt64Ref (Pointer x0);
Pointer getIntRef (Int32 x0);
void unregIntRef (Int32 x0);
Int32 regIntRef (Pointer x0);
Pointer getPointerRef (Int32 x0);
void unregPointerRef (Int32 x0);
Int32 regPointerRef (Pointer x0);
Pointer getReal32Ref (Int32 x0);
void unregReal32Ref (Int32 x0);
Int32 regReal32Ref (Pointer x0);
Pointer getReal64Ref (Int32 x0);
void unregReal64Ref (Int32 x0);
Int32 regReal64Ref (Pointer x0);
Pointer getRealRef (Int32 x0);
void unregRealRef (Int32 x0);
Int32 regRealRef (Pointer x0);
Pointer getWord8Ref (Int32 x0);
void unregWord8Ref (Int32 x0);
Int32 regWord8Ref (Pointer x0);
Pointer getWord16Ref (Int32 x0);
void unregWord16Ref (Int32 x0);
Int32 regWord16Ref (Pointer x0);
Pointer getWord32Ref (Int32 x0);
void unregWord32Ref (Int32 x0);
Int32 regWord32Ref (Pointer x0);
Pointer getWord64Ref (Int32 x0);
void unregWord64Ref (Int32 x0);
Int32 regWord64Ref (Pointer x0);
Pointer getWordRef (Int32 x0);
void unregWordRef (Int32 x0);
Int32 regWordRef (Pointer x0);
Pointer getStringRef (Int32 x0);
void unregStringRef (Int32 x0);
Int32 regStringRef (Pointer x0);
Pointer getBoolVector (Int32 x0);
void unregBoolVector (Int32 x0);
Int32 regBoolVector (Pointer x0);
Pointer getCharVector (Int32 x0);
void unregCharVector (Int32 x0);
Int32 regCharVector (Pointer x0);
Pointer getInt8Vector (Int32 x0);
void unregInt8Vector (Int32 x0);
Int32 regInt8Vector (Pointer x0);
Pointer getInt16Vector (Int32 x0);
void unregInt16Vector (Int32 x0);
Int32 regInt16Vector (Pointer x0);
Pointer getInt32Vector (Int32 x0);
void unregInt32Vector (Int32 x0);
Int32 regInt32Vector (Pointer x0);
Pointer getInt64Vector (Int32 x0);
void unregInt64Vector (Int32 x0);
Int32 regInt64Vector (Pointer x0);
Pointer getIntVector (Int32 x0);
void unregIntVector (Int32 x0);
Int32 regIntVector (Pointer x0);
Pointer getPointerVector (Int32 x0);
void unregPointerVector (Int32 x0);
Int32 regPointerVector (Pointer x0);
Pointer getReal32Vector (Int32 x0);
void unregReal32Vector (Int32 x0);
Int32 regReal32Vector (Pointer x0);
Pointer getReal64Vector (Int32 x0);
void unregReal64Vector (Int32 x0);
Int32 regReal64Vector (Pointer x0);
Pointer getRealVector (Int32 x0);
void unregRealVector (Int32 x0);
Int32 regRealVector (Pointer x0);
Pointer getWord8Vector (Int32 x0);
void unregWord8Vector (Int32 x0);
Int32 regWord8Vector (Pointer x0);
Pointer getWord16Vector (Int32 x0);
void unregWord16Vector (Int32 x0);
Int32 regWord16Vector (Pointer x0);
Pointer getWord32Vector (Int32 x0);
void unregWord32Vector (Int32 x0);
Int32 regWord32Vector (Pointer x0);
Pointer getWord64Vector (Int32 x0);
void unregWord64Vector (Int32 x0);
Int32 regWord64Vector (Pointer x0);
Pointer getWordVector (Int32 x0);
void unregWordVector (Int32 x0);
Int32 regWordVector (Pointer x0);
Pointer getStringVector (Int32 x0);
void unregStringVector (Int32 x0);
Int32 regStringVector (Pointer x0);
Pointer makeBoolArray (Int32 x0, Int32 x1);
Pointer makeBoolVector (Int32 x0, Int32 x1);
Pointer makeBoolRef (Int32 x0);
Pointer makeCharArray (Int32 x0, Int8 x1);
Pointer makeCharVector (Int32 x0, Int8 x1);
Pointer makeCharRef (Int8 x0);
Pointer makeInt8Array (Int32 x0, Int8 x1);
Pointer makeInt8Vector (Int32 x0, Int8 x1);
Pointer makeInt8Ref (Int8 x0);
Pointer makeInt16Array (Int32 x0, Int16 x1);
Pointer makeInt16Vector (Int32 x0, Int16 x1);
Pointer makeInt16Ref (Int16 x0);
Pointer makeInt32Array (Int32 x0, Int32 x1);
Pointer makeInt32Vector (Int32 x0, Int32 x1);
Pointer makeInt32Ref (Int32 x0);
Pointer makeInt64Array (Int32 x0, Int64 x1);
Pointer makeInt64Vector (Int32 x0, Int64 x1);
Pointer makeInt64Ref (Int64 x0);
Pointer makeIntArray (Int32 x0, Int32 x1);
Pointer makeIntVector (Int32 x0, Int32 x1);
Pointer makeIntRef (Int32 x0);
Pointer makePointerArray (Int32 x0, Pointer x1);
Pointer makePointerVector (Int32 x0, Pointer x1);
Pointer makePointerRef (Pointer x0);
Pointer makeReal32Array (Int32 x0, Real32 x1);
Pointer makeReal32Vector (Int32 x0, Real32 x1);
Pointer makeReal32Ref (Real32 x0);
Pointer makeReal64Array (Int32 x0, Real64 x1);
Pointer makeReal64Vector (Int32 x0, Real64 x1);
Pointer makeReal64Ref (Real64 x0);
Pointer makeRealArray (Int32 x0, Real64 x1);
Pointer makeRealVector (Int32 x0, Real64 x1);
Pointer makeRealRef (Real64 x0);
Pointer makeWord8Array (Int32 x0, Word8 x1);
Pointer makeWord8Vector (Int32 x0, Word8 x1);
Pointer makeWord8Ref (Word8 x0);
Pointer makeWord16Array (Int32 x0, Word16 x1);
Pointer makeWord16Vector (Int32 x0, Word16 x1);
Pointer makeWord16Ref (Word16 x0);
Pointer makeWord32Array (Int32 x0, Word32 x1);
Pointer makeWord32Vector (Int32 x0, Word32 x1);
Pointer makeWord32Ref (Word32 x0);
Pointer makeWord64Array (Int32 x0, Word64 x1);
Pointer makeWord64Vector (Int32 x0, Word64 x1);
Pointer makeWord64Ref (Word64 x0);
Pointer makeWordArray (Int32 x0, Word32 x1);
Pointer makeWordVector (Int32 x0, Word32 x1);
Pointer makeWordRef (Word32 x0);
Pointer makeStringArray (Int32 x0, Pointer x1);
Pointer makeStringVector (Int32 x0, Pointer x1);
Pointer makeStringRef (Pointer x0);
Int32 testIntToInt (Int32 x0);
Int32 testUnitToInt ();
Int8 testUnitToChar ();
Int32 testUnitToBool ();
Pointer testUnitToIntArray ();
Pointer testUnitToIntRef ();
Int32 iterate (Real64 x0, Real64 x1);
Pointer iterateLine (Real64 x0, Real64 x1, Real64 x2, Real64 x3);
void doubleArray (Pointer x0, Pointer x1);
Pointer makeUnitVec (Int32 x0, Int32 x1);
Pointer sumVec (Pointer x0, Pointer x1);
Pointer scaleVec (Pointer x0, Real64 x1);
Real64 dotVec (Pointer x0, Pointer x1);
void clearVec (Pointer x0);
void modifyVec (Pointer x0, Pointer x1);
-------------- next part --------------
#;
(define-base-types
  (; ml-ffi   C-FFI    SML type        C typedef   C type         Scheme->C     C->Scheme
   ($bool     _bool     bool             Int32     "long")
   ($char     _int8     char             Int8      "char"       char->integer  integer->char)
   ($int8     _int8     Int8.int         Int8      "char")
   ($int16    _int16    Int16.int        Int16     "short")
   ($int32    _int32    Int32.int        Int32     "long")
   ($int64    _int64    Int64.int        Int64     "long long")
   ($int      _int32    int              Int32     "long")
   ($pointer  _pointer  MLton.Pointer.t  Pointer   "char *")
   ($real32   _float    Real32.real      Real32    "float")
   ($real64   _double   Real64.real      Real64    "double")
   ($real     _double   real             Real64    "double")
   ($word8    _uint8    Word8.word       Word8     "unsigned char")
   ($word16   _uint16   Word16.word      Word16    "unsigned short")
   ($word32   _uint32   Word32.word      Word32    "unsigned long")
   ($word64   _uint64   Word64.word      Word64    "unsigned long")
   ($word     _uint32   word             Word32    "unsigned int")
   
   ($string   _pointer  string           Pointer   "char *")           ; READ ONLY
   ($vector   _pointer  vector           Pointer   "char *")           ; READ ONLY
   ($array    _pointer  array            Pointer   "char *")
   ($ref      _pointer  ref              Pointer   "char *")))

(define full/short-names 
  '((bool            bool)
    (char            char)
    (Int8.int        Int8)
    (Int16.int       Int16)
    (Int32.int       Int32)
    (Int64.int       Int64)
    (int             int)
    (MLton.Pointer.t Pointer)
    (Real32.real     Real32)
    (Real64.real     Real64)
    (real            real)
    (Word8.word      Word8)
    (Word16.word     Word16)
    (Word32.word     Word32)
    (Word64.word     Word64)
    (word            word)
    (string          string)
    
    ; TODO: MLton allows exports of say int vector array, but
    ;       how to register such a root in the ML code ? 

    ;(vector          vector)
    ;(array           array)
    ;(ref             ref))
  ))

(define short-names
  (map second full/short-names))

(define (short-name->full-name s)
  (let ((a (assoc s (map reverse full/short-names))))
    (if a
        (second a)
        (error #f "huh?" s))))

(require (lib "13.ss" "srfi"))

(define-struct compound-description 
  (name type make-constructor))

(define (generate base compound)
  (let* ((Base          (string-titlecase (symbol->string base)))
         (Compound      (string-titlecase (symbol->string compound)))
         (BaseCompound  (format "~a~a" Base Compound))
         (type          (format "~a ~a" (short-name->full-name base) compound)))
    (string-append
     (format "structure ~aRoot = Root(struct type t = ~a end);\n"                   BaseCompound type)
     (format "val _ = _export \"get~a\":   (int -> ~a)   -> unit; ~aRoot.get;\n"   BaseCompound type BaseCompound)
     (format "val _ = _export \"unreg~a\": (int -> unit) -> unit; ~aRoot.unreg;\n" BaseCompound BaseCompound)
     (format "val _ = _export \"reg~a\" :  (~a  -> int)  -> unit; ~aRoot.reg;\n"   BaseCompound type BaseCompound)
     )))

(define (generate-array-allocator base)
  (let* ((Base          (string-titlecase (symbol->string base)))
         (Compound      (string-titlecase (symbol->string 'array)))
         (BaseCompound  (format "~a~a" Base Compound)))
    (string-append
     (format "fun make~A (n, fill) = Array.array (n, fill);\n" BaseCompound)
     (format "val _ = _export \"make~a\" : (int*~a -> ~a Array.array) -> unit; (~aRoot.return o make~a);\n"
             BaseCompound (short-name->full-name base) (short-name->full-name base) BaseCompound BaseCompound))))

(define (generate-vector-allocator base)
  (let* ((Base          (string-titlecase (symbol->string base)))
         (Compound      (string-titlecase (symbol->string 'vector)))
         (BaseCompound  (format "~a~a" Base Compound)))
    (string-append
     (format "fun make~A (n, fill) = Array.vector (Array.array (n, fill));\n" BaseCompound)
     (format "val _ = _export \"make~a\" : (int*~a -> ~a Vector.vector) -> unit; (~aRoot.return o make~a);\n"
             BaseCompound (short-name->full-name base) (short-name->full-name base) BaseCompound BaseCompound))))

(define (generate-ref-allocator base)
  (let* ((Base          (string-titlecase (symbol->string base)))
         (Compound      (string-titlecase (symbol->string 'ref)))
         (BaseCompound  (format "~a~a" Base Compound)))
    (string-append
     (format "fun make~A v = ref v;\n" BaseCompound)
     (format "val _ = _export \"make~a\" : (~a -> ~a ref) -> unit; (~aRoot.return o make~a);\n"
             BaseCompound (short-name->full-name base) (short-name->full-name base) BaseCompound BaseCompound))))

(define (emit)
  (map (lambda (c)
         (map (lambda (b) (display (generate b c)))
              short-names))
       '(array ref vector))
  (map (lambda (b) 
         (display (generate-array-allocator b))
         (display (generate-vector-allocator b))
         (display (generate-ref-allocator b)))
       short-names)
  (void))

(emit)

-------------- next part --------------
(* lib-base-sig.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *)

signature LIB_BASE =
  sig

    exception Unimplemented of string
	(* raised to report unimplemented features *)
    exception Impossible of string
	(* raised to report internal errors *)

    exception NotFound
	(* raised by searching operations *)

    val failure : {module : string, func : string, msg : string} -> 'a
	(* raise the exception Fail with a standard format message. *)

    val version : {date : string, system : string, version_id : int list}
    val banner : string

  end (* LIB_BASE *)


(* lib-base.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *)

structure LibBase : LIB_BASE =
  struct

  (* raised to report unimplemented features *)
    exception Unimplemented of string

  (* raised to report internal errors *)
    exception Impossible of string

  (* raised by searching operations *)
    exception NotFound

  (* raise the exception Fail with a standard format message. *)
    fun failure {module, func, msg} =
	  raise (Fail(concat[module, ".", func, ": ", msg]))

    val version = {
	    date = "June 1, 1996", 
	    system = "SML/NJ Library",
	    version_id = [1, 0]
	  }

    fun f ([], l) = l
      | f ([x : int], l) = (Int.toString x)::l
      | f (x::r, l) = (Int.toString x) :: "." :: f(r, l)

    val banner = concat (
	    #system version :: ", Version " ::
	    f (#version_id version, [", ", #date version]))

  end (* LibBase *)


(* splaytree-sig.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Signature for a splay tree data structure.
 *
 *)

signature SPLAY_TREE = 
  sig
    datatype 'a splay = 
      SplayObj of {
        value : 'a,
        right : 'a splay,
        left : 'a splay
      }
    | SplayNil


    val splay : (('a -> order) * 'a splay) -> (order * 'a splay)
      (* (r,tree') = splay (cmp,tree) 
       * where tree' is tree adjusted using the comparison function cmp
       * and, if tree' = SplayObj{value,...}, r = cmp value.
       * tree' = SplayNil iff tree = SplayNil, in which case r is undefined.
       *)

    val join : 'a splay * 'a splay -> 'a splay
      (* join(t,t') returns a new splay tree formed of t and t'
       *)

  end (* SPLAY_TREE *)


(* ord-key-sig.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Abstract linearly ordered keys.
 *
 *)

signature ORD_KEY =
  sig
    type ord_key

    val compare : ord_key * ord_key -> order

  end (* ORD_KEY *)


(* ord-map-sig.sml
 *
 * COPYRIGHT (c) 1996 by AT&T Research.  See COPYRIGHT file for details.
 *
 * Abstract signature of an applicative-style finite maps (dictionaries)
 * structure over ordered monomorphic keys.
 *)

signature ORD_MAP =
  sig

    structure Key : ORD_KEY

    type 'a map

    val empty : 'a map
	(* The empty map *)

    val insert  : 'a map * Key.ord_key * 'a -> 'a map
    val insert' : ((Key.ord_key * 'a) * 'a map) -> 'a map
	(* Insert an item. *)

    val find : 'a map * Key.ord_key -> 'a option
	(* Look for an item, return NONE if the item doesn't exist *)

    val remove : 'a map * Key.ord_key -> 'a map * 'a
	(* Remove an item, returning new map and value removed.
         * Raises LibBase.NotFound if not found.
	 *)

    val numItems : 'a map ->  int
	(* Return the number of items in the map *)

    val listItems  : 'a map -> 'a list
    val listItemsi : 'a map -> (Key.ord_key * 'a) list
	(* Return an ordered list of the items (and their keys) in the map.
         *)

    val collate : ('a * 'a -> order) -> ('a map * 'a map) -> order
	(* given an ordering on the map's range, return an ordering
	 * on the map.
	 *)

    val unionWith  : ('a * 'a -> 'a) -> ('a map * 'a map) -> 'a map
    val unionWithi : (Key.ord_key * 'a * 'a -> 'a) -> ('a map * 'a map) -> 'a map
	(* return a map whose domain is the union of the domains of the two input
	 * maps, using the supplied function to define the map on elements that
	 * are in both domains.
	 *)

    val intersectWith  : ('a * 'b -> 'c) -> ('a map * 'b map) -> 'c map
    val intersectWithi : (Key.ord_key * 'a * 'b -> 'c) -> ('a map * 'b map) -> 'c map
	(* return a map whose domain is the intersection of the domains of the
	 * two input maps, using the supplied function to define the range.
	 *)

    val app  : ('a -> unit) -> 'a map -> unit
    val appi : ((Key.ord_key * 'a) -> unit) -> 'a map -> unit
	(* Apply a function to the entries of the map in map order. *)

    val map  : ('a -> 'b) -> 'a map -> 'b map
    val mapi : (Key.ord_key * 'a -> 'b) -> 'a map -> 'b map
	(* Create a new map by applying a map function to the
         * name/value pairs in the map.
         *)

    val foldl  : ('a * 'b -> 'b) -> 'b -> 'a map -> 'b
    val foldli : (Key.ord_key * 'a * 'b -> 'b) -> 'b -> 'a map -> 'b
	(* Apply a folding function to the entries of the map
         * in increasing map order.
         *)

    val foldr  : ('a * 'b -> 'b) -> 'b -> 'a map -> 'b
    val foldri : (Key.ord_key * 'a * 'b -> 'b) -> 'b -> 'a map -> 'b
	(* Apply a folding function to the entries of the map
         * in decreasing map order.
         *)

    val filter  : ('a -> bool) -> 'a map -> 'a map
    val filteri : (Key.ord_key * 'a -> bool) -> 'a map -> 'a map
	(* Filter out those elements of the map that do not satisfy the
	 * predicate.  The filtering is done in increasing map order.
	 *)

    val mapPartial  : ('a -> 'b option) -> 'a map -> 'b map
    val mapPartiali : (Key.ord_key * 'a -> 'b option) -> 'a map -> 'b map
	(* map a partial function over the elements of a map in increasing
	 * map order.
	 *)

  end (* ORD_MAP *)


(* ordset-sig.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Signature for a set of values with an order relation.
 *)

signature ORD_SET =
  sig

    structure Key : ORD_KEY

    type item = Key.ord_key
    type set

    val empty : set
	(* The empty set *)

    val singleton : item -> set
	(* Create a singleton set *)

    val add  : set * item -> set
    val add' : (item * set) -> set
	(* Insert an item. *)

    val addList : set * item list -> set
	(* Insert items from list. *)

    val delete : set * item -> set
	(* Remove an item. Raise NotFound if not found. *)

    val member : set * item -> bool
	(* Return true if and only if item is an element in the set *)

    val isEmpty : set -> bool
	(* Return true if and only if the set is empty *)

    val equal : (set * set) -> bool
	(* Return true if and only if the two sets are equal *)

    val compare : (set * set) -> order
	(* does a lexical comparison of two sets *)

    val isSubset : (set * set) -> bool
	(* Return true if and only if the first set is a subset of the second *)

    val numItems : set ->  int
	(* Return the number of items in the table *)

    val listItems : set -> item list
	(* Return an ordered list of the items in the set *)

    val union : set * set -> set
        (* Union *)

    val intersection : set * set -> set
        (* Intersection *)

    val difference : set * set -> set
        (* Difference *)

    val map : (item -> item) -> set -> set
	(* Create a new set by applying a map function to the elements
	 * of the set.
         *)
     
    val app : (item -> unit) -> set -> unit
	(* Apply a function to the entries of the set 
         * in decreasing order
         *)

    val foldl : (item * 'b -> 'b) -> 'b -> set -> 'b
	(* Apply a folding function to the entries of the set 
         * in increasing order
         *)

    val foldr : (item * 'b -> 'b) -> 'b -> set -> 'b
	(* Apply a folding function to the entries of the set 
         * in decreasing order
         *)

    val filter : (item -> bool) -> set -> set

    val exists : (item -> bool) -> set -> bool

    val find : (item -> bool) -> set -> item option

  end (* ORD_SET *)



(* splaytree.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Splay tree structure.
 *
 *)

structure SplayTree : SPLAY_TREE = 
  struct

    datatype 'a splay = 
      SplayObj of {
        value : 'a,
        right : 'a splay,
        left : 'a splay
      }
    | SplayNil

    datatype 'a ans_t = No | Eq of 'a | Lt of 'a | Gt of 'a

    fun splay (compf, root) = let
        fun adj SplayNil = (No,SplayNil,SplayNil)
          | adj (arg as SplayObj{value,left,right}) =
              (case compf value of
                EQUAL => (Eq value, left, right)
              | GREATER =>
                  (case left of
                    SplayNil => (Gt value,SplayNil,right)
                  | SplayObj{value=value',left=left',right=right'} =>
                      (case compf value' of
                        EQUAL => (Eq value',left',
                                    SplayObj{value=value,left=right',right=right})
                      | GREATER =>
                          (case left' of 
                            SplayNil => (Gt value',left',SplayObj{value=value,left=right',right=right})
                          | _ => 
                            let val (V,L,R) = adj left'
                                val rchild = SplayObj{value=value,left=right',right=right}
                            in
                              (V,L,SplayObj{value=value',left=R,right=rchild})
                            end
                          ) (* end case *)
                      | _ =>
                          (case right' of 
                            SplayNil => (Lt value',left',SplayObj{value=value,left=right',right=right})
                          | _ =>
                            let val (V,L,R) = adj right'
                                 val rchild = SplayObj{value=value,left=R,right=right}
                                 val lchild = SplayObj{value=value',left=left',right=L}
                            in
                              (V,lchild,rchild)
                            end
                          ) (* end case *)
                      ) (* end case *)
                  ) (* end case *)
              | _ =>
                 (case right of
                   SplayNil => (Lt value,left,SplayNil)
                 | SplayObj{value=value',left=left',right=right'} =>
                     (case compf value' of
                       EQUAL =>
                         (Eq value',SplayObj{value=value,left=left,right=left'},right')
                     | LESS =>
                         (case right' of
                           SplayNil => (Lt value',SplayObj{value=value,left=left,right=left'},right')
                         | _ =>
                           let val (V,L,R) = adj right'
                               val lchild = SplayObj{value=value,left=left,right=left'}
                           in
                             (V,SplayObj{value=value',left=lchild,right=L},R)
                           end
                         ) (* end case *)
                     | _ =>
                         (case left' of
                           SplayNil => (Gt value',SplayObj{value=value,left=left,right=left'},right')
                         | _ =>
                           let val (V,L,R) = adj left'
                               val rchild = SplayObj{value=value',left=R,right=right'}
                               val lchild = SplayObj{value=value,left=left,right=L}
                           in
                             (V,lchild,rchild)
                           end
                         ) (* end case *)
                     ) (* end case *)
                 ) (* end case *)
              ) (* end case *)
      in
        case adj root of
          (No,_,_) => (GREATER,SplayNil)
        | (Eq v,l,r) => (EQUAL,SplayObj{value=v,left=l,right=r})
        | (Lt v,l,r) => (LESS,SplayObj{value=v,left=l,right=r})
        | (Gt v,l,r) => (GREATER,SplayObj{value=v,left=l,right=r})
      end

    fun lrotate SplayNil = SplayNil
      | lrotate (arg as SplayObj{value,left,right=SplayNil}) = arg
      | lrotate (SplayObj{value,left,right=SplayObj{value=v,left=l,right=r}}) = 
          lrotate (SplayObj{value=v,left=SplayObj{value=value,left=left,right=l},right=r})

    fun join (SplayNil,SplayNil) = SplayNil
      | join (SplayNil,t) = t
      | join (t,SplayNil) = t
      | join (l,r) =
          case lrotate l of
            SplayNil => r      (* impossible as l is not SplayNil *)
          | SplayObj{value,left,right} => SplayObj{value=value,left=left,right=r}

  end (* SplayTree *)



(* splay-set-fn.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Functor implementing ordered sets using splay trees.
 *
 *)

functor SplaySetFn (K : ORD_KEY) : ORD_SET =
  struct
    structure Key = K
    open SplayTree

    type item = K.ord_key
  
    datatype set = 
        EMPTY
      | SET of {
        root : item splay ref,
        nobj : int
      }

    fun cmpf k = fn k' => K.compare(k',k)

    val empty = EMPTY

    fun singleton v = SET{root = ref(SplayObj{value=v,left=SplayNil,right=SplayNil}),nobj=1}
    
	(* Primitive insertion.
	 *)
    fun insert (v,(nobj,root)) =
          case splay (cmpf v, root) of
            (EQUAL,SplayObj{value,left,right}) => 
              (nobj,SplayObj{value=v,left=left,right=right})
          | (LESS,SplayObj{value,left,right}) => 
              (nobj+1,
               SplayObj{
                 value=v,
                 left=SplayObj{value=value,left=left,right=SplayNil},
                 right=right})
          | (GREATER,SplayObj{value,left,right}) => 
              (nobj+1,
               SplayObj{
                  value=v,
                  left=left,
                  right=SplayObj{value=value,left=SplayNil,right=right}})
          | (_,SplayNil) => (1,SplayObj{value=v,left=SplayNil,right=SplayNil})

	(* Add an item.  
	 *)
    fun add (EMPTY,v) = singleton v
      | add (SET{root,nobj},v) = let
          val (cnt,t) = insert(v,(nobj,!root))
          in
            SET{nobj=cnt,root=ref t}
          end
    fun add' (s, x) = add(x, s)

	(* Insert a list of items.
	 *)
    fun addList (set,[]) = set
      | addList (set,l) = let
          val arg = case set of EMPTY => (0,SplayNil) 
                              | SET{root,nobj} => (nobj,!root)
          val (cnt,t) = List.foldl insert arg l
          in
            SET{nobj=cnt,root=ref t}
          end

	(* Remove an item.
         * Raise LibBase.NotFound if not found
	 *)
    fun delete (EMPTY,_) = raise LibBase.NotFound
      | delete (SET{root,nobj},key) =
          case splay (cmpf key, !root) of
            (EQUAL,SplayObj{value,left,right}) => 
              if nobj = 1 then EMPTY
              else SET{root=ref(join(left,right)),nobj=nobj-1}
          | (_,r) => (root := r; raise LibBase.NotFound)

  (* return true if the item is in the set *)
    fun member (EMPTY, key) = false
      | member (SET{root,nobj}, key) = (case splay (cmpf key, !root)
           of (EQUAL, r) => (root := r; true)
            | (_, r) => (root := r; false)
	  (* end case *))

    fun isEmpty EMPTY = true
      | isEmpty _ = false

    local
      fun member (x,tree) = let
            fun mbr SplayNil = false
              | mbr (SplayObj{value,left,right}) =
                  case K.compare(x,value) of
                    LESS => mbr left
                  | GREATER => mbr right
                  | _ => true
          in mbr tree end

        (* true if every item in t is in t' *)
      fun treeIn (t,t') = let
            fun isIn SplayNil = true
              | isIn (SplayObj{value,left=SplayNil,right=SplayNil}) =
                  member(value, t')
              | isIn (SplayObj{value,left,right=SplayNil}) =
                  member(value, t') andalso isIn left
              | isIn (SplayObj{value,left=SplayNil,right}) =
                  member(value, t') andalso isIn right
              | isIn (SplayObj{value,left,right}) =
                  member(value, t') andalso isIn left andalso isIn right
            in
              isIn t
            end
    in
    fun equal (SET{root=rt,nobj=n},SET{root=rt',nobj=n'}) =
          (n=n') andalso treeIn (!rt,!rt')
      | equal (EMPTY, EMPTY) = true
      | equal _ = false

    fun isSubset (SET{root=rt,nobj=n},SET{root=rt',nobj=n'}) =
          (n<=n') andalso treeIn (!rt,!rt')
      | isSubset (EMPTY,_) = true
      | isSubset _ = false
    end

    local
      fun next ((t as SplayObj{right, ...})::rest) = (t, left(right, rest))
	| next _ = (SplayNil, [])
      and left (SplayNil, rest) = rest
	| left (t as SplayObj{left=l, ...}, rest) = left(l, t::rest)
    in
    fun compare (EMPTY, EMPTY) = EQUAL
      | compare (EMPTY, _) = LESS
      | compare (_, EMPTY) = GREATER
      | compare (SET{root=s1, ...}, SET{root=s2, ...}) = let
	  fun cmp (t1, t2) = (case (next t1, next t2)
		 of ((SplayNil, _), (SplayNil, _)) => EQUAL
		  | ((SplayNil, _), _) => LESS
		  | (_, (SplayNil, _)) => GREATER
		  | ((SplayObj{value=e1, ...}, r1), (SplayObj{value=e2, ...}, r2)) => (
		      case Key.compare(e1, e2)
		       of EQUAL => cmp (r1, r2)
			| order => order
		      (* end case *))
		(* end case *))
	  in
	    cmp (left(!s1, []), left(!s2, []))
	  end
    end (* local *)

	(* Return the number of items in the table *)
    fun numItems EMPTY = 0
      | numItems (SET{nobj,...}) = nobj

    fun listItems EMPTY = []
      | listItems (SET{root,...}) =
        let fun apply (SplayNil,l) = l
              | apply (SplayObj{value,left,right},l) =
                  apply(left, value::(apply (right,l)))
        in
          apply (!root,[])
        end

    fun split (value,s) =
          case splay(cmpf value, s) of
            (EQUAL,SplayObj{value,left,right}) => (SOME value, left, right)
          | (LESS,SplayObj{value,left,right}) => (NONE, SplayObj{value=value,left=left,right=SplayNil},right)
          | (GREATER,SplayObj{value,left,right}) => (NONE, left, SplayObj{value=value,right=right,left=SplayNil})
          | (_,SplayNil) => (NONE, SplayNil, SplayNil)

    fun intersection (EMPTY,_) = EMPTY
      | intersection (_,EMPTY) = EMPTY
      | intersection (SET{root,...},SET{root=root',...}) =
          let fun inter(SplayNil,_) = (SplayNil,0)
                | inter(_,SplayNil) = (SplayNil,0)
                | inter(s, SplayObj{value,left,right}) =
                    case split(value,s) of
                      (SOME v, l, r) =>
                        let val (l',lcnt) = inter(l,left)
                            val (r',rcnt) = inter(r,right)
                        in
                          (SplayObj{value=v,left=l',right=r'},lcnt+rcnt+1)
                        end
                    | (_,l,r) =>
                        let val (l',lcnt) = inter(l,left)
                            val (r',rcnt) = inter(r,right)
                        in
                          (join(l',r'),lcnt+rcnt)
                        end
          in
            case inter(!root,!root') of
              (_,0) => EMPTY
            | (root,cnt) => SET{root = ref root, nobj = cnt}
          end

    fun count st =
         let fun cnt(SplayNil,n) = n
               | cnt(SplayObj{left,right,...},n) = cnt(left,cnt(right,n+1))
         in
           cnt(st,0)
         end

    fun difference (EMPTY,_) = EMPTY
      | difference (s,EMPTY) = s
      | difference (SET{root,...}, SET{root=root',...}) =
          let fun diff(SplayNil,_) = (SplayNil,0)
                | diff(s,SplayNil) = (s, count s)
                | diff(s,SplayObj{value,right,left}) =
                    let val (_,l,r) = split(value,s)
                        val (l',lcnt) = diff(l,left)
                        val (r',rcnt) = diff(r,right)
                    in
                      (join(l',r'),lcnt+rcnt)
                    end
          in
            case diff(!root,!root') of
              (_,0) => EMPTY
            | (root,cnt) => SET{root = ref root, nobj = cnt}
          end

    fun union (EMPTY,s) = s
      | union (s,EMPTY) = s
      | union (SET{root,...}, SET{root=root',...}) =
          let fun uni(SplayNil,s) = (s,count s)
                | uni(s,SplayNil) = (s, count s)
                | uni(s,SplayObj{value,right,left}) =
                    let val (_,l,r) = split(value,s)
                        val (l',lcnt) = uni(l,left)
                        val (r',rcnt) = uni(r,right)
                    in
                      (SplayObj{value=value,right=r',left=l'},lcnt+rcnt+1)
                    end
              val (root,cnt) = uni(!root,!root')
          in
            SET{root = ref root, nobj = cnt}
          end

    fun map f EMPTY = EMPTY
      | map f (SET{root, ...}) = let
	  fun mapf (acc, SplayNil) = acc
	    | mapf (acc, SplayObj{value,left,right}) =
		mapf (add (mapf (acc, left), f value), right)
	  in
	    mapf (EMPTY, !root)
	  end

    fun app af EMPTY = ()
      | app af (SET{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value,left,right}) =
                    (apply left; af value; apply right)
          in apply (!root) end
(*
    fun revapp af (SET{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value,left,right}) = 
                    (apply right; af value; apply left)
          in apply (!root) end
*)
	(* Fold function *)
    fun foldr abf b EMPTY = b
      | foldr abf b (SET{root,...}) =
          let fun apply (SplayNil, b) = b
                | apply (SplayObj{value,left,right},b) =
                    apply(left,abf(value,apply(right,b)))
        in
          apply (!root,b)
        end

    fun foldl abf b EMPTY = b
      | foldl abf b (SET{root,...}) =
          let fun apply (SplayNil, b) = b
                | apply (SplayObj{value,left,right},b) =
                    apply(right,abf(value,apply(left,b)))
        in
          apply (!root,b)
        end

    fun filter p EMPTY = EMPTY
      | filter p (SET{root,...}) = let
          fun filt (SplayNil,tree) = tree
            | filt (SplayObj{value,left,right},tree) = let
                val t' = filt(right,filt(left,tree))
                in
                  if p value then insert(value,t')
                  else t'
                end
          in
            case filt(!root,(0,SplayNil)) of
              (0,_) => EMPTY
            | (cnt,t) => SET{nobj=cnt,root=ref t}
          end

    fun exists p EMPTY = false
      | exists p (SET{root,...}) = let
          fun ex SplayNil = false
            | ex (SplayObj{value=v,left=l,right=r}) =
                if p v then true
                else case ex l of
                       false => ex r
                     | _ => true 
          in
            ex (!root)
          end

    fun find p EMPTY = NONE
      | find p (SET{root,...}) = let
          fun ex SplayNil = NONE
            | ex (SplayObj{value=v,left=l,right=r}) =
                if p v then SOME v
                else case ex l of
                       NONE => ex r
                     | a => a 
          in
            ex (!root)
          end


  end (* SplaySet *)

(* splay-map-fn.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Functor implementing dictionaries using splay trees.
 *
 *)

functor SplayMapFn (K : ORD_KEY) : ORD_MAP =
  struct
    structure Key = K
    open SplayTree

    datatype 'a map = 
        EMPTY
      | MAP of {
        root : (K.ord_key * 'a) splay ref,
        nobj : int
      }

    fun cmpf k (k', _) = K.compare(k',k)

    val empty = EMPTY
    
	(* Insert an item.  
	 *)
    fun insert (EMPTY,key,v) =
          MAP{nobj=1,root=ref(SplayObj{value=(key,v),left=SplayNil,right=SplayNil})}
      | insert (MAP{root,nobj},key,v) =
          case splay (cmpf key, !root) of
            (EQUAL,SplayObj{value,left,right}) => 
              MAP{nobj=nobj,root=ref(SplayObj{value=(key,v),left=left,right=right})}
          | (LESS,SplayObj{value,left,right}) => 
              MAP{
                nobj=nobj+1,
                root=ref(SplayObj{value=(key,v),left=SplayObj{value=value,left=left,right=SplayNil},right=right})
              }
          | (GREATER,SplayObj{value,left,right}) => 
              MAP{
                nobj=nobj+1,
                root=ref(SplayObj{
                  value=(key,v),
                  left=left,
                  right=SplayObj{value=value,left=SplayNil,right=right}
                })
              }
          | (_,SplayNil) => raise LibBase.Impossible "SplayMapFn.insert SplayNil"
    fun insert' ((k, x), m) = insert(m, k, x)

  (* Look for an item, return NONE if the item doesn't exist *)
    fun find (EMPTY,_) = NONE
      | find (MAP{root,nobj},key) = (case splay (cmpf key, !root)
	   of (EQUAL, r as SplayObj{value,...}) => (root := r; SOME(#2 value))
	    | (_, r) => (root := r; NONE))

	(* Remove an item.
         * Raise LibBase.NotFound if not found
	 *)
    fun remove (EMPTY, _) = raise LibBase.NotFound
      | remove (MAP{root,nobj}, key) = (case (splay (cmpf key, !root))
	 of (EQUAL, SplayObj{value, left, right}) => 
	      if nobj = 1
		then (EMPTY, #2 value)
		else (MAP{root=ref(join(left,right)),nobj=nobj-1}, #2 value)
	    | (_,r) => (root := r; raise LibBase.NotFound)
	  (* end case *))

	(* Return the number of items in the table *)
    fun numItems EMPTY = 0
      | numItems (MAP{nobj,...}) = nobj

	(* Return a list of the items (and their keys) in the dictionary *)
    fun listItems EMPTY = []
      | listItems (MAP{root,...}) = let
	  fun apply (SplayNil, l) = l
            | apply (SplayObj{value=(_, v), left, right}, l) =
                apply(left, v::(apply (right,l)))
        in
          apply (!root, [])
        end
    fun listItemsi EMPTY = []
      | listItemsi (MAP{root,...}) = let
	  fun apply (SplayNil,l) = l
            | apply (SplayObj{value,left,right},l) =
                apply(left, value::(apply (right,l)))
        in
          apply (!root,[])
        end

    local
      fun next ((t as SplayObj{right, ...})::rest) = (t, left(right, rest))
	| next _ = (SplayNil, [])
      and left (SplayNil, rest) = rest
	| left (t as SplayObj{left=l, ...}, rest) = left(l, t::rest)
    in
    fun collate cmpRng (EMPTY, EMPTY) = EQUAL
      | collate cmpRng (EMPTY, _) = LESS
      | collate cmpRng (_, EMPTY) = GREATER
      | collate cmpRng (MAP{root=s1, ...}, MAP{root=s2, ...}) = let
	  fun cmp (t1, t2) = (case (next t1, next t2)
		 of ((SplayNil, _), (SplayNil, _)) => EQUAL
		  | ((SplayNil, _), _) => LESS
		  | (_, (SplayNil, _)) => GREATER
		  | ((SplayObj{value=(x1, y1), ...}, r1),
		     (SplayObj{value=(x2, y2), ...}, r2)
		    ) => (
		      case Key.compare(x1, x2)
		       of EQUAL => (case cmpRng (y1, y2)
			     of EQUAL => cmp (r1, r2)
			      | order => order
			    (* end case *))
			| order => order
		      (* end case *))
		(* end case *))
	  in
	    cmp (left(!s1, []), left(!s2, []))
	  end
    end (* local *)

	(* Apply a function to the entries of the dictionary *)
    fun appi af EMPTY = ()
      | appi af (MAP{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value,left,right}) = 
                    (apply left; af value; apply right)
        in
          apply (!root)
        end

    fun app af EMPTY = ()
      | app af (MAP{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value=(_,value),left,right}) = 
                    (apply left; af value; apply right)
        in
          apply (!root)
        end
(*
    fun revapp af (MAP{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value,left,right}) = 
                    (apply right; af value; apply left)
        in
          apply (!root)
        end
*)

	(* Fold function *)
    fun foldri (abf : K.ord_key * 'a * 'b -> 'b) b EMPTY = b
      | foldri (abf : K.ord_key * 'a * 'b -> 'b) b (MAP{root,...}) =
          let fun apply (SplayNil : (K.ord_key * 'a) splay, b) = b
                | apply (SplayObj{value,left,right},b) =
                    apply(left,abf(#1 value,#2 value,apply(right,b)))
        in
          apply (!root,b)
        end

    fun foldr (abf : 'a * 'b -> 'b) b EMPTY = b
      | foldr (abf : 'a * 'b -> 'b) b (MAP{root,...}) =
          let fun apply (SplayNil : (K.ord_key * 'a) splay, b) = b
                | apply (SplayObj{value=(_,value),left,right},b) =
                    apply(left,abf(value,apply(right,b)))
        in
          apply (!root,b)
        end

    fun foldli (abf : K.ord_key * 'a * 'b -> 'b) b EMPTY = b
      | foldli (abf : K.ord_key * 'a * 'b -> 'b) b (MAP{root,...}) =
          let fun apply (SplayNil : (K.ord_key * 'a) splay, b) = b
                | apply (SplayObj{value,left,right},b) =
                    apply(right,abf(#1 value,#2 value,apply(left,b)))
        in
          apply (!root,b)
        end

    fun foldl (abf : 'a * 'b -> 'b) b EMPTY = b
      | foldl (abf : 'a * 'b -> 'b) b (MAP{root,...}) =
          let fun apply (SplayNil : (K.ord_key * 'a) splay, b) = b
                | apply (SplayObj{value=(_,value),left,right},b) =
                    apply(right,abf(value,apply(left,b)))
        in
          apply (!root,b)
        end

	(* Map a table to a new table that has the same keys*)
    fun mapi (af : K.ord_key * 'a -> 'b) EMPTY = EMPTY
      | mapi (af : K.ord_key * 'a -> 'b) (MAP{root,nobj}) =
          let fun ap (SplayNil : (K.ord_key * 'a) splay) = SplayNil
                | ap (SplayObj{value,left,right}) = let
                    val left' = ap left
                    val value' = (#1 value, af value)
                    in
                      SplayObj{value = value', left = left', right = ap right}
                    end
        in
          MAP{root = ref(ap (!root)), nobj = nobj}
        end

    fun map (af : 'a -> 'b) EMPTY = EMPTY
      | map (af : 'a -> 'b) (MAP{root,nobj}) =
          let fun ap (SplayNil : (K.ord_key * 'a) splay) = SplayNil
                | ap (SplayObj{value,left,right}) = let
                    val left' = ap left
                    val value' = (#1 value, af (#2 value))
                    in
                      SplayObj{value = value', left = left', right = ap right}
                    end
        in
          MAP{root = ref(ap (!root)), nobj = nobj}
        end

(* the following are generic implementations of the unionWith and intersectWith
 * operetions.  These should be specialized for the internal representations
 * at some point.
 *)
    fun unionWith f (m1, m2) = let
	  fun ins f (key, x, m) = (case find(m, key)
		 of NONE => insert(m, key, x)
		  | (SOME x') => insert(m, key, f(x, x'))
		(* end case *))
	  in
	    if (numItems m1 > numItems m2)
	      then foldli (ins (fn (a, b) => f(b, a))) m1 m2
	      else foldli (ins f) m2 m1
	  end
    fun unionWithi f (m1, m2) = let
	  fun ins f (key, x, m) = (case find(m, key)
		 of NONE => insert(m, key, x)
		  | (SOME x') => insert(m, key, f(key, x, x'))
		(* end case *))
	  in
	    if (numItems m1 > numItems m2)
	      then foldli (ins (fn (k, a, b) => f(k, b, a))) m1 m2
	      else foldli (ins f) m2 m1
	  end

    fun intersectWith f (m1, m2) = let
	(* iterate over the elements of m1, checking for membership in m2 *)
	  fun intersect f (m1, m2) = let
		fun ins (key, x, m) = (case find(m2, key)
		       of NONE => m
			| (SOME x') => insert(m, key, f(x, x'))
		      (* end case *))
		in
		  foldli ins empty m1
		end
	  in
	    if (numItems m1 > numItems m2)
	      then intersect f (m1, m2)
	      else intersect (fn (a, b) => f(b, a)) (m2, m1)
	  end

    fun intersectWithi f (m1, m2) = let
	(* iterate over the elements of m1, checking for membership in m2 *)
	  fun intersect f (m1, m2) = let
		fun ins (key, x, m) = (case find(m2, key)
		       of NONE => m
			| (SOME x') => insert(m, key, f(key, x, x'))
		      (* end case *))
		in
		  foldli ins empty m1
		end
	  in
	    if (numItems m1 > numItems m2)
	      then intersect f (m1, m2)
	      else intersect (fn (k, a, b) => f(k, b, a)) (m2, m1)
	  end

  (* this is a generic implementation of mapPartial.  It should
   * be specialized to the data-structure at some point.
   *)
    fun mapPartial f m = let
	  fun g (key, item, m) = (case f item
		 of NONE => m
		  | (SOME item') => insert(m, key, item')
		(* end case *))
	  in
	    foldli g empty m
	  end
    fun mapPartiali f m = let
	  fun g (key, item, m) = (case f(key, item)
		 of NONE => m
		  | (SOME item') => insert(m, key, item')
		(* end case *))
	  in
	    foldli g empty m
	  end

  (* this is a generic implementation of filter.  It should
   * be specialized to the data-structure at some point.
   *)
    fun filter predFn m = let
	  fun f (key, item, m) = if predFn item
		then insert(m, key, item)
		else m
	  in
	    foldli f empty m
	  end
    fun filteri predFn m = let
	  fun f (key, item, m) = if predFn(key, item)
		then insert(m, key, item)
		else m
	  in
	    foldli f empty m
	  end

  end (* SplayDictFn *)


(* ----------------

structure IntSet = SplaySetFn(struct type ord_key = int; val compare = Int.compare end);

val a = IntSet.singleton 1;
val b = IntSet.singleton 2;
val c = IntSet.singleton 1;

val s = IntSet.union (IntSet.union (a,b), c);
val _ = map (fn i => print (Int.toString i)) (IntSet.listItems s);

   ---------------- *)


(* IDs *)

(* The various root types shares the id counter *)
val curRootId = ref 0;
fun getCurrentRootId () = !curRootId;
val _ = _export "getCurrentRootId" : (unit -> int) -> unit; getCurrentRootId;
fun getNextRoot () = ( curRootId := !curRootId+1 ; !curRootId);


(* For each combination of base type and compound constructor, we need a root type with operations: 
   get, reg, return, unreg *)
signature ROOT_TYPE = sig type t end

signature ROOT      = 
  sig 
    type t 
    exception RootNotFound;
    val get:    int ->  t      (* get root given root number *)
    val reg:    t   -> int     (* register and return root number *)
    val return: t   -> t       (* register and return root *)
    val unreg:  int -> unit 
  end


functor Root (R: ROOT_TYPE): ROOT =
struct
type t = R.t;
exception RootNotFound;

structure RootMap = SplayMapFn(struct 
                                 type ord_key = int; 
                                 val compare = Int.compare
                               end);

val roots = ref RootMap.empty;   

val reg = 
      fn r => ( roots := RootMap.insert (!roots, getNextRoot (), r)
              ; getCurrentRootId ());

val get =
      fn i => case RootMap.find (!roots, i)
              of NONE   => raise RootNotFound
	       | SOME r => r;

val return =
      fn r => ( roots := RootMap.insert (!roots, getNextRoot (), r)
              ; r);

(* TODO: Catch the exception, when unreg is passed an already
         unregistered root *)

val unreg =
      fn i => roots := (case RootMap.remove (!roots, i)
                        of (rs,_) => rs);
end




-------------- next part --------------
(* lib-base-sig.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *)

signature LIB_BASE =
  sig

    exception Unimplemented of string
	(* raised to report unimplemented features *)
    exception Impossible of string
	(* raised to report internal errors *)

    exception NotFound
	(* raised by searching operations *)

    val failure : {module : string, func : string, msg : string} -> 'a
	(* raise the exception Fail with a standard format message. *)

    val version : {date : string, system : string, version_id : int list}
    val banner : string

  end (* LIB_BASE *)


(* lib-base.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *)

structure LibBase : LIB_BASE =
  struct

  (* raised to report unimplemented features *)
    exception Unimplemented of string

  (* raised to report internal errors *)
    exception Impossible of string

  (* raised by searching operations *)
    exception NotFound

  (* raise the exception Fail with a standard format message. *)
    fun failure {module, func, msg} =
	  raise (Fail(concat[module, ".", func, ": ", msg]))

    val version = {
	    date = "June 1, 1996", 
	    system = "SML/NJ Library",
	    version_id = [1, 0]
	  }

    fun f ([], l) = l
      | f ([x : int], l) = (Int.toString x)::l
      | f (x::r, l) = (Int.toString x) :: "." :: f(r, l)

    val banner = concat (
	    #system version :: ", Version " ::
	    f (#version_id version, [", ", #date version]))

  end (* LibBase *)


(* splaytree-sig.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Signature for a splay tree data structure.
 *
 *)

signature SPLAY_TREE = 
  sig
    datatype 'a splay = 
      SplayObj of {
        value : 'a,
        right : 'a splay,
        left : 'a splay
      }
    | SplayNil


    val splay : (('a -> order) * 'a splay) -> (order * 'a splay)
      (* (r,tree') = splay (cmp,tree) 
       * where tree' is tree adjusted using the comparison function cmp
       * and, if tree' = SplayObj{value,...}, r = cmp value.
       * tree' = SplayNil iff tree = SplayNil, in which case r is undefined.
       *)

    val join : 'a splay * 'a splay -> 'a splay
      (* join(t,t') returns a new splay tree formed of t and t'
       *)

  end (* SPLAY_TREE *)


(* ord-key-sig.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Abstract linearly ordered keys.
 *
 *)

signature ORD_KEY =
  sig
    type ord_key

    val compare : ord_key * ord_key -> order

  end (* ORD_KEY *)


(* ord-map-sig.sml
 *
 * COPYRIGHT (c) 1996 by AT&T Research.  See COPYRIGHT file for details.
 *
 * Abstract signature of an applicative-style finite maps (dictionaries)
 * structure over ordered monomorphic keys.
 *)

signature ORD_MAP =
  sig

    structure Key : ORD_KEY

    type 'a map

    val empty : 'a map
	(* The empty map *)

    val insert  : 'a map * Key.ord_key * 'a -> 'a map
    val insert' : ((Key.ord_key * 'a) * 'a map) -> 'a map
	(* Insert an item. *)

    val find : 'a map * Key.ord_key -> 'a option
	(* Look for an item, return NONE if the item doesn't exist *)

    val remove : 'a map * Key.ord_key -> 'a map * 'a
	(* Remove an item, returning new map and value removed.
         * Raises LibBase.NotFound if not found.
	 *)

    val numItems : 'a map ->  int
	(* Return the number of items in the map *)

    val listItems  : 'a map -> 'a list
    val listItemsi : 'a map -> (Key.ord_key * 'a) list
	(* Return an ordered list of the items (and their keys) in the map.
         *)

    val collate : ('a * 'a -> order) -> ('a map * 'a map) -> order
	(* given an ordering on the map's range, return an ordering
	 * on the map.
	 *)

    val unionWith  : ('a * 'a -> 'a) -> ('a map * 'a map) -> 'a map
    val unionWithi : (Key.ord_key * 'a * 'a -> 'a) -> ('a map * 'a map) -> 'a map
	(* return a map whose domain is the union of the domains of the two input
	 * maps, using the supplied function to define the map on elements that
	 * are in both domains.
	 *)

    val intersectWith  : ('a * 'b -> 'c) -> ('a map * 'b map) -> 'c map
    val intersectWithi : (Key.ord_key * 'a * 'b -> 'c) -> ('a map * 'b map) -> 'c map
	(* return a map whose domain is the intersection of the domains of the
	 * two input maps, using the supplied function to define the range.
	 *)

    val app  : ('a -> unit) -> 'a map -> unit
    val appi : ((Key.ord_key * 'a) -> unit) -> 'a map -> unit
	(* Apply a function to the entries of the map in map order. *)

    val map  : ('a -> 'b) -> 'a map -> 'b map
    val mapi : (Key.ord_key * 'a -> 'b) -> 'a map -> 'b map
	(* Create a new map by applying a map function to the
         * name/value pairs in the map.
         *)

    val foldl  : ('a * 'b -> 'b) -> 'b -> 'a map -> 'b
    val foldli : (Key.ord_key * 'a * 'b -> 'b) -> 'b -> 'a map -> 'b
	(* Apply a folding function to the entries of the map
         * in increasing map order.
         *)

    val foldr  : ('a * 'b -> 'b) -> 'b -> 'a map -> 'b
    val foldri : (Key.ord_key * 'a * 'b -> 'b) -> 'b -> 'a map -> 'b
	(* Apply a folding function to the entries of the map
         * in decreasing map order.
         *)

    val filter  : ('a -> bool) -> 'a map -> 'a map
    val filteri : (Key.ord_key * 'a -> bool) -> 'a map -> 'a map
	(* Filter out those elements of the map that do not satisfy the
	 * predicate.  The filtering is done in increasing map order.
	 *)

    val mapPartial  : ('a -> 'b option) -> 'a map -> 'b map
    val mapPartiali : (Key.ord_key * 'a -> 'b option) -> 'a map -> 'b map
	(* map a partial function over the elements of a map in increasing
	 * map order.
	 *)

  end (* ORD_MAP *)


(* ordset-sig.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Signature for a set of values with an order relation.
 *)

signature ORD_SET =
  sig

    structure Key : ORD_KEY

    type item = Key.ord_key
    type set

    val empty : set
	(* The empty set *)

    val singleton : item -> set
	(* Create a singleton set *)

    val add  : set * item -> set
    val add' : (item * set) -> set
	(* Insert an item. *)

    val addList : set * item list -> set
	(* Insert items from list. *)

    val delete : set * item -> set
	(* Remove an item. Raise NotFound if not found. *)

    val member : set * item -> bool
	(* Return true if and only if item is an element in the set *)

    val isEmpty : set -> bool
	(* Return true if and only if the set is empty *)

    val equal : (set * set) -> bool
	(* Return true if and only if the two sets are equal *)

    val compare : (set * set) -> order
	(* does a lexical comparison of two sets *)

    val isSubset : (set * set) -> bool
	(* Return true if and only if the first set is a subset of the second *)

    val numItems : set ->  int
	(* Return the number of items in the table *)

    val listItems : set -> item list
	(* Return an ordered list of the items in the set *)

    val union : set * set -> set
        (* Union *)

    val intersection : set * set -> set
        (* Intersection *)

    val difference : set * set -> set
        (* Difference *)

    val map : (item -> item) -> set -> set
	(* Create a new set by applying a map function to the elements
	 * of the set.
         *)
     
    val app : (item -> unit) -> set -> unit
	(* Apply a function to the entries of the set 
         * in decreasing order
         *)

    val foldl : (item * 'b -> 'b) -> 'b -> set -> 'b
	(* Apply a folding function to the entries of the set 
         * in increasing order
         *)

    val foldr : (item * 'b -> 'b) -> 'b -> set -> 'b
	(* Apply a folding function to the entries of the set 
         * in decreasing order
         *)

    val filter : (item -> bool) -> set -> set

    val exists : (item -> bool) -> set -> bool

    val find : (item -> bool) -> set -> item option

  end (* ORD_SET *)



(* splaytree.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Splay tree structure.
 *
 *)

structure SplayTree : SPLAY_TREE = 
  struct

    datatype 'a splay = 
      SplayObj of {
        value : 'a,
        right : 'a splay,
        left : 'a splay
      }
    | SplayNil

    datatype 'a ans_t = No | Eq of 'a | Lt of 'a | Gt of 'a

    fun splay (compf, root) = let
        fun adj SplayNil = (No,SplayNil,SplayNil)
          | adj (arg as SplayObj{value,left,right}) =
              (case compf value of
                EQUAL => (Eq value, left, right)
              | GREATER =>
                  (case left of
                    SplayNil => (Gt value,SplayNil,right)
                  | SplayObj{value=value',left=left',right=right'} =>
                      (case compf value' of
                        EQUAL => (Eq value',left',
                                    SplayObj{value=value,left=right',right=right})
                      | GREATER =>
                          (case left' of 
                            SplayNil => (Gt value',left',SplayObj{value=value,left=right',right=right})
                          | _ => 
                            let val (V,L,R) = adj left'
                                val rchild = SplayObj{value=value,left=right',right=right}
                            in
                              (V,L,SplayObj{value=value',left=R,right=rchild})
                            end
                          ) (* end case *)
                      | _ =>
                          (case right' of 
                            SplayNil => (Lt value',left',SplayObj{value=value,left=right',right=right})
                          | _ =>
                            let val (V,L,R) = adj right'
                                 val rchild = SplayObj{value=value,left=R,right=right}
                                 val lchild = SplayObj{value=value',left=left',right=L}
                            in
                              (V,lchild,rchild)
                            end
                          ) (* end case *)
                      ) (* end case *)
                  ) (* end case *)
              | _ =>
                 (case right of
                   SplayNil => (Lt value,left,SplayNil)
                 | SplayObj{value=value',left=left',right=right'} =>
                     (case compf value' of
                       EQUAL =>
                         (Eq value',SplayObj{value=value,left=left,right=left'},right')
                     | LESS =>
                         (case right' of
                           SplayNil => (Lt value',SplayObj{value=value,left=left,right=left'},right')
                         | _ =>
                           let val (V,L,R) = adj right'
                               val lchild = SplayObj{value=value,left=left,right=left'}
                           in
                             (V,SplayObj{value=value',left=lchild,right=L},R)
                           end
                         ) (* end case *)
                     | _ =>
                         (case left' of
                           SplayNil => (Gt value',SplayObj{value=value,left=left,right=left'},right')
                         | _ =>
                           let val (V,L,R) = adj left'
                               val rchild = SplayObj{value=value',left=R,right=right'}
                               val lchild = SplayObj{value=value,left=left,right=L}
                           in
                             (V,lchild,rchild)
                           end
                         ) (* end case *)
                     ) (* end case *)
                 ) (* end case *)
              ) (* end case *)
      in
        case adj root of
          (No,_,_) => (GREATER,SplayNil)
        | (Eq v,l,r) => (EQUAL,SplayObj{value=v,left=l,right=r})
        | (Lt v,l,r) => (LESS,SplayObj{value=v,left=l,right=r})
        | (Gt v,l,r) => (GREATER,SplayObj{value=v,left=l,right=r})
      end

    fun lrotate SplayNil = SplayNil
      | lrotate (arg as SplayObj{value,left,right=SplayNil}) = arg
      | lrotate (SplayObj{value,left,right=SplayObj{value=v,left=l,right=r}}) = 
          lrotate (SplayObj{value=v,left=SplayObj{value=value,left=left,right=l},right=r})

    fun join (SplayNil,SplayNil) = SplayNil
      | join (SplayNil,t) = t
      | join (t,SplayNil) = t
      | join (l,r) =
          case lrotate l of
            SplayNil => r      (* impossible as l is not SplayNil *)
          | SplayObj{value,left,right} => SplayObj{value=value,left=left,right=r}

  end (* SplayTree *)



(* splay-set-fn.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Functor implementing ordered sets using splay trees.
 *
 *)

functor SplaySetFn (K : ORD_KEY) : ORD_SET =
  struct
    structure Key = K
    open SplayTree

    type item = K.ord_key
  
    datatype set = 
        EMPTY
      | SET of {
        root : item splay ref,
        nobj : int
      }

    fun cmpf k = fn k' => K.compare(k',k)

    val empty = EMPTY

    fun singleton v = SET{root = ref(SplayObj{value=v,left=SplayNil,right=SplayNil}),nobj=1}
    
	(* Primitive insertion.
	 *)
    fun insert (v,(nobj,root)) =
          case splay (cmpf v, root) of
            (EQUAL,SplayObj{value,left,right}) => 
              (nobj,SplayObj{value=v,left=left,right=right})
          | (LESS,SplayObj{value,left,right}) => 
              (nobj+1,
               SplayObj{
                 value=v,
                 left=SplayObj{value=value,left=left,right=SplayNil},
                 right=right})
          | (GREATER,SplayObj{value,left,right}) => 
              (nobj+1,
               SplayObj{
                  value=v,
                  left=left,
                  right=SplayObj{value=value,left=SplayNil,right=right}})
          | (_,SplayNil) => (1,SplayObj{value=v,left=SplayNil,right=SplayNil})

	(* Add an item.  
	 *)
    fun add (EMPTY,v) = singleton v
      | add (SET{root,nobj},v) = let
          val (cnt,t) = insert(v,(nobj,!root))
          in
            SET{nobj=cnt,root=ref t}
          end
    fun add' (s, x) = add(x, s)

	(* Insert a list of items.
	 *)
    fun addList (set,[]) = set
      | addList (set,l) = let
          val arg = case set of EMPTY => (0,SplayNil) 
                              | SET{root,nobj} => (nobj,!root)
          val (cnt,t) = List.foldl insert arg l
          in
            SET{nobj=cnt,root=ref t}
          end

	(* Remove an item.
         * Raise LibBase.NotFound if not found
	 *)
    fun delete (EMPTY,_) = raise LibBase.NotFound
      | delete (SET{root,nobj},key) =
          case splay (cmpf key, !root) of
            (EQUAL,SplayObj{value,left,right}) => 
              if nobj = 1 then EMPTY
              else SET{root=ref(join(left,right)),nobj=nobj-1}
          | (_,r) => (root := r; raise LibBase.NotFound)

  (* return true if the item is in the set *)
    fun member (EMPTY, key) = false
      | member (SET{root,nobj}, key) = (case splay (cmpf key, !root)
           of (EQUAL, r) => (root := r; true)
            | (_, r) => (root := r; false)
	  (* end case *))

    fun isEmpty EMPTY = true
      | isEmpty _ = false

    local
      fun member (x,tree) = let
            fun mbr SplayNil = false
              | mbr (SplayObj{value,left,right}) =
                  case K.compare(x,value) of
                    LESS => mbr left
                  | GREATER => mbr right
                  | _ => true
          in mbr tree end

        (* true if every item in t is in t' *)
      fun treeIn (t,t') = let
            fun isIn SplayNil = true
              | isIn (SplayObj{value,left=SplayNil,right=SplayNil}) =
                  member(value, t')
              | isIn (SplayObj{value,left,right=SplayNil}) =
                  member(value, t') andalso isIn left
              | isIn (SplayObj{value,left=SplayNil,right}) =
                  member(value, t') andalso isIn right
              | isIn (SplayObj{value,left,right}) =
                  member(value, t') andalso isIn left andalso isIn right
            in
              isIn t
            end
    in
    fun equal (SET{root=rt,nobj=n},SET{root=rt',nobj=n'}) =
          (n=n') andalso treeIn (!rt,!rt')
      | equal (EMPTY, EMPTY) = true
      | equal _ = false

    fun isSubset (SET{root=rt,nobj=n},SET{root=rt',nobj=n'}) =
          (n<=n') andalso treeIn (!rt,!rt')
      | isSubset (EMPTY,_) = true
      | isSubset _ = false
    end

    local
      fun next ((t as SplayObj{right, ...})::rest) = (t, left(right, rest))
	| next _ = (SplayNil, [])
      and left (SplayNil, rest) = rest
	| left (t as SplayObj{left=l, ...}, rest) = left(l, t::rest)
    in
    fun compare (EMPTY, EMPTY) = EQUAL
      | compare (EMPTY, _) = LESS
      | compare (_, EMPTY) = GREATER
      | compare (SET{root=s1, ...}, SET{root=s2, ...}) = let
	  fun cmp (t1, t2) = (case (next t1, next t2)
		 of ((SplayNil, _), (SplayNil, _)) => EQUAL
		  | ((SplayNil, _), _) => LESS
		  | (_, (SplayNil, _)) => GREATER
		  | ((SplayObj{value=e1, ...}, r1), (SplayObj{value=e2, ...}, r2)) => (
		      case Key.compare(e1, e2)
		       of EQUAL => cmp (r1, r2)
			| order => order
		      (* end case *))
		(* end case *))
	  in
	    cmp (left(!s1, []), left(!s2, []))
	  end
    end (* local *)

	(* Return the number of items in the table *)
    fun numItems EMPTY = 0
      | numItems (SET{nobj,...}) = nobj

    fun listItems EMPTY = []
      | listItems (SET{root,...}) =
        let fun apply (SplayNil,l) = l
              | apply (SplayObj{value,left,right},l) =
                  apply(left, value::(apply (right,l)))
        in
          apply (!root,[])
        end

    fun split (value,s) =
          case splay(cmpf value, s) of
            (EQUAL,SplayObj{value,left,right}) => (SOME value, left, right)
          | (LESS,SplayObj{value,left,right}) => (NONE, SplayObj{value=value,left=left,right=SplayNil},right)
          | (GREATER,SplayObj{value,left,right}) => (NONE, left, SplayObj{value=value,right=right,left=SplayNil})
          | (_,SplayNil) => (NONE, SplayNil, SplayNil)

    fun intersection (EMPTY,_) = EMPTY
      | intersection (_,EMPTY) = EMPTY
      | intersection (SET{root,...},SET{root=root',...}) =
          let fun inter(SplayNil,_) = (SplayNil,0)
                | inter(_,SplayNil) = (SplayNil,0)
                | inter(s, SplayObj{value,left,right}) =
                    case split(value,s) of
                      (SOME v, l, r) =>
                        let val (l',lcnt) = inter(l,left)
                            val (r',rcnt) = inter(r,right)
                        in
                          (SplayObj{value=v,left=l',right=r'},lcnt+rcnt+1)
                        end
                    | (_,l,r) =>
                        let val (l',lcnt) = inter(l,left)
                            val (r',rcnt) = inter(r,right)
                        in
                          (join(l',r'),lcnt+rcnt)
                        end
          in
            case inter(!root,!root') of
              (_,0) => EMPTY
            | (root,cnt) => SET{root = ref root, nobj = cnt}
          end

    fun count st =
         let fun cnt(SplayNil,n) = n
               | cnt(SplayObj{left,right,...},n) = cnt(left,cnt(right,n+1))
         in
           cnt(st,0)
         end

    fun difference (EMPTY,_) = EMPTY
      | difference (s,EMPTY) = s
      | difference (SET{root,...}, SET{root=root',...}) =
          let fun diff(SplayNil,_) = (SplayNil,0)
                | diff(s,SplayNil) = (s, count s)
                | diff(s,SplayObj{value,right,left}) =
                    let val (_,l,r) = split(value,s)
                        val (l',lcnt) = diff(l,left)
                        val (r',rcnt) = diff(r,right)
                    in
                      (join(l',r'),lcnt+rcnt)
                    end
          in
            case diff(!root,!root') of
              (_,0) => EMPTY
            | (root,cnt) => SET{root = ref root, nobj = cnt}
          end

    fun union (EMPTY,s) = s
      | union (s,EMPTY) = s
      | union (SET{root,...}, SET{root=root',...}) =
          let fun uni(SplayNil,s) = (s,count s)
                | uni(s,SplayNil) = (s, count s)
                | uni(s,SplayObj{value,right,left}) =
                    let val (_,l,r) = split(value,s)
                        val (l',lcnt) = uni(l,left)
                        val (r',rcnt) = uni(r,right)
                    in
                      (SplayObj{value=value,right=r',left=l'},lcnt+rcnt+1)
                    end
              val (root,cnt) = uni(!root,!root')
          in
            SET{root = ref root, nobj = cnt}
          end

    fun map f EMPTY = EMPTY
      | map f (SET{root, ...}) = let
	  fun mapf (acc, SplayNil) = acc
	    | mapf (acc, SplayObj{value,left,right}) =
		mapf (add (mapf (acc, left), f value), right)
	  in
	    mapf (EMPTY, !root)
	  end

    fun app af EMPTY = ()
      | app af (SET{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value,left,right}) =
                    (apply left; af value; apply right)
          in apply (!root) end
(*
    fun revapp af (SET{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value,left,right}) = 
                    (apply right; af value; apply left)
          in apply (!root) end
*)
	(* Fold function *)
    fun foldr abf b EMPTY = b
      | foldr abf b (SET{root,...}) =
          let fun apply (SplayNil, b) = b
                | apply (SplayObj{value,left,right},b) =
                    apply(left,abf(value,apply(right,b)))
        in
          apply (!root,b)
        end

    fun foldl abf b EMPTY = b
      | foldl abf b (SET{root,...}) =
          let fun apply (SplayNil, b) = b
                | apply (SplayObj{value,left,right},b) =
                    apply(right,abf(value,apply(left,b)))
        in
          apply (!root,b)
        end

    fun filter p EMPTY = EMPTY
      | filter p (SET{root,...}) = let
          fun filt (SplayNil,tree) = tree
            | filt (SplayObj{value,left,right},tree) = let
                val t' = filt(right,filt(left,tree))
                in
                  if p value then insert(value,t')
                  else t'
                end
          in
            case filt(!root,(0,SplayNil)) of
              (0,_) => EMPTY
            | (cnt,t) => SET{nobj=cnt,root=ref t}
          end

    fun exists p EMPTY = false
      | exists p (SET{root,...}) = let
          fun ex SplayNil = false
            | ex (SplayObj{value=v,left=l,right=r}) =
                if p v then true
                else case ex l of
                       false => ex r
                     | _ => true 
          in
            ex (!root)
          end

    fun find p EMPTY = NONE
      | find p (SET{root,...}) = let
          fun ex SplayNil = NONE
            | ex (SplayObj{value=v,left=l,right=r}) =
                if p v then SOME v
                else case ex l of
                       NONE => ex r
                     | a => a 
          in
            ex (!root)
          end


  end (* SplaySet *)

(* splay-map-fn.sml
 *
 * COPYRIGHT (c) 1993 by AT&T Bell Laboratories.  See COPYRIGHT file for details.
 *
 * Functor implementing dictionaries using splay trees.
 *
 *)

functor SplayMapFn (K : ORD_KEY) : ORD_MAP =
  struct
    structure Key = K
    open SplayTree

    datatype 'a map = 
        EMPTY
      | MAP of {
        root : (K.ord_key * 'a) splay ref,
        nobj : int
      }

    fun cmpf k (k', _) = K.compare(k',k)

    val empty = EMPTY
    
	(* Insert an item.  
	 *)
    fun insert (EMPTY,key,v) =
          MAP{nobj=1,root=ref(SplayObj{value=(key,v),left=SplayNil,right=SplayNil})}
      | insert (MAP{root,nobj},key,v) =
          case splay (cmpf key, !root) of
            (EQUAL,SplayObj{value,left,right}) => 
              MAP{nobj=nobj,root=ref(SplayObj{value=(key,v),left=left,right=right})}
          | (LESS,SplayObj{value,left,right}) => 
              MAP{
                nobj=nobj+1,
                root=ref(SplayObj{value=(key,v),left=SplayObj{value=value,left=left,right=SplayNil},right=right})
              }
          | (GREATER,SplayObj{value,left,right}) => 
              MAP{
                nobj=nobj+1,
                root=ref(SplayObj{
                  value=(key,v),
                  left=left,
                  right=SplayObj{value=value,left=SplayNil,right=right}
                })
              }
          | (_,SplayNil) => raise LibBase.Impossible "SplayMapFn.insert SplayNil"
    fun insert' ((k, x), m) = insert(m, k, x)

  (* Look for an item, return NONE if the item doesn't exist *)
    fun find (EMPTY,_) = NONE
      | find (MAP{root,nobj},key) = (case splay (cmpf key, !root)
	   of (EQUAL, r as SplayObj{value,...}) => (root := r; SOME(#2 value))
	    | (_, r) => (root := r; NONE))

	(* Remove an item.
         * Raise LibBase.NotFound if not found
	 *)
    fun remove (EMPTY, _) = raise LibBase.NotFound
      | remove (MAP{root,nobj}, key) = (case (splay (cmpf key, !root))
	 of (EQUAL, SplayObj{value, left, right}) => 
	      if nobj = 1
		then (EMPTY, #2 value)
		else (MAP{root=ref(join(left,right)),nobj=nobj-1}, #2 value)
	    | (_,r) => (root := r; raise LibBase.NotFound)
	  (* end case *))

	(* Return the number of items in the table *)
    fun numItems EMPTY = 0
      | numItems (MAP{nobj,...}) = nobj

	(* Return a list of the items (and their keys) in the dictionary *)
    fun listItems EMPTY = []
      | listItems (MAP{root,...}) = let
	  fun apply (SplayNil, l) = l
            | apply (SplayObj{value=(_, v), left, right}, l) =
                apply(left, v::(apply (right,l)))
        in
          apply (!root, [])
        end
    fun listItemsi EMPTY = []
      | listItemsi (MAP{root,...}) = let
	  fun apply (SplayNil,l) = l
            | apply (SplayObj{value,left,right},l) =
                apply(left, value::(apply (right,l)))
        in
          apply (!root,[])
        end

    local
      fun next ((t as SplayObj{right, ...})::rest) = (t, left(right, rest))
	| next _ = (SplayNil, [])
      and left (SplayNil, rest) = rest
	| left (t as SplayObj{left=l, ...}, rest) = left(l, t::rest)
    in
    fun collate cmpRng (EMPTY, EMPTY) = EQUAL
      | collate cmpRng (EMPTY, _) = LESS
      | collate cmpRng (_, EMPTY) = GREATER
      | collate cmpRng (MAP{root=s1, ...}, MAP{root=s2, ...}) = let
	  fun cmp (t1, t2) = (case (next t1, next t2)
		 of ((SplayNil, _), (SplayNil, _)) => EQUAL
		  | ((SplayNil, _), _) => LESS
		  | (_, (SplayNil, _)) => GREATER
		  | ((SplayObj{value=(x1, y1), ...}, r1),
		     (SplayObj{value=(x2, y2), ...}, r2)
		    ) => (
		      case Key.compare(x1, x2)
		       of EQUAL => (case cmpRng (y1, y2)
			     of EQUAL => cmp (r1, r2)
			      | order => order
			    (* end case *))
			| order => order
		      (* end case *))
		(* end case *))
	  in
	    cmp (left(!s1, []), left(!s2, []))
	  end
    end (* local *)

	(* Apply a function to the entries of the dictionary *)
    fun appi af EMPTY = ()
      | appi af (MAP{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value,left,right}) = 
                    (apply left; af value; apply right)
        in
          apply (!root)
        end

    fun app af EMPTY = ()
      | app af (MAP{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value=(_,value),left,right}) = 
                    (apply left; af value; apply right)
        in
          apply (!root)
        end
(*
    fun revapp af (MAP{root,...}) =
          let fun apply SplayNil = ()
                | apply (SplayObj{value,left,right}) = 
                    (apply right; af value; apply left)
        in
          apply (!root)
        end
*)

	(* Fold function *)
    fun foldri (abf : K.ord_key * 'a * 'b -> 'b) b EMPTY = b
      | foldri (abf : K.ord_key * 'a * 'b -> 'b) b (MAP{root,...}) =
          let fun apply (SplayNil : (K.ord_key * 'a) splay, b) = b
                | apply (SplayObj{value,left,right},b) =
                    apply(left,abf(#1 value,#2 value,apply(right,b)))
        in
          apply (!root,b)
        end

    fun foldr (abf : 'a * 'b -> 'b) b EMPTY = b
      | foldr (abf : 'a * 'b -> 'b) b (MAP{root,...}) =
          let fun apply (SplayNil : (K.ord_key * 'a) splay, b) = b
                | apply (SplayObj{value=(_,value),left,right},b) =
                    apply(left,abf(value,apply(right,b)))
        in
          apply (!root,b)
        end

    fun foldli (abf : K.ord_key * 'a * 'b -> 'b) b EMPTY = b
      | foldli (abf : K.ord_key * 'a * 'b -> 'b) b (MAP{root,...}) =
          let fun apply (SplayNil : (K.ord_key * 'a) splay, b) = b
                | apply (SplayObj{value,left,right},b) =
                    apply(right,abf(#1 value,#2 value,apply(left,b)))
        in
          apply (!root,b)
        end

    fun foldl (abf : 'a * 'b -> 'b) b EMPTY = b
      | foldl (abf : 'a * 'b -> 'b) b (MAP{root,...}) =
          let fun apply (SplayNil : (K.ord_key * 'a) splay, b) = b
                | apply (SplayObj{value=(_,value),left,right},b) =
                    apply(right,abf(value,apply(left,b)))
        in
          apply (!root,b)
        end

	(* Map a table to a new table that has the same keys*)
    fun mapi (af : K.ord_key * 'a -> 'b) EMPTY = EMPTY
      | mapi (af : K.ord_key * 'a -> 'b) (MAP{root,nobj}) =
          let fun ap (SplayNil : (K.ord_key * 'a) splay) = SplayNil
                | ap (SplayObj{value,left,right}) = let
                    val left' = ap left
                    val value' = (#1 value, af value)
                    in
                      SplayObj{value = value', left = left', right = ap right}
                    end
        in
          MAP{root = ref(ap (!root)), nobj = nobj}
        end

    fun map (af : 'a -> 'b) EMPTY = EMPTY
      | map (af : 'a -> 'b) (MAP{root,nobj}) =
          let fun ap (SplayNil : (K.ord_key * 'a) splay) = SplayNil
                | ap (SplayObj{value,left,right}) = let
                    val left' = ap left
                    val value' = (#1 value, af (#2 value))
                    in
                      SplayObj{value = value', left = left', right = ap right}
                    end
        in
          MAP{root = ref(ap (!root)), nobj = nobj}
        end

(* the following are generic implementations of the unionWith and intersectWith
 * operetions.  These should be specialized for the internal representations
 * at some point.
 *)
    fun unionWith f (m1, m2) = let
	  fun ins f (key, x, m) = (case find(m, key)
		 of NONE => insert(m, key, x)
		  | (SOME x') => insert(m, key, f(x, x'))
		(* end case *))
	  in
	    if (numItems m1 > numItems m2)
	      then foldli (ins (fn (a, b) => f(b, a))) m1 m2
	      else foldli (ins f) m2 m1
	  end
    fun unionWithi f (m1, m2) = let
	  fun ins f (key, x, m) = (case find(m, key)
		 of NONE => insert(m, key, x)
		  | (SOME x') => insert(m, key, f(key, x, x'))
		(* end case *))
	  in
	    if (numItems m1 > numItems m2)
	      then foldli (ins (fn (k, a, b) => f(k, b, a))) m1 m2
	      else foldli (ins f) m2 m1
	  end

    fun intersectWith f (m1, m2) = let
	(* iterate over the elements of m1, checking for membership in m2 *)
	  fun intersect f (m1, m2) = let
		fun ins (key, x, m) = (case find(m2, key)
		       of NONE => m
			| (SOME x') => insert(m, key, f(x, x'))
		      (* end case *))
		in
		  foldli ins empty m1
		end
	  in
	    if (numItems m1 > numItems m2)
	      then intersect f (m1, m2)
	      else intersect (fn (a, b) => f(b, a)) (m2, m1)
	  end

    fun intersectWithi f (m1, m2) = let
	(* iterate over the elements of m1, checking for membership in m2 *)
	  fun intersect f (m1, m2) = let
		fun ins (key, x, m) = (case find(m2, key)
		       of NONE => m
			| (SOME x') => insert(m, key, f(key, x, x'))
		      (* end case *))
		in
		  foldli ins empty m1
		end
	  in
	    if (numItems m1 > numItems m2)
	      then intersect f (m1, m2)
	      else intersect (fn (k, a, b) => f(k, b, a)) (m2, m1)
	  end

  (* this is a generic implementation of mapPartial.  It should
   * be specialized to the data-structure at some point.
   *)
    fun mapPartial f m = let
	  fun g (key, item, m) = (case f item
		 of NONE => m
		  | (SOME item') => insert(m, key, item')
		(* end case *))
	  in
	    foldli g empty m
	  end
    fun mapPartiali f m = let
	  fun g (key, item, m) = (case f(key, item)
		 of NONE => m
		  | (SOME item') => insert(m, key, item')
		(* end case *))
	  in
	    foldli g empty m
	  end

  (* this is a generic implementation of filter.  It should
   * be specialized to the data-structure at some point.
   *)
    fun filter predFn m = let
	  fun f (key, item, m) = if predFn item
		then insert(m, key, item)
		else m
	  in
	    foldli f empty m
	  end
    fun filteri predFn m = let
	  fun f (key, item, m) = if predFn(key, item)
		then insert(m, key, item)
		else m
	  in
	    foldli f empty m
	  end

  end (* SplayDictFn *)


(* ----------------

structure IntSet = SplaySetFn(struct type ord_key = int; val compare = Int.compare end);

val a = IntSet.singleton 1;
val b = IntSet.singleton 2;
val c = IntSet.singleton 1;

val s = IntSet.union (IntSet.union (a,b), c);
val _ = map (fn i => print (Int.toString i)) (IntSet.listItems s);

   ---------------- *)


(* IDs *)

(* The various root types shares the id counter *)
val curRootId = ref 0;
fun getCurrentRootId () = !curRootId;
val _ = _export "getCurrentRootId" : (unit -> int) -> unit; getCurrentRootId;
fun getNextRoot () = ( curRootId := !curRootId+1 ; !curRootId);


(* For each combination of base type and compound constructor, we need a root type with operations: 
   get, reg, return, unreg *)
signature ROOT_TYPE = sig type t end

signature ROOT      = 
  sig 
    type t 
    exception RootNotFound;
    val get:    int ->  t      (* get root given root number *)
    val reg:    t   -> int     (* register and return root number *)
    val return: t   -> t       (* register and return root *)
    val unreg:  int -> unit 
  end


functor Root (R: ROOT_TYPE): ROOT =
struct
type t = R.t;
exception RootNotFound;

structure RootMap = SplayMapFn(struct 
                                 type ord_key = int; 
                                 val compare = Int.compare
                               end);

val roots = ref RootMap.empty;   

val reg = 
      fn r => ( roots := RootMap.insert (!roots, getNextRoot (), r)
              ; getCurrentRootId ());

val get =
      fn i => case RootMap.find (!roots, i)
              of NONE   => raise RootNotFound
	       | SOME r => r;

val return =
      fn r => ( roots := RootMap.insert (!roots, getNextRoot (), r)
              ; r);

(* TODO: Catch the exception, when unreg is passed an already
         unregistered root *)

val unreg =
      fn i => roots := (case RootMap.remove (!roots, i)
                        of (rs,_) => rs);
end

structure BoolArrayRoot = Root(struct type t = bool array end);



-------------- next part --------------
(* MZTON  --  Jens Axel Søgaard  --  2005 *)

(* ROOTS  
 * 
 * Values of the base types int, bool, ... can be returned from MLton to
 * Scheme with no problems.
 * Compound values such as arrays, vectors and references need to be registered
 * with the MLton garbage collector - otherwise MLton might remove a value
 * still in use at the Scheme side.
 *
 * The function 'reg' registers a value and returns an
 * id (an integer) that later can be used to unregister it with 'unreg'.
 * The function 'return' registers a value and returns the same value;
 * the id can later be read with 'getCurrentRootId'.
 * The function 'get' is used by the Scheme implementation to retrieve
 * a pointer to an ML value given an id previously returned from 'reg'.
 * The function 'get' is neccessary to have, since MLton has a moving
 * garbage collector, and it therefore isn't possible to just hold
 * on to any C-pointers returned from ML on the Scheme side. 
 *
 *
 * Example:
 *  
 *  fun testUnitToIntArray () = Array.array (5,0);
 *  val _ = _export "testUnitToIntArray" : (unit -> int array) -> unit; (IntArrayRoot.return o testUnitToIntArray);
 *
 *)

(*

(* IDs *)

(* The various root types shares the id counter *)
val curRootId = ref 0;
fun getCurrentRootId () = !curRootId;
val _ = _export "getCurrentRootId" : (unit -> int) -> unit; getCurrentRootId;
fun getNextRoot () = ( curRootId := !curRootId+1 ; !curRootId);


(* For each combination of base type and compound constructor, we need a root type with operations: 
   get, reg, return, unreg *)
signature ROOT_TYPE = sig type t end

signature ROOT      = 
  sig 
    type t 
    exception RootNotFound;
    val get:    int ->  t      (* get root given root number *)
    val reg:    t   -> int     (* register and return root number *)
    val return: t   -> t       (* register and return root *)
    val unreg:  int -> unit 
  end


functor Root (R: ROOT_TYPE): ROOT =
struct
type t = R.t;
exception RootNotFound;
val roots = ref [];   (* association list of integer*root pairs *)

val get =
      fn i => let
                val rec loop = fn nil         => raise RootNotFound
				| (j,r)::more => if i=j then r else (loop more)
                in
                  loop (!roots)
                end;

val reg =
      fn r => ( roots := (getNextRoot (), r) :: (!roots)
              ; getCurrentRootId ());

val return = 
      fn r => ( roots := (getNextRoot (), r) :: (!roots)
              ; r );

val unreg = 
      fn i => let val rec loop = fn nil         => raise RootNotFound
                                  | (j,r)::more => if i=j then more else (j,r)::(loop more)
              in roots := loop (!roots)
              end;

end

*)

structure BoolArrayRoot = Root(struct type t = bool array end);
val _ = _export "getBoolArray":    (int -> bool array)   -> unit; BoolArrayRoot.get;
val _ = _export "unregBoolArray" : (int -> unit) -> unit; BoolArrayRoot.unreg;
val _ = _export "regBoolArray" :   (bool array  -> int)  -> unit; BoolArrayRoot.reg;
structure CharArrayRoot = Root(struct type t = char array end);
val _ = _export "getCharArray":    (int -> char array)   -> unit; CharArrayRoot.get;
val _ = _export "unregCharArray" : (int -> unit) -> unit; CharArrayRoot.unreg;
val _ = _export "regCharArray" :   (char array  -> int)  -> unit; CharArrayRoot.reg;
structure Int8ArrayRoot = Root(struct type t = Int8.int array end);
val _ = _export "getInt8Array":    (int -> Int8.int array)   -> unit; Int8ArrayRoot.get;
val _ = _export "unregInt8Array" : (int -> unit) -> unit; Int8ArrayRoot.unreg;
val _ = _export "regInt8Array" :   (Int8.int array  -> int)  -> unit; Int8ArrayRoot.reg;
structure Int16ArrayRoot = Root(struct type t = Int16.int array end);
val _ = _export "getInt16Array":    (int -> Int16.int array)   -> unit; Int16ArrayRoot.get;
val _ = _export "unregInt16Array" : (int -> unit) -> unit; Int16ArrayRoot.unreg;
val _ = _export "regInt16Array" :   (Int16.int array  -> int)  -> unit; Int16ArrayRoot.reg;
structure Int32ArrayRoot = Root(struct type t = Int32.int array end);
val _ = _export "getInt32Array":    (int -> Int32.int array)   -> unit; Int32ArrayRoot.get;
val _ = _export "unregInt32Array" : (int -> unit) -> unit; Int32ArrayRoot.unreg;
val _ = _export "regInt32Array" :   (Int32.int array  -> int)  -> unit; Int32ArrayRoot.reg;
structure Int64ArrayRoot = Root(struct type t = Int64.int array end);
val _ = _export "getInt64Array":    (int -> Int64.int array)   -> unit; Int64ArrayRoot.get;
val _ = _export "unregInt64Array" : (int -> unit) -> unit; Int64ArrayRoot.unreg;
val _ = _export "regInt64Array" :   (Int64.int array  -> int)  -> unit; Int64ArrayRoot.reg;
structure IntArrayRoot = Root(struct type t = int array end);
val _ = _export "getIntArray":    (int -> int array)   -> unit; IntArrayRoot.get;
val _ = _export "unregIntArray" : (int -> unit) -> unit; IntArrayRoot.unreg;
val _ = _export "regIntArray" :   (int array  -> int)  -> unit; IntArrayRoot.reg;
structure PointerArrayRoot = Root(struct type t = MLton.Pointer.t array end);
val _ = _export "getPointerArray":    (int -> MLton.Pointer.t array)   -> unit; PointerArrayRoot.get;
val _ = _export "unregPointerArray" : (int -> unit) -> unit; PointerArrayRoot.unreg;
val _ = _export "regPointerArray" :   (MLton.Pointer.t array  -> int)  -> unit; PointerArrayRoot.reg;
structure Real32ArrayRoot = Root(struct type t = Real32.real array end);
val _ = _export "getReal32Array":    (int -> Real32.real array)   -> unit; Real32ArrayRoot.get;
val _ = _export "unregReal32Array" : (int -> unit) -> unit; Real32ArrayRoot.unreg;
val _ = _export "regReal32Array" :   (Real32.real array  -> int)  -> unit; Real32ArrayRoot.reg;
structure Real64ArrayRoot = Root(struct type t = Real64.real array end);
val _ = _export "getReal64Array":    (int -> Real64.real array)   -> unit; Real64ArrayRoot.get;
val _ = _export "unregReal64Array" : (int -> unit) -> unit; Real64ArrayRoot.unreg;
val _ = _export "regReal64Array" :   (Real64.real array  -> int)  -> unit; Real64ArrayRoot.reg;
structure RealArrayRoot = Root(struct type t = real array end);
val _ = _export "getRealArray":    (int -> real array)   -> unit; RealArrayRoot.get;
val _ = _export "unregRealArray" : (int -> unit) -> unit; RealArrayRoot.unreg;
val _ = _export "regRealArray" :   (real array  -> int)  -> unit; RealArrayRoot.reg;
structure Word8ArrayRoot = Root(struct type t = Word8.word array end);
val _ = _export "getWord8Array":    (int -> Word8.word array)   -> unit; Word8ArrayRoot.get;
val _ = _export "unregWord8Array" : (int -> unit) -> unit; Word8ArrayRoot.unreg;
val _ = _export "regWord8Array" :   (Word8.word array  -> int)  -> unit; Word8ArrayRoot.reg;
structure Word16ArrayRoot = Root(struct type t = Word16.word array end);
val _ = _export "getWord16Array":    (int -> Word16.word array)   -> unit; Word16ArrayRoot.get;
val _ = _export "unregWord16Array" : (int -> unit) -> unit; Word16ArrayRoot.unreg;
val _ = _export "regWord16Array" :   (Word16.word array  -> int)  -> unit; Word16ArrayRoot.reg;
structure Word32ArrayRoot = Root(struct type t = Word32.word array end);
val _ = _export "getWord32Array":    (int -> Word32.word array)   -> unit; Word32ArrayRoot.get;
val _ = _export "unregWord32Array" : (int -> unit) -> unit; Word32ArrayRoot.unreg;
val _ = _export "regWord32Array" :   (Word32.word array  -> int)  -> unit; Word32ArrayRoot.reg;
structure Word64ArrayRoot = Root(struct type t = Word64.word array end);
val _ = _export "getWord64Array":    (int -> Word64.word array)   -> unit; Word64ArrayRoot.get;
val _ = _export "unregWord64Array" : (int -> unit) -> unit; Word64ArrayRoot.unreg;
val _ = _export "regWord64Array" :   (Word64.word array  -> int)  -> unit; Word64ArrayRoot.reg;
structure WordArrayRoot = Root(struct type t = word array end);
val _ = _export "getWordArray":    (int -> word array)   -> unit; WordArrayRoot.get;
val _ = _export "unregWordArray" : (int -> unit) -> unit; WordArrayRoot.unreg;
val _ = _export "regWordArray" :   (word array  -> int)  -> unit; WordArrayRoot.reg;
structure StringArrayRoot = Root(struct type t = string array end);
val _ = _export "getStringArray":    (int -> string array)   -> unit; StringArrayRoot.get;
val _ = _export "unregStringArray" : (int -> unit) -> unit; StringArrayRoot.unreg;
val _ = _export "regStringArray" :   (string array  -> int)  -> unit; StringArrayRoot.reg;
structure BoolRefRoot = Root(struct type t = bool ref end);
val _ = _export "getBoolRef":    (int -> bool ref)   -> unit; BoolRefRoot.get;
val _ = _export "unregBoolRef" : (int -> unit) -> unit; BoolRefRoot.unreg;
val _ = _export "regBoolRef" :   (bool ref  -> int)  -> unit; BoolRefRoot.reg;
structure CharRefRoot = Root(struct type t = char ref end);
val _ = _export "getCharRef":    (int -> char ref)   -> unit; CharRefRoot.get;
val _ = _export "unregCharRef" : (int -> unit) -> unit; CharRefRoot.unreg;
val _ = _export "regCharRef" :   (char ref  -> int)  -> unit; CharRefRoot.reg;
structure Int8RefRoot = Root(struct type t = Int8.int ref end);
val _ = _export "getInt8Ref":    (int -> Int8.int ref)   -> unit; Int8RefRoot.get;
val _ = _export "unregInt8Ref" : (int -> unit) -> unit; Int8RefRoot.unreg;
val _ = _export "regInt8Ref" :   (Int8.int ref  -> int)  -> unit; Int8RefRoot.reg;
structure Int16RefRoot = Root(struct type t = Int16.int ref end);
val _ = _export "getInt16Ref":    (int -> Int16.int ref)   -> unit; Int16RefRoot.get;
val _ = _export "unregInt16Ref" : (int -> unit) -> unit; Int16RefRoot.unreg;
val _ = _export "regInt16Ref" :   (Int16.int ref  -> int)  -> unit; Int16RefRoot.reg;
structure Int32RefRoot = Root(struct type t = Int32.int ref end);
val _ = _export "getInt32Ref":    (int -> Int32.int ref)   -> unit; Int32RefRoot.get;
val _ = _export "unregInt32Ref" : (int -> unit) -> unit; Int32RefRoot.unreg;
val _ = _export "regInt32Ref" :   (Int32.int ref  -> int)  -> unit; Int32RefRoot.reg;
structure Int64RefRoot = Root(struct type t = Int64.int ref end);
val _ = _export "getInt64Ref":    (int -> Int64.int ref)   -> unit; Int64RefRoot.get;
val _ = _export "unregInt64Ref" : (int -> unit) -> unit; Int64RefRoot.unreg;
val _ = _export "regInt64Ref" :   (Int64.int ref  -> int)  -> unit; Int64RefRoot.reg;
structure IntRefRoot = Root(struct type t = int ref end);
val _ = _export "getIntRef":    (int -> int ref)   -> unit; IntRefRoot.get;
val _ = _export "unregIntRef" : (int -> unit) -> unit; IntRefRoot.unreg;
val _ = _export "regIntRef" :   (int ref  -> int)  -> unit; IntRefRoot.reg;
structure PointerRefRoot = Root(struct type t = MLton.Pointer.t ref end);
val _ = _export "getPointerRef":    (int -> MLton.Pointer.t ref)   -> unit; PointerRefRoot.get;
val _ = _export "unregPointerRef" : (int -> unit) -> unit; PointerRefRoot.unreg;
val _ = _export "regPointerRef" :   (MLton.Pointer.t ref  -> int)  -> unit; PointerRefRoot.reg;
structure Real32RefRoot = Root(struct type t = Real32.real ref end);
val _ = _export "getReal32Ref":    (int -> Real32.real ref)   -> unit; Real32RefRoot.get;
val _ = _export "unregReal32Ref" : (int -> unit) -> unit; Real32RefRoot.unreg;
val _ = _export "regReal32Ref" :   (Real32.real ref  -> int)  -> unit; Real32RefRoot.reg;
structure Real64RefRoot = Root(struct type t = Real64.real ref end);
val _ = _export "getReal64Ref":    (int -> Real64.real ref)   -> unit; Real64RefRoot.get;
val _ = _export "unregReal64Ref" : (int -> unit) -> unit; Real64RefRoot.unreg;
val _ = _export "regReal64Ref" :   (Real64.real ref  -> int)  -> unit; Real64RefRoot.reg;
structure RealRefRoot = Root(struct type t = real ref end);
val _ = _export "getRealRef":    (int -> real ref)   -> unit; RealRefRoot.get;
val _ = _export "unregRealRef" : (int -> unit) -> unit; RealRefRoot.unreg;
val _ = _export "regRealRef" :   (real ref  -> int)  -> unit; RealRefRoot.reg;
structure Word8RefRoot = Root(struct type t = Word8.word ref end);
val _ = _export "getWord8Ref":    (int -> Word8.word ref)   -> unit; Word8RefRoot.get;
val _ = _export "unregWord8Ref" : (int -> unit) -> unit; Word8RefRoot.unreg;
val _ = _export "regWord8Ref" :   (Word8.word ref  -> int)  -> unit; Word8RefRoot.reg;
structure Word16RefRoot = Root(struct type t = Word16.word ref end);
val _ = _export "getWord16Ref":    (int -> Word16.word ref)   -> unit; Word16RefRoot.get;
val _ = _export "unregWord16Ref" : (int -> unit) -> unit; Word16RefRoot.unreg;
val _ = _export "regWord16Ref" :   (Word16.word ref  -> int)  -> unit; Word16RefRoot.reg;
structure Word32RefRoot = Root(struct type t = Word32.word ref end);
val _ = _export "getWord32Ref":    (int -> Word32.word ref)   -> unit; Word32RefRoot.get;
val _ = _export "unregWord32Ref" : (int -> unit) -> unit; Word32RefRoot.unreg;
val _ = _export "regWord32Ref" :   (Word32.word ref  -> int)  -> unit; Word32RefRoot.reg;
structure Word64RefRoot = Root(struct type t = Word64.word ref end);
val _ = _export "getWord64Ref":    (int -> Word64.word ref)   -> unit; Word64RefRoot.get;
val _ = _export "unregWord64Ref" : (int -> unit) -> unit; Word64RefRoot.unreg;
val _ = _export "regWord64Ref" :   (Word64.word ref  -> int)  -> unit; Word64RefRoot.reg;
structure WordRefRoot = Root(struct type t = word ref end);
val _ = _export "getWordRef":    (int -> word ref)   -> unit; WordRefRoot.get;
val _ = _export "unregWordRef" : (int -> unit) -> unit; WordRefRoot.unreg;
val _ = _export "regWordRef" :   (word ref  -> int)  -> unit; WordRefRoot.reg;
structure StringRefRoot = Root(struct type t = string ref end);
val _ = _export "getStringRef":    (int -> string ref)   -> unit; StringRefRoot.get;
val _ = _export "unregStringRef" : (int -> unit) -> unit; StringRefRoot.unreg;
val _ = _export "regStringRef" :   (string ref  -> int)  -> unit; StringRefRoot.reg;
structure BoolVectorRoot = Root(struct type t = bool vector end);
val _ = _export "getBoolVector":    (int -> bool vector)   -> unit; BoolVectorRoot.get;
val _ = _export "unregBoolVector" : (int -> unit) -> unit; BoolVectorRoot.unreg;
val _ = _export "regBoolVector" :   (bool vector  -> int)  -> unit; BoolVectorRoot.reg;
structure CharVectorRoot = Root(struct type t = char vector end);
val _ = _export "getCharVector":    (int -> char vector)   -> unit; CharVectorRoot.get;
val _ = _export "unregCharVector" : (int -> unit) -> unit; CharVectorRoot.unreg;
val _ = _export "regCharVector" :   (char vector  -> int)  -> unit; CharVectorRoot.reg;
structure Int8VectorRoot = Root(struct type t = Int8.int vector end);
val _ = _export "getInt8Vector":    (int -> Int8.int vector)   -> unit; Int8VectorRoot.get;
val _ = _export "unregInt8Vector" : (int -> unit) -> unit; Int8VectorRoot.unreg;
val _ = _export "regInt8Vector" :   (Int8.int vector  -> int)  -> unit; Int8VectorRoot.reg;
structure Int16VectorRoot = Root(struct type t = Int16.int vector end);
val _ = _export "getInt16Vector":    (int -> Int16.int vector)   -> unit; Int16VectorRoot.get;
val _ = _export "unregInt16Vector" : (int -> unit) -> unit; Int16VectorRoot.unreg;
val _ = _export "regInt16Vector" :   (Int16.int vector  -> int)  -> unit; Int16VectorRoot.reg;
structure Int32VectorRoot = Root(struct type t = Int32.int vector end);
val _ = _export "getInt32Vector":    (int -> Int32.int vector)   -> unit; Int32VectorRoot.get;
val _ = _export "unregInt32Vector" : (int -> unit) -> unit; Int32VectorRoot.unreg;
val _ = _export "regInt32Vector" :   (Int32.int vector  -> int)  -> unit; Int32VectorRoot.reg;
structure Int64VectorRoot = Root(struct type t = Int64.int vector end);
val _ = _export "getInt64Vector":    (int -> Int64.int vector)   -> unit; Int64VectorRoot.get;
val _ = _export "unregInt64Vector" : (int -> unit) -> unit; Int64VectorRoot.unreg;
val _ = _export "regInt64Vector" :   (Int64.int vector  -> int)  -> unit; Int64VectorRoot.reg;
structure IntVectorRoot = Root(struct type t = int vector end);
val _ = _export "getIntVector":    (int -> int vector)   -> unit; IntVectorRoot.get;
val _ = _export "unregIntVector" : (int -> unit) -> unit; IntVectorRoot.unreg;
val _ = _export "regIntVector" :   (int vector  -> int)  -> unit; IntVectorRoot.reg;
structure PointerVectorRoot = Root(struct type t = MLton.Pointer.t vector end);
val _ = _export "getPointerVector":    (int -> MLton.Pointer.t vector)   -> unit; PointerVectorRoot.get;
val _ = _export "unregPointerVector" : (int -> unit) -> unit; PointerVectorRoot.unreg;
val _ = _export "regPointerVector" :   (MLton.Pointer.t vector  -> int)  -> unit; PointerVectorRoot.reg;
structure Real32VectorRoot = Root(struct type t = Real32.real vector end);
val _ = _export "getReal32Vector":    (int -> Real32.real vector)   -> unit; Real32VectorRoot.get;
val _ = _export "unregReal32Vector" : (int -> unit) -> unit; Real32VectorRoot.unreg;
val _ = _export "regReal32Vector" :   (Real32.real vector  -> int)  -> unit; Real32VectorRoot.reg;
structure Real64VectorRoot = Root(struct type t = Real64.real vector end);
val _ = _export "getReal64Vector":    (int -> Real64.real vector)   -> unit; Real64VectorRoot.get;
val _ = _export "unregReal64Vector" : (int -> unit) -> unit; Real64VectorRoot.unreg;
val _ = _export "regReal64Vector" :   (Real64.real vector  -> int)  -> unit; Real64VectorRoot.reg;
structure RealVectorRoot = Root(struct type t = real vector end);
val _ = _export "getRealVector":    (int -> real vector)   -> unit; RealVectorRoot.get;
val _ = _export "unregRealVector" : (int -> unit) -> unit; RealVectorRoot.unreg;
val _ = _export "regRealVector" :   (real vector  -> int)  -> unit; RealVectorRoot.reg;
structure Word8VectorRoot = Root(struct type t = Word8.word vector end);
val _ = _export "getWord8Vector":    (int -> Word8.word vector)   -> unit; Word8VectorRoot.get;
val _ = _export "unregWord8Vector" : (int -> unit) -> unit; Word8VectorRoot.unreg;
val _ = _export "regWord8Vector" :   (Word8.word vector  -> int)  -> unit; Word8VectorRoot.reg;
structure Word16VectorRoot = Root(struct type t = Word16.word vector end);
val _ = _export "getWord16Vector":    (int -> Word16.word vector)   -> unit; Word16VectorRoot.get;
val _ = _export "unregWord16Vector" : (int -> unit) -> unit; Word16VectorRoot.unreg;
val _ = _export "regWord16Vector" :   (Word16.word vector  -> int)  -> unit; Word16VectorRoot.reg;
structure Word32VectorRoot = Root(struct type t = Word32.word vector end);
val _ = _export "getWord32Vector":    (int -> Word32.word vector)   -> unit; Word32VectorRoot.get;
val _ = _export "unregWord32Vector" : (int -> unit) -> unit; Word32VectorRoot.unreg;
val _ = _export "regWord32Vector" :   (Word32.word vector  -> int)  -> unit; Word32VectorRoot.reg;
structure Word64VectorRoot = Root(struct type t = Word64.word vector end);
val _ = _export "getWord64Vector":    (int -> Word64.word vector)   -> unit; Word64VectorRoot.get;
val _ = _export "unregWord64Vector" : (int -> unit) -> unit; Word64VectorRoot.unreg;
val _ = _export "regWord64Vector" :   (Word64.word vector  -> int)  -> unit; Word64VectorRoot.reg;
structure WordVectorRoot = Root(struct type t = word vector end);
val _ = _export "getWordVector":    (int -> word vector)   -> unit; WordVectorRoot.get;
val _ = _export "unregWordVector" : (int -> unit) -> unit; WordVectorRoot.unreg;
val _ = _export "regWordVector" :   (word vector  -> int)  -> unit; WordVectorRoot.reg;
structure StringVectorRoot = Root(struct type t = string vector end);
val _ = _export "getStringVector":    (int -> string vector)   -> unit; StringVectorRoot.get;
val _ = _export "unregStringVector" : (int -> unit) -> unit; StringVectorRoot.unreg;
val _ = _export "regStringVector" :   (string vector  -> int)  -> unit; StringVectorRoot.reg;

(* ALLOCATORS *)
fun makeBoolArray (n, fill) = Array.array (n, fill);
val _ = _export "makeBoolArray" : (int*bool -> bool Array.array) -> unit; (BoolArrayRoot.return o makeBoolArray);
fun makeBoolVector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeBoolVector" : (int*bool -> bool Vector.vector) -> unit; (BoolVectorRoot.return o makeBoolVector);
fun makeBoolRef v = ref v;
val _ = _export "makeBoolRef" : (bool -> bool ref) -> unit; (BoolRefRoot.return o makeBoolRef);
fun makeCharArray (n, fill) = Array.array (n, fill);
val _ = _export "makeCharArray" : (int*char -> char Array.array) -> unit; (CharArrayRoot.return o makeCharArray);
fun makeCharVector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeCharVector" : (int*char -> char Vector.vector) -> unit; (CharVectorRoot.return o makeCharVector);
fun makeCharRef v = ref v;
val _ = _export "makeCharRef" : (char -> char ref) -> unit; (CharRefRoot.return o makeCharRef);
fun makeInt8Array (n, fill) = Array.array (n, fill);
val _ = _export "makeInt8Array" : (int*Int8.int -> Int8.int Array.array) -> unit; (Int8ArrayRoot.return o makeInt8Array);
fun makeInt8Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeInt8Vector" : (int*Int8.int -> Int8.int Vector.vector) -> unit; (Int8VectorRoot.return o makeInt8Vector);
fun makeInt8Ref v = ref v;
val _ = _export "makeInt8Ref" : (Int8.int -> Int8.int ref) -> unit; (Int8RefRoot.return o makeInt8Ref);
fun makeInt16Array (n, fill) = Array.array (n, fill);
val _ = _export "makeInt16Array" : (int*Int16.int -> Int16.int Array.array) -> unit; (Int16ArrayRoot.return o makeInt16Array);
fun makeInt16Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeInt16Vector" : (int*Int16.int -> Int16.int Vector.vector) -> unit; (Int16VectorRoot.return o makeInt16Vector);
fun makeInt16Ref v = ref v;
val _ = _export "makeInt16Ref" : (Int16.int -> Int16.int ref) -> unit; (Int16RefRoot.return o makeInt16Ref);
fun makeInt32Array (n, fill) = Array.array (n, fill);
val _ = _export "makeInt32Array" : (int*Int32.int -> Int32.int Array.array) -> unit; (Int32ArrayRoot.return o makeInt32Array);
fun makeInt32Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeInt32Vector" : (int*Int32.int -> Int32.int Vector.vector) -> unit; (Int32VectorRoot.return o makeInt32Vector);
fun makeInt32Ref v = ref v;
val _ = _export "makeInt32Ref" : (Int32.int -> Int32.int ref) -> unit; (Int32RefRoot.return o makeInt32Ref);
fun makeInt64Array (n, fill) = Array.array (n, fill);
val _ = _export "makeInt64Array" : (int*Int64.int -> Int64.int Array.array) -> unit; (Int64ArrayRoot.return o makeInt64Array);
fun makeInt64Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeInt64Vector" : (int*Int64.int -> Int64.int Vector.vector) -> unit; (Int64VectorRoot.return o makeInt64Vector);
fun makeInt64Ref v = ref v;
val _ = _export "makeInt64Ref" : (Int64.int -> Int64.int ref) -> unit; (Int64RefRoot.return o makeInt64Ref);
fun makeIntArray (n, fill) = Array.array (n, fill);
val _ = _export "makeIntArray" : (int*int -> int Array.array) -> unit; (IntArrayRoot.return o makeIntArray);
fun makeIntVector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeIntVector" : (int*int -> int Vector.vector) -> unit; (IntVectorRoot.return o makeIntVector);
fun makeIntRef v = ref v;
val _ = _export "makeIntRef" : (int -> int ref) -> unit; (IntRefRoot.return o makeIntRef);
fun makePointerArray (n, fill) = Array.array (n, fill);
val _ = _export "makePointerArray" : (int*MLton.Pointer.t -> MLton.Pointer.t Array.array) -> unit; (PointerArrayRoot.return o makePointerArray);
fun makePointerVector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makePointerVector" : (int*MLton.Pointer.t -> MLton.Pointer.t Vector.vector) -> unit; (PointerVectorRoot.return o makePointerVector);
fun makePointerRef v = ref v;
val _ = _export "makePointerRef" : (MLton.Pointer.t -> MLton.Pointer.t ref) -> unit; (PointerRefRoot.return o makePointerRef);
fun makeReal32Array (n, fill) = Array.array (n, fill);
val _ = _export "makeReal32Array" : (int*Real32.real -> Real32.real Array.array) -> unit; (Real32ArrayRoot.return o makeReal32Array);
fun makeReal32Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeReal32Vector" : (int*Real32.real -> Real32.real Vector.vector) -> unit; (Real32VectorRoot.return o makeReal32Vector);
fun makeReal32Ref v = ref v;
val _ = _export "makeReal32Ref" : (Real32.real -> Real32.real ref) -> unit; (Real32RefRoot.return o makeReal32Ref);
fun makeReal64Array (n, fill) = Array.array (n, fill);
val _ = _export "makeReal64Array" : (int*Real64.real -> Real64.real Array.array) -> unit; (Real64ArrayRoot.return o makeReal64Array);
fun makeReal64Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeReal64Vector" : (int*Real64.real -> Real64.real Vector.vector) -> unit; (Real64VectorRoot.return o makeReal64Vector);
fun makeReal64Ref v = ref v;
val _ = _export "makeReal64Ref" : (Real64.real -> Real64.real ref) -> unit; (Real64RefRoot.return o makeReal64Ref);
fun makeRealArray (n, fill) = Array.array (n, fill);
val _ = _export "makeRealArray" : (int*real -> real Array.array) -> unit; (RealArrayRoot.return o makeRealArray);
fun makeRealVector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeRealVector" : (int*real -> real Vector.vector) -> unit; (RealVectorRoot.return o makeRealVector);
fun makeRealRef v = ref v;
val _ = _export "makeRealRef" : (real -> real ref) -> unit; (RealRefRoot.return o makeRealRef);
fun makeWord8Array (n, fill) = Array.array (n, fill);
val _ = _export "makeWord8Array" : (int*Word8.word -> Word8.word Array.array) -> unit; (Word8ArrayRoot.return o makeWord8Array);
fun makeWord8Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeWord8Vector" : (int*Word8.word -> Word8.word Vector.vector) -> unit; (Word8VectorRoot.return o makeWord8Vector);
fun makeWord8Ref v = ref v;
val _ = _export "makeWord8Ref" : (Word8.word -> Word8.word ref) -> unit; (Word8RefRoot.return o makeWord8Ref);
fun makeWord16Array (n, fill) = Array.array (n, fill);
val _ = _export "makeWord16Array" : (int*Word16.word -> Word16.word Array.array) -> unit; (Word16ArrayRoot.return o makeWord16Array);
fun makeWord16Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeWord16Vector" : (int*Word16.word -> Word16.word Vector.vector) -> unit; (Word16VectorRoot.return o makeWord16Vector);
fun makeWord16Ref v = ref v;
val _ = _export "makeWord16Ref" : (Word16.word -> Word16.word ref) -> unit; (Word16RefRoot.return o makeWord16Ref);
fun makeWord32Array (n, fill) = Array.array (n, fill);
val _ = _export "makeWord32Array" : (int*Word32.word -> Word32.word Array.array) -> unit; (Word32ArrayRoot.return o makeWord32Array);
fun makeWord32Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeWord32Vector" : (int*Word32.word -> Word32.word Vector.vector) -> unit; (Word32VectorRoot.return o makeWord32Vector);
fun makeWord32Ref v = ref v;
val _ = _export "makeWord32Ref" : (Word32.word -> Word32.word ref) -> unit; (Word32RefRoot.return o makeWord32Ref);
fun makeWord64Array (n, fill) = Array.array (n, fill);
val _ = _export "makeWord64Array" : (int*Word64.word -> Word64.word Array.array) -> unit; (Word64ArrayRoot.return o makeWord64Array);
fun makeWord64Vector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeWord64Vector" : (int*Word64.word -> Word64.word Vector.vector) -> unit; (Word64VectorRoot.return o makeWord64Vector);
fun makeWord64Ref v = ref v;
val _ = _export "makeWord64Ref" : (Word64.word -> Word64.word ref) -> unit; (Word64RefRoot.return o makeWord64Ref);
fun makeWordArray (n, fill) = Array.array (n, fill);
val _ = _export "makeWordArray" : (int*word -> word Array.array) -> unit; (WordArrayRoot.return o makeWordArray);
fun makeWordVector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeWordVector" : (int*word -> word Vector.vector) -> unit; (WordVectorRoot.return o makeWordVector);
fun makeWordRef v = ref v;
val _ = _export "makeWordRef" : (word -> word ref) -> unit; (WordRefRoot.return o makeWordRef);
fun makeStringArray (n, fill) = Array.array (n, fill);
val _ = _export "makeStringArray" : (int*string -> string Array.array) -> unit; (StringArrayRoot.return o makeStringArray);
fun makeStringVector (n, fill) = Array.vector (Array.array (n, fill));
val _ = _export "makeStringVector" : (int*string -> string Vector.vector) -> unit; (StringVectorRoot.return o makeStringVector);
fun makeStringRef v = ref v;
val _ = _export "makeStringRef" : (string -> string ref) -> unit; (StringRefRoot.return o makeStringRef);



fun testIntToInt x = 42;
val _ = _export "testIntToInt" : (int -> int) -> unit; testIntToInt;

fun testUnitToInt () = 43;
val _ = _export "testUnitToInt" : (unit -> int) -> unit; testUnitToInt;

fun testUnitToChar () = #"A";
val _ = _export "testUnitToChar" : (unit -> char) -> unit; testUnitToChar;

fun testUnitToBool () = false;
val _ = _export "testUnitToBool" : (unit -> bool) -> unit; testUnitToBool;


fun testUnitToIntArray () = let 
                             val A = Array.array (5,0)
	                    in
                              ( Array.update (A,0,0)
			      ; Array.update (A,1,1)
			      ; Array.update (A,2,2)
			      ; Array.update (A,3,3)
			      ; Array.update (A,4,4)
			      ; A )
                            end;
val _ = _export "testUnitToIntArray" : (unit -> int array) -> unit; (IntArrayRoot.return o testUnitToIntArray);

fun testUnitToIntRef () = ref 42;
val _ = _export "testUnitToIntRef" : (unit -> int ref) -> unit; (IntRefRoot.return o testUnitToIntRef);

(* MANDELBROT *)

fun sq (x,y) = (x*x-y*y, 2.0*x*y);
fun add (x,y) (u,v) = (x+u,y+v):real*real;
fun sub (x,y) (u,v) = (x-u,y-v):real*real;
fun abs2 (x,y) = (x*x+y*y):real;

val maxIterations = 150;

fun iterate (c,d) = 
  let fun loop (x,y,n) =
        let val (s,t) = add (c,d) (sq (x,y))
        in if (abs2(s,t) < 4.0) andalso n<maxIterations
           then loop (s,t,n+1)
           else n
        end
  in loop (c,d,0)
  end;

val _ = _export "iterate" : (real*real -> int) -> unit; iterate;


fun iterateLine (x, yFrom, yTo, yDelta) =
  let val l = ceil ((yTo - yFrom) / yDelta) 
      val line = Array.array (l, 0)
  in 
    let fun loop (y,i) =
          if y >= yTo
          then line
          else (  Array.update(line, i, iterate(x,y))
                ; loop (y+yDelta, i+1) )
    in loop (yFrom, 0)
    end
  end

val _ = _export "iterateLine" : (real*real*real*real -> int array) -> unit; (IntArrayRoot.return o iterateLine);


fun doubleArray (A, pointerToF) 
  = let val f = (_import * : MLton.Pointer.t -> real -> real;) pointerToF
    in  
      (Array.modify f A)
    end;

val _ = _export "doubleArray" : ( (real array)*MLton.Pointer.t -> unit) -> unit; doubleArray;


(*** DOCUMENTATION EXAMPLE 

 makeUnitVec  : int * int -> real array
 sumVec       : (real array) * (real array) -> real array
 scaleVec     : real * (real array) -> real array
 dotVec       : (real array) * (real array)   -> real
 clearVec     : real array -> unit

 modifyVec    : (real array) * (real -> real) -> unit

***)


(* return the i'th basis vector of R^n *)
fun makeUnitVec (n,i) = Array.tabulate (n, fn j => if i=j+1 then 1.0 else 0.0);  
val _ = _export "makeUnitVec" : ( int*int -> real array ) -> unit; (RealArrayRoot.return o makeUnitVec);


(* return the vector sum of A and B *)
fun sumVec (A:(real array),B) = Array.tabulate (Array.length A, fn i => (Array.sub (A,i)) + (Array.sub (B,i)));
val _ = _export "sumVec" : ( (real array)*(real array) -> real array ) -> unit; (RealArrayRoot.return o sumVec);

(* multiply each entry of A with s *)
fun scaleVec (A, s:real) = Array.tabulate (Array.length A, fn i => s*(Array.sub(A,i)));
val _ = _export "scaleVec" : ( (real array)*real -> real array ) -> unit; (RealArrayRoot.return o scaleVec);

(* return the dot product of A and B *)
fun dotVec (A,B) = Array.foldli (fn (i, a, x) => (Array.sub(B,i)*a+x)) 0.0 A;
val _ = _export "dotVec" : ( (real array)*(real array) -> real ) -> unit; dotVec;


(* set each entry of A to 0.0 *)
fun clearVec A = Array.modify (fn x => 0.0) A;
val _ = _export "clearVec" : ( real array -> unit ) -> unit; clearVec;

(* replace x with f(x) for each entry x of A *)
fun modifyVec (A,f) = Array.modify f A;
(* Can't export because f has type  real->real *)

fun exportedModifyVec (A,pointerToF)
  = let val f = (_import * : MLton.Pointer.t -> real -> real;) pointerToF
    in 
      modifyVec (A,f) 
    end;

val _ = _export "modifyVec" : ( (real array)*MLton.Pointer.t ) -> unit; exportedModifyVec;



(*
(define (iterate-line x y-from y-to y-delta)
  (let* ([l    (inexact->exact (ceiling (/ (- y-to y-from) y-delta)))]
         [line (make-vector l 0.0)])
    (do ([y y-from (+ y y-delta)]
         [i 0      (+ i 1)])
      [(>= y y-to) line]
      (vector-set! line i (iterate x y)))))

*)


                                    






	             
-------------- next part --------------
;;; TODO: 
;;  - allow $unit at the left side of an _fun arrow
;;  - test string conversion
;;  - put this in a module
;;  - write support functions for ML vectors
;;  - think about what happens if a call to regRoot triggers a garbage 
;;       Solution: Demand that all returned compound values are registered at the ML side

;;; CONVENIENCES

; define x as short for (exit)
(define-syntax x (syntax-id-rules () (x (exit))))

(require (lib "list.ss"))


;;; OPEN FOREIGN LIBRARY

(require (lib "foreign.ss"))
(unsafe!)

(display "* Opening library\n")
(define lib (string->path "./test") #;<optional-version> )


;;; BASE TYPES

(define $unit _void)

; ((sml-abbrev sml-type $type ffi-name) ...)
(begin-for-syntax 
  (define base-types '()))

(define-syntax (define-base-types stx)
  (syntax-case stx ()
    [(_ ()) 
     #'(void)]
    [(_ ((ffi-name ffi-type sml-type sml-abbrev c-typedef c-type) clause ...))
     (begin
       (set! base-types (cons (syntax-object->datum #'(sml-abbrev sml-type ffi-type ffi-name)) base-types))
       #'(begin
           (define ffi-name ffi-type)
           (define-base-types (clause ...))))]
    [(_ ((ffi-name ffi-type sml-type sml-abbrev c-typedef c-type scheme->c c->scheme) clause ...))
     (begin
       (set! base-types (cons (syntax-object->datum #'(sml-abbrev sml-type ffi-type ffi-name)) base-types))
       #'(begin
           (define ffi-name (make-ctype ffi-type scheme->c c->scheme))
           (define-base-types (clause ...))))]))

(define (object->bool o)
  (case o
    [(#t #f) o]
    [else       (error 'object->bool "boolean expected, got: ~a" o)]))

(define (bool->object b)
  b)

(define-base-types
  (; ml-ffi   C-FFI    SML type        SML-abbrev   C typedef   C type         Scheme->C     C->Scheme
   
   ; Values of these values are simply copied between ML and Scheme
   ($int8     _int8     Int8.int         Int8        Int8      "char")
   ($int16    _int16    Int16.int        Int16       Int16     "short")
   ($int32    _int32    Int32.int        Int32       Int32     "long")
   ($int64    _int64    Int64.int        Int64       Int64     "long long")
   ($int      _int32    int              int         Int32     "long")
   ($real32   _float    Real32.real      Real32      Real32    "float")
   ($real64   _double   Real64.real      Real64      Real64    "double")
   ($real     _double   real             real        Real64    "double")
   ($word8    _uint8    Word8.word       Word8       Word8     "unsigned char")
   ($word16   _uint16   Word16.word      Word16      Word16    "unsigned short")
   ($word32   _uint32   Word32.word      Word32      Word32    "unsigned long")
   ($word64   _uint64   Word64.word      Word64      Word64    "unsigned long")
   ($word     _uint32   word             word        Word32    "unsigned int")
   
   ; Booleans and characters needs conversion to/from booleans/characters from/to integers
   ($bool     _bool     bool             Bool        Int32     "long"       object->bool   bool->object)
   ($char     _int8     char             char        Int8      "char"       char->integer  integer->char)
   
   ; Values of these types are only valid until the next MLton garbage collection
   ; These are unsafe to use - hence rooted versions of these are defined below,
   ; in order to provide automatic GC of ML values.
   ($array    _pointer  array            array       Pointer   "char *")
   ($pointer  _pointer  MLton.Pointer.t  Pointer     Pointer   "char *")
   ($ref      _pointer  ref              ref         Pointer   "char *")
   ($string   _pointer  string           string      Pointer   "char *")           ; READ ONLY
   ($vector   _pointer  vector           vector      Pointer   "char *")           ; READ ONLY
   ))

; The registry base-types have the format ((sml-abbrev sml-type $type) ...)
(begin-for-syntax 
  (require (lib "list.ss"))
  ; (define short-names (map first base-types)
  (define short-names 
    '(Bool Int8 Int16 Int32 Int64 int Real32 Real64 real Word8 Word16 Word32 Word64 word char string))
  (define (short-name->$name name)
    (let ((a (assoc name base-types)))
      (if a
          (fourth a)
          (error "short name not found in base-types")))))

;;; Each compound ML value is represented as a Scheme struct which
;;; holds the root id and the operations get and unreg.
;;; The id is given to the value by ML, when it is registered
;;; at the ML side before returning it to the Scheme side. The
;;; id is to be used by private routines of mzton, each time the 
;;; value is refered (since ML moves values during garbage collection) 
;;; and when the value is to be deallocated.

(define (make-root id get unreg base-type)
  (vector id get unreg base-type))
(define (root-id v)            (vector-ref v 0))
(define (root-get v)           (vector-ref v 1)) ; retrieve uptodate C-pointer
(define (root-unreg v)         (vector-ref v 2)) ; deallocate
(define (root-base-type v)     (vector-ref v 3)) ; e.g. the base type of "array int" is "int"
(define (root? o) (and (vector? o) (= (vector-length o) 4)))

(define (root->val r)  ((root-get r) (root-id r)))

;;; We want to Scheme to automatically unregister the ML values, when
;;; the roots become unreachable. This is done by registering the root
;;; values with a will executor. 

(define ml-will-executor (make-will-executor))

;;; The unreachable roots needs to be unregistered; a separate
;;; thread takes care of this.

; CAREFUL: Remember to suspend this thread, in regions, where an ML garbage
;          collection is unwanted.

; TODO:    Compare the efficiency of the suspend/resume implementation
;          with a semaphore version.

; start the separate thread
(define will-thread (thread (lambda ()
                              (let loop ()
                                (will-execute ml-will-executor)
                                (loop)))))

; TODO: this works fine, but using semaphores would be faster
;(define-syntax critical-region
;  (syntax-rules ()
;    ((critical-region body ...)
;     (let ()
;       (thread-suspend will-thread)
;       (begin0 
;         (begin body ...)
;         (thread-resume will-thread))))))


; This semaphore makes sure the main thread and the will
; thread doesn't call into ML at the same time
(define semaphore-for-critical-region (make-semaphore 1))

; This is used by the main thread in order not to wait
; for an already obtained semaphore. NOTE: This assumes
; the main thread is single threaded ???
(define inside-critical-region? #f)

(define-syntax will-critical-region
  (syntax-rules ()
    ((critical-region body ...)
     (begin
       (semaphore-wait semaphore-for-critical-region)
       (begin0 
         (begin body ...)
         (semaphore-post semaphore-for-critical-region))))))

(define-syntax critical-region
  (syntax-rules ()
    ((critical-region body ...)
     (if inside-critical-region?
         (begin body ...)
         (begin
           (semaphore-wait semaphore-for-critical-region)
           (set! inside-critical-region? #t)
           (begin0 
             (begin body ...)
             (set! inside-critical-region? #f)
             (semaphore-post semaphore-for-critical-region)))))))


(define (ml-will-register-root r)
  (will-register ml-will-executor r 
                 ; this thunk will eventually be called by the will-thread,
                 (lambda (r) 
                   (will-critical-region
                    (display "unregistering> ") (display r) (display " ")
                    ((root-unreg r) (root-id r))))))


;(define (ml-will-try-execute)
;  (will-try-execute ml-will-executor))

(define (signal-error o)
  (display "signal-error: ")
  (display o) 
  (newline))

; registering a root at the ML side, sets the current root id
(define get-current-root-id (get-ffi-obj "getCurrentRootId" lib (_fun -> $int) signal-error))


; (define-compound/base-type ...) importes getBaseCompound and unregBaseCompound
; from the ML side, and defines them as base-compound-get and base-compound-unreg.
; At the same time $BaseCompound is defined as a new C-type representing rooted
; BaseCompound values. $BaseCompound is to be used when importing functions
; from the ML side.
(define-syntax (define-compound/base-type stx)
  (syntax-case stx ()
    [(_ $BaseCompound compound base $base  
        base-compound-get   getBaseCompound
        base-compound-unreg unregBaseCompound
        make-base-compound makeBaseCompound)
     #`(begin
         (define base-compound-get   (get-ffi-obj getBaseCompound  lib (_fun $int -> _pointer) 
                                                  (lambda () (error getBaseCompound))))
         (define base-compound-unreg (get-ffi-obj unregBaseCompound lib (_fun $int -> _void) signal-error))
         ; (display '("defining " '$BaseCompound)) (newline)
         (define $BaseCompound (make-ctype _pointer 
                                           ; Scheme->C
                                           root->val
                                           ; C->Scheme
                                           ; (assumption: the returned value were just registered by the MLton side,
                                           ;  thus get-current-root-id can get the id from the ML side)
                                           (lambda (a)
                                             (let ([root (make-root (get-current-root-id)
                                                                    base-compound-get
                                                                    base-compound-unreg
                                                                    $base)])
                                               ; make sure the Scheme garbage collector will unregister the root
                                               (will-register ml-will-executor 
                                                              root
                                                              (lambda (r) ((root-unreg r) (root-id r))))
                                               root))))
         (define make-base-compound 
           (case 'compound
             [(ref) (get-ffi-obj makeBaseCompound lib (_fun $base -> $BaseCompound) 
                                 (lambda () (signal-error (format "couldn't open ~a from lib" makeBaseCompound))))]
             [else  (get-ffi-obj makeBaseCompound lib (_fun $int $base -> $BaseCompound) 
                                 (lambda () (signal-error (format "couldn't open ~a from lib" makeBaseCompound))))])))]))


; (define-compound ...) for each base type the type (and associated functions)
; $BaseCompound is defined by building the names, and then using
; define-compound/base-type.
(define-syntax (define-compound stx)
  (syntax-case stx ()
    [(_ compound-stx)
     (begin
       (define (string->id s) (syntax-local-introduce (quasisyntax/loc stx #,(string->symbol s))))
       (quasisyntax/loc stx
         (begin #,@(map (lambda (base)
                          (let* ([compound           (syntax-object->datum #'compound-stx)]
                                 [Base               (string-titlecase (symbol->string base))]
                                 [base               (string-downcase (symbol->string base))]
                                 [Compound           (string-titlecase (symbol->string compound))]
                                 [BaseCompound       (format "~a~a" Base Compound)]
                                 [$BaseCompound      (format "$~a~a" Base Compound)]
                                 [getBaseCompound    (format "get~a~a" Base Compound)]
                                 [unregBaseCompound  (format "unreg~a~a" Base Compound)]
                                 [make-base-compound (format "make-~a-~a" base compound)])
                            (with-syntax ((base                base)
                                          ($base               (string->id (format "$~a" base)))
                                          (Compound            (string->id Compound))
                                          (BaseCompound        (string->id BaseCompound))
                                          ($BaseCompound       (string->id $BaseCompound))
                                          (getBaseCompound     getBaseCompound)
                                          (unregBaseCompound   unregBaseCompound)
                                          (base-compound-get   (string->id (string-downcase (format "~a-~a-get" base compound))))
                                          (base-compound-unreg (string->id (string-downcase (format "~a-~a-unreg" base compound))))
                                          (make-base-compound  (string->id make-base-compound))
                                          (makeBaseCompound    (format "make~a" BaseCompound)))
                              #'(begin
                                  (define-compound/base-type $BaseCompound compound-stx base $base
                                    base-compound-get   getBaseCompound
                                    base-compound-unreg unregBaseCompound
                                    make-base-compound makeBaseCompound
                                    )))))
                        short-names))))]))

; The compound constructors are array, ref and vector.
(define-compound array)
(define-compound ref)
(define-compound vector)

;;;
;;; WORKING WITH COMPOUND TYPES
;;;

;;; REF

; ml-raw-ref-ref : $base cpointer-to-base-ref -> base
(define (ml-raw-ref-ref $base cpointer-to-base-ref)
  (ptr-ref cpointer-to-base-ref $base))

; ml-ref-ref : rooted-base-ref -> base
(define (ml-ref-ref rooted-base-ref)
  (unless (root? rooted-base-ref)
    (error #f "rooted ref expected, got " rooted-base-ref))
  (critical-region
   (ml-raw-ref-ref (root-base-type rooted-base-ref) 
                   (root->val rooted-base-ref))))

; ml-raw-ml-ref-set! : $base cpointer-to-base-ref -> 
(define (ml-raw-ref-set! $base cpointer-to-base-ref new-val)
  (ptr-set! cpointer-to-base-ref $base new-val))

; ml-ref-set! : rooted-base-ref base -> 
(define (ml-ref-set! rooted-base-ref new-val)
  (unless (root? rooted-base-ref)
    (error #f "rooted ref expected, got " rooted-base-ref))
  (critical-region
   (ml-raw-ref-set! (root-base-type rooted-base-ref)
                    (root->val rooted-base-ref)
                    new-val)))


;;; ARRAY

; ml-raw-array-length : pointer -> integer
;   return the length of an raw (as opposed to rooted) array returned from mlton
(define (ml-raw-array-length cpointer-to-array)
  ; see GC_arrayNumElementsp in gc.h
  (ptr-ref cpointer-to-array _uint -2))

; ml-array-length : rooted-array -> integer
(define (ml-array-length rooted-array)
  (unless (root? rooted-array)
    (error "rooted array expected, got " rooted-array))
  (critical-region
   (ml-raw-array-length (root->val rooted-array))))

; ml-raw-array-ref : pointer $base integer -> base
(define (ml-raw-array-ref $base cpointer-to-base-array index)
  (ptr-ref cpointer-to-base-array $base index))

; ml-array-ref : rooted-base-array integer -> base
(define (ml-array-ref rooted-base-array index)
  (unless (root? rooted-base-array)
    (error #f "rooted array or vector expected, got " rooted-base-array))
  (unless (<= 0 index (sub1 (ml-array-length rooted-base-array)))
    (error "index out of range" index))
  (critical-region
   (ml-raw-array-ref (root-base-type rooted-base-array)
                     (root->val rooted-base-array) 
                     index)))

; ml-raw-array-set! : $base cpointer-to-base-array integer new-val -> 
(define (ml-raw-array-set! $base cpointer-to-base-array index new-val)
  (ptr-set! cpointer-to-base-array $base index new-val))

; ml-array-set! : rooted-base-array integer base -> 
(define (ml-array-set! rooted-base-array index new-val)
  (unless (root? rooted-base-array)
    (error "rooted array expected, got " rooted-base-array))
  (unless (<= 0 index (sub1 (ml-array-length rooted-base-array)))
    (error "index out of range" index))
    (critical-region
     (let ([$base (root-base-type rooted-base-array)]
           [c-pointer (root->val rooted-base-array)])
       (ml-raw-array-set! $base c-pointer index new-val))))

; ml-array->vector : (rooted-base-array alpha) -> (vector alpha)
(define (ml-array->vector rooted-base-array)
  (unless (root? rooted-base-array)
    (error #f "rooted array expected, got " rooted-base-array))
  (critical-region
   (let* ((len   (ml-array-length rooted-base-array))
          (s     (make-vector len))
          (array (root->val rooted-base-array))
          ($base (root-base-type rooted-base-array)))
     (do ([i 0 (add1 i)]) [(= i len) s]
       (vector-set! s i (ml-raw-array-ref $base array i))))))


; vector->ml-array : (vector alpha) $alpha -> ml-array
; TODO: Infer ml-make-alpha-array from $alpha
(define (vector->ml-array v $alpha ml-make-alpha-array)
  (critical-region
   (let* ([len       (vector-length v)]
          [a         (ml-make-alpha-array len (vector-ref v 0))]
          [c-pointer (root->val a)])
     (do ([i 0 (add1 i)]) [(= i len) a]
       (ml-raw-array-set! $alpha c-pointer i (vector-ref v i))))))

;;; VECTOR


;; TODO TODO:  Write vector convenience functions
;; TODO TODO:  Write ml-array->ml-vector


; ml-vector-length : rooted-vector -> integer
(define (ml-vector-length v)
  (ml-array-length v))

; ml-raw-vector-ref $base c-pointer-to-vector integer -> base
(define (ml-raw-vector-ref $base c-pointer index)
  (ml-raw-array-ref $base c-pointer index))

; ml-vector-ref (ml-vector alpha) integer -> alpha
(define (ml-vector-ref rooted-vector index)
  (ml-array-ref rooted-vector index))

; ml-vector->vector : (rooted-base-vector alpha) -> (vector alpha)
(define (ml-vector->vector rooted-base-vector)
  (unless (root? rooted-base-vector)
    (error #f "rooted vector expected, got " rooted-base-vector))
  (critical-region
   (let* ((len   (ml-vector-length rooted-base-vector))
          (s     (make-vector len))
          (vector (root->val rooted-base-vector))
          ($base (root-base-type rooted-base-vector)))
     (do ([i 0 (add1 i)]) [(= i len) s]
       (vector-set! s i (ml-raw-vector-ref $base vector i))))))


;;;
;;; TEST
;;;

(display "* Opening library functions\n")
; defined above
;(define get-current-root-id       (get-ffi-obj "getCurrentRootId"   lib (_fun      -> $int)   signal-error))
(define test-int-to-int        (get-ffi-obj "testIntToInt"       lib (_fun $int -> $int)      signal-error))
(define test-unit-to-int       (get-ffi-obj "testUnitToInt"      lib (_fun      -> $int)      signal-error))
(define test-unit-to-char      (get-ffi-obj "testUnitToChar"     lib (_fun      -> $char)     signal-error))
(define test-unit-to-bool      (get-ffi-obj "testUnitToBool"     lib (_fun      -> $bool)     signal-error))
(define test-unit-to-int-array (get-ffi-obj "testUnitToIntArray" lib (_fun      -> $IntArray) signal-error))
(define test-unit-to-int-ref   (get-ffi-obj "testUnitToIntRef"   lib (_fun      -> $IntRef)   signal-error))

;;; MANDELBROT
(define ml-iterate             (get-ffi-obj "iterate"            lib (_fun $real $real -> $int)   signal-error))
(define ml-iterate-line        (get-ffi-obj "iterateLine"        lib (_fun $real $real $real $real -> $IntArray)   signal-error))

;;; CALLBACKS

(define ml-double-array        (get-ffi-obj "doubleArray"        lib (_fun $RealArray (_fun $real -> $real) -> $unit)   signal-error))


(display "* Done opening library objects\n")

(display (test-int-to-int 1))       (newline)
(display (test-unit-to-int))        (newline)
(display (test-unit-to-char))       (newline)
(define a (test-unit-to-int-array))
(display a)                         (newline)
(display (ml-array-length a))       (newline)
(display (ml-array-ref a 0))       (newline)
(display (ml-array-ref a 1))       (newline)
(display (ml-array-ref a 2))       (newline)
(display (ml-array-ref a 3))       (newline)
(display (ml-array-ref a 4))       (newline)
(display (ml-array->vector a))     (newline)
(newline)
(define r (test-unit-to-int-ref))
(display (ml-ref-ref r)) (newline)
(ml-ref-set! r 43)
(display (ml-ref-ref r))
(newline)
(display "double test\n")
(define b (make-real-array 4 0.0))
(ml-array-set! b 0 0.0)
(ml-array-set! b 1 1.0)
(ml-array-set! b 2 2.0)
(ml-array-set! b 3 3.0)
(ml-double-array b (lambda (x) (* 2.0 x)))
(display (ml-array->vector b)) (newline)


(define make-unit-vec (get-ffi-obj "makeUnitVec" lib (_fun $int $int -> $RealArray) signal-error))
(define dot-vec       (get-ffi-obj "dotVec"      lib (_fun $RealArray $RealArray -> $real) signal-error))
(define modify-vec!   (get-ffi-obj "modifyVec"   lib (_fun $RealArray (_fun $real -> $real) -> $unit) signal-error))





-------------- next part --------------
A non-text attachment was scrubbed...
Name: test-from-c
Type: application/octet-stream
Size: 13052 bytes
Desc: not available
Url : http://mlton.org/pipermail/mlton/attachments/20060810/995fe511/test-from-c-0001.obj
-------------- next part --------------
#include <stdlib.h>
#include <stdio.h>
#include <dlfcn.h>

int main (int argc, char **argv) {
  void *handle;
  long (*int42)();
  char *error;

  /* open the shared library */
  handle = dlopen( "/home/js/mzton/mzton/test.so", RTLD_LAZY);
  if (!handle) {
    fputs(dlerror(), stderr);
    exit(1);
  }
  
  /* import a function from the shared library */
  int42 = dlsym(handle, "int42");
  if ((error = dlerror()) != NULL) {
    fputs(error, stderr);
    exit(1);
  }

  printf("%d\n", (*int42)());
  printf("%d\n", (*int42)());
  printf("%d\n", (*int42)());
  printf("%d\n", (*int42)());
  printf("%d\n", (*int42)());
  printf("%d\n", (*int42)());

  dlclose(handle);
}
  
  
-------------- next part --------------
(define mlton-will-executor (make-will-executor))

; Start a seperate thread, that keeps unregistering unreachable roots
(thread (lambda ()
          (let loop ()
            (display "!")
            (sleep 1)
            (will-try-execute mlton-will-executor)
            (loop))))

(define a (make-vector 1000))
(will-register mlton-will-executor
               a
               (lambda (r)
                 (display "Heureka")
                 (newline)))

               


More information about the MLton mailing list