Strassen's algorithm
In linear algebra, the Strassen algorithm (named after Volker Strassen), is an algorithm for matrix multiplication.

You are encouraged to solve this task according to the task description, using any language you may know.
- Description
It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices, but would be slower than the fastest known algorithms for extremely large matrices.
- Task
Write a routine, function, procedure etc. in your language to implement the Strassen algorithm for matrix multiplication.
While practical implementations of Strassen's algorithm usually switch to standard methods of matrix multiplication for small enough sub-matrices (currently anything less than 512×512 according to Wikipedia), for the purposes of this task you should not switch until reaching a size of 1 or 2.
- Related task
- See also
C#
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
class Matrix
{
public List<List<double>> data;
public int rows;
public int cols;
public Matrix(List<List<double>> data)
{
this.data = data;
rows = data.Count;
cols = (rows > 0) ? data[0].Count : 0;
}
public int GetRows()
{
return rows;
}
public int GetCols()
{
return cols;
}
public void ValidateDimensions(Matrix other)
{
if (GetRows() != other.GetRows() || GetCols() != other.GetCols())
{
throw new InvalidOperationException("Matrices must have the same dimensions.");
}
}
public void ValidateMultiplication(Matrix other)
{
if (GetCols() != other.GetRows())
{
throw new InvalidOperationException("Cannot multiply these matrices.");
}
}
public void ValidateSquarePowerOfTwo()
{
if (GetRows() != GetCols())
{
throw new InvalidOperationException("Matrix must be square.");
}
if (GetRows() == 0 || (GetRows() & (GetRows() - 1)) != 0)
{
throw new InvalidOperationException("Size of matrix must be a power of two.");
}
}
public static Matrix operator +(Matrix a, Matrix b)
{
a.ValidateDimensions(b);
List<List<double>> resultData = new List<List<double>>();
for (int i = 0; i < a.rows; ++i)
{
List<double> row = new List<double>();
for (int j = 0; j < a.cols; ++j)
{
row.Add(a.data[i][j] + b.data[i][j]);
}
resultData.Add(row);
}
return new Matrix(resultData);
}
public static Matrix operator -(Matrix a, Matrix b)
{
a.ValidateDimensions(b);
List<List<double>> resultData = new List<List<double>>();
for (int i = 0; i < a.rows; ++i)
{
List<double> row = new List<double>();
for (int j = 0; j < a.cols; ++j)
{
row.Add(a.data[i][j] - b.data[i][j]);
}
resultData.Add(row);
}
return new Matrix(resultData);
}
public static Matrix operator *(Matrix a, Matrix b)
{
a.ValidateMultiplication(b);
List<List<double>> resultData = new List<List<double>>();
for (int i = 0; i < a.rows; ++i)
{
List<double> row = new List<double>();
for (int j = 0; j < b.cols; ++j)
{
double sum = 0.0;
for (int k = 0; k < b.rows; ++k)
{
sum += a.data[i][k] * b.data[k][j];
}
row.Add(sum);
}
resultData.Add(row);
}
return new Matrix(resultData);
}
public override string ToString()
{
StringBuilder sb = new StringBuilder();
foreach (var row in data)
{
sb.Append("[");
for (int i = 0; i < row.Count; ++i)
{
sb.Append(row[i]);
if (i < row.Count - 1)
{
sb.Append(", ");
}
}
sb.AppendLine("]");
}
return sb.ToString();
}
public string ToStringWithPrecision(int p)
{
StringBuilder sb = new StringBuilder();
double pow = Math.Pow(10.0, p);
foreach (var row in data)
{
sb.Append("[");
for (int i = 0; i < row.Count; ++i)
{
double r = Math.Round(row[i] * pow) / pow;
string formatted = r.ToString($"F{p}");
if (formatted == "-0" + (p > 0 ? "." + new string('0', p) : ""))
{
formatted = "0" + (p > 0 ? "." + new string('0', p) : "");
}
sb.Append(formatted);
if (i < row.Count - 1)
{
sb.Append(", ");
}
}
sb.AppendLine("]");
}
return sb.ToString();
}
private static int[,] GetParams(int r, int c)
{
return new int[,]
{
{0, r, 0, c, 0, 0},
{0, r, c, 2 * c, 0, c},
{r, 2 * r, 0, c, r, 0},
{r, 2 * r, c, 2 * c, r, c}
};
}
public Matrix[] ToQuarters()
{
int r = GetRows() / 2;
int c = GetCols() / 2;
int[,] p = GetParams(r, c);
Matrix[] quarters = new Matrix[4];
for (int k = 0; k < 4; ++k)
{
List<List<double>> qData = new List<List<double>>();
for (int i = 0; i < r; i++)
{
List<double> row = new List<double>();
for (int j = 0; j < c; j++)
{
row.Add(0.0);
}
qData.Add(row);
}
for (int i = p[k, 0]; i < p[k, 1]; ++i)
{
for (int j = p[k, 2]; j < p[k, 3]; ++j)
{
qData[i - p[k, 4]][j - p[k, 5]] = data[i][j];
}
}
quarters[k] = new Matrix(qData);
}
return quarters;
}
public static Matrix FromQuarters(Matrix[] q)
{
int r = q[0].GetRows();
int c = q[0].GetCols();
int[,] p = GetParams(r, c);
int rows = r * 2;
int cols = c * 2;
List<List<double>> mData = new List<List<double>>();
for (int i = 0; i < rows; i++)
{
List<double> row = new List<double>();
for (int j = 0; j < cols; j++)
{
row.Add(0.0);
}
mData.Add(row);
}
for (int k = 0; k < 4; ++k)
{
for (int i = p[k, 0]; i < p[k, 1]; ++i)
{
for (int j = p[k, 2]; j < p[k, 3]; ++j)
{
mData[i][j] = q[k].data[i - p[k, 4]][j - p[k, 5]];
}
}
}
return new Matrix(mData);
}
public Matrix Strassen(Matrix other)
{
ValidateSquarePowerOfTwo();
other.ValidateSquarePowerOfTwo();
if (GetRows() != other.GetRows() || GetCols() != other.GetCols())
{
throw new InvalidOperationException("Matrices must be square and of equal size for Strassen multiplication.");
}
if (GetRows() == 1)
{
return this * other;
}
Matrix[] qa = ToQuarters();
Matrix[] qb = other.ToQuarters();
Matrix p1 = (qa[1] - qa[3]).Strassen(qb[2] + qb[3]);
Matrix p2 = (qa[0] + qa[3]).Strassen(qb[0] + qb[3]);
Matrix p3 = (qa[0] - qa[2]).Strassen(qb[0] + qb[1]);
Matrix p4 = (qa[0] + qa[1]).Strassen(qb[3]);
Matrix p5 = qa[0].Strassen(qb[1] - qb[3]);
Matrix p6 = qa[3].Strassen(qb[2] - qb[0]);
Matrix p7 = (qa[2] + qa[3]).Strassen(qb[0]);
Matrix[] q = new Matrix[4];
q[0] = p1 + p2 - p4 + p6;
q[1] = p4 + p5;
q[2] = p6 + p7;
q[3] = p2 - p3 + p5 - p7;
return FromQuarters(q);
}
}
class Program
{
static void Main(string[] args)
{
Matrix a = new Matrix(new List<List<double>> { new List<double> { 1.0, 2.0 }, new List<double> { 3.0, 4.0 } });
Matrix b = new Matrix(new List<List<double>> { new List<double> { 5.0, 6.0 }, new List<double> { 7.0, 8.0 } });
Matrix c = new Matrix(new List<List<double>>
{
new List<double> { 1.0, 1.0, 1.0, 1.0 },
new List<double> { 2.0, 4.0, 8.0, 16.0 },
new List<double> { 3.0, 9.0, 27.0, 81.0 },
new List<double> { 4.0, 16.0, 64.0, 256.0 }
});
Matrix d = new Matrix(new List<List<double>>
{
new List<double> { 4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0 },
new List<double> { -13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0 },
new List<double> { 3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0 },
new List<double> { -1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0 }
});
Matrix e = new Matrix(new List<List<double>>
{
new List<double> { 1.0, 2.0, 3.0, 4.0 },
new List<double> { 5.0, 6.0, 7.0, 8.0 },
new List<double> { 9.0, 10.0, 11.0, 12.0 },
new List<double> { 13.0, 14.0, 15.0, 16.0 }
});
Matrix f = new Matrix(new List<List<double>>
{
new List<double> { 1.0, 0.0, 0.0, 0.0 },
new List<double> { 0.0, 1.0, 0.0, 0.0 },
new List<double> { 0.0, 0.0, 1.0, 0.0 },
new List<double> { 0.0, 0.0, 0.0, 1.0 }
});
Console.WriteLine("Using 'normal' matrix multiplication:");
Console.WriteLine($" a * b = {a * b}");
Console.WriteLine($" c * d = {(c * d).ToStringWithPrecision(6)}");
Console.WriteLine($" e * f = {e * f}");
Console.WriteLine("\nUsing 'Strassen' matrix multiplication:");
Console.WriteLine($" a * b = {a.Strassen(b)}");
Console.WriteLine($" c * d = {c.Strassen(d).ToStringWithPrecision(6)}");
Console.WriteLine($" e * f = {e.Strassen(f)}");
}
}
- Output:
Using 'normal' matrix multiplication: a * b = [19, 22] [43, 50] c * d = [1.000000, 0.000000, 0.000000, 0.000000] [0.000000, 1.000000, 0.000000, 0.000000] [0.000000, 0.000000, 1.000000, 0.000000] [0.000000, 0.000000, 0.000000, 1.000000] e * f = [1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16] Using 'Strassen' matrix multiplication: a * b = [19, 22] [43, 50] c * d = [1.000000, 0.000000, 0.000000, 0.000000] [0.000000, 1.000000, 0.000000, 0.000000] [0.000000, 0.000000, 1.000000, 0.000000] [0.000000, 0.000000, 0.000000, 1.000000] e * f = [1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16]
C++
#include <iostream>
#include <vector>
#include <iomanip>
#include <cmath>
#include <sstream>
#include <stdexcept>
class Matrix {
public:
std::vector<std::vector<double>> data;
size_t rows;
size_t cols;
Matrix(const std::vector<std::vector<double>>& data) : data(data) {
rows = data.size();
cols = (rows > 0) ? data[0].size() : 0;
}
size_t getRows() const {
return rows;
}
size_t getCols() const {
return cols;
}
void validateDimensions(const Matrix& other) const {
if (getRows() != other.getRows() || getCols() != other.getCols()) {
throw std::runtime_error("Matrices must have the same dimensions.");
}
}
void validateMultiplication(const Matrix& other) const {
if (getCols() != other.getRows()) {
throw std::runtime_error("Cannot multiply these matrices.");
}
}
void validateSquarePowerOfTwo() const {
if (getRows() != getCols()) {
throw std::runtime_error("Matrix must be square.");
}
if (getRows() == 0 || (getRows() & (getRows() - 1)) != 0) {
throw std::runtime_error("Size of matrix must be a power of two.");
}
}
Matrix operator+(const Matrix& other) const {
validateDimensions(other);
std::vector<std::vector<double>> result_data(rows, std::vector<double>(cols));
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
result_data[i][j] = data[i][j] + other.data[i][j];
}
}
return Matrix(result_data);
}
Matrix operator-(const Matrix& other) const {
validateDimensions(other);
std::vector<std::vector<double>> result_data(rows, std::vector<double>(cols));
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
result_data[i][j] = data[i][j] - other.data[i][j];
}
}
return Matrix(result_data);
}
Matrix operator*(const Matrix& other) const {
validateMultiplication(other);
std::vector<std::vector<double>> result_data(rows, std::vector<double>(other.cols));
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < other.cols; ++j) {
double sum = 0.0;
for (size_t k = 0; k < other.rows; ++k) {
sum += data[i][k] * other.data[k][j];
}
result_data[i][j] = sum;
}
}
return Matrix(result_data);
}
friend std::ostream& operator<<(std::ostream& os, const Matrix& matrix) {
for (const auto& row : matrix.data) {
os << "[";
for (size_t i = 0; i < row.size(); ++i) {
os << row[i];
if (i < row.size() - 1) {
os << ", ";
}
}
os << "]" << std::endl;
}
return os;
}
std::string toStringWithPrecision(size_t p) const {
std::stringstream ss;
ss << std::fixed << std::setprecision(p);
double pow = std::pow(10.0, p);
for (const auto& row : data) {
ss << "[";
for (size_t i = 0; i < row.size(); ++i) {
double r = std::round(row[i] * pow) / pow;
std::string formatted = ss.str();
ss.str(std::string());
ss << r;
formatted = ss.str();
if (formatted == "-0") {
ss.str(std::string());
ss << "0";
formatted = ss.str();
}
ss.str(std::string());
ss << formatted;
if (i < row.size() - 1) {
ss << ", ";
}
}
ss << "]" << std::endl;
}
return ss.str();
}
static std::array<std::array<size_t, 6>, 4> params(size_t r, size_t c) {
return {
{{{0, r, 0, c, 0, 0}},
{{0, r, c, 2 * c, 0, c}},
{{r, 2 * r, 0, c, r, 0}},
{{r, 2 * r, c, 2 * c, r, c}}}
};
}
std::array<Matrix, 4> toQuarters() const {
size_t r = getRows() / 2;
size_t c = getCols() / 2;
auto p = Matrix::params(r, c);
std::array<Matrix, 4> quarters = {
Matrix(std::vector<std::vector<double>>(r, std::vector<double>(c, 0.0))),
Matrix(std::vector<std::vector<double>>(r, std::vector<double>(c, 0.0))),
Matrix(std::vector<std::vector<double>>(r, std::vector<double>(c, 0.0))),
Matrix(std::vector<std::vector<double>>(r, std::vector<double>(c, 0.0)))
};
for (size_t k = 0; k < 4; ++k) {
std::vector<std::vector<double>> q_data(r, std::vector<double>(c));
for (size_t i = p[k][0]; i < p[k][1]; ++i) {
for (size_t j = p[k][2]; j < p[k][3]; ++j) {
q_data[i - p[k][4]][j - p[k][5]] = data[i][j];
}
}
quarters[k] = Matrix(q_data);
}
return quarters;
}
static Matrix fromQuarters(const std::array<Matrix, 4>& q) {
size_t r = q[0].getRows();
size_t c = q[0].getCols();
auto p = Matrix::params(r, c);
size_t rows = r * 2;
size_t cols = c * 2;
std::vector<std::vector<double>> m_data(rows, std::vector<double>(cols, 0.0));
for (size_t k = 0; k < 4; ++k) {
for (size_t i = p[k][0]; i < p[k][1]; ++i) {
for (size_t j = p[k][2]; j < p[k][3]; ++j) {
m_data[i][j] = q[k].data[i - p[k][4]][j - p[k][5]];
}
}
}
return Matrix(m_data);
}
Matrix strassen(const Matrix& other) const {
validateSquarePowerOfTwo();
other.validateSquarePowerOfTwo();
if (getRows() != other.getRows() || getCols() != other.getCols()) {
throw std::runtime_error("Matrices must be square and of equal size for Strassen multiplication.");
}
if (getRows() == 1) {
return *this * other;
}
auto qa = toQuarters();
auto qb = other.toQuarters();
Matrix p1 = (qa[1] - qa[3]).strassen(qb[2] + qb[3]);
Matrix p2 = (qa[0] + qa[3]).strassen(qb[0] + qb[3]);
Matrix p3 = (qa[0] - qa[2]).strassen(qb[0] + qb[1]);
Matrix p4 = (qa[0] + qa[1]).strassen(qb[3]);
Matrix p5 = qa[0].strassen(qb[1] - qb[3]);
Matrix p6 = qa[3].strassen(qb[2] - qb[0]);
Matrix p7 = (qa[2] + qa[3]).strassen(qb[0]);
std::array<Matrix, 4> q = {
Matrix(std::vector<std::vector<double>>(qa[0].getRows(), std::vector<double>(qa[0].getCols(), 0.0))),
Matrix(std::vector<std::vector<double>>(qa[0].getRows(), std::vector<double>(qa[0].getCols(), 0.0))),
Matrix(std::vector<std::vector<double>>(qa[0].getRows(), std::vector<double>(qa[0].getCols(), 0.0))),
Matrix(std::vector<std::vector<double>>(qa[0].getRows(), std::vector<double>(qa[0].getCols(), 0.0)))
};
q[0] = p1 + p2 - p4 + p6;
q[1] = p4 + p5;
q[2] = p6 + p7;
q[3] = p2 - p3 + p5 - p7;
return Matrix::fromQuarters(q);
}
};
int main() {
Matrix a({ {1.0, 2.0}, {3.0, 4.0} });
Matrix b({ {5.0, 6.0}, {7.0, 8.0} });
Matrix c({ {1.0, 1.0, 1.0, 1.0}, {2.0, 4.0, 8.0, 16.0}, {3.0, 9.0, 27.0, 81.0}, {4.0, 16.0, 64.0, 256.0} });
Matrix d({ {4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0}, {-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0}, {3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0}, {-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0} });
Matrix e({ {1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}, {9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0} });
Matrix f({ {1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0} });
std::cout << "Using 'normal' matrix multiplication:" << std::endl;
std::cout << " a * b = " << a * b << std::endl;
std::cout << " c * d = " << (c * d).toStringWithPrecision(6) << std::endl;
std::cout << " e * f = " << e * f << std::endl;
std::cout << "\nUsing 'Strassen' matrix multiplication:" << std::endl;
std::cout << " a * b = " << a.strassen(b) << std::endl;
std::cout << " c * d = " << c.strassen(d).toStringWithPrecision(6) << std::endl;
std::cout << " e * f = " << e.strassen(f) << std::endl;
return 0;
}
- Output:
Using 'normal' matrix multiplication: a * b = [19, 22] [43, 50] c * d = 1.000000] e * f = [1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16] Using 'Strassen' matrix multiplication: a * b = [19, 22] [43, 50] c * d = 1.000000] e * f = [1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16]
FreeBASIC
Type Matrix
As Integer rows
As Integer cols
As Double dato(Any)
End Type
Function makeMatrix(rows As Integer, cols As Integer) As Matrix
Dim As Matrix m
m.rows = rows
m.cols = cols
Redim m.dato(rows * cols - 1)
Return m
End Function
Function matrixAdd(m1 As Matrix, m2 As Matrix) As Matrix
If m1.rows <> m2.rows Or m1.cols <> m2.cols Then
Print "Matrices must have the same dimensions."
End
End If
Dim As Matrix result = makeMatrix(m1.rows, m1.cols)
For i As Integer = 0 To m1.rows * m1.cols - 1
result.dato(i) = m1.dato(i) + m2.dato(i)
Next
Return result
End Function
Function matrixSub(m1 As Matrix, m2 As Matrix) As Matrix
If m1.rows <> m2.rows Or m1.cols <> m2.cols Then
Print "Matrices must have the same dimensions."
End
End If
Dim As Matrix result = makeMatrix(m1.rows, m1.cols)
For i As Integer = 0 To m1.rows * m1.cols - 1
result.dato(i) = m1.dato(i) - m2.dato(i)
Next
Return result
End Function
Function matrixMul(m1 As Matrix, m2 As Matrix) As Matrix
If m1.cols <> m2.rows Then
Print "Cannot multiply these matrices."
End
End If
Dim As Integer i, j, k
Dim As Matrix result = makeMatrix(m1.rows, m2.cols)
For i = 0 To m1.rows - 1
For j = 0 To m2.cols - 1
Dim As Double sum = 0
For k = 0 To m1.cols - 1
sum += m1.dato(i * m1.cols + k) * m2.dato(k * m2.cols + j)
Next
result.dato(i * result.cols + j) = sum
Next
Next
Return result
End Function
Sub printMatrix(m As Matrix, precision As Integer = 6)
Dim As Integer i, j
Print "[[";
For i = 0 To m.rows - 1
If i > 0 Then Print " [";
For j = 0 To m.cols - 1
Dim As Double valor = m.dato(i * m.cols + j)
valor = Int(valor * (10 ^ precision) + 0.5) / (10 ^ precision)
If Abs(valor) < 1e-10 Then valor = 0
Print Using "&"; valor;
If j < m.cols - 1 Then Print " ";
Next
If i < m.rows - 1 Then Print "]";
Next
Print "]]"
End Sub
Function getQuarter(m As Matrix, quarter As Integer) As Matrix
Dim As Integer i, j
Dim As Integer halfRows, halfCols
halfRows = m.rows \ 2
halfCols = m.cols \ 2
Dim As Matrix result = makeMatrix(halfRows, halfCols)
Dim As Integer rowOffset, colOffset
rowOffset = (quarter \ 2) * halfRows
colOffset = (quarter Mod 2) * halfCols
For i = 0 To halfRows - 1
For j = 0 To halfCols - 1
result.dato(i * halfCols + j) = m.dato((i + rowOffset) * m.cols + j + colOffset)
Next
Next
Return result
End Function
Function combineQuarters(q1 As Matrix, q2 As Matrix, q3 As Matrix, q4 As Matrix) As Matrix
Dim As Integer i, j, n
n = q1.rows
Dim As Matrix result = makeMatrix(n * 2, n * 2)
For i = 0 To n - 1
For j = 0 To n - 1
result.dato(i * result.cols + j) = q1.dato(i * n + j)
result.dato(i * result.cols + j + n) = q2.dato(i * n + j)
result.dato((i + n) * result.cols + j) = q3.dato(i * n + j)
result.dato((i + n) * result.cols + j + n) = q4.dato(i * n + j)
Next
Next
Return result
End Function
Function strassen(a As Matrix, b As Matrix) As Matrix
If a.rows <> a.cols Or b.rows <> b.cols Or a.rows <> b.rows Then
Print "Matrices must be square and of equal size."
End
End If
If a.rows = 1 Then Return matrixMul(a, b)
Dim As Matrix a11 = getQuarter(a, 0)
Dim As Matrix a12 = getQuarter(a, 1)
Dim As Matrix a21 = getQuarter(a, 2)
Dim As Matrix a22 = getQuarter(a, 3)
Dim As Matrix b11 = getQuarter(b, 0)
Dim As Matrix b12 = getQuarter(b, 1)
Dim As Matrix b21 = getQuarter(b, 2)
Dim As Matrix b22 = getQuarter(b, 3)
Dim As Matrix p1 = strassen(matrixSub(a12, a22), matrixAdd(b21, b22))
Dim As Matrix p2 = strassen(matrixAdd(a11, a22), matrixAdd(b11, b22))
Dim As Matrix p3 = strassen(matrixSub(a11, a21), matrixAdd(b11, b12))
Dim As Matrix p4 = strassen(matrixAdd(a11, a12), b22)
Dim As Matrix p5 = strassen(a11, matrixSub(b12, b22))
Dim As Matrix p6 = strassen(a22, matrixSub(b21, b11))
Dim As Matrix p7 = strassen(matrixAdd(a21, a22), b11)
Dim As Matrix c11 = matrixAdd(matrixSub(matrixAdd(p1, p2), p4), p6)
Dim As Matrix c12 = matrixAdd(p4, p5)
Dim As Matrix c21 = matrixAdd(p6, p7)
Dim As Matrix c22 = matrixSub(matrixAdd(matrixSub(p2, p3), p5), p7)
Return combineQuarters(c11, c12, c21, c22)
End Function
Sub main()
' Matrix A (2x2)
Dim As Matrix a = makeMatrix(2, 2)
a.dato(0) = 1: a.dato(1) = 2
a.dato(2) = 3: a.dato(3) = 4
' Matrix B (2x2)
Dim As Matrix b = makeMatrix(2, 2)
b.dato(0) = 5: b.dato(1) = 6
b.dato(2) = 7: b.dato(3) = 8
' Matrix C (4x4)
Dim As Matrix c = makeMatrix(4, 4)
c.dato(0) = 1: c.dato(1) = 1: c.dato(2) = 1: c.dato(3) = 1
c.dato(4) = 2: c.dato(5) = 4: c.dato(6) = 8: c.dato(7) = 16
c.dato(8) = 3: c.dato(9) = 9: c.dato(10) = 27: c.dato(11) = 81
c.dato(12) = 4: c.dato(13) = 16: c.dato(14) = 64: c.dato(15) = 256
' Matrix D (4x4)
Dim As Matrix d = makeMatrix(4, 4)
d.dato(0) = 4: d.dato(1) = -3: d.dato(2) = 4/3: d.dato(3) = -1/4
d.dato(4) = -13/3: d.dato(5) = 19/4: d.dato(6) = -7/3: d.dato(7) = 11/24
d.dato(8) = 3/2: d.dato(9) = -2: d.dato(10) = 7/6: d.dato(11) = -1/4
d.dato(12) = -1/6: d.dato(13) = 1/4: d.dato(14) = -1/6: d.dato(15) = 1/24
' Matrix E (4x4)
Dim As Matrix e = makeMatrix(4, 4)
e.dato(0) = 1: e.dato(1) = 2: e.dato(2) = 3: e.dato(3) = 4
e.dato(4) = 5: e.dato(5) = 6: e.dato(6) = 7: e.dato(7) = 8
e.dato(8) = 9: e.dato(9) = 10: e.dato(10) = 11: e.dato(11) = 12
e.dato(12) = 13: e.dato(13) = 14: e.dato(14) = 15: e.dato(15) = 16
' Matrix F (Identity 4x4)
Dim As Matrix f = makeMatrix(4, 4)
f.dato(0) = 1: f.dato(1) = 0: f.dato(2) = 0: f.dato(3) = 0
f.dato(4) = 0: f.dato(5) = 1: f.dato(6) = 0: f.dato(7) = 0
f.dato(8) = 0: f.dato(9) = 0: f.dato(10) = 1: f.dato(11) = 0
f.dato(12) = 0: f.dato(13) = 0: f.dato(14) = 0: f.dato(15) = 1
Print "Using 'normal' matrix multiplication:"
Print " a * b = ";
printMatrix(matrixMul(a, b))
Print " c * d = ";
printMatrix(matrixMul(c, d), 6)
Print " e * f = ";
printMatrix(matrixMul(e, f))
Print !"\nUsing 'Strassen' matrix multiplication:"
Print " a * b = ";
printMatrix(strassen(a, b))
Print " c * d = ";
printMatrix(strassen(c, d), 6)
Print " e * f = ";
printMatrix(strassen(e, f))
End Sub
main()
Sleep
- Output:
Using 'normal' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]] Using 'Strassen' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]]
Go
Rather than use a library such as gonum, we create a simple Matrix type which is adequate for this task.
package main
import (
"fmt"
"log"
"math"
)
type Matrix [][]float64
func (m Matrix) rows() int { return len(m) }
func (m Matrix) cols() int { return len(m[0]) }
func (m Matrix) add(m2 Matrix) Matrix {
if m.rows() != m2.rows() || m.cols() != m2.cols() {
log.Fatal("Matrices must have the same dimensions.")
}
c := make(Matrix, m.rows())
for i := 0; i < m.rows(); i++ {
c[i] = make([]float64, m.cols())
for j := 0; j < m.cols(); j++ {
c[i][j] = m[i][j] + m2[i][j]
}
}
return c
}
func (m Matrix) sub(m2 Matrix) Matrix {
if m.rows() != m2.rows() || m.cols() != m2.cols() {
log.Fatal("Matrices must have the same dimensions.")
}
c := make(Matrix, m.rows())
for i := 0; i < m.rows(); i++ {
c[i] = make([]float64, m.cols())
for j := 0; j < m.cols(); j++ {
c[i][j] = m[i][j] - m2[i][j]
}
}
return c
}
func (m Matrix) mul(m2 Matrix) Matrix {
if m.cols() != m2.rows() {
log.Fatal("Cannot multiply these matrices.")
}
c := make(Matrix, m.rows())
for i := 0; i < m.rows(); i++ {
c[i] = make([]float64, m2.cols())
for j := 0; j < m2.cols(); j++ {
for k := 0; k < m2.rows(); k++ {
c[i][j] += m[i][k] * m2[k][j]
}
}
}
return c
}
func (m Matrix) toString(p int) string {
s := make([]string, m.rows())
pow := math.Pow(10, float64(p))
for i := 0; i < m.rows(); i++ {
t := make([]string, m.cols())
for j := 0; j < m.cols(); j++ {
r := math.Round(m[i][j]*pow) / pow
t[j] = fmt.Sprintf("%g", r)
if t[j] == "-0" {
t[j] = "0"
}
}
s[i] = fmt.Sprintf("%v", t)
}
return fmt.Sprintf("%v", s)
}
func params(r, c int) [4][6]int {
return [4][6]int{
{0, r, 0, c, 0, 0},
{0, r, c, 2 * c, 0, c},
{r, 2 * r, 0, c, r, 0},
{r, 2 * r, c, 2 * c, r, c},
}
}
func toQuarters(m Matrix) [4]Matrix {
r := m.rows() / 2
c := m.cols() / 2
p := params(r, c)
var quarters [4]Matrix
for k := 0; k < 4; k++ {
q := make(Matrix, r)
for i := p[k][0]; i < p[k][1]; i++ {
q[i-p[k][4]] = make([]float64, c)
for j := p[k][2]; j < p[k][3]; j++ {
q[i-p[k][4]][j-p[k][5]] = m[i][j]
}
}
quarters[k] = q
}
return quarters
}
func fromQuarters(q [4]Matrix) Matrix {
r := q[0].rows()
c := q[0].cols()
p := params(r, c)
r *= 2
c *= 2
m := make(Matrix, r)
for i := 0; i < c; i++ {
m[i] = make([]float64, c)
}
for k := 0; k < 4; k++ {
for i := p[k][0]; i < p[k][1]; i++ {
for j := p[k][2]; j < p[k][3]; j++ {
m[i][j] = q[k][i-p[k][4]][j-p[k][5]]
}
}
}
return m
}
func strassen(a, b Matrix) Matrix {
if a.rows() != a.cols() || b.rows() != b.cols() || a.rows() != b.rows() {
log.Fatal("Matrices must be square and of equal size.")
}
if a.rows() == 0 || (a.rows()&(a.rows()-1)) != 0 {
log.Fatal("Size of matrices must be a power of two.")
}
if a.rows() == 1 {
return a.mul(b)
}
qa := toQuarters(a)
qb := toQuarters(b)
p1 := strassen(qa[1].sub(qa[3]), qb[2].add(qb[3]))
p2 := strassen(qa[0].add(qa[3]), qb[0].add(qb[3]))
p3 := strassen(qa[0].sub(qa[2]), qb[0].add(qb[1]))
p4 := strassen(qa[0].add(qa[1]), qb[3])
p5 := strassen(qa[0], qb[1].sub(qb[3]))
p6 := strassen(qa[3], qb[2].sub(qb[0]))
p7 := strassen(qa[2].add(qa[3]), qb[0])
var q [4]Matrix
q[0] = p1.add(p2).sub(p4).add(p6)
q[1] = p4.add(p5)
q[2] = p6.add(p7)
q[3] = p2.sub(p3).add(p5).sub(p7)
return fromQuarters(q)
}
func main() {
a := Matrix{{1, 2}, {3, 4}}
b := Matrix{{5, 6}, {7, 8}}
c := Matrix{{1, 1, 1, 1}, {2, 4, 8, 16}, {3, 9, 27, 81}, {4, 16, 64, 256}}
d := Matrix{{4, -3, 4.0 / 3, -1.0 / 4}, {-13.0 / 3, 19.0 / 4, -7.0 / 3, 11.0 / 24},
{3.0 / 2, -2, 7.0 / 6, -1.0 / 4}, {-1.0 / 6, 1.0 / 4, -1.0 / 6, 1.0 / 24}}
e := Matrix{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}
f := Matrix{{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}}
fmt.Println("Using 'normal' matrix multiplication:")
fmt.Printf(" a * b = %v\n", a.mul(b))
fmt.Printf(" c * d = %v\n", c.mul(d).toString(6))
fmt.Printf(" e * f = %v\n", e.mul(f))
fmt.Println("\nUsing 'Strassen' matrix multiplication:")
fmt.Printf(" a * b = %v\n", strassen(a, b))
fmt.Printf(" c * d = %v\n", strassen(c, d).toString(6))
fmt.Printf(" e * f = %v\n", strassen(e, f))
}
- Output:
Using 'normal' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]] Using 'Strassen' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]]
Java
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
class Matrix {
public List<List<Double>> data;
public int rows;
public int cols;
public Matrix(List<List<Double>> data) {
this.data = data;
rows = data.size();
cols = (rows > 0) ? data.get(0).size() : 0;
}
public int getRows() {
return rows;
}
public int getCols() {
return cols;
}
public void validateDimensions(Matrix other) {
if (getRows() != other.getRows() || getCols() != other.getCols()) {
throw new RuntimeException("Matrices must have the same dimensions.");
}
}
public void validateMultiplication(Matrix other) {
if (getCols() != other.getRows()) {
throw new RuntimeException("Cannot multiply these matrices.");
}
}
public void validateSquarePowerOfTwo() {
if (getRows() != getCols()) {
throw new RuntimeException("Matrix must be square.");
}
if (getRows() == 0 || (getRows() & (getRows() - 1)) != 0) {
throw new RuntimeException("Size of matrix must be a power of two.");
}
}
public Matrix add(Matrix other) {
validateDimensions(other);
List<List<Double>> resultData = new ArrayList<>();
for (int i = 0; i < rows; ++i) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < cols; ++j) {
row.add(data.get(i).get(j) + other.data.get(i).get(j));
}
resultData.add(row);
}
return new Matrix(resultData);
}
public Matrix subtract(Matrix other) {
validateDimensions(other);
List<List<Double>> resultData = new ArrayList<>();
for (int i = 0; i < rows; ++i) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < cols; ++j) {
row.add(data.get(i).get(j) - other.data.get(i).get(j));
}
resultData.add(row);
}
return new Matrix(resultData);
}
public Matrix multiply(Matrix other) {
validateMultiplication(other);
List<List<Double>> resultData = new ArrayList<>();
for (int i = 0; i < rows; ++i) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < other.cols; ++j) {
double sum = 0.0;
for (int k = 0; k < other.rows; ++k) {
sum += data.get(i).get(k) * other.data.get(k).get(j);
}
row.add(sum);
}
resultData.add(row);
}
return new Matrix(resultData);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (List<Double> row : data) {
sb.append("[");
for (int i = 0; i < row.size(); ++i) {
sb.append(row.get(i));
if (i < row.size() - 1) {
sb.append(", ");
}
}
sb.append("]\n");
}
return sb.toString();
}
public String toStringWithPrecision(int p) {
StringBuilder sb = new StringBuilder();
double pow = Math.pow(10.0, p);
for (List<Double> row : data) {
sb.append("[");
for (int i = 0; i < row.size(); ++i) {
double r = Math.round(row.get(i) * pow) / pow;
String formatted = String.format("%." + p + "f", r);
if (formatted.equals("-0" + (p > 0 ? "." + "0".repeat(p) : ""))) {
formatted = "0" + (p > 0 ? "." + "0".repeat(p) : "");
}
sb.append(formatted);
if (i < row.size() - 1) {
sb.append(", ");
}
}
sb.append("]\n");
}
return sb.toString();
}
private static int[][] getParams(int r, int c) {
return new int[][] {
{0, r, 0, c, 0, 0},
{0, r, c, 2 * c, 0, c},
{r, 2 * r, 0, c, r, 0},
{r, 2 * r, c, 2 * c, r, c}
};
}
public Matrix[] toQuarters() {
int r = getRows() / 2;
int c = getCols() / 2;
int[][] p = getParams(r, c);
Matrix[] quarters = new Matrix[4];
for (int k = 0; k < 4; ++k) {
List<List<Double>> qData = new ArrayList<>();
for (int i = 0; i < r; i++) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < c; j++) {
row.add(0.0);
}
qData.add(row);
}
for (int i = p[k][0]; i < p[k][1]; ++i) {
for (int j = p[k][2]; j < p[k][3]; ++j) {
qData.get(i - p[k][4]).set(j - p[k][5], data.get(i).get(j));
}
}
quarters[k] = new Matrix(qData);
}
return quarters;
}
public static Matrix fromQuarters(Matrix[] q) {
int r = q[0].getRows();
int c = q[0].getCols();
int[][] p = getParams(r, c);
int rows = r * 2;
int cols = c * 2;
List<List<Double>> mData = new ArrayList<>();
for (int i = 0; i < rows; i++) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < cols; j++) {
row.add(0.0);
}
mData.add(row);
}
for (int k = 0; k < 4; ++k) {
for (int i = p[k][0]; i < p[k][1]; ++i) {
for (int j = p[k][2]; j < p[k][3]; ++j) {
mData.get(i).set(j, q[k].data.get(i - p[k][4]).get(j - p[k][5]));
}
}
}
return new Matrix(mData);
}
public Matrix strassen(Matrix other) {
validateSquarePowerOfTwo();
other.validateSquarePowerOfTwo();
if (getRows() != other.getRows() || getCols() != other.getCols()) {
throw new RuntimeException("Matrices must be square and of equal size for Strassen multiplication.");
}
if (getRows() == 1) {
return this.multiply(other);
}
Matrix[] qa = toQuarters();
Matrix[] qb = other.toQuarters();
Matrix p1 = qa[1].subtract(qa[3]).strassen(qb[2].add(qb[3]));
Matrix p2 = qa[0].add(qa[3]).strassen(qb[0].add(qb[3]));
Matrix p3 = qa[0].subtract(qa[2]).strassen(qb[0].add(qb[1]));
Matrix p4 = qa[0].add(qa[1]).strassen(qb[3]);
Matrix p5 = qa[0].strassen(qb[1].subtract(qb[3]));
Matrix p6 = qa[3].strassen(qb[2].subtract(qb[0]));
Matrix p7 = qa[2].add(qa[3]).strassen(qb[0]);
Matrix[] q = new Matrix[4];
q[0] = p1.add(p2).subtract(p4).add(p6);
q[1] = p4.add(p5);
q[2] = p6.add(p7);
q[3] = p2.subtract(p3).add(p5).subtract(p7);
return fromQuarters(q);
}
}
public class Main{
public static void main(String[] args) {
List<List<Double>> aData = new ArrayList<>();
aData.add(Arrays.asList(1.0, 2.0));
aData.add(Arrays.asList(3.0, 4.0));
Matrix a = new Matrix(aData);
List<List<Double>> bData = new ArrayList<>();
bData.add(Arrays.asList(5.0, 6.0));
bData.add(Arrays.asList(7.0, 8.0));
Matrix b = new Matrix(bData);
List<List<Double>> cData = new ArrayList<>();
cData.add(Arrays.asList(1.0, 1.0, 1.0, 1.0));
cData.add(Arrays.asList(2.0, 4.0, 8.0, 16.0));
cData.add(Arrays.asList(3.0, 9.0, 27.0, 81.0));
cData.add(Arrays.asList(4.0, 16.0, 64.0, 256.0));
Matrix c = new Matrix(cData);
List<List<Double>> dData = new ArrayList<>();
dData.add(Arrays.asList(4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0));
dData.add(Arrays.asList(-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0));
dData.add(Arrays.asList(3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0));
dData.add(Arrays.asList(-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0));
Matrix d = new Matrix(dData);
List<List<Double>> eData = new ArrayList<>();
eData.add(Arrays.asList(1.0, 2.0, 3.0, 4.0));
eData.add(Arrays.asList(5.0, 6.0, 7.0, 8.0));
eData.add(Arrays.asList(9.0, 10.0, 11.0, 12.0));
eData.add(Arrays.asList(13.0, 14.0, 15.0, 16.0));
Matrix e = new Matrix(eData);
List<List<Double>> fData = new ArrayList<>();
fData.add(Arrays.asList(1.0, 0.0, 0.0, 0.0));
fData.add(Arrays.asList(0.0, 1.0, 0.0, 0.0));
fData.add(Arrays.asList(0.0, 0.0, 1.0, 0.0));
fData.add(Arrays.asList(0.0, 0.0, 0.0, 1.0));
Matrix f = new Matrix(fData);
System.out.println("Using 'normal' matrix multiplication:");
System.out.println(" a * b = " + a.multiply(b));
System.out.println(" c * d = " + c.multiply(d).toStringWithPrecision(6));
System.out.println(" e * f = " + e.multiply(f));
System.out.println("\nUsing 'Strassen' matrix multiplication:");
System.out.println(" a * b = " + a.strassen(b));
System.out.println(" c * d = " + c.strassen(d).toStringWithPrecision(6));
System.out.println(" e * f = " + e.strassen(f));
}
}
- Output:
Using 'normal' matrix multiplication: a * b = [19.0, 22.0] [43.0, 50.0] c * d = [1.000000, 0.000000, 0.000000, 0.000000] [0.000000, 1.000000, 0.000000, 0.000000] [0.000000, 0.000000, 1.000000, 0.000000] [0.000000, 0.000000, 0.000000, 1.000000] e * f = [1.0, 2.0, 3.0, 4.0] [5.0, 6.0, 7.0, 8.0] [9.0, 10.0, 11.0, 12.0] [13.0, 14.0, 15.0, 16.0] Using 'Strassen' matrix multiplication: a * b = [19.0, 22.0] [43.0, 50.0] c * d = [1.000000, 0.000000, 0.000000, 0.000000] [0.000000, 1.000000, 0.000000, 0.000000] [0.000000, 0.000000, 1.000000, 0.000000] [0.000000, 0.000000, 0.000000, 1.000000] e * f = [1.0, 2.0, 3.0, 4.0] [5.0, 6.0, 7.0, 8.0] [9.0, 10.0, 11.0, 12.0] [13.0, 14.0, 15.0, 16.0]
JavaScript
Version 1
/**
* Represents the dimensions of a matrix.
* @typedef {object} Shape
* @property {number} rows - Number of rows.
* @property {number} cols - Number of columns.
*/
/**
* A matrix implemented as a wrapper around a 2D array.
*/
class Matrix {
/**
* Creates a Matrix instance.
* @param {number[][]} data - A 2D array representing the matrix data.
*/
constructor(data = []) {
if (!Array.isArray(data) || (data.length > 0 && !Array.isArray(data[0]))) {
throw new Error("Matrix data must be a 2D array.");
}
// Basic check for consistent row lengths
if (data.length > 1) {
const firstLen = data[0].length;
if (!data.every(row => row.length === firstLen)) {
throw new Error("Matrix rows must have consistent lengths.");
}
}
this.data = data;
}
/**
* Gets the dimensions (shape) of the matrix.
* @returns {Shape} An object with rows and cols properties.
*/
get shape() {
const rows = this.data.length;
const cols = rows > 0 ? this.data[0].length : 0;
return { rows, cols };
}
/**
* Creates a new Matrix assembled from nested blocks of matrices.
* @param {Matrix[][]} blocks - A 2D array of Matrix objects.
* @returns {Matrix} A new Matrix assembled from the blocks.
* @static
*/
static block(blocks) {
const newMatrixData = [];
for (const hblock of blocks) {
if (!hblock || hblock.length === 0) continue;
const numRowsInBlock = hblock[0].shape.rows; // Assume consistent rows within a hblock
for (let i = 0; i < numRowsInBlock; i++) {
let newRow = [];
for (const matrix of hblock) {
if (matrix.data[i]) { // Check if row exists
newRow = newRow.concat(matrix.data[i]);
} else {
// Handle potential inconsistencies if needed, maybe throw error or fill?
console.warn("Inconsistent row count during block assembly");
}
}
newMatrixData.push(newRow);
}
}
return new Matrix(newMatrixData);
}
/**
* Performs naive matrix multiplication (dot product).
* @param {Matrix} b - The matrix to multiply with.
* @returns {Matrix} The resulting matrix product.
*/
dot(b) {
if (!(b instanceof Matrix)) {
throw new Error("Argument must be a Matrix instance.");
}
const aShape = this.shape;
const bShape = b.shape;
if (aShape.cols !== bShape.rows) {
throw new Error(`Matrices incompatible for multiplication: ${aShape.cols} cols != ${bShape.rows} rows`);
}
const resultData = [];
for (let i = 0; i < aShape.rows; i++) {
resultData[i] = [];
for (let j = 0; j < bShape.cols; j++) {
let sum = 0;
for (let k = 0; k < aShape.cols; k++) {
sum += this.data[i][k] * b.data[k][j];
}
resultData[i][j] = sum;
}
}
return new Matrix(resultData);
}
/**
* Multiplies this matrix by another matrix (using naive multiplication).
* Equivalent to Python's __matmul__.
* @param {Matrix} b - The matrix to multiply with.
* @returns {Matrix} The resulting matrix product.
*/
multiply(b) {
return this.dot(b);
}
/**
* Adds another matrix to this matrix.
* Equivalent to Python's __add__.
* @param {Matrix} b - The matrix to add.
* @returns {Matrix} The resulting matrix sum.
*/
add(b) {
if (!(b instanceof Matrix)) {
throw new Error("Argument must be a Matrix instance.");
}
const aShape = this.shape;
const bShape = b.shape;
if (aShape.rows !== bShape.rows || aShape.cols !== bShape.cols) {
throw new Error("Matrices must have the same shape for addition.");
}
const resultData = this.data.map((row, i) =>
row.map((val, j) => val + b.data[i][j])
);
return new Matrix(resultData);
}
/**
* Subtracts another matrix from this matrix.
* Equivalent to Python's __sub__.
* @param {Matrix} b - The matrix to subtract.
* @returns {Matrix} The resulting matrix difference.
*/
subtract(b) {
if (!(b instanceof Matrix)) {
throw new Error("Argument must be a Matrix instance.");
}
const aShape = this.shape;
const bShape = b.shape;
if (aShape.rows !== bShape.rows || aShape.cols !== bShape.cols) {
throw new Error("Matrices must have the same shape for subtraction.");
}
const resultData = this.data.map((row, i) =>
row.map((val, j) => val - b.data[i][j])
);
return new Matrix(resultData);
}
/**
* Helper function to slice the matrix data.
* @param {number} rowStart - Starting row index (inclusive).
* @param {number} rowEnd - Ending row index (exclusive).
* @param {number} colStart - Starting column index (inclusive).
* @param {number} colEnd - Ending column index (exclusive).
* @returns {Matrix} A new Matrix containing the sliced data.
* @private // Indicates intended internal use
*/
_slice(rowStart, rowEnd, colStart, colEnd) {
const slicedData = this.data.slice(rowStart, rowEnd)
.map(row => row.slice(colStart, colEnd));
return new Matrix(slicedData);
}
/**
* Performs matrix multiplication using Strassen's algorithm.
* Requires square matrices whose dimensions are powers of 2.
* @param {Matrix} b - The matrix to multiply with.
* @returns {Matrix} The resulting matrix product.
*/
strassen(b) {
if (!(b instanceof Matrix)) {
throw new Error("Argument must be a Matrix instance.");
}
const aShape = this.shape;
const bShape = b.shape;
if (aShape.rows !== aShape.cols) {
throw new Error("Matrix must be square for Strassen's algorithm.");
}
if (aShape.rows !== bShape.rows || aShape.cols !== bShape.cols) {
throw new Error("Matrices must have the same shape for Strassen's algorithm.");
}
// Check if dimension is a power of 2
if (aShape.rows === 0 || (aShape.rows & (aShape.rows - 1)) !== 0) {
throw new Error("Matrix dimension must be a power of 2 for Strassen's algorithm.");
}
if (aShape.rows === 1) {
return this.dot(b); // Base case
}
const n = aShape.rows;
const p = n / 2; // Partition size
// Partition matrices
const a11 = this._slice(0, p, 0, p);
const a12 = this._slice(0, p, p, n);
const a21 = this._slice(p, n, 0, p);
const a22 = this._slice(p, n, p, n);
const b11 = b._slice(0, p, 0, p);
const b12 = b._slice(0, p, p, n);
const b21 = b._slice(p, n, 0, p);
const b22 = b._slice(p, n, p, n);
// Recursive calls (Strassen's 7 multiplications)
const m1 = (a11.add(a22)).strassen(b11.add(b22));
const m2 = (a21.add(a22)).strassen(b11);
const m3 = a11.strassen(b12.subtract(b22));
const m4 = a22.strassen(b21.subtract(b11));
const m5 = (a11.add(a12)).strassen(b22);
const m6 = (a21.subtract(a11)).strassen(b11.add(b12));
const m7 = (a12.subtract(a22)).strassen(b21.add(b22));
// Combine results
const c11 = m1.add(m4).subtract(m5).add(m7);
const c12 = m3.add(m5);
const c21 = m2.add(m4);
const c22 = m1.subtract(m2).add(m3).add(m6);
// Assemble the final matrix from blocks
return Matrix.block([[c11, c12], [c21, c22]]);
}
/**
* Rounds the elements of the matrix to a specified number of decimal places.
* @param {number} [ndigits=0] - Number of decimal places to round to. If undefined or 0, rounds to the nearest integer.
* @returns {Matrix} A new Matrix with rounded elements.
*/
round(ndigits = 0) {
const factor = Math.pow(10, ndigits);
const roundFn = ndigits > 0
? (num) => Math.round((num + Number.EPSILON) * factor) / factor
: (num) => Math.round(num);
const roundedData = this.data.map(row =>
row.map(val => roundFn(val))
);
return new Matrix(roundedData);
}
/**
* Provides a string representation of the matrix.
* @returns {string} The string representation.
*/
toString() {
const rowsStr = this.data.map(row => ` [${row.join(', ')}]`);
return `Matrix([\n${rowsStr.join(',\n')}\n])`;
}
}
// --- Examples ---
function examples() {
const a = new Matrix([
[1, 2],
[3, 4],
]);
const b = new Matrix([
[5, 6],
[7, 8],
]);
const c = new Matrix([
[1, 1, 1, 1],
[2, 4, 8, 16],
[3, 9, 27, 81],
[4, 16, 64, 256],
]);
const d = new Matrix([
[4, -3, 4 / 3, -1 / 4],
[-13 / 3, 19 / 4, -7 / 3, 11 / 24],
[3 / 2, -2, 7 / 6, -1 / 4],
[-1 / 6, 1 / 4, -1 / 6, 1 / 24],
]);
const e = new Matrix([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]);
const f = new Matrix([ // Identity matrix
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]);
console.log("Naive matrix multiplication:");
console.log(` a * b = ${a.multiply(b)}`); // Uses toString implicitly
console.log(` c * d = ${c.multiply(d).round(2)}`); // Round near-zero elements
console.log(` e * f = ${e.multiply(f)}`);
console.log("\nStrassen's matrix multiplication:");
console.log(` a * b = ${a.strassen(b)}`);
console.log(` c * d = ${c.strassen(d).round(2)}`); // Round near-zero elements
console.log(` e * f = ${e.strassen(f)}`);
// Example of addition/subtraction
console.log("\nAddition/Subtraction:");
const sum_ab = a.add(b);
console.log(` a + b = ${sum_ab}`);
const diff_ba = b.subtract(a);
console.log(` b - a = ${diff_ba}`);
// Example of block creation (creates a 4x4 matrix from four 2x2 matrices)
console.log("\nBlock Creation:");
const blocked = Matrix.block([[a, b], [b, a]]);
console.log(` Blocked [a,b],[b,a] = ${blocked}`);
}
// Run examples
examples();
- Output:
Naive matrix multiplication: a * b = Matrix([ [19, 22], [43, 50] ]) c * d = Matrix([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ]) e * f = Matrix([ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16] ]) Strassen's matrix multiplication: a * b = Matrix([ [19, 22], [43, 50] ]) c * d = Matrix([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ]) e * f = Matrix([ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16] ]) Addition/Subtraction: a + b = Matrix([ [6, 8], [10, 12] ]) b - a = Matrix([ [4, 4], [4, 4] ]) Block Creation: Blocked [a,b],[b,a] = Matrix([ [1, 2, 5, 6], [3, 4, 7, 8], [5, 6, 1, 2], [7, 8, 3, 4] ])
Version 2
class Matrix {
/** @type {number[][]} */
data;
/** @type {number} */
rows;
/** @type {number} */
cols;
/**
* @param {number[][]} data The matrix data as a 2D array.
*/
constructor(data) {
if (!Array.isArray(data) || (data.length > 0 && !Array.isArray(data[0]))) {
throw new Error("Input data must be a 2D array.");
}
// Optional: Deep copy to prevent external modifications
this.data = data.map(row => [...row]);
this.rows = data.length;
this.cols = (this.rows > 0) ? (data[0]?.length ?? 0) : 0; // Handle empty rows gracefully
// Optional: Validate that all rows have the same length
if (this.rows > 0) {
const firstRowLength = this.cols;
for (let i = 1; i < this.rows; i++) {
if (data[i].length !== firstRowLength) {
throw new Error("All rows in the matrix must have the same length.");
}
}
}
}
/** @returns {number} */
getRows() {
return this.rows;
}
/** @returns {number} */
getCols() {
return this.cols;
}
/** @param {Matrix} other */
validateDimensions(other) {
if (this.getRows() !== other.getRows() || this.getCols() !== other.getCols()) {
throw new Error("Matrices must have the same dimensions.");
}
}
/** @param {Matrix} other */
validateMultiplication(other) {
if (this.getCols() !== other.getRows()) {
throw new Error(`Cannot multiply matrices: (${this.getRows()}x${this.getCols()}) * (${other.getRows()}x${other.getCols()})`);
}
}
validateSquarePowerOfTwo() {
if (this.getRows() !== this.getCols()) {
throw new Error("Matrix must be square for this operation.");
}
const n = this.getRows();
// Check if n is 0 or not a power of two
// (n & (n - 1)) === 0 checks if n is a power of two (or 0)
if (n === 0 || (n & (n - 1)) !== 0) {
throw new Error("Size of matrix must be a power of two for Strassen.");
}
}
/**
* Adds another matrix to this matrix.
* @param {Matrix} other The matrix to add.
* @returns {Matrix} A new matrix representing the sum.
*/
add(other) {
this.validateDimensions(other);
const result_data = Array.from({ length: this.rows }, () => Array(this.cols).fill(0.0));
for (let i = 0; i < this.rows; ++i) {
for (let j = 0; j < this.cols; ++j) {
result_data[i][j] = this.data[i][j] + other.data[i][j];
}
}
return new Matrix(result_data);
}
/**
* Subtracts another matrix from this matrix.
* @param {Matrix} other The matrix to subtract.
* @returns {Matrix} A new matrix representing the difference.
*/
subtract(other) {
this.validateDimensions(other);
const result_data = Array.from({ length: this.rows }, () => Array(this.cols).fill(0.0));
for (let i = 0; i < this.rows; ++i) {
for (let j = 0; j < this.cols; ++j) {
result_data[i][j] = this.data[i][j] - other.data[i][j];
}
}
return new Matrix(result_data);
}
/**
* Multiplies this matrix by another matrix (standard algorithm).
* @param {Matrix} other The matrix to multiply by.
* @returns {Matrix} A new matrix representing the product.
*/
multiply(other) {
this.validateMultiplication(other);
const result_data = Array.from({ length: this.rows }, () => Array(other.cols).fill(0.0));
for (let i = 0; i < this.rows; ++i) {
for (let j = 0; j < other.cols; ++j) {
let sum = 0.0;
// K loops through columns of 'this' and rows of 'other'
for (let k = 0; k < this.cols; ++k) {
sum += this.data[i][k] * other.data[k][j];
}
result_data[i][j] = sum;
}
}
return new Matrix(result_data);
}
/**
* Returns a string representation of the matrix.
* @returns {string}
*/
toString() {
return this.data.map(row => `[${row.join(', ')}]`).join('\n');
}
/**
* Returns a string representation with specified precision, handling rounding and "-0".
* @param {number} p Precision (number of decimal places).
* @returns {string}
*/
toStringWithPrecision(p) {
let resultString = "";
const pow = Math.pow(10, p);
const zeroString = (0).toFixed(p);
const negZeroString = `-${zeroString}`;
for (const row of this.data) {
resultString += "[";
for (let i = 0; i < row.length; ++i) {
let val = row[i];
// Round like C++: round(val * 10^p) / 10^p
let roundedVal = Math.round(val * pow) / pow;
// Format to fixed precision
let formattedVal = roundedVal.toFixed(p);
// Handle the "-0.00..." case that toFixed might produce after rounding
if (formattedVal === negZeroString) {
formattedVal = zeroString;
}
resultString += formattedVal;
if (i < row.length - 1) {
resultString += ", ";
}
}
resultString += "]\n"; // Add newline after each row like C++ example
}
return resultString.trimEnd(); // Remove trailing newline
}
/**
* Helper function to get quadrant slicing parameters.
* @param {number} r Half rows
* @param {number} c Half columns
* @returns {number[][]} Array of [startRow, endRow, startCol, endCol, offsetRow, offsetCol]
*/
static params(r, c) {
// [startRow, endRow, startCol, endCol, resultOffsetRow, resultOffsetCol]
return [
[0, r, 0, c, 0, 0], // Top-left quadrant (0)
[0, r, c, 2 * c, 0, c], // Top-right quadrant (1)
[r, 2 * r, 0, c, r, 0], // Bottom-left quadrant (2)
[r, 2 * r, c, 2 * c, r, c] // Bottom-right quadrant (3)
];
}
/**
* Splits the matrix into four equally sized quadrants.
* Assumes matrix dimensions are even.
* @returns {Matrix[]} An array of four matrices [TopLeft, TopRight, BottomLeft, BottomRight].
*/
toQuarters() {
const r = this.getRows() / 2;
const c = this.getCols() / 2;
if (!Number.isInteger(r) || !Number.isInteger(c)) {
throw new Error("Matrix dimensions must be even for splitting into quarters.");
}
const p = Matrix.params(r, c);
const quarters = Array(4); // Will hold 4 Matrix objects
for (let k = 0; k < 4; ++k) {
const q_data = Array.from({ length: r }, () => Array(c));
const [startRow, endRow, startCol, endCol, offsetRow, offsetCol] = p[k];
for (let i = startRow; i < endRow; ++i) {
for (let j = startCol; j < endCol; ++j) {
// Adjust indices for the smaller quarter matrix
q_data[i - offsetRow][j - offsetCol] = this.data[i][j];
}
}
quarters[k] = new Matrix(q_data);
}
return quarters; // [TopLeft, TopRight, BottomLeft, BottomRight]
}
/**
* Creates a new matrix by combining four quadrant matrices.
* @param {Matrix[]} q An array of four matrices [TopLeft, TopRight, BottomLeft, BottomRight].
* @returns {Matrix} The combined matrix.
*/
static fromQuarters(q) {
if (q.length !== 4) throw new Error("Requires exactly four quadrant matrices.");
// Basic validation: Ensure quadrants have compatible dimensions
const r = q[0].getRows();
const c = q[0].getCols();
if (q[1].getRows() !== r || q[1].getCols() !== c ||
q[2].getRows() !== r || q[2].getCols() !== c ||
q[3].getRows() !== r || q[3].getCols() !== c) {
throw new Error("Quadrant matrices must have the same dimensions.");
}
const p = Matrix.params(r, c);
const rows = r * 2;
const cols = c * 2;
const m_data = Array.from({ length: rows }, () => Array(cols));
for (let k = 0; k < 4; ++k) {
const [startRow, endRow, startCol, endCol, offsetRow, offsetCol] = p[k];
for (let i = startRow; i < endRow; ++i) {
for (let j = startCol; j < endCol; ++j) {
// Adjust indices to read from the correct quadrant
m_data[i][j] = q[k].data[i - offsetRow][j - offsetCol];
}
}
}
return new Matrix(m_data);
}
/**
* Multiplies this matrix by another using Strassen's algorithm.
* Assumes both matrices are square and their size is a power of two.
* @param {Matrix} other The matrix to multiply by.
* @returns {Matrix} The resulting matrix product.
*/
strassen(other) {
this.validateSquarePowerOfTwo();
other.validateSquarePowerOfTwo();
if (this.getRows() !== other.getRows()) { // Columns already checked by validateSquarePowerOfTwo
throw new Error("Matrices must be square and of equal size for Strassen multiplication.");
}
// Base case: If the matrix is 1x1
if (this.getRows() === 1) {
// Use standard multiplication for the 1x1 case
return this.multiply(other);
}
// Split matrices into quarters
const qa = this.toQuarters(); // [a11, a12, a21, a22]
const qb = other.toQuarters(); // [b11, b12, b21, b22]
// Calculate the 7 products recursively (P1 to P7)
const p1 = (qa[1].subtract(qa[3])).strassen(qb[2].add(qb[3])); // p1 = (a12 - a22) * (b21 + b22)
const p2 = (qa[0].add(qa[3])).strassen(qb[0].add(qb[3])); // p2 = (a11 + a22) * (b11 + b22)
const p3 = (qa[0].subtract(qa[2])).strassen(qb[0].add(qb[1])); // p3 = (a11 - a21) * (b11 + b12)
const p4 = (qa[0].add(qa[1])).strassen(qb[3]); // p4 = (a11 + a12) * b22
const p5 = qa[0].strassen(qb[1].subtract(qb[3])); // p5 = a11 * (b12 - b22)
const p6 = qa[3].strassen(qb[2].subtract(qb[0])); // p6 = a22 * (b21 - b11)
const p7 = (qa[2].add(qa[3])).strassen(qb[0]); // p7 = (a21 + a22) * b11
// Calculate the result quarters (C11, C12, C21, C22)
const c11 = p1.add(p2).subtract(p4).add(p6);
const c12 = p4.add(p5);
const c21 = p6.add(p7);
const c22 = p2.subtract(p3).add(p5).subtract(p7);
// Combine the quarters into the result matrix
return Matrix.fromQuarters([c11, c12, c21, c22]);
}
}
// --- Main execution (equivalent to C++ main) ---
function main() {
const a = new Matrix([[1.0, 2.0], [3.0, 4.0]]);
const b = new Matrix([[5.0, 6.0], [7.0, 8.0]]);
const c = new Matrix([[1.0, 1.0, 1.0, 1.0], [2.0, 4.0, 8.0, 16.0], [3.0, 9.0, 27.0, 81.0], [4.0, 16.0, 64.0, 256.0]]);
const d = new Matrix([[4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0], [-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0], [3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0], [-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0]]);
const e = new Matrix([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]);
const f = new Matrix([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]); // Identity Matrix
console.log("Using 'normal' matrix multiplication:");
console.log(` a * b = \n${a.multiply(b).toString()}`);
console.log(`\n c * d = \n${c.multiply(d).toStringWithPrecision(6)}`); // Should be close to identity
console.log(`\n e * f = \n${e.multiply(f).toString()}`); // Should be e
console.log("\nUsing 'Strassen' matrix multiplication:");
try {
console.log(` a * b = \n${a.strassen(b).toString()}`);
console.log(`\n c * d = \n${c.strassen(d).toStringWithPrecision(6)}`); // Should be close to identity
console.log(`\n e * f = \n${e.strassen(f).toString()}`); // Should be e
} catch (error) {
console.error("Strassen multiplication failed:", error.message);
}
}
// Run the main function
main();
- Output:
Using 'normal' matrix multiplication: a * b = [19, 22] [43, 50] c * d = [1.000000, 0.000000, 0.000000, 0.000000] [0.000000, 1.000000, 0.000000, 0.000000] [0.000000, 0.000000, 1.000000, 0.000000] [0.000000, 0.000000, 0.000000, 1.000000] e * f = [1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16] Using 'Strassen' matrix multiplication: a * b = [19, 22] [43, 50] c * d = [1.000000, 0.000000, 0.000000, 0.000000] [0.000000, 1.000000, 0.000000, 0.000000] [0.000000, 0.000000, 1.000000, 0.000000] [0.000000, 0.000000, 0.000000, 1.000000] e * f = [1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16]
jq
Works with gojq, the Go implementation of jq
### Generic utilities
def sigma(stream): reduce stream as $x (0; . + $x);
# Hilbert-Schmidt norm:
def frobenius:
sigma( flatten[] | .*. ) | sqrt;
### Matrices
# Create an m x n matrix
def matrix(m; n; init):
if m == 0 then []
elif m == 1 then [[range(0;n) | init]]
elif m > 0 then
matrix(1;n;init) as $row
| [range(0;m) | $row ]
else error("matrix\(m);_;_) invalid")
end;
# A and B must be (multi-dimensional) vectors of the same shape
def vector_add($A;$B):
if ($A|type) == "array"
then reduce range(0; $A|length) as $i ([];
. + [vector_add($A[$i]; $B[$i])] )
else $A + $B
end;
def vector_negate:
if type == "array"
then map(vector_negate)
else - .
end;
def vector_subtract($A;$B):
vector_add($A; $B|vector_negate);
# A should be m by n; and B n by p
# Pre-allocating the resultant matrix results in a very small net speedup.
def multiply($A; $B):
($A|length) as $m
| ($A[0]|length) as $n
| ($B[0]|length) as $p
| reduce range(0; $m) as $i
(matrix($m; $p; 0); # initialize to avoid resizing
reduce range(0;$p) as $j (.;
.[$i][$j] = reduce range(0;$n) as $k
(0;
# Cij = innerproduct of row i, column j
. + $A[$i][$k] * $B[$k][$j] ))) ;
def submatrix($m1; $m2; $n1; $n2):
.[$m1:$m2] | map( .[$n1:$n2]);
def submatrix($A; $m1; $m2; $n1; $n2):
$A | submatrix($m1; $m2; $n1; $n2);
def rowwise_extend($A;$B):
reduce range(0; $A|length) as $i ([]; . + [$A[$i] + $B[$i]]);
### Strassen multiplication of n*n square matrices where n is a power of 2.
def Strassen($A; $B):
($A[0]|length) as $n
| if $n == 1 then multiply($A; $B)
else
submatrix($A; 0; $n/2; 0; $n/2) as $A11
| submatrix($A; 0; $n/2; $n/2; $n) as $A12
| submatrix($A; $n/2; $n; 0; $n/2) as $A21
| submatrix($A; $n/2; $n; $n/2; $n) as $A22
| submatrix($B; 0; $n/2; 0; $n/2) as $B11
| submatrix($B; 0; $n/2; $n/2; $n) as $B12
| submatrix($B; $n/2; $n; 0; $n/2) as $B21
| submatrix($B; $n/2; $n; $n/2+0; $n) as $B22
| Strassen( vector_subtract($A12; $A22); vector_add($B21; $B22)) as $P1
| Strassen( vector_add($A11; $A22); vector_add($B11; $B22)) as $P2
| Strassen( vector_subtract($A11; $A21); vector_add($B11; $B12)) as $P3
| Strassen( vector_add($A11; $A12); $B22) as $P4
| Strassen( $A11; vector_subtract($B12; $B22)) as $P5
| Strassen( $A22; vector_subtract($B21; $B11) ) as $P6
| Strassen( vector_add($A21; $A22); $B11) as $P7
| vector_add(vector_subtract(vector_add($P1; $P2); $P4); $P6) as $C11
| vector_add($P4; $P5) as $C12
| vector_add($P6; $P7) as $C21
| vector_add(vector_subtract($P2; $P3); vector_subtract($P5;$P7)) as $C22
# [C11 C12; C21 C22]
| rowwise_extend($C11; $C12) + rowwise_extend($C21; $C22)
end
;
# ## Examples
def A: [[1, 2], [3, 4]];
def B: [[5, 6], [7, 8]];
def C: [[1, 1, 1, 1],
[2, 4, 8, 16],
[3, 9, 27, 81],
[4, 16, 64, 256]];
def D: [[4, -3, 4/3, -1/4],
[-13/3, 19/4, -7/3, 11/24],
[3/2, -2, 7/6, -1/4],
[-1/6, 1/4, -1/6, 1/24]];
def E: [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]];
def F: [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [ 0, 0, 0, 1]];
def r: (2|sqrt)/2;
def R: [[r, r], [-r, r]];
"A*B == Strassen(A;B): \((multiply(A; B) == Strassen(A; B)))",
"Frobenius norm for C*D - Strassen(C;D): \(vector_subtract(multiply(C; D); Strassen(C; D)) | frobenius)",
"E*F == Strassen(E;F): \(multiply(E; F) == Strassen( E; F))",
"R*R == Strassen(R;R): \(multiply(R ; R) == Strassen(R; R))"
- Output:
A*B == Strassen(A;B): true Frobenius norm for C*D - Strassen(C;D): 3.4360172596923834e-13 E*F == Strassen(E;F): true R*R == Strassen(R;R): true
Julia
With dynamic padding
Because Julia uses column major in matrices, sometimes the code uses the adjoint of a matrix in order to match examples as written.
"""
Strassen's matrix multiplication algorithm.
Use dynamic padding in order to reduce required auxiliary memory.
"""
function strassen(x::Matrix, y::Matrix)
# Check that the sizes of these matrices match.
(r1, c1) = size(x)
(r2, c2) = size(y)
if c1 != r2
error("Multiplying $r1 x $c1 and $r2 x $c2 matrix: dimensions do not match.")
end
# Put a matrix into the top left of a matrix of zeros.
# `rows` and `cols` are the dimensions of the output matrix.
function embed(mat, rows, cols)
# If the matrix is already of the right dimensions, don't allocate new memory.
(r, c) = size(mat)
if (r, c) == (rows, cols)
return mat
end
# Pad the matrix with zeros to be the right size.
out = zeros(Int, rows, cols)
out[1:r, 1:c] = mat
out
end
# Make sure both matrices are the same size.
# This is exclusively for simplicity:
# this algorithm can be implemented with matrices of different sizes.
r = max(r1, r2); c = max(c1, c2)
x = embed(x, r, c)
y = embed(y, r, c)
# Our recursive multiplication function.
function block_mult(a, b, rows, cols)
# For small matrices, resort to naive multiplication.
# if rows <= 128 || cols <= 128
if rows == 1 && cols == 1
# if rows == 2 && cols == 2
return a * b
end
# Apply dynamic padding.
if rows % 2 == 1 && cols % 2 == 1
a = embed(a, rows + 1, cols + 1)
b = embed(b, rows + 1, cols + 1)
elseif rows % 2 == 1
a = embed(a, rows + 1, cols)
b = embed(b, rows + 1, cols)
elseif cols % 2 == 1
a = embed(a, rows, cols + 1)
b = embed(b, rows, cols + 1)
end
half_rows = Int(size(a, 1) / 2)
half_cols = Int(size(a, 2) / 2)
# Subdivide input matrices.
a11 = a[1:half_rows, 1:half_cols]
b11 = b[1:half_rows, 1:half_cols]
a12 = a[1:half_rows, half_cols+1:size(a, 2)]
b12 = b[1:half_rows, half_cols+1:size(a, 2)]
a21 = a[half_rows+1:size(a, 1), 1:half_cols]
b21 = b[half_rows+1:size(a, 1), 1:half_cols]
a22 = a[half_rows+1:size(a, 1), half_cols+1:size(a, 2)]
b22 = b[half_rows+1:size(a, 1), half_cols+1:size(a, 2)]
# Compute intermediate values.
multip(x, y) = block_mult(x, y, half_rows, half_cols)
m1 = multip(a11 + a22, b11 + b22)
m2 = multip(a21 + a22, b11)
m3 = multip(a11, b12 - b22)
m4 = multip(a22, b21 - b11)
m5 = multip(a11 + a12, b22)
m6 = multip(a21 - a11, b11 + b12)
m7 = multip(a12 - a22, b21 + b22)
# Combine intermediate values into the output.
c11 = m1 + m4 - m5 + m7
c12 = m3 + m5
c21 = m2 + m4
c22 = m1 - m2 + m3 + m6
# Crop output to the desired size (undo dynamic padding).
out = [c11 c12; c21 c22]
out[1:rows, 1:cols]
end
block_mult(x, y, r, c)
end
const A = [[1, 2] [3, 4]]
const B = [[5, 6] [7, 8]]
const C = [[1, 1, 1, 1] [2, 4, 8, 16] [3, 9, 27, 81] [4, 16, 64, 256]]
const D = [[4, -3, 4/3, -1/4] [-13/3, 19/4, -7/3, 11/24] [3/2, -2, 7/6, -1/4] [-1/6, 1/4, -1/6, 1/24]]
const E = [[1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16]]
const F = [[1, 0, 0, 0] [0, 1, 0, 0] [0, 0, 1, 0] [0, 0, 0, 1]]
""" For pretty printing: change matrix to integer if it is within 0.00000001 of an integer """
intprint(s, mat) = println(s, map(x -> Int(round(x, digits=8)), mat)')
intprint("Regular multiply: ", A' * B')
intprint("Strassen multiply: ", strassen(Matrix(A'), Matrix(B')))
intprint("Regular multiply: ", C * D)
intprint("Strassen multiply: ", strassen(C, D))
intprint("Regular multiply: ", E * F)
intprint("Strassen multiply: ", strassen(E, F))
const r = sqrt(2)/2
const R = [[r, r] [-r, r]]
intprint("Regular multiply: ", R * R)
intprint("Strassen multiply: ", strassen(R,R))
- Output:
Regular multiply: [19 43; 22 50] Strassen multiply: [19 43; 22 50] Regular multiply: [1 0 0 0; 0 1 0 0; 0 0 1 0; 0 0 0 1] Strassen multiply: [1 0 0 0; 0 1 0 0; 0 0 1 0; 0 0 0 1] Regular multiply: [1 2 3 4; 5 6 7 8; 9 10 11 12; 13 14 15 16] Strassen multiply: [1 2 3 4; 5 6 7 8; 9 10 11 12; 13 14 15 16] Regular multiply: [0 1; -1 0] Strassen multiply: [0 1; -1 0]
Recursive
Output is the same as the dynamically padded version.
function Strassen(A, B)
n = size(A, 1)
if n == 1
return A * B
end
@views A11 = A[1:n÷2, 1:n÷2]
@views A12 = A[1:n÷2, n÷2+1:n]
@views A21 = A[n÷2+1:n, 1:n÷2]
@views A22 = A[n÷2+1:n, n÷2+1:n]
@views B11 = B[1:n÷2, 1:n÷2]
@views B12 = B[1:n÷2, n÷2+1:n]
@views B21 = B[n÷2+1:n, 1:n÷2]
@views B22 = B[n÷2+1:n, n÷2+1:n]
P1 = Strassen(A12 - A22, B21 + B22)
P2 = Strassen(A11 + A22, B11 + B22)
P3 = Strassen(A11 - A21, B11 + B12)
P4 = Strassen(A11 + A12, B22)
P5 = Strassen(A11, B12 - B22)
P6 = Strassen(A22, B21 - B11)
P7 = Strassen(A21 + A22, B11)
C11 = P1 + P2 - P4 + P6
C12 = P4 + P5
C21 = P6 + P7
C22 = P2 - P3 + P5 - P7
return [C11 C12; C21 C22]
end
const A = [[1, 2] [3, 4]]
const B = [[5, 6] [7, 8]]
const C = [[1, 1, 1, 1] [2, 4, 8, 16] [3, 9, 27, 81] [4, 16, 64, 256]]
const D = [[4, -3, 4/3, -1/4] [-13/3, 19/4, -7/3, 11/24] [3/2, -2, 7/6, -1/4] [-1/6, 1/4, -1/6, 1/24]]
const E = [[1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16]]
const F = [[1, 0, 0, 0] [0, 1, 0, 0] [0, 0, 1, 0] [0, 0, 0, 1]]
intprint(s, mat) = println(s, map(x -> Int(round(x, digits=8)), mat)')
intprint("Regular multiply: ", A' * B')
intprint("Strassen multiply: ", Strassen(Matrix(A'), Matrix(B')))
intprint("Regular multiply: ", C * D)
intprint("Strassen multiply: ", Strassen(C, D))
intprint("Regular multiply: ", E * F)
intprint("Strassen multiply: ", Strassen(E, F))
const r = sqrt(2)/2
const R = [[r, r] [-r, r]]
intprint("Regular multiply: ", R * R)
intprint("Strassen multiply: ", Strassen(R,R))
MATLAB
clear all;close all;clc;
A = [1, 2; 3, 4];
B = [5, 6; 7, 8];
C = [1, 1, 1, 1; 2, 4, 8, 16; 3, 9, 27, 81; 4, 16, 64, 256];
D = [4, -3, 4/3, -1/4; -13/3, 19/4, -7/3, 11/24; 3/2, -2, 7/6, -1/4; -1/6, 1/4, -1/6, 1/24];
E = [1, 2, 3, 4; 5, 6, 7, 8; 9, 10, 11, 12; 13, 14, 15, 16];
F = eye(4);
disp('Regular multiply: ');
disp(A' * B');
disp('Strassen multiply: ');
disp(Strassen(A', B'));
disp('Regular multiply: ');
disp(C * D);
disp('Strassen multiply: ');
disp(Strassen(C, D));
disp('Regular multiply: ');
disp(E * F);
disp('Strassen multiply: ');
disp(Strassen(E, F));
r = sqrt(2)/2;
R = [r, r; -r, r];
disp('Regular multiply: ');
disp(R * R);
disp('Strassen multiply: ');
disp(Strassen(R, R));
function C = Strassen(A, B)
n = size(A, 1);
if n == 1
C = A * B;
return
end
A11 = A(1:n/2, 1:n/2);
A12 = A(1:n/2, n/2+1:n);
A21 = A(n/2+1:n, 1:n/2);
A22 = A(n/2+1:n, n/2+1:n);
B11 = B(1:n/2, 1:n/2);
B12 = B(1:n/2, n/2+1:n);
B21 = B(n/2+1:n, 1:n/2);
B22 = B(n/2+1:n, n/2+1:n);
P1 = Strassen(A12 - A22, B21 + B22);
P2 = Strassen(A11 + A22, B11 + B22);
P3 = Strassen(A11 - A21, B11 + B12);
P4 = Strassen(A11 + A12, B22);
P5 = Strassen(A11, B12 - B22);
P6 = Strassen(A22, B21 - B11);
P7 = Strassen(A21 + A22, B11);
C11 = P1 + P2 - P4 + P6;
C12 = P4 + P5;
C21 = P6 + P7;
C22 = P2 - P3 + P5 - P7;
C = [C11 C12; C21 C22];
end
- Output:
Regular multiply: 23 31 34 46 Strassen multiply: 23 31 34 46 Regular multiply: 1.0000 0 -0.0000 -0.0000 0.0000 1.0000 -0.0000 -0.0000 0 0 1.0000 0 0.0000 0 0.0000 1.0000 Strassen multiply: 1.0000 0.0000 -0.0000 -0.0000 -0.0000 1.0000 -0.0000 0.0000 0 0 1.0000 0.0000 0 0 -0.0000 1.0000 Regular multiply: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 Strassen multiply: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 Regular multiply: 0 1.0000 -1.0000 0 Strassen multiply: 0 1.0000 -1.0000 0
Nim
import math, sequtils, strutils
type Matrix = seq[seq[float]]
template rows(m: Matrix): Positive = m.len
template cols(m: Matrix): Positive = m[0].len
func `+`(m1, m2: Matrix): Matrix =
doAssert m1.rows == m2.rows and m1.cols == m2.cols, "Matrices must have the same dimensions."
result = newSeqWith(m1.rows, newSeq[float](m1.cols))
for i in 0..<m1.rows:
for j in 0..<m1.cols:
result[i][j] = m1[i][j] + m2[i][j]
func `-`(m1, m2: Matrix): Matrix =
doAssert m1.rows == m2.rows and m1.cols == m2.cols, "Matrices must have the same dimensions."
result = newSeqWith(m1.rows, newSeq[float](m1.cols))
for i in 0..<m1.rows:
for j in 0..<m1.cols:
result[i][j] = m1[i][j] - m2[i][j]
func `*`(m1, m2: Matrix): Matrix =
doAssert m1.cols == m2.rows, "Cannot multiply these matrices."
result = newSeqWith(m1.rows, newSeq[float](m2.cols))
for i in 0..<m1.rows:
for j in 0..<m2.cols:
for k in 0..<m2.rows:
result[i][j] += m1[i][k] * m2[k][j]
func toString(m: Matrix; p: Natural): string =
## Round all elements to 'p' places.
var res: seq[string]
let pow = 10.0^p
for row in m:
var line: seq[string]
for val in row:
let r = round(val * pow) / pow
var s = r.formatFloat(precision = -1)
if s == "-0": s = "0"
line.add s
res.add '[' & line.join(" ") & ']'
result = '[' & res.join(" ") & ']'
func params(r, c: int): array[4, array[6, int]] =
[[0, r, 0, c, 0, 0],
[0, r, c, 2 * c, 0, c],
[r, 2 * r, 0, c, r, 0],
[r, 2 * r, c, 2 * c, r, c]]
func toQuarters(m: Matrix): array[4, Matrix] =
let
r = m.rows() div 2
c = m.cols() div 2
p = params(r, c)
for k in 0..3:
var q = newSeqWith(r, newSeq[float](c))
for i in p[k][0]..<p[k][1]:
for j in p[k][2]..<p[k][3]:
q[i-p[k][4]][j-p[k][5]] = m[i][j]
result[k] = move(q)
func fromQuarters(q: array[4, Matrix]): Matrix =
var
r = q[0].rows
c = q[0].cols
let p = params(r, c)
r *= 2
c *= 2
result = newSeqWith(r, newSeq[float](c))
for k in 0..3:
for i in p[k][0]..<p[k][1]:
for j in p[k][2]..<p[k][3]:
result[i][j] = q[k][i-p[k][4]][j-p[k][5]]
func strassen(a, b: Matrix): Matrix =
doAssert a.rows == a.cols() and b.rows == b.cols and a.rows == b.rows,
"Matrices must be square and of equal size."
doAssert a.rows != 0 and (a.rows and (a.rows-1)) == 0,
"Size of matrices must be a power of two."
if a.rows == 1: return a * b
let
qa = a.toQuarters()
qb = b.toQuarters()
p1 = strassen(qa[1] - qa[3], qb[2] + qb[3])
p2 = strassen(qa[0] + qa[3], qb[0] + qb[3])
p3 = strassen(qa[0] - qa[2], qb[0] + qb[1])
p4 = strassen(qa[0] + qa[1], qb[3])
p5 = strassen(qa[0], qb[1] - qb[3])
p6 = strassen(qa[3], qb[2] - qb[0])
p7 = strassen(qa[2] + qa[3], qb[0])
var q: array[4, Matrix]
q[0] = p1 + p2 - p4 + p6
q[1] = p4 + p5
q[2] = p6 + p7
q[3] = p2 - p3 + p5 - p7
result = fromQuarters(q)
when isMainModule:
let
a = @[@[float 1, 2],
@[float 3, 4]]
b = @[@[float 5, 6],
@[float 7, 8]]
c = @[@[float 1, 1, 1, 1],
@[float 2, 4, 8, 16],
@[float 3, 9, 27, 81],
@[float 4, 16, 64, 256]]
d = @[@[4.0, -3, 4/3, -1/4],
@[-13/3, 19/4, -7/3, 11/24],
@[3/2, -2, 7/6, -1/4],
@[-1/6, 1/4, -1/6, 1/24]]
e = @[@[float 1, 2, 3, 4],
@[float 5, 6, 7, 8],
@[float 9, 10, 11, 12],
@[float 13, 14, 15, 16]]
f = @[@[float 1, 0, 0, 0],
@[float 0, 1, 0, 0],
@[float 0, 0, 1, 0],
@[float 0, 0, 0, 1]]
echo "Using 'normal' matrix multiplication:"
echo " a * b = ", (a * b).toString(10)
echo " c * d = ", (c * d).toString(6)
echo " e * f = ", (e * f).toString(10)
echo "\nUsing 'Strassen' matrix multiplication:"
echo " a * b = ", strassen(a, b).toString(10)
echo " c * d = ", strassen(c, d).toString(6)
echo " e * f = ", strassen(e, f).toString(10)
- Output:
Using 'normal' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]] Using 'Strassen' matrix multiplication: a * b = [[19 22] [43 50]] c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]]
Phix
As noted on wp, you could pad with zeroes, and strip them on exit, instead of crashing for non-square 2n matrices.
with javascript_semantics function strassen(sequence a, b) integer l = length(a) if length(a[1])!=l or length(b)!=l or length(b[1])!=l then crash("two equal square matrices only") end if if l=1 then return sq_mul(a,b) end if if remainder(l,1) then crash("2^n matrices only") end if integer h = l/2 sequence {a11,a12,a21,a22,b11,b12,b21,b22} = repeat(repeat(repeat(0,h),h),8) for i=1 to h do for j=1 to h do a11[i][j] = a[i][j] a12[i][j] = a[i][j+h] a21[i][j] = a[i+h][j] a22[i][j] = a[i+h][j+h] b11[i][j] = b[i][j] b12[i][j] = b[i][j+h] b21[i][j] = b[i+h][j] b22[i][j] = b[i+h][j+h] end for end for sequence p1 = strassen(sq_sub(a12,a22), sq_add(b21,b22)), p2 = strassen(sq_add(a11,a22), sq_add(b11,b22)), p3 = strassen(sq_sub(a11,a21), sq_add(b11,b12)), p4 = strassen(sq_add(a11,a12), b22), p5 = strassen(a11, sq_sub(b12,b22)), p6 = strassen(a22, sq_sub(b21,b11)), p7 = strassen(sq_add(a21,a22), b11), c11 = sq_add(sq_sub(sq_add(p1,p2),p4),p6), c12 = sq_add(p4,p5), c21 = sq_add(p6,p7), c22 = sq_sub(sq_add(sq_sub(p2,p3),p5),p7), c = repeat(repeat(0,l),l) for i=1 to h do for j=1 to h do c[i][j] = c11[i][j] c[i][j+h] = c12[i][j] c[i+h][j] = c21[i][j] c[i+h][j+h] = c22[i][j] end for end for return c end function ppOpt({pp_Nest,1,pp_IntFmt,"%3d",pp_FltFmt,"%3.0f",pp_IntCh,false}) constant A = {{1,2}, {3,4}}, B = {{5,6}, {7,8}} pp(strassen(A,B)) constant C = { { 1, 1, 1, 1 }, { 2, 4, 8, 16 }, { 3, 9, 27, 81 }, { 4, 16, 64, 256 }}, D = { { 4, -3, 4/3, -1/ 4 }, {-13/3, 19/4, -7/3, 11/24 }, { 3/2, -2, 7/6, -1/ 4 }, { -1/6, 1/4, -1/6, 1/24 }} pp(strassen(C,D)) constant F = {{ 1, 2, 3, 4}, { 5, 6, 7, 8}, { 9,10,11,12}, {13,14,15,16}}, G = {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}} pp(strassen(F,G)) constant r = sqrt(2)/2, R = {{ r,r}, {-r,r}} pp(strassen(R,R))
- Output:
Matches that of Matrix_multiplication#Phix, when given the same inputs. Note that a few "-0" show up in the second one (the identity matrix) under pwa/p2js.
{{ 19, 22}, { 43, 50}} {{ 1, 0, 0, 0}, { 0, 1, 0, 0}, { 0, 0, 1, 0}, { 0, 0, 0, 1}} {{ 1, 2, 3, 4}, { 5, 6, 7, 8}, { 9, 10, 11, 12}, { 13, 14, 15, 16}} {{ 0, 1}, { -1, 0}}
Python
"""Matrix multiplication using Strassen's algorithm. Requires Python >= 3.7."""
from __future__ import annotations
from itertools import chain
from typing import List
from typing import NamedTuple
from typing import Optional
class Shape(NamedTuple):
rows: int
cols: int
class Matrix(List):
"""A matrix implemented as a two-dimensional list."""
@classmethod
def block(cls, blocks) -> Matrix:
"""Return a new Matrix assembled from nested blocks."""
m = Matrix()
for hblock in blocks:
for row in zip(*hblock):
m.append(list(chain.from_iterable(row)))
return m
def dot(self, b: Matrix) -> Matrix:
"""Return a new Matrix that is the product of this matrix and matrix `b`.
Uses 'simple' or 'naive' matrix multiplication."""
assert self.shape.cols == b.shape.rows
m = Matrix()
for row in self:
new_row = []
for c in range(len(b[0])):
col = [b[r][c] for r in range(len(b))]
new_row.append(sum(x * y for x, y in zip(row, col)))
m.append(new_row)
return m
def __matmul__(self, b: Matrix) -> Matrix:
return self.dot(b)
def __add__(self, b: Matrix) -> Matrix:
"""Matrix addition."""
assert self.shape == b.shape
rows, cols = self.shape
return Matrix(
[[self[i][j] + b[i][j] for j in range(cols)] for i in range(rows)]
)
def __sub__(self, b: Matrix) -> Matrix:
"""Matrix subtraction."""
assert self.shape == b.shape
rows, cols = self.shape
return Matrix(
[[self[i][j] - b[i][j] for j in range(cols)] for i in range(rows)]
)
def strassen(self, b: Matrix) -> Matrix:
"""Return a new Matrix that is the product of this matrix and matrix `b`.
Uses strassen algorithm."""
rows, cols = self.shape
assert rows == cols, "matrices must be square"
assert self.shape == b.shape, "matrices must be the same shape"
assert rows and (rows & rows - 1) == 0, "shape must be a power of 2"
if rows == 1:
return self.dot(b)
p = rows // 2 # partition
a11 = Matrix([n[:p] for n in self[:p]])
a12 = Matrix([n[p:] for n in self[:p]])
a21 = Matrix([n[:p] for n in self[p:]])
a22 = Matrix([n[p:] for n in self[p:]])
b11 = Matrix([n[:p] for n in b[:p]])
b12 = Matrix([n[p:] for n in b[:p]])
b21 = Matrix([n[:p] for n in b[p:]])
b22 = Matrix([n[p:] for n in b[p:]])
m1 = (a11 + a22).strassen(b11 + b22)
m2 = (a21 + a22).strassen(b11)
m3 = a11.strassen(b12 - b22)
m4 = a22.strassen(b21 - b11)
m5 = (a11 + a12).strassen(b22)
m6 = (a21 - a11).strassen(b11 + b12)
m7 = (a12 - a22).strassen(b21 + b22)
c11 = m1 + m4 - m5 + m7
c12 = m3 + m5
c21 = m2 + m4
c22 = m1 - m2 + m3 + m6
return Matrix.block([[c11, c12], [c21, c22]])
def round(self, ndigits: Optional[int] = None) -> Matrix:
return Matrix([[round(i, ndigits) for i in row] for row in self])
@property
def shape(self) -> Shape:
cols = len(self[0]) if self else 0
return Shape(len(self), cols)
def examples():
a = Matrix(
[
[1, 2],
[3, 4],
]
)
b = Matrix(
[
[5, 6],
[7, 8],
]
)
c = Matrix(
[
[1, 1, 1, 1],
[2, 4, 8, 16],
[3, 9, 27, 81],
[4, 16, 64, 256],
]
)
d = Matrix(
[
[4, -3, 4 / 3, -1 / 4],
[-13 / 3, 19 / 4, -7 / 3, 11 / 24],
[3 / 2, -2, 7 / 6, -1 / 4],
[-1 / 6, 1 / 4, -1 / 6, 1 / 24],
]
)
e = Matrix(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]
)
f = Matrix(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
print("Naive matrix multiplication:")
print(f" a * b = {a @ b}")
print(f" c * d = {(c @ d).round()}")
print(f" e * f = {e @ f}")
print("Strassen's matrix multiplication:")
print(f" a * b = {a.strassen(b)}")
print(f" c * d = {c.strassen(d).round()}")
print(f" e * f = {e.strassen(f)}")
if __name__ == "__main__":
examples()
- Output:
Naive matrix multiplication: a * b = [[19, 22], [43, 50]] c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] Strassen's matrix multiplication: a * b = [[19, 22], [43, 50]] c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
Raku
Special thanks go to the module author, Fernando Santagata, on showing how to deal with a pass-by-value case.
# 20210126 Raku programming solution
use Math::Libgsl::Constants;
use Math::Libgsl::Matrix;
use Math::Libgsl::BLAS;
my @M;
sub SQM (\in) { # create custom sq matrix from CSV
die "Not a ■" if (my \L = in.split(/\,/)).sqrt != (my \size = L.sqrt.Int);
my Math::Libgsl::Matrix \M .= new: size, size;
for ^size Z L.rotor(size) -> ($i, @row) { M.set-row: $i, @row }
M
}
sub infix:<⊗>(\x,\y) { # custom multiplication
my Math::Libgsl::Matrix \z .= new: x.size1, x.size2;
dgemm(CblasNoTrans, CblasNoTrans, 1, x, y, 1, z);
z
}
sub infix:<⊕>(\x,\y) { # custom addition
my Math::Libgsl::Matrix \z .= new: x.size1, x.size2;
z.copy(x).add(y)
}
sub infix:<⊖>(\x,\y) { # custom subtraction
my Math::Libgsl::Matrix \z .= new: x.size1, x.size2;
z.copy(x).sub(y)
}
sub Strassen($A, $B) {
{ return $A ⊗ $B } if (my \n = $A.size1) == 1;
my Math::Libgsl::Matrix ($A11,$A12,$A21,$A22,$B11,$B12,$B21,$B22);
my Math::Libgsl::Matrix ($P1,$P2,$P3,$P4,$P5,$P6,$P7);
my Math::Libgsl::Matrix::View ($mv1,$mv2,$mv3,$mv4,$mv5,$mv6,$mv7,$mv8);
($mv1,$mv2,$mv3,$mv4,$mv5,$mv6,$mv7,$mv8)».=new ;
my \half = n div 2; # dimension of quarter submatrices
$A11 = $mv1.submatrix($A, 0,0, half,half); #
$A12 = $mv2.submatrix($A, 0,half, half,half); # create quarter views
$A21 = $mv3.submatrix($A, half,0, half,half); # of operand matrices
$A22 = $mv4.submatrix($A, half,half, half,half); #
$B11 = $mv5.submatrix($B, 0,0, half,half); # 11 12
$B12 = $mv6.submatrix($B, 0,half, half,half); #
$B21 = $mv7.submatrix($B, half,0, half,half); # 21 22
$B22 = $mv8.submatrix($B, half,half, half,half); #
$P1 = Strassen($A12 ⊖ $A22, $B21 ⊕ $B22);
$P2 = Strassen($A11 ⊕ $A22, $B11 ⊕ $B22);
$P3 = Strassen($A11 ⊖ $A21, $B11 ⊕ $B12);
$P4 = Strassen($A11 ⊕ $A12, $B22 );
$P5 = Strassen($A11, $B12 ⊖ $B22);
$P6 = Strassen($A22, $B21 ⊖ $B11);
$P7 = Strassen($A21 ⊕ $A22, $B11 );
my Math::Libgsl::Matrix $C .= new: n, n; # Build C from
my Math::Libgsl::Matrix::View ($mvC11,$mvC12,$mvC21,$mvC22); # C11 C12
($mvC11,$mvC12,$mvC21,$mvC22)».=new ; # C21 C22
given $mvC11.submatrix($C, 0,0, half,half) { .add: (($P1 ⊕ $P2) ⊖ $P4) ⊕ $P6 };
given $mvC12.submatrix($C, 0,half, half,half) { .add: $P4 ⊕ $P5 };
given $mvC21.submatrix($C, half,0, half,half) { .add: $P6 ⊕ $P7 };
given $mvC22.submatrix($C, half,half, half,half) { .add: (($P2 ⊖ $P3) ⊕ $P5) ⊖ $P7 };
$C
}
for $=pod[0].contents { next if /^\n$/ ; @M.append: SQM $_ }
for @M.rotor(2) {
my $product = @_[0] ⊗ @_[1];
# $product.get-row($_)».round(1).fmt('%2d').put for ^$product.size1;
say "Regular multiply:";
$product.get-row($_)».fmt('%.10g').put for ^$product.size1;
$product = Strassen @_[0], @_[1];
say "Strassen multiply:";
$product.get-row($_)».fmt('%.10g').put for ^$product.size1;
}
=begin code
1,2,3,4
5,6,7,8
1,1,1,1,2,4,8,16,3,9,27,81,4,16,64,256
4,-3,4/3,-1/4,-13/3,19/4,-7/3,11/24,3/2,-2,7/6,-1/4,-1/6,1/4,-1/6,1/24
1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16
1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1
=end code
- Output:
Regular multiply: 19 22 43 50 Strassen multiply: 19 22 43 50 Regular multiply: 1 0 -1.387778781e-16 -2.081668171e-17 1.33226763e-15 1 -4.440892099e-16 -1.110223025e-16 0 0 1 0 7.105427358e-15 0 7.105427358e-15 1 Strassen multiply: 1 5.684341886e-14 -2.664535259e-15 -1.110223025e-16 -1.136868377e-13 1 -7.105427358e-15 2.220446049e-15 0 0 1 5.684341886e-14 0 0 -2.273736754e-13 1 Regular multiply: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 Strassen multiply: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
Rust
use std::fmt;
use std::ops::{Add, Mul, Sub};
#[derive(Debug, Clone)]
struct Matrix {
data: Vec<Vec<f64>>,
rows: usize,
cols: usize,
}
impl Matrix {
fn new(data: Vec<Vec<f64>>) -> Self {
let rows = data.len();
let cols = if rows > 0 { data[0].len() } else { 0 };
Matrix { data, rows, cols }
}
fn rows(&self) -> usize {
self.rows
}
fn cols(&self) -> usize {
self.cols
}
fn validate_dimensions(&self, other: &Matrix) {
if self.rows() != other.rows() || self.cols() != other.cols() {
panic!("Matrices must have the same dimensions.");
}
}
fn validate_multiplication(&self, other: &Matrix) {
if self.cols() != other.rows() {
panic!("Cannot multiply these matrices.");
}
}
fn validate_square_power_of_two(&self) {
if self.rows() != self.cols() {
panic!("Matrix must be square.");
}
if self.rows() == 0 || (self.rows() & (self.rows() - 1)) != 0 {
panic!("Size of matrix must be a power of two.");
}
}
}
impl Add for Matrix {
type Output = Self;
fn add(self, other: Self) -> Self {
self.validate_dimensions(&other);
let mut result_data = Vec::with_capacity(self.rows());
for i in 0..self.rows() {
let mut row = Vec::with_capacity(self.cols());
for j in 0..self.cols() {
row.push(self.data[i][j] + other.data[i][j]);
}
result_data.push(row);
}
Matrix::new(result_data)
}
}
impl Sub for Matrix {
type Output = Self;
fn sub(self, other: Self) -> Self {
self.validate_dimensions(&other);
let mut result_data = Vec::with_capacity(self.rows());
for i in 0..self.rows() {
let mut row = Vec::with_capacity(self.cols());
for j in 0..self.cols() {
row.push(self.data[i][j] - other.data[i][j]);
}
result_data.push(row);
}
Matrix::new(result_data)
}
}
impl Mul for Matrix {
type Output = Self;
fn mul(self, other: Self) -> Self {
self.validate_multiplication(&other);
let mut result_data = Vec::with_capacity(self.rows());
for i in 0..self.rows() {
let mut row = Vec::with_capacity(other.cols());
for j in 0..other.cols() {
let mut sum = 0.0;
for k in 0..other.rows() {
sum += self.data[i][k] * other.data[k][j];
}
row.push(sum);
}
result_data.push(row);
}
Matrix::new(result_data)
}
}
impl fmt::Display for Matrix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut s = String::new();
for row in &self.data {
s.push_str(&format!("{:?}\n", row));
}
write!(f, "{}", s)
}
}
impl Matrix {
fn to_string_with_precision(&self, p: usize) -> String {
let mut s = String::new();
let pow = 10.0_f64.powi(p as i32);
for row in &self.data {
let mut t = Vec::new();
for &val in row {
let r = (val * pow).round() / pow;
let formatted = format!("{}", r);
if formatted == "-0" {
t.push("0".to_string());
} else {
t.push(formatted);
}
}
s.push_str(&format!("{:?}\n", t));
}
s
}
fn params(r: usize, c: usize) -> [[usize; 6]; 4] {
[
[0, r, 0, c, 0, 0],
[0, r, c, 2 * c, 0, c],
[r, 2 * r, 0, c, r, 0],
[r, 2 * r, c, 2 * c, r, c],
]
}
fn to_quarters(&self) -> [Matrix; 4] {
let r = self.rows() / 2;
let c = self.cols() / 2;
let p = Matrix::params(r, c);
let mut quarters: [Matrix; 4] = [
Matrix::new(vec![vec![0.0; c]; r]),
Matrix::new(vec![vec![0.0; c]; r]),
Matrix::new(vec![vec![0.0; c]; r]),
Matrix::new(vec![vec![0.0; c]; r]),
];
for k in 0..4 {
let mut q_data = Vec::with_capacity(r);
for i in p[k][0]..p[k][1] {
let mut row = Vec::with_capacity(c);
for j in p[k][2]..p[k][3] {
row.push(self.data[i][j]);
}
q_data.push(row);
}
quarters[k] = Matrix::new(q_data);
}
quarters
}
fn from_quarters(q: [Matrix; 4]) -> Matrix {
let r = q[0].rows();
let c = q[0].cols();
let p = Matrix::params(r, c);
let rows = r * 2;
let cols = c * 2;
let mut m_data = vec![vec![0.0; cols]; rows];
for k in 0..4 {
for i in p[k][0]..p[k][1] {
for j in p[k][2]..p[k][3] {
m_data[i][j] = q[k].data[i - p[k][4]][j - p[k][5]];
}
}
}
Matrix::new(m_data)
}
fn strassen(&self, other: Matrix) -> Matrix {
self.validate_square_power_of_two();
other.validate_square_power_of_two();
if self.rows() != other.rows() || self.cols() != other.cols() {
panic!("Matrices must be square and of equal size for Strassen multiplication.");
}
if self.rows() == 1 {
return self.clone() * other;
}
let qa = self.to_quarters();
let qb = other.to_quarters();
let p1 = (qa[1].clone() - qa[3].clone()).strassen(qb[2].clone() + qb[3].clone());
let p2 = (qa[0].clone() + qa[3].clone()).strassen(qb[0].clone() + qb[3].clone());
let p3 = (qa[0].clone() - qa[2].clone()).strassen(qb[0].clone() + qb[1].clone());
let p4 = (qa[0].clone() + qa[1].clone()).strassen(qb[3].clone());
let p5 = qa[0].clone().strassen(qb[1].clone() - qb[3].clone());
let p6 = qa[3].clone().strassen(qb[2].clone() - qb[0].clone());
let p7 = (qa[2].clone() + qa[3].clone()).strassen(qb[0].clone());
let mut q: [Matrix; 4] = [
Matrix::new(vec![vec![0.0; qa[0].cols()]; qa[0].rows()]),
Matrix::new(vec![vec![0.0; qa[0].cols()]; qa[0].rows()]),
Matrix::new(vec![vec![0.0; qa[0].cols()]; qa[0].rows()]),
Matrix::new(vec![vec![0.0; qa[0].cols()]; qa[0].rows()]),
];
q[0] = p1.clone() + p2.clone() - p4.clone() + p6.clone();
q[1] = p4 + p5.clone();
q[2] = p6 + p7.clone();
q[3] = p2 - p3.clone() + p5 - p7;
Matrix::from_quarters(q)
}
}
fn main() {
let a = Matrix::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let b = Matrix::new(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
let c = Matrix::new(vec![
vec![1.0, 1.0, 1.0, 1.0],
vec![2.0, 4.0, 8.0, 16.0],
vec![3.0, 9.0, 27.0, 81.0],
vec![4.0, 16.0, 64.0, 256.0],
]);
let d = Matrix::new(vec![
vec![4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0],
vec![-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0],
vec![3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0],
vec![-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0],
]);
let e = Matrix::new(vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![5.0, 6.0, 7.0, 8.0],
vec![9.0, 10.0, 11.0, 12.0],
vec![13.0, 14.0, 15.0, 16.0],
]);
let f = Matrix::new(vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
]);
println!("Using 'normal' matrix multiplication:");
println!(" a * b = {}", a.clone() * b.clone());
println!(" c * d = {}", (c.clone() * d.clone()).to_string_with_precision(6));
println!(" e * f = {}", e.clone() * f.clone());
println!("\nUsing 'Strassen' matrix multiplication:");
println!(" a * b = {}", a.strassen(b));
println!(" c * d = {}", c.strassen(d).to_string_with_precision(6));
println!(" e * f = {}", e.strassen(f));
}
- Output:
Using 'normal' matrix multiplication: a * b = [19.0, 22.0] [43.0, 50.0] c * d = ["1", "0", "0", "0"] ["0", "1", "0", "0"] ["0", "0", "1", "0"] ["0", "0", "0", "1"] e * f = [1.0, 2.0, 3.0, 4.0] [5.0, 6.0, 7.0, 8.0] [9.0, 10.0, 11.0, 12.0] [13.0, 14.0, 15.0, 16.0] Using 'Strassen' matrix multiplication: a * b = [19.0, 22.0] [43.0, 50.0] c * d = ["1", "0", "0", "0"] ["0", "1", "0", "0"] ["0", "0", "1", "0"] ["0", "0", "0", "1"] e * f = [1.0, 2.0, 3.0, 4.0] [5.0, 6.0, 7.0, 8.0] [9.0, 10.0, 11.0, 12.0] [13.0, 14.0, 15.0, 16.0]
Scala
import scala.math
object MatrixOperations {
type Matrix = Array[Array[Double]]
implicit class RichMatrix(val m: Matrix) {
def rows: Int = m.length
def cols: Int = m(0).length
def add(m2: Matrix): Matrix = {
require(
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
)
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) + m2(i)(j))
}
def sub(m2: Matrix): Matrix = {
require(
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
)
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) - m2(i)(j))
}
def mul(m2: Matrix): Matrix = {
require(m.cols == m2.rows, "Cannot multiply these matrices.")
Array.tabulate(m.rows, m2.cols)((i, j) =>
(0 until m.cols).map(k => m(i)(k) * m2(k)(j)).sum
)
}
def toString(p: Int): String = {
val pow = math.pow(10, p)
m.map(row =>
row
.map(value => (math.round(value * pow) / pow).toString)
.mkString("[", ", ", "]")
).mkString("[", ",\n ", "]")
}
}
def toQuarters(m: Matrix): Array[Matrix] = {
val r = m.rows / 2
val c = m.cols / 2
val p = params(r, c)
(0 until 4).map { k =>
Array.tabulate(r, c)((i, j) => m(p(k)(0) + i)(p(k)(2) + j))
}.toArray
}
def fromQuarters(q: Array[Matrix]): Matrix = {
val r = q(0).rows
val c = q(0).cols
val p = params(r, c)
Array.tabulate(r * 2, c * 2)((i, j) => q((i / r) * 2 + j / c)(i % r)(j % c))
}
def strassen(a: Matrix, b: Matrix): Matrix = {
require(
a.rows == a.cols && b.rows == b.cols && a.rows == b.rows,
"Matrices must be square and of equal size."
)
require(
a.rows != 0 && (a.rows & (a.rows - 1)) == 0,
"Size of matrices must be a power of two."
)
if (a.rows == 1) {
return a.mul(b)
}
val qa = toQuarters(a)
val qb = toQuarters(b)
val p1 = strassen(qa(1).sub(qa(3)), qb(2).add(qb(3)))
val p2 = strassen(qa(0).add(qa(3)), qb(0).add(qb(3)))
val p3 = strassen(qa(0).sub(qa(2)), qb(0).add(qb(1)))
val p4 = strassen(qa(0).add(qa(1)), qb(3))
val p5 = strassen(qa(0), qb(1).sub(qb(3)))
val p6 = strassen(qa(3), qb(2).sub(qb(0)))
val p7 = strassen(qa(2).add(qa(3)), qb(0))
val q = Array(
p1.add(p2).sub(p4).add(p6),
p4.add(p5),
p6.add(p7),
p2.sub(p3).add(p5).sub(p7)
)
fromQuarters(q)
}
private def params(r: Int, c: Int): Array[Array[Int]] = {
Array(
Array(0, r, 0, c, 0, 0),
Array(0, r, c, 2 * c, 0, c),
Array(r, 2 * r, 0, c, r, 0),
Array(r, 2 * r, c, 2 * c, r, c)
)
}
def main(args: Array[String]): Unit = {
val a: Matrix = Array(Array(1.0, 2.0), Array(3.0, 4.0))
val b: Matrix = Array(Array(5.0, 6.0), Array(7.0, 8.0))
val c: Matrix = Array(
Array(1.0, 1.0, 1.0, 1.0),
Array(2.0, 4.0, 8.0, 16.0),
Array(3.0, 9.0, 27.0, 81.0),
Array(4.0, 16.0, 64.0, 256.0)
)
val d: Matrix = Array(
Array(4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0),
Array(-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0),
Array(3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0),
Array(-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0)
)
val e: Matrix = Array(
Array(1.0, 2.0, 3.0, 4.0),
Array(5.0, 6.0, 7.0, 8.0),
Array(9.0, 10.0, 11.0, 12.0),
Array(13.0, 14.0, 15.0, 16.0)
)
val f: Matrix = Array(
Array(1.0, 0.0, 0.0, 0.0),
Array(0.0, 1.0, 0.0, 0.0),
Array(0.0, 0.0, 1.0, 0.0),
Array(0.0, 0.0, 0.0, 1.0)
)
println("Using 'normal' matrix multiplication:")
println(
s" a * b = ${a.mul(b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
println(s" c * d = ${c.mul(d).toString(6)}")
println(
s" e * f = ${e.mul(f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
println("\nUsing 'Strassen' matrix multiplication:")
println(
s" a * b = ${strassen(a, b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
println(s" c * d = ${strassen(c, d).toString(6)}")
println(
s" e * f = ${strassen(e, f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
}
}
- Output:
Using 'normal' matrix multiplication: a * b = [[19.0, 22.0], [43.0, 50.0]] c * d = [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]] Using 'Strassen' matrix multiplication: a * b = [[19.0, 22.0], [43.0, 50.0]] c * d = [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]
Swift
// Matrix Strassen Multiplication
func strassenMultiply(matrix1: Matrix, matrix2: Matrix) -> Matrix {
precondition(matrix1.columns == matrix2.columns,
"Two matrices can only be matrix multiplied if one has dimensions mxn & the other has dimensions nxp where m, n, p are in R")
// Transform to square matrix
let maxColumns = Swift.max(matrix1.rows, matrix1.columns, matrix2.rows, matrix2.columns)
let pwr2 = nextPowerOfTwo(num: maxColumns)
var sqrMatrix1 = Matrix(rows: pwr2, columns: pwr2)
var sqrMatrix2 = Matrix(rows: pwr2, columns: pwr2)
// fill square matrix 1 with values
for i in 0..<matrix1.rows {
for j in 0..<matrix1.columns{
sqrMatrix1[i, j] = matrix1[i, j]
}
}
// fill square matrix 2 with values
for i in 0..<matrix2.rows {
for j in 0..<matrix2.columns{
sqrMatrix2[i, j] = matrix2[i, j]
}
}
// Get strassen result and transfer to array with proper size
let formulaResult = strassenFormula(matrix1: sqrMatrix1, matrix2: sqrMatrix2)
var finalResult = Matrix(rows: matrix1.rows, columns: matrix2.columns)
for i in 0..<finalResult.rows{
for j in 0..<finalResult.columns {
finalResult[i, j] = formulaResult[i, j]
}
}
return finalResult
}
// Calculate next power of 2
func nextPowerOfTwo(num: Int) -> Int {
// formula for next power of 2
return Int(pow(2,(ceil(log2(Double(num))))))
}
// Multiply Matrices Using Strassen Formula
func strassenFormula(matrix1: Matrix, matrix2: Matrix) -> Matrix {
precondition(matrix1.rows == matrix1.columns && matrix2.rows == matrix2.columns, "Matrices need to be square")
guard matrix1.rows > 1 && matrix2.rows > 1 else { return matrix1 * matrix2 }
let rowHalf = matrix1.rows / 2
// Strassen Formula https://www.geeksforgeeks.org/easy-way-remember-strassens-matrix-equation/
// p1 = a(f-h) p2 = (a+b)h
// p2 = (c+d)e p4 = d(g-e)
// p5 = (a+d)(e+h) p6 = (b-d)(g+h)
// p7 = (a-c)(e+f)
|a b| x |e f| = |(p5+p4-p2+p6) (p1+p2)|
|c d| |g h| |(p3+p4) (p1+p5-p3-p7)|
Matrix 1 Matrix 2 Result
// create empty matrices for a, b, c, d, e, f, g, h
var a = Matrix(rows: rowHalf, columns: rowHalf)
var b = Matrix(rows: rowHalf, columns: rowHalf)
var c = Matrix(rows: rowHalf, columns: rowHalf)
var d = Matrix(rows: rowHalf, columns: rowHalf)
var e = Matrix(rows: rowHalf, columns: rowHalf)
var f = Matrix(rows: rowHalf, columns: rowHalf)
var g = Matrix(rows: rowHalf, columns: rowHalf)
var h = Matrix(rows: rowHalf, columns: rowHalf)
// fill the matrices with values
for i in 0..<rowHalf {
for j in 0..<rowHalf {
a[i, j] = matrix1[i, j]
b[i, j] = matrix1[i, j+rowHalf]
c[i, j] = matrix1[i+rowHalf, j]
d[i, j] = matrix1[i+rowHalf, j+rowHalf]
e[i, j] = matrix2[i, j]
f[i, j] = matrix2[i, j+rowHalf]
g[i, j] = matrix2[i+rowHalf, j]
h[i, j] = matrix2[i+rowHalf, j+rowHalf]
}
}
// a * (f - h)
let p1 = strassenFormula(matrix1: a, matrix2: (f - h))
// (a + b) * h
let p2 = strassenFormula(matrix1: (a + b), matrix2: h)
// (c + d) * e
let p3 = strassenFormula(matrix1: (c + d), matrix2: e)
// d * (g - e)
let p4 = strassenFormula(matrix1: d, matrix2: (g - e))
// (a + d) * (e + h)
let p5 = strassenFormula(matrix1: (a + d), matrix2: (e + h))
// (b - d) * (g + h)
let p6 = strassenFormula(matrix1: (b - d), matrix2: (g + h))
// (a - c) * (e + f)
let p7 = strassenFormula(matrix1: (a - c), matrix2: (e + f))
// p5 + p4 - p2 + p6
let result11 = p5 + p4 - p2 + p6
// p1 + p2
let result12 = p1 + p2
// p3 + p4
let result21 = p3 + p4
// p1 + p5 - p3 - p7
let result22 = p1 + p5 - p3 - p7
// create an empty matrix for result and fill with values
var result = Matrix(rows: matrix1.rows, columns: matrix1.rows)
for i in 0..<rowHalf {
for j in 0..<rowHalf {
result[i, j] = result11[i, j]
result[i, j+rowHalf] = result12[i, j]
result[i+rowHalf, j] = result21[i, j]
result[i+rowHalf, j+rowHalf] = result22[i, j]
}
}
return result
}
func main(){
// Matrix Class https://github.com/hollance/Matrix/blob/master/Matrix.swift
var a = Matrix(rows: 2, columns: 2)
a[row: 0] = [1, 2]
a[row: 1] = [3, 4]
var b = Matrix(rows: 2, columns: 2)
b[row: 0] = [5, 6]
b[row: 1] = [7, 8]
var c = Matrix(rows: 4, columns: 4)
c[row: 0] = [1, 1, 1,1]
c[row: 1] = [2, 4, 8, 16]
c[row: 2] = [3, 9, 27, 81]
c[row: 3] = [4, 16, 64, 256]
var d = Matrix(rows: 4, columns: 4)
d[row: 0] = [4, -3, Double(4/3), Double(-1/4)]
d[row: 1] = [Double(-13/3), Double(19/4), Double(-7/3), Double(11/24)]
d[row: 2] = [Double(3/2), Double(-2), Double(7/6), Double(-1/4)]
d[row: 3] = [Double(-1/6), Double(1/4), Double(-1/6), Double(1/24)]
var e = Matrix(rows: 4, columns: 4)
e[row: 0] = [1, 2, 3, 4]
e[row: 1] = [5, 6, 7, 8]
e[row: 2] = [9, 10, 11, 12]
e[row: 3] = [13, 14, 15, 16]
var f = Matrix(rows: 4, columns: 4)
f[row: 0] = [1, 0, 0, 0]
f[row: 1] = [0, 1, 0 ,0]
f[row: 2] = [0 ,0 ,1, 0]
f[row: 3] = [0, 0 ,0 ,1]
let result1 = strassenMultiply(matrix1: a, matrix2: b)
print("AxB")
print(result1.description)
let result2 = strassenMultiply(matrix1: c, matrix2: d)
print("CxD")
print(result2.description)
let result3 = strassenMultiply(matrix1: e, matrix2: f)
print("ExF")
print(result3.description)
}
main()
- Output:
AxB 19 22 43 50 CxD 1 -1 0 0 0 -6 2 0 3 -27 12 0 16 -76 36 0 ExF 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
Wren
Wren doesn't currently have a matrix module so I've written a rudimentary Matrix class with sufficient functionality to complete this task.
I've used the Phix entry's examples to test the Strassen algorithm implementation.
class Matrix {
construct new(a) {
if (a.type != List || a.count == 0 || a[0].type != List || a[0].count == 0 || a[0][0].type != Num) {
Fiber.abort("Argument must be a non-empty two dimensional list of numbers.")
}
_a = a
}
rows { _a.count }
cols { _a[0].count }
+(b) {
if (b.type != Matrix) Fiber.abort("Argument must be a matrix.")
if ((this.rows != b.rows) || (this.cols != b.cols)) {
Fiber.abort("Matrices must have the same dimensions.")
}
var c = List.filled(rows, null)
for (i in 0...rows) {
c[i] = List.filled(cols, 0)
for (j in 0...cols) c[i][j] = _a[i][j] + b[i, j]
}
return Matrix.new(c)
}
- { this * -1 }
-(b) { this + (-b) }
*(b) {
var c = List.filled(rows, null)
if (b is Num) {
for (i in 0...rows) {
c[i] = List.filled(cols, 0)
for (j in 0...cols) c[i][j] = _a[i][j] * b
}
} else if (b is Matrix) {
if (this.cols != b.rows) Fiber.abort("Cannot multiply these matrices.")
for (i in 0...rows) {
c[i] = List.filled(b.cols, 0)
for (j in 0...b.cols) {
for (k in 0...b.rows) c[i][j] = c[i][j] + _a[i][k] * b[k, j]
}
}
} else {
Fiber.abort("Argument must be a matrix or a number.")
}
return Matrix.new(c)
}
[i] { _a[i].toList }
[i, j] { _a[i][j] }
toString { _a.toString }
// rounds all elements to 'p' places
toString(p) {
var s = List.filled(rows, "")
var pow = 10.pow(p)
for (i in 0...rows) {
var t = List.filled(cols, "")
for (j in 0...cols) {
var r = (_a[i][j]*pow).round / pow
t[j] = r.toString
if (t[j] == "-0") t[j] = "0"
}
s[i] = t.toString
}
return s
}
}
var params = Fn.new { |r, c|
return [
[0...r, 0...c, 0, 0],
[0...r, c...2*c, 0, c],
[r...2*r, 0...c, r, 0],
[r...2*r, c...2*c, r, c]
]
}
var toQuarters = Fn.new { |m|
var r = (m.rows/2).floor
var c = (m.cols/2).floor
var p = params.call(r, c)
var quarters = []
for (k in 0..3) {
var q = List.filled(r, null)
for (i in p[k][0]) {
q[i - p[k][2]] = List.filled(c, 0)
for (j in p[k][1]) q[i - p[k][2]][j - p[k][3]] = m[i, j]
}
quarters.add(Matrix.new(q))
}
return quarters
}
var fromQuarters = Fn.new { |q|
var r = q[0].rows
var c = q[0].cols
var p = params.call(r, c)
r = r * 2
c = c * 2
var m = List.filled(r, null)
for (i in 0...c) m[i] = List.filled(c, 0)
for (k in 0..3) {
for (i in p[k][0]) {
for (j in p[k][1]) m[i][j] = q[k][i - p[k][2], j - p[k][3]]
}
}
return Matrix.new(m)
}
var strassen // recursive
strassen = Fn.new { |a, b|
if (a.rows != a.cols || b.rows != b.cols || a.rows != b.rows) {
Fiber.abort("Matrices must be square and of equal size.")
}
if (a.rows == 0 || (a.rows & (a.rows - 1)) != 0) {
Fiber.abort("Size of matrices must be a power of two.")
}
if (a.rows == 1) return a * b
var qa = toQuarters.call(a)
var qb = toQuarters.call(b)
var p1 = strassen.call(qa[1] - qa[3], qb[2] + qb[3])
var p2 = strassen.call(qa[0] + qa[3], qb[0] + qb[3])
var p3 = strassen.call(qa[0] - qa[2], qb[0] + qb[1])
var p4 = strassen.call(qa[0] + qa[1], qb[3])
var p5 = strassen.call(qa[0], qb[1] - qb[3])
var p6 = strassen.call(qa[3], qb[2] - qb[0])
var p7 = strassen.call(qa[2] + qa[3], qb[0])
var q = List.filled(4, null)
q[0] = p1 + p2 - p4 + p6
q[1] = p4 + p5
q[2] = p6 + p7
q[3] = p2 - p3 + p5 - p7
return fromQuarters.call(q)
}
var a = Matrix.new([ [1,2], [3, 4] ])
var b = Matrix.new([ [5,6], [7, 8] ])
var c = Matrix.new([ [1, 1, 1, 1], [2, 4, 8, 16], [3, 9, 27, 81], [4, 16, 64, 256] ])
var d = Matrix.new([ [4, -3, 4/3, -1/4], [-13/3, 19/4, -7/3, 11/24],
[3/2, -2, 7/6, -1/4], [-1/6, 1/4, -1/6, 1/24] ])
var e = Matrix.new([ [1, 2, 3, 4], [5, 6, 7, 8], [9,10,11,12], [13,14,15,16] ])
var f = Matrix.new([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ])
System.print("Using 'normal' matrix multiplication:")
System.print(" a * b = %(a * b)")
System.print(" c * d = %((c * d).toString(6))")
System.print(" e * f = %(e * f)")
System.print("\nUsing 'Strassen' matrix multiplication:")
System.print(" a * b = %(strassen.call(a, b))")
System.print(" c * d = %(strassen.call(c, d).toString(6))")
System.print(" e * f = %(strassen.call(e, f))")
- Output:
Using 'normal' matrix multiplication: a * b = [[19, 22], [43, 50]] c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] Using 'Strassen' matrix multiplication: a * b = [[19, 22], [43, 50]] c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
Since the above version was written, a Matrix module has been added and the following version uses it. The output is exactly the same as before.
import "./matrix" for Matrix
var params = Fn.new { |r, c|
return [
[0...r, 0...c, 0, 0],
[0...r, c...2*c, 0, c],
[r...2*r, 0...c, r, 0],
[r...2*r, c...2*c, r, c]
]
}
var toQuarters = Fn.new { |m|
var r = (m.numRows/2).floor
var c = (m.numCols/2).floor
var p = params.call(r, c)
var quarters = []
for (k in 0..3) {
var q = List.filled(r, null)
for (i in p[k][0]) {
q[i - p[k][2]] = List.filled(c, 0)
for (j in p[k][1]) q[i - p[k][2]][j - p[k][3]] = m[i, j]
}
quarters.add(Matrix.new(q))
}
return quarters
}
var fromQuarters = Fn.new { |q|
var r = q[0].numRows
var c = q[0].numCols
var p = params.call(r, c)
r = r * 2
c = c * 2
var m = List.filled(r, null)
for (i in 0...c) m[i] = List.filled(c, 0)
for (k in 0..3) {
for (i in p[k][0]) {
for (j in p[k][1]) m[i][j] = q[k][i - p[k][2], j - p[k][3]]
}
}
return Matrix.new(m)
}
var strassen // recursive
strassen = Fn.new { |a, b|
if (!a.isSquare || !b.isSquare || !a.sameSize(b)) {
Fiber.abort("Matrices must be square and of equal size.")
}
if (a.numRows == 0 || (a.numRows & (a.numRows - 1)) != 0) {
Fiber.abort("Size of matrices must be a power of two.")
}
if (a.numRows == 1) return a * b
var qa = toQuarters.call(a)
var qb = toQuarters.call(b)
var p1 = strassen.call(qa[1] - qa[3], qb[2] + qb[3])
var p2 = strassen.call(qa[0] + qa[3], qb[0] + qb[3])
var p3 = strassen.call(qa[0] - qa[2], qb[0] + qb[1])
var p4 = strassen.call(qa[0] + qa[1], qb[3])
var p5 = strassen.call(qa[0], qb[1] - qb[3])
var p6 = strassen.call(qa[3], qb[2] - qb[0])
var p7 = strassen.call(qa[2] + qa[3], qb[0])
var q = List.filled(4, null)
q[0] = p1 + p2 - p4 + p6
q[1] = p4 + p5
q[2] = p6 + p7
q[3] = p2 - p3 + p5 - p7
return fromQuarters.call(q)
}
var a = Matrix.new([ [1,2], [3, 4] ])
var b = Matrix.new([ [5,6], [7, 8] ])
var c = Matrix.new([ [1, 1, 1, 1], [2, 4, 8, 16], [3, 9, 27, 81], [4, 16, 64, 256] ])
var d = Matrix.new([ [4, -3, 4/3, -1/4], [-13/3, 19/4, -7/3, 11/24],
[3/2, -2, 7/6, -1/4], [-1/6, 1/4, -1/6, 1/24] ])
var e = Matrix.new([ [1, 2, 3, 4], [5, 6, 7, 8], [9,10,11,12], [13,14,15,16] ])
var f = Matrix.new([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ])
System.print("Using 'normal' matrix multiplication:")
System.print(" a * b = %(a * b)")
System.print(" c * d = %((c * d).toString(6))")
System.print(" e * f = %(e * f)")
System.print("\nUsing 'Strassen' matrix multiplication:")
System.print(" a * b = %(strassen.call(a, b))")
System.print(" c * d = %(strassen.call(c, d).toString(6))")
System.print(" e * f = %(strassen.call(e, f))")
Zig
const std = @import("std");
const fmt = std.fmt;
const ArrayList = std.ArrayList;
const Allocator = std.mem.Allocator;
const Matrix = struct {
data: ArrayList(ArrayList(f64)),
rows: usize,
cols: usize,
allocator: Allocator,
pub fn init(allocator: Allocator, data: ArrayList(ArrayList(f64))) !Matrix {
const rows = data.items.len;
const cols = if (rows > 0) data.items[0].items.len else 0;
return Matrix{
.data = data,
.rows = rows,
.cols = cols,
.allocator = allocator,
};
}
pub fn deinit(self: *Matrix) void {
for (self.data.items) |*row| {
row.deinit();
}
self.data.deinit();
}
pub fn clone(self: Matrix) !Matrix {
var new_data = ArrayList(ArrayList(f64)).init(self.allocator);
try new_data.ensureTotalCapacity(self.rows);
for (self.data.items) |row| {
var new_row = ArrayList(f64).init(self.allocator);
try new_row.ensureTotalCapacity(self.cols);
try new_row.appendSlice(row.items);
try new_data.append(new_row);
}
return Matrix{
.data = new_data,
.rows = self.rows,
.cols = self.cols,
.allocator = self.allocator,
};
}
pub fn validateDimensions(self: Matrix, other: Matrix) !void {
if (self.rows != other.rows or self.cols != other.cols) {
return error.DimensionMismatch;
}
}
pub fn validateMultiplication(self: Matrix, other: Matrix) !void {
if (self.cols != other.rows) {
return error.CannotMultiply;
}
}
pub fn validateSquarePowerOfTwo(self: Matrix) !void {
if (self.rows != self.cols) {
return error.NotSquare;
}
if (self.rows == 0 or (self.rows & (self.rows - 1)) != 0) {
return error.NotPowerOfTwo;
}
}
pub fn add(self: Matrix, other: Matrix) !Matrix {
try self.validateDimensions(other);
var result_data = ArrayList(ArrayList(f64)).init(self.allocator);
try result_data.ensureTotalCapacity(self.rows);
for (0..self.rows) |i| {
var row = ArrayList(f64).init(self.allocator);
try row.ensureTotalCapacity(self.cols);
for (0..self.cols) |j| {
try row.append(self.data.items[i].items[j] + other.data.items[i].items[j]);
}
try result_data.append(row);
}
return try Matrix.init(self.allocator, result_data);
}
pub fn sub(self: Matrix, other: Matrix) !Matrix {
try self.validateDimensions(other);
var result_data = ArrayList(ArrayList(f64)).init(self.allocator);
try result_data.ensureTotalCapacity(self.rows);
for (0..self.rows) |i| {
var row = ArrayList(f64).init(self.allocator);
try row.ensureTotalCapacity(self.cols);
for (0..self.cols) |j| {
try row.append(self.data.items[i].items[j] - other.data.items[i].items[j]);
}
try result_data.append(row);
}
return try Matrix.init(self.allocator, result_data);
}
pub fn mul(self: Matrix, other: Matrix) !Matrix {
try self.validateMultiplication(other);
var result_data = ArrayList(ArrayList(f64)).init(self.allocator);
try result_data.ensureTotalCapacity(self.rows);
for (0..self.rows) |i| {
var row = ArrayList(f64).init(self.allocator);
try row.ensureTotalCapacity(other.cols);
for (0..other.cols) |j| {
var sum: f64 = 0.0;
for (0..self.cols) |k| {
sum += self.data.items[i].items[k] * other.data.items[k].items[j];
}
try row.append(sum);
}
try result_data.append(row);
}
return try Matrix.init(self.allocator, result_data);
}
pub fn format(self: Matrix, comptime _: []const u8, _: fmt.FormatOptions, writer: anytype) !void {
for (self.data.items) |row| {
try writer.print("{any}\n", .{row.items});
}
}
pub fn toStringWithPrecision(self: Matrix, p: usize, allocator: Allocator) ![]u8 {
var output = ArrayList(u8).init(allocator);
defer output.deinit();
const pow = std.math.pow(f64, 10.0, @as(f64, @floatFromInt(p)));
for (self.data.items) |row| {
var formatted_row = ArrayList([]const u8).init(allocator);
defer {
for (formatted_row.items) |item| {
allocator.free(item);
}
formatted_row.deinit();
}
for (row.items) |val| {
const r = @round(val * pow) / pow;
const formatted = try fmt.allocPrint(allocator, "{d}", .{r});
if (std.mem.eql(u8, formatted, "-0")) {
allocator.free(formatted);
try formatted_row.append(try allocator.dupe(u8, "0"));
} else {
try formatted_row.append(formatted);
}
}
std.debug.print("{any}\n", .{formatted_row.items});
}
return output.toOwnedSlice();
}
fn params(r: usize, c: usize) [4][6]usize {
return [4][6]usize{
[_]usize{ 0, r, 0, c, 0, 0 },
[_]usize{ 0, r, c, 2 * c, 0, c },
[_]usize{ r, 2 * r, 0, c, r, 0 },
[_]usize{ r, 2 * r, c, 2 * c, r, c },
};
}
pub fn toQuarters(self: Matrix) ![4]Matrix {
const r = self.rows / 2;
const c = self.cols / 2;
const p = Matrix.params(r, c);
var quarters: [4]Matrix = undefined;
for (0..4) |k| {
var q_data = ArrayList(ArrayList(f64)).init(self.allocator);
try q_data.ensureTotalCapacity(r);
for (p[k][0]..p[k][1]) |i| {
var row = ArrayList(f64).init(self.allocator);
try row.ensureTotalCapacity(c);
for (p[k][2]..p[k][3]) |j| {
try row.append(self.data.items[i].items[j]);
}
try q_data.append(row);
}
quarters[k] = try Matrix.init(self.allocator, q_data);
}
return quarters;
}
pub fn fromQuarters(q: [4]Matrix, allocator: Allocator) !Matrix {
const r = q[0].rows;
const c = q[0].cols;
const p = Matrix.params(r, c);
const rows = r * 2;
const cols = c * 2;
var m_data = ArrayList(ArrayList(f64)).init(allocator);
try m_data.ensureTotalCapacity(rows);
for (0..rows) |_| {
var row = ArrayList(f64).init(allocator);
try row.ensureTotalCapacity(cols);
for (0..cols) |_| {
try row.append(0.0);
}
try m_data.append(row);
}
for (0..4) |k| {
for (p[k][0]..p[k][1]) |i| {
for (p[k][2]..p[k][3]) |j| {
m_data.items[i].items[j] = q[k].data.items[i - p[k][4]].items[j - p[k][5]];
}
}
}
return try Matrix.init(allocator, m_data);
}
pub fn strassen(self: Matrix, other: Matrix) !Matrix {
try self.validateSquarePowerOfTwo();
try other.validateSquarePowerOfTwo();
if (self.rows != other.rows or self.cols != other.cols) {
return error.InvalidDimensions;
}
if (self.rows == 1) {
return self.mul(other);
}
var qa = try self.toQuarters();
defer for (&qa) |*q| q.deinit();
var qb = try other.toQuarters();
defer for (&qb) |*q| q.deinit();
var t1 = try qa[1].sub(qa[3]);
defer t1.deinit();
var t2 = try qb[2].add(qb[3]);
defer t2.deinit();
var p1 = try t1.strassen(t2);
defer p1.deinit();
var t3 = try qa[0].add(qa[3]);
defer t3.deinit();
var t4 = try qb[0].add(qb[3]);
defer t4.deinit();
var p2 = try t3.strassen(t4);
defer p2.deinit();
var t5 = try qa[0].sub(qa[2]);
defer t5.deinit();
var t6 = try qb[0].add(qb[1]);
defer t6.deinit();
var p3 = try t5.strassen(t6);
defer p3.deinit();
var t7 = try qa[0].add(qa[1]);
defer t7.deinit();
var p4 = try t7.strassen(qb[3]);
defer p4.deinit();
var t8 = try qb[1].sub(qb[3]);
defer t8.deinit();
var p5 = try qa[0].strassen(t8);
defer p5.deinit();
var t9 = try qb[2].sub(qb[0]);
defer t9.deinit();
var p6 = try qa[3].strassen(t9);
defer p6.deinit();
var t10 = try qa[2].add(qa[3]);
defer t10.deinit();
var p7 = try t10.strassen(qb[0]);
defer p7.deinit();
var q: [4]Matrix = undefined;
// q[0] = p1 + p2 - p4 + p6
var ta = try p1.add(p2);
defer ta.deinit();
var tb = try ta.sub(p4);
defer tb.deinit();
q[0] = try tb.add(p6);
// q[1] = p4 + p5
q[1] = try p4.add(p5);
// q[2] = p6 + p7
q[2] = try p6.add(p7);
// q[3] = p2 - p3 + p5 - p7
var tc = try p2.sub(p3);
defer tc.deinit();
var td = try tc.add(p5);
defer td.deinit();
q[3] = try td.sub(p7);
defer for (&q) |*quarter| quarter.deinit();
return Matrix.fromQuarters(q, self.allocator);
}
};
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
// Matrix A - [1 2; 3 4]
var a_data = ArrayList(ArrayList(f64)).init(allocator);
var a_row1 = ArrayList(f64).init(allocator);
try a_row1.appendSlice(&[_]f64{ 1.0, 2.0 });
var a_row2 = ArrayList(f64).init(allocator);
try a_row2.appendSlice(&[_]f64{ 3.0, 4.0 });
try a_data.append(a_row1);
try a_data.append(a_row2);
var a = try Matrix.init(allocator, a_data);
defer a.deinit();
// Matrix B - [5 6; 7 8]
var b_data = ArrayList(ArrayList(f64)).init(allocator);
var b_row1 = ArrayList(f64).init(allocator);
try b_row1.appendSlice(&[_]f64{ 5.0, 6.0 });
var b_row2 = ArrayList(f64).init(allocator);
try b_row2.appendSlice(&[_]f64{ 7.0, 8.0 });
try b_data.append(b_row1);
try b_data.append(b_row2);
var b = try Matrix.init(allocator, b_data);
defer b.deinit();
// Matrix C - 4x4
var c_data = ArrayList(ArrayList(f64)).init(allocator);
var c_row1 = ArrayList(f64).init(allocator);
try c_row1.appendSlice(&[_]f64{ 1.0, 1.0, 1.0, 1.0 });
var c_row2 = ArrayList(f64).init(allocator);
try c_row2.appendSlice(&[_]f64{ 2.0, 4.0, 8.0, 16.0 });
var c_row3 = ArrayList(f64).init(allocator);
try c_row3.appendSlice(&[_]f64{ 3.0, 9.0, 27.0, 81.0 });
var c_row4 = ArrayList(f64).init(allocator);
try c_row4.appendSlice(&[_]f64{ 4.0, 16.0, 64.0, 256.0 });
try c_data.append(c_row1);
try c_data.append(c_row2);
try c_data.append(c_row3);
try c_data.append(c_row4);
var c = try Matrix.init(allocator, c_data);
defer c.deinit();
// Matrix D - 4x4
var d_data = ArrayList(ArrayList(f64)).init(allocator);
var d_row1 = ArrayList(f64).init(allocator);
try d_row1.appendSlice(&[_]f64{ 4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0 });
var d_row2 = ArrayList(f64).init(allocator);
try d_row2.appendSlice(&[_]f64{ -13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0 });
var d_row3 = ArrayList(f64).init(allocator);
try d_row3.appendSlice(&[_]f64{ 3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0 });
var d_row4 = ArrayList(f64).init(allocator);
try d_row4.appendSlice(&[_]f64{ -1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0 });
try d_data.append(d_row1);
try d_data.append(d_row2);
try d_data.append(d_row3);
try d_data.append(d_row4);
var d = try Matrix.init(allocator, d_data);
defer d.deinit();
// Matrix E - 4x4
var e_data = ArrayList(ArrayList(f64)).init(allocator);
var e_row1 = ArrayList(f64).init(allocator);
try e_row1.appendSlice(&[_]f64{ 1.0, 2.0, 3.0, 4.0 });
var e_row2 = ArrayList(f64).init(allocator);
try e_row2.appendSlice(&[_]f64{ 5.0, 6.0, 7.0, 8.0 });
var e_row3 = ArrayList(f64).init(allocator);
try e_row3.appendSlice(&[_]f64{ 9.0, 10.0, 11.0, 12.0 });
var e_row4 = ArrayList(f64).init(allocator);
try e_row4.appendSlice(&[_]f64{ 13.0, 14.0, 15.0, 16.0 });
try e_data.append(e_row1);
try e_data.append(e_row2);
try e_data.append(e_row3);
try e_data.append(e_row4);
var e = try Matrix.init(allocator, e_data);
defer e.deinit();
// Matrix F - Identity 4x4
var f_data = ArrayList(ArrayList(f64)).init(allocator);
var f_row1 = ArrayList(f64).init(allocator);
try f_row1.appendSlice(&[_]f64{ 1.0, 0.0, 0.0, 0.0 });
var f_row2 = ArrayList(f64).init(allocator);
try f_row2.appendSlice(&[_]f64{ 0.0, 1.0, 0.0, 0.0 });
var f_row3 = ArrayList(f64).init(allocator);
try f_row3.appendSlice(&[_]f64{ 0.0, 0.0, 1.0, 0.0 });
var f_row4 = ArrayList(f64).init(allocator);
try f_row4.appendSlice(&[_]f64{ 0.0, 0.0, 0.0, 1.0 });
try f_data.append(f_row1);
try f_data.append(f_row2);
try f_data.append(f_row3);
try f_data.append(f_row4);
var f = try Matrix.init(allocator, f_data);
defer f.deinit();
const stdout = std.io.getStdOut().writer();
try stdout.print("Using 'normal' matrix multiplication:\n", .{});
var a_clone = try a.clone();
defer a_clone.deinit();
var b_clone = try b.clone();
defer b_clone.deinit();
var ab = try a_clone.mul(b_clone);
defer ab.deinit();
try stdout.print(" a * b = {}\n", .{ab});
var c_clone = try c.clone();
defer c_clone.deinit();
var d_clone = try d.clone();
defer d_clone.deinit();
var cd = try c_clone.mul(d_clone);
defer cd.deinit();
const cd_str = try cd.toStringWithPrecision(6, allocator);
defer allocator.free(cd_str);
try stdout.print(" c * d = {s}\n", .{cd_str});
var e_clone = try e.clone();
defer e_clone.deinit();
var f_clone = try f.clone();
defer f_clone.deinit();
var ef = try e_clone.mul(f_clone);
defer ef.deinit();
try stdout.print(" e * f = {}\n", .{ef});
try stdout.print("\nUsing 'Strassen' matrix multiplication:\n", .{});
var a_clone2 = try a.clone();
defer a_clone2.deinit();
var b_clone2 = try b.clone();
defer b_clone2.deinit();
var ab_s = try a_clone2.strassen(b_clone2);
defer ab_s.deinit();
try stdout.print(" a * b = {}\n", .{ab_s});
var c_clone2 = try c.clone();
defer c_clone2.deinit();
var d_clone2 = try d.clone();
defer d_clone2.deinit();
var cd_s = try c_clone2.strassen(d_clone2);
defer cd_s.deinit();
const cd_s_str = try cd_s.toStringWithPrecision(6, allocator);
defer allocator.free(cd_s_str);
try stdout.print(" c * d = {s}\n", .{cd_s_str});
var e_clone2 = try e.clone();
defer e_clone2.deinit();
var f_clone2 = try f.clone();
defer f_clone2.deinit();
var ef_s = try e_clone2.strassen(f_clone2);
defer ef_s.deinit();
try stdout.print(" e * f = {}\n", .{ef_s});
}
- Output:
Using 'normal' matrix multiplication: a * b = { 1.9e1, 2.2e1 } { 4.3e1, 5e1 } { { 49 }, { 48 }, { 48 }, { 48 } } { { 48 }, { 49 }, { 48 }, { 48 } } { { 48 }, { 48 }, { 49 }, { 48 } } { { 48 }, { 48 }, { 48 }, { 49 } } c * d = e * f = { 1e0, 2e0, 3e0, 4e0 } { 5e0, 6e0, 7e0, 8e0 } { 9e0, 1e1, 1.1e1, 1.2e1 } { 1.3e1, 1.4e1, 1.5e1, 1.6e1 } Using 'Strassen' matrix multiplication: a * b = { 1.9e1, 2.2e1 } { 4.3e1, 5e1 } { { 49 }, { 48 }, { 48 }, { 48 } } { { 48 }, { 49 }, { 48 }, { 48 } } { { 48 }, { 48 }, { 49 }, { 48 } } { { 48 }, { 48 }, { 48 }, { 49 } } c * d = e * f = { 1e0, 2e0, 3e0, 4e0 } { 5e0, 6e0, 7e0, 8e0 } { 9e0, 1e1, 1.1e1, 1.2e1 } { 1.3e1, 1.4e1, 1.5e1, 1.6e1 }