Strassen's algorithm

You are encouraged to solve this task according to the task description, using any language you may know.
- Description
In linear algebra, the Strassen algorithm (named after Volker Strassen), is an algorithm for matrix multiplication.
It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices, but would be slower than the fastest known algorithms for extremely large matrices.
- Task
Write a routine, function, procedure etc. in your language to implement the Strassen algorithm for matrix multiplication.
While practical implementations of Strassen's algorithm usually switch to standard methods of matrix multiplication for small enough sub-matrices (currently anything less than 512×512 according to Wikipedia), for the purposes of this task you should not switch until reaching a size of 1 or 2.
- Related task
- See also
Type Matrix
As Integer rows
As Integer cols
As Double dato(Any)
End Type
Function makeMatrix(rows As Integer, cols As Integer) As Matrix
Dim As Matrix m
m.rows = rows
m.cols = cols
Redim m.dato(rows * cols - 1)
Return m
End Function
Function matrixAdd(m1 As Matrix, m2 As Matrix) As Matrix
If m1.rows <> m2.rows Or m1.cols <> m2.cols Then
Print "Matrices must have the same dimensions."
End If
Dim As Matrix result = makeMatrix(m1.rows, m1.cols)
For i As Integer = 0 To m1.rows * m1.cols - 1
result.dato(i) = m1.dato(i) + m2.dato(i)
Return result
End Function
Function matrixSub(m1 As Matrix, m2 As Matrix) As Matrix
If m1.rows <> m2.rows Or m1.cols <> m2.cols Then
Print "Matrices must have the same dimensions."
End If
Dim As Matrix result = makeMatrix(m1.rows, m1.cols)
For i As Integer = 0 To m1.rows * m1.cols - 1
result.dato(i) = m1.dato(i) - m2.dato(i)
Return result
End Function
Function matrixMul(m1 As Matrix, m2 As Matrix) As Matrix
If m1.cols <> m2.rows Then
Print "Cannot multiply these matrices."
End If
Dim As Integer i, j, k
Dim As Matrix result = makeMatrix(m1.rows, m2.cols)
For i = 0 To m1.rows - 1
For j = 0 To m2.cols - 1
Dim As Double sum = 0
For k = 0 To m1.cols - 1
sum += m1.dato(i * m1.cols + k) * m2.dato(k * m2.cols + j)
result.dato(i * result.cols + j) = sum
Return result
End Function
Sub printMatrix(m As Matrix, precision As Integer = 6)
Dim As Integer i, j
Print "[[";
For i = 0 To m.rows - 1
If i > 0 Then Print " [";
For j = 0 To m.cols - 1
Dim As Double valor = m.dato(i * m.cols + j)
valor = Int(valor * (10 ^ precision) + 0.5) / (10 ^ precision)
If Abs(valor) < 1e-10 Then valor = 0
Print Using "&"; valor;
If j < m.cols - 1 Then Print " ";
If i < m.rows - 1 Then Print "]";
Print "]]"
End Sub
Function getQuarter(m As Matrix, quarter As Integer) As Matrix
Dim As Integer i, j
Dim As Integer halfRows, halfCols
halfRows = m.rows \ 2
halfCols = m.cols \ 2
Dim As Matrix result = makeMatrix(halfRows, halfCols)
Dim As Integer rowOffset, colOffset
rowOffset = (quarter \ 2) * halfRows
colOffset = (quarter Mod 2) * halfCols
For i = 0 To halfRows - 1
For j = 0 To halfCols - 1
result.dato(i * halfCols + j) = m.dato((i + rowOffset) * m.cols + j + colOffset)
Return result
End Function
Function combineQuarters(q1 As Matrix, q2 As Matrix, q3 As Matrix, q4 As Matrix) As Matrix
Dim As Integer i, j, n
n = q1.rows
Dim As Matrix result = makeMatrix(n * 2, n * 2)
For i = 0 To n - 1
For j = 0 To n - 1
result.dato(i * result.cols + j) = q1.dato(i * n + j)
result.dato(i * result.cols + j + n) = q2.dato(i * n + j)
result.dato((i + n) * result.cols + j) = q3.dato(i * n + j)
result.dato((i + n) * result.cols + j + n) = q4.dato(i * n + j)
Return result
End Function
Function strassen(a As Matrix, b As Matrix) As Matrix
If a.rows <> a.cols Or b.rows <> b.cols Or a.rows <> b.rows Then
Print "Matrices must be square and of equal size."
End If
If a.rows = 1 Then Return matrixMul(a, b)
Dim As Matrix a11 = getQuarter(a, 0)
Dim As Matrix a12 = getQuarter(a, 1)
Dim As Matrix a21 = getQuarter(a, 2)
Dim As Matrix a22 = getQuarter(a, 3)
Dim As Matrix b11 = getQuarter(b, 0)
Dim As Matrix b12 = getQuarter(b, 1)
Dim As Matrix b21 = getQuarter(b, 2)
Dim As Matrix b22 = getQuarter(b, 3)
Dim As Matrix p1 = strassen(matrixSub(a12, a22), matrixAdd(b21, b22))
Dim As Matrix p2 = strassen(matrixAdd(a11, a22), matrixAdd(b11, b22))
Dim As Matrix p3 = strassen(matrixSub(a11, a21), matrixAdd(b11, b12))
Dim As Matrix p4 = strassen(matrixAdd(a11, a12), b22)
Dim As Matrix p5 = strassen(a11, matrixSub(b12, b22))
Dim As Matrix p6 = strassen(a22, matrixSub(b21, b11))
Dim As Matrix p7 = strassen(matrixAdd(a21, a22), b11)
Dim As Matrix c11 = matrixAdd(matrixSub(matrixAdd(p1, p2), p4), p6)
Dim As Matrix c12 = matrixAdd(p4, p5)
Dim As Matrix c21 = matrixAdd(p6, p7)
Dim As Matrix c22 = matrixSub(matrixAdd(matrixSub(p2, p3), p5), p7)
Return combineQuarters(c11, c12, c21, c22)
End Function
Sub main()
' Matrix A (2x2)
Dim As Matrix a = makeMatrix(2, 2)
a.dato(0) = 1: a.dato(1) = 2
a.dato(2) = 3: a.dato(3) = 4
' Matrix B (2x2)
Dim As Matrix b = makeMatrix(2, 2)
b.dato(0) = 5: b.dato(1) = 6
b.dato(2) = 7: b.dato(3) = 8
' Matrix C (4x4)
Dim As Matrix c = makeMatrix(4, 4)
c.dato(0) = 1: c.dato(1) = 1: c.dato(2) = 1: c.dato(3) = 1
c.dato(4) = 2: c.dato(5) = 4: c.dato(6) = 8: c.dato(7) = 16
c.dato(8) = 3: c.dato(9) = 9: c.dato(10) = 27: c.dato(11) = 81
c.dato(12) = 4: c.dato(13) = 16: c.dato(14) = 64: c.dato(15) = 256
' Matrix D (4x4)
Dim As Matrix d = makeMatrix(4, 4)
d.dato(0) = 4: d.dato(1) = -3: d.dato(2) = 4/3: d.dato(3) = -1/4
d.dato(4) = -13/3: d.dato(5) = 19/4: d.dato(6) = -7/3: d.dato(7) = 11/24
d.dato(8) = 3/2: d.dato(9) = -2: d.dato(10) = 7/6: d.dato(11) = -1/4
d.dato(12) = -1/6: d.dato(13) = 1/4: d.dato(14) = -1/6: d.dato(15) = 1/24
' Matrix E (4x4)
Dim As Matrix e = makeMatrix(4, 4)
e.dato(0) = 1: e.dato(1) = 2: e.dato(2) = 3: e.dato(3) = 4
e.dato(4) = 5: e.dato(5) = 6: e.dato(6) = 7: e.dato(7) = 8
e.dato(8) = 9: e.dato(9) = 10: e.dato(10) = 11: e.dato(11) = 12
e.dato(12) = 13: e.dato(13) = 14: e.dato(14) = 15: e.dato(15) = 16
' Matrix F (Identity 4x4)
Dim As Matrix f = makeMatrix(4, 4)
f.dato(0) = 1: f.dato(1) = 0: f.dato(2) = 0: f.dato(3) = 0
f.dato(4) = 0: f.dato(5) = 1: f.dato(6) = 0: f.dato(7) = 0
f.dato(8) = 0: f.dato(9) = 0: f.dato(10) = 1: f.dato(11) = 0
f.dato(12) = 0: f.dato(13) = 0: f.dato(14) = 0: f.dato(15) = 1
Print "Using 'normal' matrix multiplication:"
Print " a * b = ";
printMatrix(matrixMul(a, b))
Print " c * d = ";
printMatrix(matrixMul(c, d), 6)
Print " e * f = ";
printMatrix(matrixMul(e, f))
Print !"\nUsing 'Strassen' matrix multiplication:"
Print " a * b = ";
printMatrix(strassen(a, b))
Print " c * d = ";
printMatrix(strassen(c, d), 6)
Print " e * f = ";
printMatrix(strassen(e, f))
End Sub
- Output:
Using 'normal' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]] Using 'Strassen' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]]
Rather than use a library such as gonum, we create a simple Matrix type which is adequate for this task.
package main
import (
type Matrix [][]float64
func (m Matrix) rows() int { return len(m) }
func (m Matrix) cols() int { return len(m[0]) }
func (m Matrix) add(m2 Matrix) Matrix {
if m.rows() != m2.rows() || m.cols() != m2.cols() {
log.Fatal("Matrices must have the same dimensions.")
c := make(Matrix, m.rows())
for i := 0; i < m.rows(); i++ {
c[i] = make([]float64, m.cols())
for j := 0; j < m.cols(); j++ {
c[i][j] = m[i][j] + m2[i][j]
return c
func (m Matrix) sub(m2 Matrix) Matrix {
if m.rows() != m2.rows() || m.cols() != m2.cols() {
log.Fatal("Matrices must have the same dimensions.")
c := make(Matrix, m.rows())
for i := 0; i < m.rows(); i++ {
c[i] = make([]float64, m.cols())
for j := 0; j < m.cols(); j++ {
c[i][j] = m[i][j] - m2[i][j]
return c
func (m Matrix) mul(m2 Matrix) Matrix {
if m.cols() != m2.rows() {
log.Fatal("Cannot multiply these matrices.")
c := make(Matrix, m.rows())
for i := 0; i < m.rows(); i++ {
c[i] = make([]float64, m2.cols())
for j := 0; j < m2.cols(); j++ {
for k := 0; k < m2.rows(); k++ {
c[i][j] += m[i][k] * m2[k][j]
return c
func (m Matrix) toString(p int) string {
s := make([]string, m.rows())
pow := math.Pow(10, float64(p))
for i := 0; i < m.rows(); i++ {
t := make([]string, m.cols())
for j := 0; j < m.cols(); j++ {
r := math.Round(m[i][j]*pow) / pow
t[j] = fmt.Sprintf("%g", r)
if t[j] == "-0" {
t[j] = "0"
s[i] = fmt.Sprintf("%v", t)
return fmt.Sprintf("%v", s)
func params(r, c int) [4][6]int {
return [4][6]int{
{0, r, 0, c, 0, 0},
{0, r, c, 2 * c, 0, c},
{r, 2 * r, 0, c, r, 0},
{r, 2 * r, c, 2 * c, r, c},
func toQuarters(m Matrix) [4]Matrix {
r := m.rows() / 2
c := m.cols() / 2
p := params(r, c)
var quarters [4]Matrix
for k := 0; k < 4; k++ {
q := make(Matrix, r)
for i := p[k][0]; i < p[k][1]; i++ {
q[i-p[k][4]] = make([]float64, c)
for j := p[k][2]; j < p[k][3]; j++ {
q[i-p[k][4]][j-p[k][5]] = m[i][j]
quarters[k] = q
return quarters
func fromQuarters(q [4]Matrix) Matrix {
r := q[0].rows()
c := q[0].cols()
p := params(r, c)
r *= 2
c *= 2
m := make(Matrix, r)
for i := 0; i < c; i++ {
m[i] = make([]float64, c)
for k := 0; k < 4; k++ {
for i := p[k][0]; i < p[k][1]; i++ {
for j := p[k][2]; j < p[k][3]; j++ {
m[i][j] = q[k][i-p[k][4]][j-p[k][5]]
return m
func strassen(a, b Matrix) Matrix {
if a.rows() != a.cols() || b.rows() != b.cols() || a.rows() != b.rows() {
log.Fatal("Matrices must be square and of equal size.")
if a.rows() == 0 || (a.rows()&(a.rows()-1)) != 0 {
log.Fatal("Size of matrices must be a power of two.")
if a.rows() == 1 {
return a.mul(b)
qa := toQuarters(a)
qb := toQuarters(b)
p1 := strassen(qa[1].sub(qa[3]), qb[2].add(qb[3]))
p2 := strassen(qa[0].add(qa[3]), qb[0].add(qb[3]))
p3 := strassen(qa[0].sub(qa[2]), qb[0].add(qb[1]))
p4 := strassen(qa[0].add(qa[1]), qb[3])
p5 := strassen(qa[0], qb[1].sub(qb[3]))
p6 := strassen(qa[3], qb[2].sub(qb[0]))
p7 := strassen(qa[2].add(qa[3]), qb[0])
var q [4]Matrix
q[0] = p1.add(p2).sub(p4).add(p6)
q[1] = p4.add(p5)
q[2] = p6.add(p7)
q[3] = p2.sub(p3).add(p5).sub(p7)
return fromQuarters(q)
func main() {
a := Matrix{{1, 2}, {3, 4}}
b := Matrix{{5, 6}, {7, 8}}
c := Matrix{{1, 1, 1, 1}, {2, 4, 8, 16}, {3, 9, 27, 81}, {4, 16, 64, 256}}
d := Matrix{{4, -3, 4.0 / 3, -1.0 / 4}, {-13.0 / 3, 19.0 / 4, -7.0 / 3, 11.0 / 24},
{3.0 / 2, -2, 7.0 / 6, -1.0 / 4}, {-1.0 / 6, 1.0 / 4, -1.0 / 6, 1.0 / 24}}
e := Matrix{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}
f := Matrix{{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}}
fmt.Println("Using 'normal' matrix multiplication:")
fmt.Printf(" a * b = %v\n", a.mul(b))
fmt.Printf(" c * d = %v\n", c.mul(d).toString(6))
fmt.Printf(" e * f = %v\n", e.mul(f))
fmt.Println("\nUsing 'Strassen' matrix multiplication:")
fmt.Printf(" a * b = %v\n", strassen(a, b))
fmt.Printf(" c * d = %v\n", strassen(c, d).toString(6))
fmt.Printf(" e * f = %v\n", strassen(e, f))
- Output:
Using 'normal' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]] Using 'Strassen' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]]
Works with gojq, the Go implementation of jq
### Generic utilities
def sigma(stream): reduce stream as $x (0; . + $x);
# Hilbert-Schmidt norm:
def frobenius:
sigma( flatten[] | .*. ) | sqrt;
### Matrices
# Create an m x n matrix
def matrix(m; n; init):
if m == 0 then []
elif m == 1 then [[range(0;n) | init]]
elif m > 0 then
matrix(1;n;init) as $row
| [range(0;m) | $row ]
else error("matrix\(m);_;_) invalid")
# A and B must be (multi-dimensional) vectors of the same shape
def vector_add($A;$B):
if ($A|type) == "array"
then reduce range(0; $A|length) as $i ([];
. + [vector_add($A[$i]; $B[$i])] )
else $A + $B
def vector_negate:
if type == "array"
then map(vector_negate)
else - .
def vector_subtract($A;$B):
vector_add($A; $B|vector_negate);
# A should be m by n; and B n by p
# Pre-allocating the resultant matrix results in a very small net speedup.
def multiply($A; $B):
($A|length) as $m
| ($A[0]|length) as $n
| ($B[0]|length) as $p
| reduce range(0; $m) as $i
(matrix($m; $p; 0); # initialize to avoid resizing
reduce range(0;$p) as $j (.;
.[$i][$j] = reduce range(0;$n) as $k
# Cij = innerproduct of row i, column j
. + $A[$i][$k] * $B[$k][$j] ))) ;
def submatrix($m1; $m2; $n1; $n2):
.[$m1:$m2] | map( .[$n1:$n2]);
def submatrix($A; $m1; $m2; $n1; $n2):
$A | submatrix($m1; $m2; $n1; $n2);
def rowwise_extend($A;$B):
reduce range(0; $A|length) as $i ([]; . + [$A[$i] + $B[$i]]);
### Strassen multiplication of n*n square matrices where n is a power of 2.
def Strassen($A; $B):
($A[0]|length) as $n
| if $n == 1 then multiply($A; $B)
submatrix($A; 0; $n/2; 0; $n/2) as $A11
| submatrix($A; 0; $n/2; $n/2; $n) as $A12
| submatrix($A; $n/2; $n; 0; $n/2) as $A21
| submatrix($A; $n/2; $n; $n/2; $n) as $A22
| submatrix($B; 0; $n/2; 0; $n/2) as $B11
| submatrix($B; 0; $n/2; $n/2; $n) as $B12
| submatrix($B; $n/2; $n; 0; $n/2) as $B21
| submatrix($B; $n/2; $n; $n/2+0; $n) as $B22
| Strassen( vector_subtract($A12; $A22); vector_add($B21; $B22)) as $P1
| Strassen( vector_add($A11; $A22); vector_add($B11; $B22)) as $P2
| Strassen( vector_subtract($A11; $A21); vector_add($B11; $B12)) as $P3
| Strassen( vector_add($A11; $A12); $B22) as $P4
| Strassen( $A11; vector_subtract($B12; $B22)) as $P5
| Strassen( $A22; vector_subtract($B21; $B11) ) as $P6
| Strassen( vector_add($A21; $A22); $B11) as $P7
| vector_add(vector_subtract(vector_add($P1; $P2); $P4); $P6) as $C11
| vector_add($P4; $P5) as $C12
| vector_add($P6; $P7) as $C21
| vector_add(vector_subtract($P2; $P3); vector_subtract($P5;$P7)) as $C22
# [C11 C12; C21 C22]
| rowwise_extend($C11; $C12) + rowwise_extend($C21; $C22)
# ## Examples
def A: [[1, 2], [3, 4]];
def B: [[5, 6], [7, 8]];
def C: [[1, 1, 1, 1],
[2, 4, 8, 16],
[3, 9, 27, 81],
[4, 16, 64, 256]];
def D: [[4, -3, 4/3, -1/4],
[-13/3, 19/4, -7/3, 11/24],
[3/2, -2, 7/6, -1/4],
[-1/6, 1/4, -1/6, 1/24]];
def E: [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]];
def F: [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [ 0, 0, 0, 1]];
def r: (2|sqrt)/2;
def R: [[r, r], [-r, r]];
"A*B == Strassen(A;B): \((multiply(A; B) == Strassen(A; B)))",
"Frobenius norm for C*D - Strassen(C;D): \(vector_subtract(multiply(C; D); Strassen(C; D)) | frobenius)",
"E*F == Strassen(E;F): \(multiply(E; F) == Strassen( E; F))",
"R*R == Strassen(R;R): \(multiply(R ; R) == Strassen(R; R))"
- Output:
A*B == Strassen(A;B): true Frobenius norm for C*D - Strassen(C;D): 3.4360172596923834e-13 E*F == Strassen(E;F): true R*R == Strassen(R;R): true
With dynamic padding
Because Julia uses column major in matrices, sometimes the code uses the adjoint of a matrix in order to match examples as written.
Strassen's matrix multiplication algorithm.
Use dynamic padding in order to reduce required auxiliary memory.
function strassen(x::Matrix, y::Matrix)
# Check that the sizes of these matrices match.
(r1, c1) = size(x)
(r2, c2) = size(y)
if c1 != r2
error("Multiplying $r1 x $c1 and $r2 x $c2 matrix: dimensions do not match.")
# Put a matrix into the top left of a matrix of zeros.
# `rows` and `cols` are the dimensions of the output matrix.
function embed(mat, rows, cols)
# If the matrix is already of the right dimensions, don't allocate new memory.
(r, c) = size(mat)
if (r, c) == (rows, cols)
return mat
# Pad the matrix with zeros to be the right size.
out = zeros(Int, rows, cols)
out[1:r, 1:c] = mat
# Make sure both matrices are the same size.
# This is exclusively for simplicity:
# this algorithm can be implemented with matrices of different sizes.
r = max(r1, r2); c = max(c1, c2)
x = embed(x, r, c)
y = embed(y, r, c)
# Our recursive multiplication function.
function block_mult(a, b, rows, cols)
# For small matrices, resort to naive multiplication.
# if rows <= 128 || cols <= 128
if rows == 1 && cols == 1
# if rows == 2 && cols == 2
return a * b
# Apply dynamic padding.
if rows % 2 == 1 && cols % 2 == 1
a = embed(a, rows + 1, cols + 1)
b = embed(b, rows + 1, cols + 1)
elseif rows % 2 == 1
a = embed(a, rows + 1, cols)
b = embed(b, rows + 1, cols)
elseif cols % 2 == 1
a = embed(a, rows, cols + 1)
b = embed(b, rows, cols + 1)
half_rows = Int(size(a, 1) / 2)
half_cols = Int(size(a, 2) / 2)
# Subdivide input matrices.
a11 = a[1:half_rows, 1:half_cols]
b11 = b[1:half_rows, 1:half_cols]
a12 = a[1:half_rows, half_cols+1:size(a, 2)]
b12 = b[1:half_rows, half_cols+1:size(a, 2)]
a21 = a[half_rows+1:size(a, 1), 1:half_cols]
b21 = b[half_rows+1:size(a, 1), 1:half_cols]
a22 = a[half_rows+1:size(a, 1), half_cols+1:size(a, 2)]
b22 = b[half_rows+1:size(a, 1), half_cols+1:size(a, 2)]
# Compute intermediate values.
multip(x, y) = block_mult(x, y, half_rows, half_cols)
m1 = multip(a11 + a22, b11 + b22)
m2 = multip(a21 + a22, b11)
m3 = multip(a11, b12 - b22)
m4 = multip(a22, b21 - b11)
m5 = multip(a11 + a12, b22)
m6 = multip(a21 - a11, b11 + b12)
m7 = multip(a12 - a22, b21 + b22)
# Combine intermediate values into the output.
c11 = m1 + m4 - m5 + m7
c12 = m3 + m5
c21 = m2 + m4
c22 = m1 - m2 + m3 + m6
# Crop output to the desired size (undo dynamic padding).
out = [c11 c12; c21 c22]
out[1:rows, 1:cols]
block_mult(x, y, r, c)
const A = [[1, 2] [3, 4]]
const B = [[5, 6] [7, 8]]
const C = [[1, 1, 1, 1] [2, 4, 8, 16] [3, 9, 27, 81] [4, 16, 64, 256]]
const D = [[4, -3, 4/3, -1/4] [-13/3, 19/4, -7/3, 11/24] [3/2, -2, 7/6, -1/4] [-1/6, 1/4, -1/6, 1/24]]
const E = [[1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16]]
const F = [[1, 0, 0, 0] [0, 1, 0, 0] [0, 0, 1, 0] [0, 0, 0, 1]]
""" For pretty printing: change matrix to integer if it is within 0.00000001 of an integer """
intprint(s, mat) = println(s, map(x -> Int(round(x, digits=8)), mat)')
intprint("Regular multiply: ", A' * B')
intprint("Strassen multiply: ", strassen(Matrix(A'), Matrix(B')))
intprint("Regular multiply: ", C * D)
intprint("Strassen multiply: ", strassen(C, D))
intprint("Regular multiply: ", E * F)
intprint("Strassen multiply: ", strassen(E, F))
const r = sqrt(2)/2
const R = [[r, r] [-r, r]]
intprint("Regular multiply: ", R * R)
intprint("Strassen multiply: ", strassen(R,R))
- Output:
Regular multiply: [19 43; 22 50] Strassen multiply: [19 43; 22 50] Regular multiply: [1 0 0 0; 0 1 0 0; 0 0 1 0; 0 0 0 1] Strassen multiply: [1 0 0 0; 0 1 0 0; 0 0 1 0; 0 0 0 1] Regular multiply: [1 2 3 4; 5 6 7 8; 9 10 11 12; 13 14 15 16] Strassen multiply: [1 2 3 4; 5 6 7 8; 9 10 11 12; 13 14 15 16] Regular multiply: [0 1; -1 0] Strassen multiply: [0 1; -1 0]
Output is the same as the dynamically padded version.
function Strassen(A, B)
n = size(A, 1)
if n == 1
return A * B
@views A11 = A[1:n÷2, 1:n÷2]
@views A12 = A[1:n÷2, n÷2+1:n]
@views A21 = A[n÷2+1:n, 1:n÷2]
@views A22 = A[n÷2+1:n, n÷2+1:n]
@views B11 = B[1:n÷2, 1:n÷2]
@views B12 = B[1:n÷2, n÷2+1:n]
@views B21 = B[n÷2+1:n, 1:n÷2]
@views B22 = B[n÷2+1:n, n÷2+1:n]
P1 = Strassen(A12 - A22, B21 + B22)
P2 = Strassen(A11 + A22, B11 + B22)
P3 = Strassen(A11 - A21, B11 + B12)
P4 = Strassen(A11 + A12, B22)
P5 = Strassen(A11, B12 - B22)
P6 = Strassen(A22, B21 - B11)
P7 = Strassen(A21 + A22, B11)
C11 = P1 + P2 - P4 + P6
C12 = P4 + P5
C21 = P6 + P7
C22 = P2 - P3 + P5 - P7
return [C11 C12; C21 C22]
const A = [[1, 2] [3, 4]]
const B = [[5, 6] [7, 8]]
const C = [[1, 1, 1, 1] [2, 4, 8, 16] [3, 9, 27, 81] [4, 16, 64, 256]]
const D = [[4, -3, 4/3, -1/4] [-13/3, 19/4, -7/3, 11/24] [3/2, -2, 7/6, -1/4] [-1/6, 1/4, -1/6, 1/24]]
const E = [[1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16]]
const F = [[1, 0, 0, 0] [0, 1, 0, 0] [0, 0, 1, 0] [0, 0, 0, 1]]
intprint(s, mat) = println(s, map(x -> Int(round(x, digits=8)), mat)')
intprint("Regular multiply: ", A' * B')
intprint("Strassen multiply: ", Strassen(Matrix(A'), Matrix(B')))
intprint("Regular multiply: ", C * D)
intprint("Strassen multiply: ", Strassen(C, D))
intprint("Regular multiply: ", E * F)
intprint("Strassen multiply: ", Strassen(E, F))
const r = sqrt(2)/2
const R = [[r, r] [-r, r]]
intprint("Regular multiply: ", R * R)
intprint("Strassen multiply: ", Strassen(R,R))
clear all;close all;clc;
A = [1, 2; 3, 4];
B = [5, 6; 7, 8];
C = [1, 1, 1, 1; 2, 4, 8, 16; 3, 9, 27, 81; 4, 16, 64, 256];
D = [4, -3, 4/3, -1/4; -13/3, 19/4, -7/3, 11/24; 3/2, -2, 7/6, -1/4; -1/6, 1/4, -1/6, 1/24];
E = [1, 2, 3, 4; 5, 6, 7, 8; 9, 10, 11, 12; 13, 14, 15, 16];
F = eye(4);
disp('Regular multiply: ');
disp(A' * B');
disp('Strassen multiply: ');
disp(Strassen(A', B'));
disp('Regular multiply: ');
disp(C * D);
disp('Strassen multiply: ');
disp(Strassen(C, D));
disp('Regular multiply: ');
disp(E * F);
disp('Strassen multiply: ');
disp(Strassen(E, F));
r = sqrt(2)/2;
R = [r, r; -r, r];
disp('Regular multiply: ');
disp(R * R);
disp('Strassen multiply: ');
disp(Strassen(R, R));
function C = Strassen(A, B)
n = size(A, 1);
if n == 1
C = A * B;
A11 = A(1:n/2, 1:n/2);
A12 = A(1:n/2, n/2+1:n);
A21 = A(n/2+1:n, 1:n/2);
A22 = A(n/2+1:n, n/2+1:n);
B11 = B(1:n/2, 1:n/2);
B12 = B(1:n/2, n/2+1:n);
B21 = B(n/2+1:n, 1:n/2);
B22 = B(n/2+1:n, n/2+1:n);
P1 = Strassen(A12 - A22, B21 + B22);
P2 = Strassen(A11 + A22, B11 + B22);
P3 = Strassen(A11 - A21, B11 + B12);
P4 = Strassen(A11 + A12, B22);
P5 = Strassen(A11, B12 - B22);
P6 = Strassen(A22, B21 - B11);
P7 = Strassen(A21 + A22, B11);
C11 = P1 + P2 - P4 + P6;
C12 = P4 + P5;
C21 = P6 + P7;
C22 = P2 - P3 + P5 - P7;
C = [C11 C12; C21 C22];
- Output:
Regular multiply: 23 31 34 46 Strassen multiply: 23 31 34 46 Regular multiply: 1.0000 0 -0.0000 -0.0000 0.0000 1.0000 -0.0000 -0.0000 0 0 1.0000 0 0.0000 0 0.0000 1.0000 Strassen multiply: 1.0000 0.0000 -0.0000 -0.0000 -0.0000 1.0000 -0.0000 0.0000 0 0 1.0000 0.0000 0 0 -0.0000 1.0000 Regular multiply: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 Strassen multiply: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 Regular multiply: 0 1.0000 -1.0000 0 Strassen multiply: 0 1.0000 -1.0000 0
import math, sequtils, strutils
type Matrix = seq[seq[float]]
template rows(m: Matrix): Positive = m.len
template cols(m: Matrix): Positive = m[0].len
func `+`(m1, m2: Matrix): Matrix =
doAssert m1.rows == m2.rows and m1.cols == m2.cols, "Matrices must have the same dimensions."
result = newSeqWith(m1.rows, newSeq[float](m1.cols))
for i in 0..<m1.rows:
for j in 0..<m1.cols:
result[i][j] = m1[i][j] + m2[i][j]
func `-`(m1, m2: Matrix): Matrix =
doAssert m1.rows == m2.rows and m1.cols == m2.cols, "Matrices must have the same dimensions."
result = newSeqWith(m1.rows, newSeq[float](m1.cols))
for i in 0..<m1.rows:
for j in 0..<m1.cols:
result[i][j] = m1[i][j] - m2[i][j]
func `*`(m1, m2: Matrix): Matrix =
doAssert m1.cols == m2.rows, "Cannot multiply these matrices."
result = newSeqWith(m1.rows, newSeq[float](m2.cols))
for i in 0..<m1.rows:
for j in 0..<m2.cols:
for k in 0..<m2.rows:
result[i][j] += m1[i][k] * m2[k][j]
func toString(m: Matrix; p: Natural): string =
## Round all elements to 'p' places.
var res: seq[string]
let pow = 10.0^p
for row in m:
var line: seq[string]
for val in row:
let r = round(val * pow) / pow
var s = r.formatFloat(precision = -1)
if s == "-0": s = "0"
line.add s
res.add '[' & line.join(" ") & ']'
result = '[' & res.join(" ") & ']'
func params(r, c: int): array[4, array[6, int]] =
[[0, r, 0, c, 0, 0],
[0, r, c, 2 * c, 0, c],
[r, 2 * r, 0, c, r, 0],
[r, 2 * r, c, 2 * c, r, c]]
func toQuarters(m: Matrix): array[4, Matrix] =
r = m.rows() div 2
c = m.cols() div 2
p = params(r, c)
for k in 0..3:
var q = newSeqWith(r, newSeq[float](c))
for i in p[k][0]..<p[k][1]:
for j in p[k][2]..<p[k][3]:
q[i-p[k][4]][j-p[k][5]] = m[i][j]
result[k] = move(q)
func fromQuarters(q: array[4, Matrix]): Matrix =
r = q[0].rows
c = q[0].cols
let p = params(r, c)
r *= 2
c *= 2
result = newSeqWith(r, newSeq[float](c))
for k in 0..3:
for i in p[k][0]..<p[k][1]:
for j in p[k][2]..<p[k][3]:
result[i][j] = q[k][i-p[k][4]][j-p[k][5]]
func strassen(a, b: Matrix): Matrix =
doAssert a.rows == a.cols() and b.rows == b.cols and a.rows == b.rows,
"Matrices must be square and of equal size."
doAssert a.rows != 0 and (a.rows and (a.rows-1)) == 0,
"Size of matrices must be a power of two."
if a.rows == 1: return a * b
qa = a.toQuarters()
qb = b.toQuarters()
p1 = strassen(qa[1] - qa[3], qb[2] + qb[3])
p2 = strassen(qa[0] + qa[3], qb[0] + qb[3])
p3 = strassen(qa[0] - qa[2], qb[0] + qb[1])
p4 = strassen(qa[0] + qa[1], qb[3])
p5 = strassen(qa[0], qb[1] - qb[3])
p6 = strassen(qa[3], qb[2] - qb[0])
p7 = strassen(qa[2] + qa[3], qb[0])
var q: array[4, Matrix]
q[0] = p1 + p2 - p4 + p6
q[1] = p4 + p5
q[2] = p6 + p7
q[3] = p2 - p3 + p5 - p7
result = fromQuarters(q)
when isMainModule:
a = @[@[float 1, 2],
@[float 3, 4]]
b = @[@[float 5, 6],
@[float 7, 8]]
c = @[@[float 1, 1, 1, 1],
@[float 2, 4, 8, 16],
@[float 3, 9, 27, 81],
@[float 4, 16, 64, 256]]
d = @[@[4.0, -3, 4/3, -1/4],
@[-13/3, 19/4, -7/3, 11/24],
@[3/2, -2, 7/6, -1/4],
@[-1/6, 1/4, -1/6, 1/24]]
e = @[@[float 1, 2, 3, 4],
@[float 5, 6, 7, 8],
@[float 9, 10, 11, 12],
@[float 13, 14, 15, 16]]
f = @[@[float 1, 0, 0, 0],
@[float 0, 1, 0, 0],
@[float 0, 0, 1, 0],
@[float 0, 0, 0, 1]]
echo "Using 'normal' matrix multiplication:"
echo " a * b = ", (a * b).toString(10)
echo " c * d = ", (c * d).toString(6)
echo " e * f = ", (e * f).toString(10)
echo "\nUsing 'Strassen' matrix multiplication:"
echo " a * b = ", strassen(a, b).toString(10)
echo " c * d = ", strassen(c, d).toString(6)
echo " e * f = ", strassen(e, f).toString(10)
- Output:
Using 'normal' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]] Using 'Strassen' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]]
As noted on wp, you could pad with zeroes, and strip them on exit, instead of crashing for non-square 2n matrices.
with javascript_semantics function strassen(sequence a, b) integer l = length(a) if length(a[1])!=l or length(b)!=l or length(b[1])!=l then crash("two equal square matrices only") end if if l=1 then return sq_mul(a,b) end if if remainder(l,1) then crash("2^n matrices only") end if integer h = l/2 sequence {a11,a12,a21,a22,b11,b12,b21,b22} = repeat(repeat(repeat(0,h),h),8) for i=1 to h do for j=1 to h do a11[i][j] = a[i][j] a12[i][j] = a[i][j+h] a21[i][j] = a[i+h][j] a22[i][j] = a[i+h][j+h] b11[i][j] = b[i][j] b12[i][j] = b[i][j+h] b21[i][j] = b[i+h][j] b22[i][j] = b[i+h][j+h] end for end for sequence p1 = strassen(sq_sub(a12,a22), sq_add(b21,b22)), p2 = strassen(sq_add(a11,a22), sq_add(b11,b22)), p3 = strassen(sq_sub(a11,a21), sq_add(b11,b12)), p4 = strassen(sq_add(a11,a12), b22), p5 = strassen(a11, sq_sub(b12,b22)), p6 = strassen(a22, sq_sub(b21,b11)), p7 = strassen(sq_add(a21,a22), b11), c11 = sq_add(sq_sub(sq_add(p1,p2),p4),p6), c12 = sq_add(p4,p5), c21 = sq_add(p6,p7), c22 = sq_sub(sq_add(sq_sub(p2,p3),p5),p7), c = repeat(repeat(0,l),l) for i=1 to h do for j=1 to h do c[i][j] = c11[i][j] c[i][j+h] = c12[i][j] c[i+h][j] = c21[i][j] c[i+h][j+h] = c22[i][j] end for end for return c end function ppOpt({pp_Nest,1,pp_IntFmt,"%3d",pp_FltFmt,"%3.0f",pp_IntCh,false}) constant A = {{1,2}, {3,4}}, B = {{5,6}, {7,8}} pp(strassen(A,B)) constant C = { { 1, 1, 1, 1 }, { 2, 4, 8, 16 }, { 3, 9, 27, 81 }, { 4, 16, 64, 256 }}, D = { { 4, -3, 4/3, -1/ 4 }, {-13/3, 19/4, -7/3, 11/24 }, { 3/2, -2, 7/6, -1/ 4 }, { -1/6, 1/4, -1/6, 1/24 }} pp(strassen(C,D)) constant F = {{ 1, 2, 3, 4}, { 5, 6, 7, 8}, { 9,10,11,12}, {13,14,15,16}}, G = {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}} pp(strassen(F,G)) constant r = sqrt(2)/2, R = {{ r,r}, {-r,r}} pp(strassen(R,R))
- Output:
Matches that of Matrix_multiplication#Phix, when given the same inputs. Note that a few "-0" show up in the second one (the identity matrix) under pwa/p2js.
{{ 19, 22}, { 43, 50}} {{ 1, 0, 0, 0}, { 0, 1, 0, 0}, { 0, 0, 1, 0}, { 0, 0, 0, 1}} {{ 1, 2, 3, 4}, { 5, 6, 7, 8}, { 9, 10, 11, 12}, { 13, 14, 15, 16}} {{ 0, 1}, { -1, 0}}
"""Matrix multiplication using Strassen's algorithm. Requires Python >= 3.7."""
from __future__ import annotations
from itertools import chain
from typing import List
from typing import NamedTuple
from typing import Optional
class Shape(NamedTuple):
rows: int
cols: int
class Matrix(List):
"""A matrix implemented as a two-dimensional list."""
def block(cls, blocks) -> Matrix:
"""Return a new Matrix assembled from nested blocks."""
m = Matrix()
for hblock in blocks:
for row in zip(*hblock):
return m
def dot(self, b: Matrix) -> Matrix:
"""Return a new Matrix that is the product of this matrix and matrix `b`.
Uses 'simple' or 'naive' matrix multiplication."""
assert self.shape.cols == b.shape.rows
m = Matrix()
for row in self:
new_row = []
for c in range(len(b[0])):
col = [b[r][c] for r in range(len(b))]
new_row.append(sum(x * y for x, y in zip(row, col)))
return m
def __matmul__(self, b: Matrix) -> Matrix:
def __add__(self, b: Matrix) -> Matrix:
"""Matrix addition."""
assert self.shape == b.shape
rows, cols = self.shape
return Matrix(
[[self[i][j] + b[i][j] for j in range(cols)] for i in range(rows)]
def __sub__(self, b: Matrix) -> Matrix:
"""Matrix subtraction."""
assert self.shape == b.shape
rows, cols = self.shape
return Matrix(
[[self[i][j] - b[i][j] for j in range(cols)] for i in range(rows)]
def strassen(self, b: Matrix) -> Matrix:
"""Return a new Matrix that is the product of this matrix and matrix `b`.
Uses strassen algorithm."""
rows, cols = self.shape
assert rows == cols, "matrices must be square"
assert self.shape == b.shape, "matrices must be the same shape"
assert rows and (rows & rows - 1) == 0, "shape must be a power of 2"
if rows == 1:
p = rows // 2 # partition
a11 = Matrix([n[:p] for n in self[:p]])
a12 = Matrix([n[p:] for n in self[:p]])
a21 = Matrix([n[:p] for n in self[p:]])
a22 = Matrix([n[p:] for n in self[p:]])
b11 = Matrix([n[:p] for n in b[:p]])
b12 = Matrix([n[p:] for n in b[:p]])
b21 = Matrix([n[:p] for n in b[p:]])
b22 = Matrix([n[p:] for n in b[p:]])
m1 = (a11 + a22).strassen(b11 + b22)
m2 = (a21 + a22).strassen(b11)
m3 = a11.strassen(b12 - b22)
m4 = a22.strassen(b21 - b11)
m5 = (a11 + a12).strassen(b22)
m6 = (a21 - a11).strassen(b11 + b12)
m7 = (a12 - a22).strassen(b21 + b22)
c11 = m1 + m4 - m5 + m7
c12 = m3 + m5
c21 = m2 + m4
c22 = m1 - m2 + m3 + m6
return Matrix.block([[c11, c12], [c21, c22]])
def round(self, ndigits: Optional[int] = None) -> Matrix:
return Matrix([[round(i, ndigits) for i in row] for row in self])
def shape(self) -> Shape:
cols = len(self[0]) if self else 0
return Shape(len(self), cols)
def examples():
a = Matrix(
[1, 2],
[3, 4],
b = Matrix(
[5, 6],
[7, 8],
c = Matrix(
[1, 1, 1, 1],
[2, 4, 8, 16],
[3, 9, 27, 81],
[4, 16, 64, 256],
d = Matrix(
[4, -3, 4 / 3, -1 / 4],
[-13 / 3, 19 / 4, -7 / 3, 11 / 24],
[3 / 2, -2, 7 / 6, -1 / 4],
[-1 / 6, 1 / 4, -1 / 6, 1 / 24],
e = Matrix(
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
f = Matrix(
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
print("Naive matrix multiplication:")
print(f" a * b = {a @ b}")
print(f" c * d = {(c @ d).round()}")
print(f" e * f = {e @ f}")
print("Strassen's matrix multiplication:")
print(f" a * b = {a.strassen(b)}")
print(f" c * d = {c.strassen(d).round()}")
print(f" e * f = {e.strassen(f)}")
if __name__ == "__main__":
- Output:
Naive matrix multiplication: a * b = [[19, 22], [43, 50]] c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] Strassen's matrix multiplication: a * b = [[19, 22], [43, 50]] c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
Special thanks go to the module author, Fernando Santagata, on showing how to deal with a pass-by-value case.
# 20210126 Raku programming solution
use Math::Libgsl::Constants;
use Math::Libgsl::Matrix;
use Math::Libgsl::BLAS;
my @M;
sub SQM (\in) { # create custom sq matrix from CSV
die "Not a ■" if (my \L = in.split(/\,/)).sqrt != (my \size = L.sqrt.Int);
my Math::Libgsl::Matrix \M .= new: size, size;
for ^size Z L.rotor(size) -> ($i, @row) { M.set-row: $i, @row }
sub infix:<⊗>(\x,\y) { # custom multiplication
my Math::Libgsl::Matrix \z .= new: x.size1, x.size2;
dgemm(CblasNoTrans, CblasNoTrans, 1, x, y, 1, z);
sub infix:<⊕>(\x,\y) { # custom addition
my Math::Libgsl::Matrix \z .= new: x.size1, x.size2;
sub infix:<⊖>(\x,\y) { # custom subtraction
my Math::Libgsl::Matrix \z .= new: x.size1, x.size2;
sub Strassen($A, $B) {
{ return $A ⊗ $B } if (my \n = $A.size1) == 1;
my Math::Libgsl::Matrix ($A11,$A12,$A21,$A22,$B11,$B12,$B21,$B22);
my Math::Libgsl::Matrix ($P1,$P2,$P3,$P4,$P5,$P6,$P7);
my Math::Libgsl::Matrix::View ($mv1,$mv2,$mv3,$mv4,$mv5,$mv6,$mv7,$mv8);
($mv1,$mv2,$mv3,$mv4,$mv5,$mv6,$mv7,$mv8)».=new ;
my \half = n div 2; # dimension of quarter submatrices
$A11 = $mv1.submatrix($A, 0,0, half,half); #
$A12 = $mv2.submatrix($A, 0,half, half,half); # create quarter views
$A21 = $mv3.submatrix($A, half,0, half,half); # of operand matrices
$A22 = $mv4.submatrix($A, half,half, half,half); #
$B11 = $mv5.submatrix($B, 0,0, half,half); # 11 12
$B12 = $mv6.submatrix($B, 0,half, half,half); #
$B21 = $mv7.submatrix($B, half,0, half,half); # 21 22
$B22 = $mv8.submatrix($B, half,half, half,half); #
$P1 = Strassen($A12 ⊖ $A22, $B21 ⊕ $B22);
$P2 = Strassen($A11 ⊕ $A22, $B11 ⊕ $B22);
$P3 = Strassen($A11 ⊖ $A21, $B11 ⊕ $B12);
$P4 = Strassen($A11 ⊕ $A12, $B22 );
$P5 = Strassen($A11, $B12 ⊖ $B22);
$P6 = Strassen($A22, $B21 ⊖ $B11);
$P7 = Strassen($A21 ⊕ $A22, $B11 );
my Math::Libgsl::Matrix $C .= new: n, n; # Build C from
my Math::Libgsl::Matrix::View ($mvC11,$mvC12,$mvC21,$mvC22); # C11 C12
($mvC11,$mvC12,$mvC21,$mvC22)».=new ; # C21 C22
given $mvC11.submatrix($C, 0,0, half,half) { .add: (($P1 ⊕ $P2) ⊖ $P4) ⊕ $P6 };
given $mvC12.submatrix($C, 0,half, half,half) { .add: $P4 ⊕ $P5 };
given $mvC21.submatrix($C, half,0, half,half) { .add: $P6 ⊕ $P7 };
given $mvC22.submatrix($C, half,half, half,half) { .add: (($P2 ⊖ $P3) ⊕ $P5) ⊖ $P7 };
for $=pod[0].contents { next if /^\n$/ ; @M.append: SQM $_ }
for @M.rotor(2) {
my $product = @_[0] ⊗ @_[1];
# $product.get-row($_)».round(1).fmt('%2d').put for ^$product.size1;
say "Regular multiply:";
$product.get-row($_)».fmt('%.10g').put for ^$product.size1;
$product = Strassen @_[0], @_[1];
say "Strassen multiply:";
$product.get-row($_)».fmt('%.10g').put for ^$product.size1;
=begin code
=end code
- Output:
Regular multiply: 19 22 43 50 Strassen multiply: 19 22 43 50 Regular multiply: 1 0 -1.387778781e-16 -2.081668171e-17 1.33226763e-15 1 -4.440892099e-16 -1.110223025e-16 0 0 1 0 7.105427358e-15 0 7.105427358e-15 1 Strassen multiply: 1 5.684341886e-14 -2.664535259e-15 -1.110223025e-16 -1.136868377e-13 1 -7.105427358e-15 2.220446049e-15 0 0 1 5.684341886e-14 0 0 -2.273736754e-13 1 Regular multiply: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 Strassen multiply: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
import scala.math
object MatrixOperations {
type Matrix = Array[Array[Double]]
implicit class RichMatrix(val m: Matrix) {
def rows: Int = m.length
def cols: Int = m(0).length
def add(m2: Matrix): Matrix = {
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) + m2(i)(j))
def sub(m2: Matrix): Matrix = {
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) - m2(i)(j))
def mul(m2: Matrix): Matrix = {
require(m.cols == m2.rows, "Cannot multiply these matrices.")
Array.tabulate(m.rows, m2.cols)((i, j) =>
(0 until m.cols).map(k => m(i)(k) * m2(k)(j)).sum
def toString(p: Int): String = {
val pow = math.pow(10, p) =>
.map(value => (math.round(value * pow) / pow).toString)
.mkString("[", ", ", "]")
).mkString("[", ",\n ", "]")
def toQuarters(m: Matrix): Array[Matrix] = {
val r = m.rows / 2
val c = m.cols / 2
val p = params(r, c)
(0 until 4).map { k =>
Array.tabulate(r, c)((i, j) => m(p(k)(0) + i)(p(k)(2) + j))
def fromQuarters(q: Array[Matrix]): Matrix = {
val r = q(0).rows
val c = q(0).cols
val p = params(r, c)
Array.tabulate(r * 2, c * 2)((i, j) => q((i / r) * 2 + j / c)(i % r)(j % c))
def strassen(a: Matrix, b: Matrix): Matrix = {
a.rows == a.cols && b.rows == b.cols && a.rows == b.rows,
"Matrices must be square and of equal size."
a.rows != 0 && (a.rows & (a.rows - 1)) == 0,
"Size of matrices must be a power of two."
if (a.rows == 1) {
return a.mul(b)
val qa = toQuarters(a)
val qb = toQuarters(b)
val p1 = strassen(qa(1).sub(qa(3)), qb(2).add(qb(3)))
val p2 = strassen(qa(0).add(qa(3)), qb(0).add(qb(3)))
val p3 = strassen(qa(0).sub(qa(2)), qb(0).add(qb(1)))
val p4 = strassen(qa(0).add(qa(1)), qb(3))
val p5 = strassen(qa(0), qb(1).sub(qb(3)))
val p6 = strassen(qa(3), qb(2).sub(qb(0)))
val p7 = strassen(qa(2).add(qa(3)), qb(0))
val q = Array(
private def params(r: Int, c: Int): Array[Array[Int]] = {
Array(0, r, 0, c, 0, 0),
Array(0, r, c, 2 * c, 0, c),
Array(r, 2 * r, 0, c, r, 0),
Array(r, 2 * r, c, 2 * c, r, c)
def main(args: Array[String]): Unit = {
val a: Matrix = Array(Array(1.0, 2.0), Array(3.0, 4.0))
val b: Matrix = Array(Array(5.0, 6.0), Array(7.0, 8.0))
val c: Matrix = Array(
Array(1.0, 1.0, 1.0, 1.0),
Array(2.0, 4.0, 8.0, 16.0),
Array(3.0, 9.0, 27.0, 81.0),
Array(4.0, 16.0, 64.0, 256.0)
val d: Matrix = Array(
Array(4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0),
Array(-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0),
Array(3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0),
Array(-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0)
val e: Matrix = Array(
Array(1.0, 2.0, 3.0, 4.0),
Array(5.0, 6.0, 7.0, 8.0),
Array(9.0, 10.0, 11.0, 12.0),
Array(13.0, 14.0, 15.0, 16.0)
val f: Matrix = Array(
Array(1.0, 0.0, 0.0, 0.0),
Array(0.0, 1.0, 0.0, 0.0),
Array(0.0, 0.0, 1.0, 0.0),
Array(0.0, 0.0, 0.0, 1.0)
println("Using 'normal' matrix multiplication:")
s" a * b = ${a.mul(b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
println(s" c * d = ${c.mul(d).toString(6)}")
s" e * f = ${e.mul(f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
println("\nUsing 'Strassen' matrix multiplication:")
s" a * b = ${strassen(a, b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
println(s" c * d = ${strassen(c, d).toString(6)}")
s" e * f = ${strassen(e, f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
- Output:
Using 'normal' matrix multiplication: a * b = [[19.0, 22.0], [43.0, 50.0]] c * d = [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]] Using 'Strassen' matrix multiplication: a * b = [[19.0, 22.0], [43.0, 50.0]] c * d = [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]
// Matrix Strassen Multiplication
func strassenMultiply(matrix1: Matrix, matrix2: Matrix) -> Matrix {
precondition(matrix1.columns == matrix2.columns,
"Two matrices can only be matrix multiplied if one has dimensions mxn & the other has dimensions nxp where m, n, p are in R")
// Transform to square matrix
let maxColumns = Swift.max(matrix1.rows, matrix1.columns, matrix2.rows, matrix2.columns)
let pwr2 = nextPowerOfTwo(num: maxColumns)
var sqrMatrix1 = Matrix(rows: pwr2, columns: pwr2)
var sqrMatrix2 = Matrix(rows: pwr2, columns: pwr2)
// fill square matrix 1 with values
for i in 0..<matrix1.rows {
for j in 0..<matrix1.columns{
sqrMatrix1[i, j] = matrix1[i, j]
// fill square matrix 2 with values
for i in 0..<matrix2.rows {
for j in 0..<matrix2.columns{
sqrMatrix2[i, j] = matrix2[i, j]
// Get strassen result and transfer to array with proper size
let formulaResult = strassenFormula(matrix1: sqrMatrix1, matrix2: sqrMatrix2)
var finalResult = Matrix(rows: matrix1.rows, columns: matrix2.columns)
for i in 0..<finalResult.rows{
for j in 0..<finalResult.columns {
finalResult[i, j] = formulaResult[i, j]
return finalResult
// Calculate next power of 2
func nextPowerOfTwo(num: Int) -> Int {
// formula for next power of 2
return Int(pow(2,(ceil(log2(Double(num))))))
// Multiply Matrices Using Strassen Formula
func strassenFormula(matrix1: Matrix, matrix2: Matrix) -> Matrix {
precondition(matrix1.rows == matrix1.columns && matrix2.rows == matrix2.columns, "Matrices need to be square")
guard matrix1.rows > 1 && matrix2.rows > 1 else { return matrix1 * matrix2 }
let rowHalf = matrix1.rows / 2
// Strassen Formula
// p1 = a(f-h) p2 = (a+b)h
// p2 = (c+d)e p4 = d(g-e)
// p5 = (a+d)(e+h) p6 = (b-d)(g+h)
// p7 = (a-c)(e+f)
|a b| x |e f| = |(p5+p4-p2+p6) (p1+p2)|
|c d| |g h| |(p3+p4) (p1+p5-p3-p7)|
Matrix 1 Matrix 2 Result
// create empty matrices for a, b, c, d, e, f, g, h
var a = Matrix(rows: rowHalf, columns: rowHalf)
var b = Matrix(rows: rowHalf, columns: rowHalf)
var c = Matrix(rows: rowHalf, columns: rowHalf)
var d = Matrix(rows: rowHalf, columns: rowHalf)
var e = Matrix(rows: rowHalf, columns: rowHalf)
var f = Matrix(rows: rowHalf, columns: rowHalf)
var g = Matrix(rows: rowHalf, columns: rowHalf)
var h = Matrix(rows: rowHalf, columns: rowHalf)
// fill the matrices with values
for i in 0..<rowHalf {
for j in 0..<rowHalf {
a[i, j] = matrix1[i, j]
b[i, j] = matrix1[i, j+rowHalf]
c[i, j] = matrix1[i+rowHalf, j]
d[i, j] = matrix1[i+rowHalf, j+rowHalf]
e[i, j] = matrix2[i, j]
f[i, j] = matrix2[i, j+rowHalf]
g[i, j] = matrix2[i+rowHalf, j]
h[i, j] = matrix2[i+rowHalf, j+rowHalf]
// a * (f - h)
let p1 = strassenFormula(matrix1: a, matrix2: (f - h))
// (a + b) * h
let p2 = strassenFormula(matrix1: (a + b), matrix2: h)
// (c + d) * e
let p3 = strassenFormula(matrix1: (c + d), matrix2: e)
// d * (g - e)
let p4 = strassenFormula(matrix1: d, matrix2: (g - e))
// (a + d) * (e + h)
let p5 = strassenFormula(matrix1: (a + d), matrix2: (e + h))
// (b - d) * (g + h)
let p6 = strassenFormula(matrix1: (b - d), matrix2: (g + h))
// (a - c) * (e + f)
let p7 = strassenFormula(matrix1: (a - c), matrix2: (e + f))
// p5 + p4 - p2 + p6
let result11 = p5 + p4 - p2 + p6
// p1 + p2
let result12 = p1 + p2
// p3 + p4
let result21 = p3 + p4
// p1 + p5 - p3 - p7
let result22 = p1 + p5 - p3 - p7
// create an empty matrix for result and fill with values
var result = Matrix(rows: matrix1.rows, columns: matrix1.rows)
for i in 0..<rowHalf {
for j in 0..<rowHalf {
result[i, j] = result11[i, j]
result[i, j+rowHalf] = result12[i, j]
result[i+rowHalf, j] = result21[i, j]
result[i+rowHalf, j+rowHalf] = result22[i, j]
return result
func main(){
// Matrix Class
var a = Matrix(rows: 2, columns: 2)
a[row: 0] = [1, 2]
a[row: 1] = [3, 4]
var b = Matrix(rows: 2, columns: 2)
b[row: 0] = [5, 6]
b[row: 1] = [7, 8]
var c = Matrix(rows: 4, columns: 4)
c[row: 0] = [1, 1, 1,1]
c[row: 1] = [2, 4, 8, 16]
c[row: 2] = [3, 9, 27, 81]
c[row: 3] = [4, 16, 64, 256]
var d = Matrix(rows: 4, columns: 4)
d[row: 0] = [4, -3, Double(4/3), Double(-1/4)]
d[row: 1] = [Double(-13/3), Double(19/4), Double(-7/3), Double(11/24)]
d[row: 2] = [Double(3/2), Double(-2), Double(7/6), Double(-1/4)]
d[row: 3] = [Double(-1/6), Double(1/4), Double(-1/6), Double(1/24)]
var e = Matrix(rows: 4, columns: 4)
e[row: 0] = [1, 2, 3, 4]
e[row: 1] = [5, 6, 7, 8]
e[row: 2] = [9, 10, 11, 12]
e[row: 3] = [13, 14, 15, 16]
var f = Matrix(rows: 4, columns: 4)
f[row: 0] = [1, 0, 0, 0]
f[row: 1] = [0, 1, 0 ,0]
f[row: 2] = [0 ,0 ,1, 0]
f[row: 3] = [0, 0 ,0 ,1]
let result1 = strassenMultiply(matrix1: a, matrix2: b)
let result2 = strassenMultiply(matrix1: c, matrix2: d)
let result3 = strassenMultiply(matrix1: e, matrix2: f)
- Output:
AxB 19 22 43 50 CxD 1 -1 0 0 0 -6 2 0 3 -27 12 0 16 -76 36 0 ExF 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
Wren doesn't currently have a matrix module so I've written a rudimentary Matrix class with sufficient functionality to complete this task.
I've used the Phix entry's examples to test the Strassen algorithm implementation.
class Matrix {
construct new(a) {
if (a.type != List || a.count == 0 || a[0].type != List || a[0].count == 0 || a[0][0].type != Num) {
Fiber.abort("Argument must be a non-empty two dimensional list of numbers.")
_a = a
rows { _a.count }
cols { _a[0].count }
+(b) {
if (b.type != Matrix) Fiber.abort("Argument must be a matrix.")
if ((this.rows != b.rows) || (this.cols != b.cols)) {
Fiber.abort("Matrices must have the same dimensions.")
var c = List.filled(rows, null)
for (i in 0...rows) {
c[i] = List.filled(cols, 0)
for (j in 0...cols) c[i][j] = _a[i][j] + b[i, j]
- { this * -1 }
-(b) { this + (-b) }
*(b) {
var c = List.filled(rows, null)
if (b is Num) {
for (i in 0...rows) {
c[i] = List.filled(cols, 0)
for (j in 0...cols) c[i][j] = _a[i][j] * b
} else if (b is Matrix) {
if (this.cols != b.rows) Fiber.abort("Cannot multiply these matrices.")
for (i in 0...rows) {
c[i] = List.filled(b.cols, 0)
for (j in 0...b.cols) {
for (k in 0...b.rows) c[i][j] = c[i][j] + _a[i][k] * b[k, j]
} else {
Fiber.abort("Argument must be a matrix or a number.")
[i] { _a[i].toList }
[i, j] { _a[i][j] }
toString { _a.toString }
// rounds all elements to 'p' places
toString(p) {
var s = List.filled(rows, "")
var pow = 10.pow(p)
for (i in 0...rows) {
var t = List.filled(cols, "")
for (j in 0...cols) {
var r = (_a[i][j]*pow).round / pow
t[j] = r.toString
if (t[j] == "-0") t[j] = "0"
s[i] = t.toString
return s
var params = { |r, c|
return [
[0...r, 0...c, 0, 0],
[0...r, c...2*c, 0, c],
[r...2*r, 0...c, r, 0],
[r...2*r, c...2*c, r, c]
var toQuarters = { |m|
var r = (m.rows/2).floor
var c = (m.cols/2).floor
var p =, c)
var quarters = []
for (k in 0..3) {
var q = List.filled(r, null)
for (i in p[k][0]) {
q[i - p[k][2]] = List.filled(c, 0)
for (j in p[k][1]) q[i - p[k][2]][j - p[k][3]] = m[i, j]
return quarters
var fromQuarters = { |q|
var r = q[0].rows
var c = q[0].cols
var p =, c)
r = r * 2
c = c * 2
var m = List.filled(r, null)
for (i in 0...c) m[i] = List.filled(c, 0)
for (k in 0..3) {
for (i in p[k][0]) {
for (j in p[k][1]) m[i][j] = q[k][i - p[k][2], j - p[k][3]]
var strassen // recursive
strassen = { |a, b|
if (a.rows != a.cols || b.rows != b.cols || a.rows != b.rows) {
Fiber.abort("Matrices must be square and of equal size.")
if (a.rows == 0 || (a.rows & (a.rows - 1)) != 0) {
Fiber.abort("Size of matrices must be a power of two.")
if (a.rows == 1) return a * b
var qa =
var qb =
var p1 =[1] - qa[3], qb[2] + qb[3])
var p2 =[0] + qa[3], qb[0] + qb[3])
var p3 =[0] - qa[2], qb[0] + qb[1])
var p4 =[0] + qa[1], qb[3])
var p5 =[0], qb[1] - qb[3])
var p6 =[3], qb[2] - qb[0])
var p7 =[2] + qa[3], qb[0])
var q = List.filled(4, null)
q[0] = p1 + p2 - p4 + p6
q[1] = p4 + p5
q[2] = p6 + p7
q[3] = p2 - p3 + p5 - p7
var a =[ [1,2], [3, 4] ])
var b =[ [5,6], [7, 8] ])
var c =[ [1, 1, 1, 1], [2, 4, 8, 16], [3, 9, 27, 81], [4, 16, 64, 256] ])
var d =[ [4, -3, 4/3, -1/4], [-13/3, 19/4, -7/3, 11/24],
[3/2, -2, 7/6, -1/4], [-1/6, 1/4, -1/6, 1/24] ])
var e =[ [1, 2, 3, 4], [5, 6, 7, 8], [9,10,11,12], [13,14,15,16] ])
var f =[ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ])
System.print("Using 'normal' matrix multiplication:")
System.print(" a * b = %(a * b)")
System.print(" c * d = %((c * d).toString(6))")
System.print(" e * f = %(e * f)")
System.print("\nUsing 'Strassen' matrix multiplication:")
System.print(" a * b = %(, b))")
System.print(" c * d = %(, d).toString(6))")
System.print(" e * f = %(, f))")
- Output:
Using 'normal' matrix multiplication: a * b = [[19, 22], [43, 50]] c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] Using 'Strassen' matrix multiplication: a * b = [[19, 22], [43, 50]] c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
Since the above version was written, a Matrix module has been added and the following version uses it. The output is exactly the same as before.
import "./matrix" for Matrix
var params = { |r, c|
return [
[0...r, 0...c, 0, 0],
[0...r, c...2*c, 0, c],
[r...2*r, 0...c, r, 0],
[r...2*r, c...2*c, r, c]
var toQuarters = { |m|
var r = (m.numRows/2).floor
var c = (m.numCols/2).floor
var p =, c)
var quarters = []
for (k in 0..3) {
var q = List.filled(r, null)
for (i in p[k][0]) {
q[i - p[k][2]] = List.filled(c, 0)
for (j in p[k][1]) q[i - p[k][2]][j - p[k][3]] = m[i, j]
return quarters
var fromQuarters = { |q|
var r = q[0].numRows
var c = q[0].numCols
var p =, c)
r = r * 2
c = c * 2
var m = List.filled(r, null)
for (i in 0...c) m[i] = List.filled(c, 0)
for (k in 0..3) {
for (i in p[k][0]) {
for (j in p[k][1]) m[i][j] = q[k][i - p[k][2], j - p[k][3]]
var strassen // recursive
strassen = { |a, b|
if (!a.isSquare || !b.isSquare || !a.sameSize(b)) {
Fiber.abort("Matrices must be square and of equal size.")
if (a.numRows == 0 || (a.numRows & (a.numRows - 1)) != 0) {
Fiber.abort("Size of matrices must be a power of two.")
if (a.numRows == 1) return a * b
var qa =
var qb =
var p1 =[1] - qa[3], qb[2] + qb[3])
var p2 =[0] + qa[3], qb[0] + qb[3])
var p3 =[0] - qa[2], qb[0] + qb[1])
var p4 =[0] + qa[1], qb[3])
var p5 =[0], qb[1] - qb[3])
var p6 =[3], qb[2] - qb[0])
var p7 =[2] + qa[3], qb[0])
var q = List.filled(4, null)
q[0] = p1 + p2 - p4 + p6
q[1] = p4 + p5
q[2] = p6 + p7
q[3] = p2 - p3 + p5 - p7
var a =[ [1,2], [3, 4] ])
var b =[ [5,6], [7, 8] ])
var c =[ [1, 1, 1, 1], [2, 4, 8, 16], [3, 9, 27, 81], [4, 16, 64, 256] ])
var d =[ [4, -3, 4/3, -1/4], [-13/3, 19/4, -7/3, 11/24],
[3/2, -2, 7/6, -1/4], [-1/6, 1/4, -1/6, 1/24] ])
var e =[ [1, 2, 3, 4], [5, 6, 7, 8], [9,10,11,12], [13,14,15,16] ])
var f =[ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ])
System.print("Using 'normal' matrix multiplication:")
System.print(" a * b = %(a * b)")
System.print(" c * d = %((c * d).toString(6))")
System.print(" e * f = %(e * f)")
System.print("\nUsing 'Strassen' matrix multiplication:")
System.print(" a * b = %(, b))")
System.print(" c * d = %(, d).toString(6))")
System.print(" e * f = %(, f))")