IntInf.{div,mod}

Stephen Weeks sweeks@wasabi.epr.com
Wed, 20 Oct 1999 10:32:51 -0700 (PDT)


I implemented them yesterday.  Here is the new int-inf.sml.

--------------------------------------------------------------------------------

(* Copyright (C) 1997-1999 NEC Research Institute.
 * Please see the file LICENSE for license information.
 *)
(*
 * IntInf.int's either have a bottom bit of 1, in which case the top 31
 * bits are the signed integer, or else the bottom bit is 0, in which case
 * they point to an array of Word.word's.  The first word is either 0,
 * indicating that the number is positive, or 1, indicating that it is
 * negative.  The rest of the array contains the `limbs' (big digits) or
 * the absolute value of the number, from least to most significant.
 *)
(*
 * Note, all the array's should be changed to vector's.
 * This requires the magic cast from 'a array to 'a vector.
 *)
structure IntInf: INT_INF =
   struct
      local
	 structure Prim = Primitive.IntInf
	 type bigInt = Prim.int
	 open Int
	 type smallInt = int

	 (*
	  * Return the number of `limbs' in a bignum bigInt.
	  *)
	 fun bigSize (arg: bigInt): smallInt =
		Vector.length (Prim.toVector arg) - 1

	 (*
	  * Return the number of `limbs' in a bigInt.
	  * If arg is big, then |arg| is in [ 2^(32 (x-1)), 2^(32 x) )
	  * where x is size arg.  If arg is small, then it is in
	  * [ - 2^30, 2^30 ).
	  *)
	 fun size (arg: bigInt): smallInt =
		if Prim.isSmall arg
		   then 1
		   else bigSize arg

	 (*
	  * Allocate a bignum bigInt with room for size `limbs'.
	  *)
	 fun allocate (size: smallInt) =
		Primitive.Array.array (size + 1)

	 (*
	  * Given a fixnum bigInt, return the Word.word which it
	  * represents.
	  * NOTE: it is an ERROR to call stripTag on an argument
	  * which is a bignum bigInt.
	  *)
	 fun stripTag (arg: bigInt): Word.word =
		Word.~>> (Prim.toWord arg, 0w1)

	 (*
	  * Given a Word.word, add the tag bit in so that it looks like
	  * a fixnum bigInt.
	  *)
	 fun addTag (argw: Word.word): Word.word =
		Word.orb (Word.<< (argw, 0w1), 0w1)

	 (*
	  * Given a fixnum bigInt, change the tag bit to 0.
	  * NOTE: it is an ERROR to call zeroTag on an argument
	  * which is a bignum bigInt.
	  *)
	 fun zeroTag (arg: bigInt): Word.word =
		Word.- (Prim.toWord arg, 0w1)

	 (*
	  * Given a Word.word, set the tag bit back to 1.
	  *)
	 fun incTag (argw: Word.word): Word.word =
		Word.orb (argw, 0w1)

	 (*
	  * badw is the fixnum bigInt (as a word) whose negation and
	  * absolute value are not fixnums.  badv is the same thing
	  * with the tag stripped off.
	  * negBad is the negation (and absolute value) of that bigInt.
	  *)
	 val badw: Word.word = Prim.toWord ~0x40000000
	 val badv: Word.word = stripTag ~0x40000000
	 val negBad: bigInt = 0x40000000

	 (*
	  * Given two Word.word's, check if they have the same `sign' bit.
	  *)
	 fun sameSign (lhs: Word.word, rhs: Word.word): bool =
		Word.toIntX (Word.xorb (lhs, rhs)) >= 0

	 (*
	  * Given a bignum bigint, test if it is (strictly) negative.
	  * Note: it is an ERROR to call bigIsNeg on an argument
	  * which is a fixnum bigInt.
	  *)
	 fun bigIsNeg (arg: bigInt): bool =
		Primitive.Vector.unsafeSub (Prim.toVector arg, 0) <> 0w0

	 (*
	  * Convert a smallInt to a bigInt.
	  *)
	 fun bigFromInt (arg: smallInt): bigInt =
		let val argv = Word.fromInt arg
		    val ans = addTag argv
		in if sameSign (argv, ans)
		      then Prim.fromWord ans
		      else let val space = allocate 1
			       val (isneg, abs) = if arg < 0
						     then (0w1, Word.- (0w0,
									argv))
						     else (0w0, argv)
			   in Primitive.Array.unsafeUpdate (space, 0, isneg);
			      Primitive.Array.unsafeUpdate (space, 1, abs);
			      Prim.fromArray space
			   end
		end

	 (*
	  * Convert a biglInt to a smallInt, raising overflow if it
	  * is too big.
	  *)
	 fun bigToInt (arg: bigInt): smallInt =
		if Prim.isSmall arg
		   then Word.toIntX (stripTag arg)
		else if bigSize arg <> 1
			then raise Overflow
		else let val arga = Prim.toVector arg
			 val argw = Primitive.Vector.unsafeSub (arga, 1)
		     in if Primitive.Vector.unsafeSub (arga, 0) <> 0w0
			   then if Word.<= (argw, 0wx80000000)
				   then Word.toIntX (Word.- (0w0, argw))
				   else raise Overflow
			   else if Word.< (argw, 0wx80000000)
				   then Word.toIntX argw
				   else raise Overflow
		     end

	 (*
	  * bigInt negation.
	  *)
	 fun bigNegate (arg: bigInt): bigInt =
		if Prim.isSmall arg
		      then let val argw = Prim.toWord arg
			   in if argw = badw
				 then negBad
				 else Prim.fromWord (Word.- (0w2, argw))
			   end
		      else Prim.~ (arg, allocate (1 + bigSize arg))

	 (*
	  * bigInt multiplication.
	  *)
	 local fun expensive (lhs: bigInt, rhs: bigInt): bigInt =
		      let val tsize = size lhs + size rhs
		      in Prim.* (lhs, rhs, allocate tsize)
		      end
	       val carry: Word.word ref = ref 0w0
	 in fun bigMul (lhs: bigInt, rhs: bigInt): bigInt =
		   if Prim.areSmall (lhs, rhs)
		      then let val lhsv = stripTag lhs
			       val rhs0 = zeroTag rhs
			       val ans0 = Prim.smallMul (lhsv, rhs0, carry)
			   in if (! carry) = Word.~>> (ans0, 0w31)
				 then Prim.fromWord (incTag ans0)
				 else expensive (lhs, rhs)
			   end
		      else expensive (lhs, rhs)
	 end

	 (*
	  * bigInt quot.
	  * Round towards 0 (bigRem returns the remainder).
	  * Note, if size num < size den, then the answer is 0.
	  * The only non-trivial case here is num being - den,
	  * and small, but in that case, although den may be big, its
	  * size is still 1.  (den cannot be 0 in this case.)
	  * The space required for the shifted numerator limbs is <= nsize + 1.
	  * The space required for the shifted denominator limbs is <= dsize
	  * The space required for the quotient limbs is <= 1 + nsize - dsize.
	  * Thus the total space for limbs is <= 2*nsize + 2 (and one extra
	  * word for the isNeg flag).
	  *)
	 fun bigQuot (num: bigInt, den: bigInt): bigInt =
		if Prim.areSmall (num, den)
		   then let val numv = stripTag num
			    val denv = stripTag den
			in if numv = badv andalso denv = Word.fromInt ~1
			      then negBad
			      else let val numi = Word.toIntX numv
				       val deni = Word.toIntX denv
				       val ansi = Int.quot (numi, deni)
				       val answ = Word.fromInt ansi
				   in Prim.fromWord (addTag answ)
				   end
			end
		   else let val nsize = size num
			    val dsize = size den
			in if nsize < dsize
			      then 0
			   else if den = 0
			      then raise Div
			   else let val space = allocate (2*nsize + 2)
				in Prim.quot (num, den, space)
				end
			end

	 (*
	  * bigInt rem.
	  * Sign taken from numerator, quotient is returned by bigQuot).
	  * Note, if size num < size den, then the answer is 0.
	  * The only non-trivial case here is num being - den,
	  * and small, but in that case, although den may be big, its
	  * size is still 1.  (den cannot be 0 in this case.)
	  * The space required for the shifted numerator limbs is <= nsize + 1.
	  * The space required for the shifted denominator limbs is <= dsize
	  * The space required for the quotient limbs is <= 1 + nsize - dsize.
	  * Thus the total space for limbs is <= 2*nsize + 2 (and one extra
	  * word for the isNeg flag).
	  *)
	 fun bigRem (num: bigInt, den: bigInt): bigInt =
		if Prim.areSmall (num, den)
		   then let val numv = stripTag num
			    val numi = Word.toIntX numv
			    val denv = stripTag den
			    val deni = Word.toIntX denv
			    val ansi = Int.rem (numi, deni)
			    val answ = Word.fromInt ansi
			in Prim.fromWord (addTag answ)
			end
		   else let val nsize = size num
			    val dsize = size den
			in if nsize < dsize
			      then num
			   else if den = 0
			      then raise Div
			   else let val space = allocate (2*nsize + 2)
				in Prim.rem (num, den, space)
				end
			end

	 (*
	  * bigInt addition.
	  *)
	 local fun expensive (lhs: bigInt, rhs: bigInt): bigInt =
		      let val tsize = max (size lhs, size rhs) + 1
		      in Prim.+ (lhs, rhs, allocate tsize)
		      end
	 in fun bigPlus (lhs: bigInt, rhs: bigInt): bigInt =
		   if Prim.areSmall (lhs, rhs)
		      then let val ansv = Word.+ (stripTag lhs, stripTag rhs)
			       val ans = addTag ansv
			   in if sameSign (ans, ansv)
				 then Prim.fromWord ans
				 else expensive (lhs, rhs)
			   end
		      else expensive (lhs, rhs)
	 end

	 (*
	  * bigInt subtraction.
	  *)
	 local fun expensive (lhs: bigInt, rhs: bigInt): bigInt =
		      let val tsize = max (size lhs, size rhs) + 1
		      in Prim.- (lhs, rhs, allocate tsize)
		      end
	 in fun bigMinus (lhs: bigInt, rhs: bigInt): bigInt =
		   if Prim.areSmall (lhs, rhs)
		      then let val ansv = Word.- (stripTag lhs, stripTag rhs)
			       val ans = addTag ansv
			   in if sameSign (ans, ansv)
				 then Prim.fromWord ans
				 else expensive (lhs, rhs)
			   end
		      else expensive (lhs, rhs)
	 end

	 (*
	  * bigInt compare.
	  *)
	 fun bigCompare (lhs: bigInt, rhs: bigInt): order =
		if Prim.areSmall (lhs, rhs)
		   then compare (Word.toIntX (Prim.toWord lhs),
				 Word.toIntX (Prim.toWord rhs))
		   else compare (Prim.compare (lhs, rhs), 0)


	 (*
	  * bigInt comparisions.
	  *)
	 local fun makeTest (smallTest: smallInt * smallInt -> bool)
			    (lhs: bigInt, rhs: bigInt)
			    : bool =
		      if Prim.areSmall (lhs, rhs)
			 then smallTest (Word.toIntX (Prim.toWord lhs),
					 Word.toIntX (Prim.toWord rhs))
			 else smallTest (Prim.compare (lhs, rhs), 0)
	 in val bigGT = makeTest (op >)
	    val bigGE = makeTest (op >=)
	    val bigLE = makeTest (op <=)
	    val bigLT = makeTest (op <)
	 end

	 (*
	  * bigInt abs.
	  *)
	 fun bigAbs (arg: bigInt): bigInt =
		if Prim.isSmall arg
		   then let val argw = Prim.toWord arg
			in if argw = badw
			      then negBad
			   else if Word.toIntX argw < 0
			      then Prim.fromWord (Word.- (0w2, argw))
			   else arg
			end
		   else if bigIsNeg arg
			   then Prim.~ (arg, allocate (1 + bigSize arg))
			   else arg

	 (*
	  * bigInt min.
	  *)
	 fun bigMin (lhs: bigInt, rhs: bigInt): bigInt =
		if bigLE (lhs, rhs)
		   then lhs
		   else rhs

	 (*
	  * bigInt max.
	  *)
	 fun bigMax (lhs: bigInt, rhs: bigInt): bigInt =
		if bigLE (lhs, rhs)
		   then rhs
		   else lhs

	 (*
	  * bigInt sign.
	  *)
	 fun bigSign (arg: bigInt): smallInt =
		if Prim.isSmall arg
		   then Int.sign (Word.toIntX (Prim.toWord arg))
		   else if bigIsNeg arg
			   then ~1
			   else 1

	 (*
	  * bigInt sameSign.
	  *)
	 fun bigSameSign (lhs: bigInt, rhs: bigInt): bool =
		bigSign lhs = bigSign rhs

	 (*
	  * bigInt toString and fmt.
	  * dpc is the maximum number of digits per `limb'.
	  *)
	 local
	    open StringCvt

	    fun cvt {base: smallInt,
		     dpc: smallInt,
		     smallCvt: smallInt -> string}
		    (arg: bigInt)
		    : string =
		   if Prim.isSmall arg
		      then smallCvt (Word.toIntX (stripTag arg))
		      else let val len = dpc * (bigSize arg) + 2
			       val res = Primitive.Array.array len
			   in Prim.toString (arg, base, res);
			      Primitive.String.fromCharVector
				(Primitive.Vector.fromArray res)
			   end
	    val binCvt = cvt {base = 2, dpc = 32, smallCvt = Int.fmt BIN}
	    val octCvt = cvt {base = 8, dpc = 11, smallCvt = Int.fmt OCT}
	    val hexCvt = cvt {base = 16, dpc = 8, smallCvt = Int.fmt HEX}
	 in
	    val bigToString = cvt {base = 10,
				   dpc = 10,
				   smallCvt = Int.toString}
	    fun bigFmt radix =
		   case radix of
			   BIN => binCvt
			   | OCT => octCvt
			   | DEC => bigToString
			   | HEX => hexCvt
	 end

	 (*
	  * bigInt scan and fromString.
	  *)
	 local
	    open StringCvt
	 
	    (*
	     * We use Word.word to store chunks of digits.
	     * smallToInf converts such a word to a fixnum bigInt.
	     * Thus, it can only represent values in [- 2^30, 2^30).
	     *)
	    fun smallToBig (arg: Word.word): bigInt =
		   Prim.fromWord (addTag arg)
	 
	 
	    (*
	     * Given a char, if it is a digit in the appropriate base,
	     * convert it to a word.  Otherwise, return NONE.
	     * Note, both a-f and A-F are accepted as hexadecimal digits.
	     *)
	    fun binDig (ch: char): Word.word option =
		   case ch of
		      #"0" => SOME 0w0
		      | #"1" => SOME 0w1
		      | _ => NONE
	 
	    fun octDig (ch: char): Word.word option =
		   if Char.<= (#"0", ch) andalso Char.<= (ch, #"7")
		      then SOME (Word.fromInt (ord ch - ord #"0"))
		      else NONE
	 
	    fun decDig (ch: char): Word.word option =
		   if Char.<= (#"0", ch) andalso Char.<= (ch, #"9")
		      then SOME (Word.fromInt (ord ch - ord #"0"))
		      else NONE
	 
	    fun hexDig (ch: char): Word.word option =
		   if Char.<= (#"0", ch) andalso Char.<= (ch, #"9")
		      then SOME (Word.fromInt (ord ch - ord #"0"))
		   else if Char.<= (#"a", ch) andalso Char.<= (ch, #"f")
		      then SOME (Word.fromInt (ord ch - ord #"a" + 0xa))
		   else if Char.<= (#"A", ch) andalso Char.<= (ch, #"F")
		      then SOME (Word.fromInt (ord ch - ord #"A" + 0xA))
		   else
		      NONE
	 
	    (*
	     * Given a digit converter and a char reader, return a digit
	     * reader.
	     *)
	    fun toDigR (charToDig: char -> Word.word option,
			cread: (char, 'a) reader)
		       (state: 'a)
		       : (Word.word * 'a) option =
		   case cread state of
			   NONE => NONE
			   | SOME (ch, state') =>
				case charToDig ch of
					NONE => NONE
					| SOME dig => SOME (dig, state')
	 
	    (*
	     * A chunk represents the result of processing some digits.
	     * more is a bool indicating if there might be more digits.
	     * shift is base raised to the number-of-digits-seen power.
	     * chunk is the value of the digits seen.
	     *)
	    type chunk = {
		    more: bool,
		    shift: Word.word,
		    chunk: Word.word
	    }
	 
	    (*
	     * Given the base, the number of digits per chunk,
	     * a char reader and a digit reader, return a chunk reader.
	     *)
	    fun toChunkR (base: Word.word,
			  dpc: smallInt,
			  cread: (char, 'a) reader,
			  dread: (Word.word, 'a) reader)
			 : (chunk, 'a) reader =
		   let fun loop {left: smallInt,
				 shift: Word.word,
				 chunk: Word.word,
				 state: 'a}
				: chunk * 'a =
			      if left = 0
				 then ({more = true,
					shift = shift,
					chunk = chunk },
				       state)
				 else case dread state of
					      NONE => ({more = false,
							shift = shift,
							chunk = chunk},
						       state)
					      | SOME (dig, state') =>
						   loop {
						      left = left - 1,
						      shift = Word.* (base, shift),
						      chunk = Word.+ (Word.* (base,
									      chunk),
								      dig),
						      state = state'
						   }
		       fun reader (state: 'a): (chunk * 'a) option =
			      case dread state of
				      NONE => NONE
				      | SOME (dig, next) =>
					   SOME (loop {left = dpc - 1,
						       shift = base,
						       chunk = dig,
						       state = next})
		   in reader
		   end
	 
	    (*
	     * Given a chunk reader, return an unsigned reader.
	     *)
	    fun toUnsR (ckread: (chunk, 'a) reader): (bigInt, 'a) reader =
		   let fun loop (more: bool, ac: bigInt, state: 'a) =
			      if more
				    then case ckread state of
					    NONE => (ac, state)
					    | SOME ({more, shift, chunk}, state') =>
						 loop (more,
						       bigPlus (bigMul (smallToBig shift,
									ac),
								smallToBig chunk),
						       state')
				    else (ac, state)
		       fun reader (state: 'a): (bigInt * 'a) option =
			      case ckread state of
				      NONE => NONE
				      | SOME ({more, shift, chunk}, state') =>
					   SOME (loop (more,
						       smallToBig chunk,
						       state'))
		   in reader
		   end
	 
	    (*
	     * Given a char reader and an unsigned reader, return a signed
	     * reader.  This includes skipping any initial white space.
	     *)
	    fun toSign (cread: (char, 'a) reader, uread: (bigInt, 'a) reader)
		       : (bigInt, 'a) reader =
		   let fun reader (state: 'a): (bigInt * 'a) option =
			      case cread state of
				      NONE => NONE
				      | SOME (ch, state') =>
					   if Char.isSpace ch
					      then reader state'
					      else let val (isNeg, state'') =
							      case ch of
								      #"+" =>
									 (false, state')
								      | #"-" =>
									 (true, state')
								      | #"~" =>
									 (true, state')
								      | _ =>
									 (false, state)
						   in if isNeg
							 then case uread state'' of
								      NONE => NONE
								      | SOME (abs, state''') =>
									   SOME (bigNegate abs,
										 state''')
							 else uread state''
						   end
		   in reader
		   end
	 
	    (*
	     * Given a char reader and an unsigned reader, return a reader
	     * which handles the optional initial 0x or 0X.
	     *)
	    fun toX (cread: (char, 'a) reader, uread: (bigInt, 'a) reader)
		    (state: 'a)
		    : (bigInt * 'a) option =
		   case cread state of
			   NONE => NONE
			   | SOME (#"0", state') =>
				(case cread state' of
					 NONE => SOME (0, state')
					 | SOME (ch, state'') =>
					      if ch = #"X" orelse ch = #"x"
						 then case uread state'' of
							      NONE => SOME (0,
									    state')
							      | res => res
						 else uread state)
			   | _ => uread state
	 
	    (*
	     * Base-specific conversions from char readers to
	     * bigInt readers.
	     *)
	    fun binReader (cread: (char, 'a) reader): (bigInt, 'a) reader =
		   let val dread = toDigR (binDig, cread)
		       val dpc = 29
		       val ckread = toChunkR (0w2, dpc, cread, dread)
		       val uread = toUnsR ckread
		       val reader = toSign (cread, uread)
		   in reader
		   end
	 
	    fun octReader (cread: (char, 'a) reader): (bigInt, 'a) reader =
		   let val dread = toDigR (octDig, cread)
		       val dpc = 9
		       val ckread = toChunkR (0w8, dpc, cread, dread)
		       val uread = toUnsR ckread
		       val reader = toSign (cread, uread)
		   in reader
		   end
	 
	    fun decReader (cread: (char, 'a) reader): (bigInt, 'a) reader =
		   let val dread = toDigR (decDig, cread)
		       val dpc = 9
		       val ckread = toChunkR (0w10, dpc, cread, dread)
		       val uread = toUnsR ckread
		       val reader = toSign (cread, uread)
		   in reader
		   end
	 
	    fun hexReader (cread: (char, 'a) reader): (bigInt, 'a) reader =
		   let val dread = toDigR (hexDig, cread)
		       val dpc = 7
		       val ckread = toChunkR (0wx10, dpc, cread, dread)
		       val uread = toUnsR ckread
		       val uxread = toX (cread, uread)
		       val reader = toSign (cread, uxread)
		   in reader
		   end
	 
	 in
	 
	    local fun stringReader (pos, str) =
			 if pos = String.size str
			    then NONE
			    else SOME (String.sub (str, pos),
				       (pos + 1, str))
		  val reader = decReader stringReader
	    in fun bigFromString str =
		      case reader (0, str) of
			      NONE => NONE
			      | SOME (res, _) => SOME res
	    end
	 
	    fun bigScan radix =
		   case radix of
			   BIN => binReader
			   | OCT => octReader
			   | DEC => decReader
			   | HEX => hexReader
	 
	    exception HenryIsLazy
	    fun unimplemented _ =
		   raise HenryIsLazy

	 end

       local
	  fun isEven(n: int) = Int.mod(Int.abs n, 2) = 0
       in
	  fun pow(i: bigInt, j: int): bigInt =
	     if j < 0
		then
		   if i = 0
		      then raise Div
		   else if i = 1
			   then 1
			else if i = ~1
				then if isEven j
					then 1
				     else ~1
			     else 0
	    else
	       if j = 0
		  then 1
	       else
		  let
		     fun square(n: bigInt): bigInt = bigMul(n, n)
		     (* pow(j) returns (i ^ j) *)
		     fun pow(j: int): bigInt =
			if j = 0
			   then 1
			else if isEven(j)
				then evenPow(j)
			     else bigMul(i, evenPow(j - 1))
		     (* evenPow(j) returns (i ^ j), assuming j is even *)
		     and evenPow(j: int): bigInt =
			square(pow(Int.quot(j, 2)))
		  in pow(j)
		  end

	  val op + = bigPlus
	  val op - = bigMinus
	  val compare = bigCompare
	  val op > = bigGT
	  val op >= = bigGE
	  val op < = bigLT
	  val op <= = bigLE
	  val quot = bigQuot
	  val rem = bigRem

	  fun x div y =
	     if x >= 0
		then if y > 0
			then quot(x, y)
		     else if y < 0
			     then if x = 0
				     then 0
				  else quot(x - 1, y) - 1
			  else raise Div
	     else if y < 0
		     then quot(x, y)
		  else if y > 0
			  then quot(x + 1, y) - 1
		       else raise Div

	  fun x mod y =
	     if x >= 0
		then if y > 0
			then rem(x, y)
		     else if y < 0
			     then if x = 0
				     then 0
				  else rem(x - 1, y) + 1 + y
			  else raise Div
	     else if y < 0
		     then rem(x, y)
		  else if y > 0
			  then rem(x + 1, y) - 1 + y
		       else raise Div
       end

      in
	 type int = bigInt
	 val toLarge = fn x => x
	 val fromLarge = fn x => x
	 val toInt = bigToInt
	 val fromInt = bigFromInt
	 val precision = NONE
	 val minInt = NONE
	 val maxInt = NONE
	 val ~ = bigNegate
	 val op * = bigMul
	 val op div = op div
	 val op mod = op mod
	 val quot = bigQuot
	 val rem = bigRem
	 val op + = bigPlus
	 val op - = bigMinus
	 val compare = bigCompare
	 val op > = bigGT
	 val op >= = bigGE
	 val op < = bigLT
	 val op <= = bigLE
	 val abs = bigAbs
	 val min = bigMin
	 val max = bigMax
	 val pow = pow
	 val sign = bigSign
	 val sameSign = bigSameSign
	 val fmt = bigFmt
	 val toString = bigToString
	 val fromString = bigFromString
	 val scan = bigScan
      end
   end

structure LargeInt = IntInf