Strassen's algorithm: Difference between revisions

From Rosetta Code
Content added Content deleted
(Created page with "In linear algebra, the Strassen algorithm, named after Volker Strassen, is an algorithm for matrix multiplication. It is faster than the standard matrix multiplication algorit...")
 
Line 32: Line 32:


return [C11 C12; C21 C22]
return [C11 C12; C21 C22]
end
end</lang>

=={{header|Phix}}==
<lang Phix>function strassen(sequence a, b)
integer l = length(a)
if length(a[1])!=l
or length(b)!=l
or length(b[1])!=l then
crash("two equal square matrices only")
end if
if l=1 then return sq_mul(a,b) end if
if remainder(l,1) then
crash("2^n matrices only")
end if
integer h = l/2
sequence {a11,a12,a21,a22,b11,b12,b21,b22} @= repeat(repeat(0,h),h)
for i=1 to h do
for j=1 to h do
a11[i][j] = a[i][j]
a12[i][j] = a[i][j+h]
a21[i][j] = a[i+h][j]
a22[i][j] = a[i+h][j+h]
b11[i][j] = b[i][j]
b12[i][j] = b[i][j+h]
b21[i][j] = b[i+h][j]
b22[i][j] = b[i+h][j+h]
end for
end for
sequence p1 = strassen(sq_sub(a12,a22), sq_add(b21,b22)),
p2 = strassen(sq_add(a11,a22), sq_add(b11,b22)),
p3 = strassen(sq_sub(a11,a21), sq_add(b11,b12)),
p4 = strassen(sq_add(a11,a12), b22),
p5 = strassen(a11, sq_sub(b12,b22)),
p6 = strassen(a22, sq_sub(b21,b11)),
p7 = strassen(sq_add(a21,a22), b11),
c11 = sq_add(sq_sub(sq_add(p1,p2),p4),p6),
c12 = sq_add(p4,p5),
c21 = sq_add(p6,p7),
c22 = sq_sub(sq_add(sq_sub(p2,p3),p5),p7),
c = repeat(repeat(0,l),l)
for i=1 to h do
for j=1 to h do
c[i][j] = c11[i][j]
c[i][j+h] = c12[i][j]
c[i+h][j] = c21[i][j]
c[i+h][j+h] = c22[i][j]
end for
end for
return c
end function

ppOpt({pp_Nest,1,pp_IntFmt,"%3d",pp_FltFmt,"%3.0f",pp_IntCh,false})

constant A = {{1,2},
{3,4}},
B = {{5,6},
{7,8}}
pp(strassen(A,B))

constant C = { { 1, 1, 1, 1 },
{ 2, 4, 8, 16 },
{ 3, 9, 27, 81 },
{ 4, 16, 64, 256 }},
D = { { 4, -3, 4/3, -1/ 4 },
{-13/3, 19/4, -7/3, 11/24 },
{ 3/2, -2, 7/6, -1/ 4 },
{ -1/6, 1/4, -1/6, 1/24 }}
pp(strassen(C,D))

constant E = {{ 1, 2, 3, 4},
{ 5, 6, 7, 8},
{ 9,10,11,12},
{13,14,15,16}},
F = {{1, 0, 0, 0},
{0, 1, 0, 0},
{0, 0, 1, 0},
{0, 0, 0, 1}}
pp(strassen(E,F))

constant r = sqrt(2)/2,
R = {{ r,r},
{-r,r}}
pp(strassen(R,R))</lang>
{{out}}
<pre>
{{ 19, 22},
{ 43, 50}}
{{ 1, 0, 0, 0},
{ 0, 1, 0, 0},
{ 0, 0, 1, 0},
{ 0, 0, 0, 1}}
{{ 1, 2, 3, 4},
{ 5, 6, 7, 8},
{ 9, 10, 11, 12},
{ 13, 14, 15, 16}}
{{ 0, 1},
{ -1, 0}}
</pre>

Revision as of 18:14, 24 September 2020

In linear algebra, the Strassen algorithm, named after Volker Strassen, is an algorithm for matrix multiplication. It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices, but would be slower than the fastest known algorithms for extremely large matrices.

Julia

The multiplication is denoted by * <lang Julia> function Strassen(A::Matrix, B::Matrix)

   n = size(A, 1)
   if n == 1
       return A * B
   end
   @views A11 = A[1:n/2, 1:n/2]
   @views A12 = A[1:n/2, n/2+1:n]
   @views A21 = A[n/2+1:n, 1:n/2]
   @views A11 = A[n/2+1:n, n/2+1:n]
   @views B11 = B[1:n/2, 1:n/2]
   @views B12 = B[1:n/2, n/2+1:n]
   @views B21 = B[n/2+1:n, 1:n/2]
   @views B11 = B[n/2+1:n, n/2+1:n]
   P1 = Strassen(A12 - A22, B21 + B22)
   P2 = Strassen(A11 + A22, B11 + B22)
   P3 = Strassen(A11 - A21, B11 + B12)
   P4 = Strassen(A11 + A12, B22)
   P5 = Strassen(A11, B12 - B22)
   P6 = Strassen(A22, B21 - B11)
   P7 = Strassen(A21 + A22, B11)
   C11 = P1 + P2 - P4 + P6
   C12 = P4 + P5
   C21 = P6 + P7
   C22 = P2 - P3 + P5 - P7
   return [C11 C12; C21 C22]

end</lang>

Phix

<lang Phix>function strassen(sequence a, b)

   integer l = length(a)
   if length(a[1])!=l
   or length(b)!=l
   or length(b[1])!=l then
       crash("two equal square matrices only")
   end if
   if l=1 then return sq_mul(a,b) end if
   if remainder(l,1) then
       crash("2^n matrices only")
   end if
   integer h = l/2
   sequence {a11,a12,a21,a22,b11,b12,b21,b22} @= repeat(repeat(0,h),h)
   for i=1 to h do
       for j=1 to h do
           a11[i][j] = a[i][j]
           a12[i][j] = a[i][j+h]
           a21[i][j] = a[i+h][j]
           a22[i][j] = a[i+h][j+h]
           b11[i][j] = b[i][j]
           b12[i][j] = b[i][j+h]
           b21[i][j] = b[i+h][j]
           b22[i][j] = b[i+h][j+h]
       end for
   end for
   sequence p1 = strassen(sq_sub(a12,a22), sq_add(b21,b22)),
            p2 = strassen(sq_add(a11,a22), sq_add(b11,b22)),
            p3 = strassen(sq_sub(a11,a21), sq_add(b11,b12)),
            p4 = strassen(sq_add(a11,a12), b22),
            p5 = strassen(a11, sq_sub(b12,b22)),
            p6 = strassen(a22, sq_sub(b21,b11)),
            p7 = strassen(sq_add(a21,a22), b11),

            c11 = sq_add(sq_sub(sq_add(p1,p2),p4),p6),
            c12 = sq_add(p4,p5),
            c21 = sq_add(p6,p7),
            c22 = sq_sub(sq_add(sq_sub(p2,p3),p5),p7),
            c = repeat(repeat(0,l),l)
   for i=1 to h do
       for j=1 to h do
           c[i][j] = c11[i][j]
           c[i][j+h] = c12[i][j]
           c[i+h][j] = c21[i][j]
           c[i+h][j+h] = c22[i][j]
       end for
   end for
   return c

end function

ppOpt({pp_Nest,1,pp_IntFmt,"%3d",pp_FltFmt,"%3.0f",pp_IntCh,false})

constant A = {{1,2},

             {3,4}},
        B = {{5,6},
             {7,8}}

pp(strassen(A,B))

constant C = { { 1, 1, 1, 1 },

              { 2,  4,  8,  16 },
              { 3,  9, 27,  81 },
              { 4, 16, 64, 256 }},
        D = { {    4,   -3,  4/3, -1/ 4 },
              {-13/3, 19/4, -7/3, 11/24 },
              {  3/2,   -2,  7/6, -1/ 4 },
              { -1/6,  1/4, -1/6,  1/24 }}

pp(strassen(C,D))

constant E = {{ 1, 2, 3, 4},

             { 5, 6, 7, 8},
             { 9,10,11,12},
             {13,14,15,16}},
        F = {{1, 0, 0, 0},
             {0, 1, 0, 0},
             {0, 0, 1, 0},
             {0, 0, 0, 1}}

pp(strassen(E,F))

constant r = sqrt(2)/2,

        R = {{ r,r},
             {-r,r}}

pp(strassen(R,R))</lang>

Output:
{{ 19, 22},
 { 43, 50}}
{{  1,  0,  0,  0},
 {  0,  1,  0,  0},
 {  0,  0,  1,  0},
 {  0,  0,  0,  1}}
{{  1,  2,  3,  4},
 {  5,  6,  7,  8},
 {  9, 10, 11, 12},
 { 13, 14, 15, 16}}
{{  0,  1},
 { -1,  0}}