slow matrix multiply

Stephen Weeks MLton@sourcelight.com
Tue, 10 Jul 2001 18:02:49 -0700


> One thing I was surprised by was that not only loop nests was slow: so was
> matrix multiply.  I didn't look at it, but it was scary that it was slow
> as well as the nested loop.

I looked into matrix.{gcc, mlton, ocaml}.  Here are the running times I see.

gcc	1.29
ocaml	1.40
mlton	4.84

Here's the source and annotated assembly for the hot loop for each of the three
compilers.  Ocaml and gcc do better for several reasons.

* they keep the loop index and sum in registers
* MLton completely recomputes the array offset for each subscript
* MLton does some extra stuff (cltd, shuffling, ...)

--------------------------------------------------------------------------------
MLton
--------------------------------------------------------------------------------

fun loop (k, sum) =
   if k < 0
     then sum
   else loop (k - 1, sum + sub (m1, i, k) * sub (m2, k, j))

loop_51:
	movl (200*1)(%edi),%esp		# %esp = k
	cmpl $0,%esp			# if k < 0
	jl L_229
	movl %esp,%ebp			# %ebp = k
	decl %ebp			# %ebp = k - 1
	movl (188*1)(%edi),%edx		# %edx = i
	movl %edx,%ecx			# %ecx = i
	movl %ecx,%eax			# %eax = i
	movl $30,%ecx			# %ecx = 30
	cltd
	imull %ecx			# %eax = i * 30
	addl %esp,%eax			# %eax = i * 30 + k
	movl %ebp,(200*1)(%edi)		# store k - 1
	xchgl %esp,%eax			# %eax = k  %esp = i * 30 + k
	movl $30,%ebp			# %ebp = 30
	cltd
	imull %ebp			# %eax = k * 30
	addl (192*1)(%edi),%eax		# %eax = k * 30 + j
	movl (144*1)(%edi),%ebp		# %ebp = m1
	movl %esp,%edx			# %edx = i * 30 + k
	movl (%ebp,%edx,4),%esp		# %esp = sub (m1, i, k)
	xchgl %esp,%eax			# %eax = sub(m1, i, k)  %esp = k * 30 + j
	movl (160*1)(%edi),%ebp		# %ebp = m2
	movl %esp,%ecx			# %ecx = k * 30 + j
	cltd
	imull (%ebp,%ecx,4)		# %eax = sub (m1, i, k) * sub (m2, k, j)
	addl %eax,(196*1)(%edi)		# sum = sum + ...
	jmp loop_51

--------------------------------------------------------------------------------
ocaml
--------------------------------------------------------------------------------

let rec inner_loop k v m1i m2 j =
  if k < 0 then v
  else inner_loop (k - 1) (v + m1i.(k) * m2.(k).(j)) m1i m2 j

.L107:				# %eax = k  %ebx = v  %ecx = m1i  %edx = m2
	cmpl	$1, %eax
	jge	.L106
	movl	%ebx, %eax
	ret
	.align	16
.L106:
	movl	-2(%edx, %eax, 2), %edi		# %edi = m2.(k)
	movl	-2(%edi, %esi, 2), %ebp		# %ebp = m2.(k).(j)
	sarl	$1, %ebp
	movl	-2(%ecx, %eax, 2), %edi		# %eax = m1i.(k)
	decl	%edi
	imull	%ebp, %edi			# m1i.(k) * m2.(k).(j)
	addl	%edi, %ebx			# v + m1i.(k) * m2.(k).(j)
	addl	$-2, %eax			# %eax = k - 1
	jmp	.L107

--------------------------------------------------------------------------------
gcc
--------------------------------------------------------------------------------

for (k=0; k<cols; k++)
  val += m1[i][k] * m2[k][j];

.L125:					# %ebx = i  %ecx = sum  %edx = k  
	movl	16(%esp), %edi		# %edi = m1
	movl	(%edi,%edx,4), %eax	# %eax = m1   [k]
	movl	(%eax,%ebx,4), %eax	# %eax = m1[i][k]
	movl	12(%esp), %edi		# %edi = m2   [j]
	imull	(%edi,%edx,4), %eax	# %eax = m1[i][k] * m2[k][j]
	incl	%edx			# %edx = k + 1
	addl	%eax, %ecx		# %ecx = sum + m1[i][k] * m2[k][j]
	cmpl	$30, %edx		# if k < 30
	jl	.L125