Strassen算法不是最快的?
我从某处复制了施特拉森的算法,然后执行它。以下是输出
n = 256
classical took 360ms
strassen 1 took 33609ms
strassen2 took 1172ms
classical took 437ms
strassen 1 took 32891ms
strassen2 took 1156ms
classical took 266ms
strassen 1 took 27234ms
strassen2 took 734ms
,其中 strassen1
是动态方法,strassen2
用于缓存,classical
是旧的矩阵乘法。这说明我们古朴的古典是最好的。这是真的还是我在某个地方错了?这是 Java 中的代码。
import java.util.Random;
class TestIntMatrixMultiplication {
public static void main (String...args) throws Exception {
final int n = args.length > 0 ? Integer.parseInt(args[0]) : 256;
final int seed = args.length > 1 ? Integer.parseInt(args[1]) : 256;
final Random random = new Random(seed);
int[][] a, b, c;
a = new int[n][n];
b = new int[n][n];
c = new int[n][n];
for(int i=0; i<n; i++) {
for(int j=0; j<n; j++) {
a[i][j] = random.nextInt(100);
b[i][j] = random.nextInt(100);
}
}
System.out.println("n = " + n);
if (a.length < 64) {
System.out.println("A");
dumpMatrix(a);
System.out.println("B");
dumpMatrix(b);
System.out.println("classic");
Classical.mult(c, a, b);
dumpMatrix(c);
System.out.println("strassen");
strassen2.mult(c, a, b);
dumpMatrix(c);
return;
}
for (int i = 0; i <3; ++i) {
timeMultiplies1(a, b, c);
if (n <= 256)
timeMultiplies2( a, b, c);
timeMultiplies3( a, b, c);
}
}
static void timeMultiplies1 (int[][] a, int[][] b, int[][] c) {
final long start = System.currentTimeMillis();
Classical.mult(c, a, b);
final long finish = System.currentTimeMillis();
System.out.println("classical took " + (finish - start) + "ms");
}
static void timeMultiplies2(int[][] a, int[][] b, int[][] c) {
final long start = System.currentTimeMillis();
strassen1.mult(c, a, b);
final long finish = System.currentTimeMillis();
System.out.println("strassen 1 took " + (finish - start) + "ms");
}
static void timeMultiplies3 (int[][] a, int[][] b, int[][] c) {
final long start = System.currentTimeMillis();
strassen2.mult(c, a, b);
final long finish = System.currentTimeMillis();
System.out.println("strassen2 took " + (finish - start) + "ms");
}
static void dumpMatrix (int[][] m) {
for (int[] row : m) {
System.out.print("[\t");
for (int val : row) {
System.out.print(val);
System.out.print('\t');
}
System.out.println(']');
}
}
}
class strassen1{
public String getName () {
return "Strassen(dynamic)";
}
public static int[][] mult (int[][] c, int[][] a, int[][] b) {
return strassenMatrixMultiplication(a, b);
}
public static int [][] strassenMatrixMultiplication(int [][] A, int [][] B) {
int n = A.length;
int [][] result = new int[n][n];
if(n == 1) {
result[0][0] = A[0][0] * B[0][0];
} else {
int [][] A11 = new int[n/2][n/2];
int [][] A12 = new int[n/2][n/2];
int [][] A21 = new int[n/2][n/2];
int [][] A22 = new int[n/2][n/2];
int [][] B11 = new int[n/2][n/2];
int [][] B12 = new int[n/2][n/2];
int [][] B21 = new int[n/2][n/2];
int [][] B22 = new int[n/2][n/2];
divideArray(A, A11, 0 , 0);
divideArray(A, A12, 0 , n/2);
divideArray(A, A21, n/2, 0);
divideArray(A, A22, n/2, n/2);
divideArray(B, B11, 0 , 0);
divideArray(B, B12, 0 , n/2);
divideArray(B, B21, n/2, 0);
divideArray(B, B22, n/2, n/2);
int [][] P1 = strassenMatrixMultiplication(addMatrices(A11, A22), addMatrices(B11, B22));
int [][] P2 = strassenMatrixMultiplication(addMatrices(A21, A22), B11);
int [][] P3 = strassenMatrixMultiplication(A11, subtractMatrices(B12, B22));
int [][] P4 = strassenMatrixMultiplication(A22, subtractMatrices(B21, B11));
int [][] P5 = strassenMatrixMultiplication(addMatrices(A11, A12), B22);
int [][] P6 = strassenMatrixMultiplication(subtractMatrices(A21, A11), addMatrices(B11, B12));
int [][] P7 = strassenMatrixMultiplication(subtractMatrices(A12, A22), addMatrices(B21, B22));
int [][] C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
int [][] C12 = addMatrices(P3, P5);
int [][] C21 = addMatrices(P2, P4);
int [][] C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);
copySubArray(C11, result, 0 , 0);
copySubArray(C12, result, 0 , n/2);
copySubArray(C21, result, n/2, 0);
copySubArray(C22, result, n/2, n/2);
}
return result;
}
public static int [][] addMatrices(int [][] A, int [][] B) {
int n = A.length;
int [][] result = new int[n][n];
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
result[i][j] = A[i][j] + B[i][j];
return result;
}
public static int [][] subtractMatrices(int [][] A, int [][] B) {
int n = A.length;
int [][] result = new int[n][n];
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
result[i][j] = A[i][j] - B[i][j];
return result;
}
public static void divideArray(int[][] parent, int[][] child, int iB, int jB) {
for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
child[i1][j1] = parent[i2][j2];
}
public static void copySubArray(int[][] child, int[][] parent, int iB, int jB) {
for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
parent[i2][j2] = child[i1][j1];
}
}
class strassen2{
public String getName () {
return "Strassen(cached)";
}
static int [][] p1;
static int [][] p2;
static int [][] p3;
static int [][] p4;
static int [][] p5;
static int [][] p6;
static int [][] p7;
static int [][] t0;
static int [][] t1;
public static int[][] mult (int[][] c, int[][] a, int[][] b) {
final int n = c.length;
if (p1 == null || p1.length < n) {
p1 = new int[n/2][n-1];
p2 = new int[n/2][n-1];
p3 = new int[n/2][n-1];
p4 = new int[n/2][n-1];
p5 = new int[n/2][n-1];
p6 = new int[n/2][n-1];
p7 = new int[n/2][n-1];
t0 = new int[n/2][n-1];
t1 = new int[n/2][n-1];
}
mult(c, a, b, 0, 0, n, 0);
return c;
}
public static void mult (int[][] c, int[][] a, int[][] b, int i0, int j0, int n, int offs) {
if(n == 1) {
c[i0][j0] = a[i0][j0] * b[i0][j0];
} else {
final int nBy2 = n/2;
final int i1 = i0 + nBy2;
final int j1 = j0 + nBy2;
// offset applied to 'p' j index so recursive calls don't overwrite data
final int jp0 = offs;
final int jp1 = nBy2 + offs;
// P1 <- (A11 + A22)(B11 + B22)
// T0 <- (A11 + A22), T1 <- (B11 + B22), P1 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i0][j + j0] + b[i + i1][j + j1];
}
}
mult(p1, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P2 <- (A21 + A22)B11
// T0 <- (A21 + A22), T1 <- B11, P2 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i1][j + j0] + a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i0][j + j0];
}
}
mult(p2, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P3 <- A11(B12 - B22)
// T0 <- A11, T1 <- (B12 - B22), P3 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j0];
t1[i + i0][j + jp0] = b[i + i0][j + j1] - b[i + i1][j + j1];
}
}
mult(p3, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P4 <- A22(B21 - B11)
// T0 <- A22, T1 <- (B21 - B11), P4 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i1][j + j0] - b[i + i0][j + j0];
}
}
mult(p4, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P5 <- (A11 + A12) B22
// T0 <- (A11 + A12), T1 <- B22, P5 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i0][j + j1];
t1[i + i0][j + jp0] = b[i + i1][j + j1];
}
}
mult(p5, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P6 <- (A21 - A11)(B11 - B12)
// T0 <- (A21 - A11), T1 <- (B11 - B12), P6 <- T0 * T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i1][j + j0] - a[i + i0][j + j0];
t1[i + i0][j + jp0] = b[i + i0][j + j0] - b[i + i0][j + j1];
}
}
mult(p6, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P7 <- (A12 - A22)(B21 + B22)
// T0 <- (A12 - A22), T1 <- (B21 + B22), P7 <- T0 * T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j1] - a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i1][j + j0] + b[i + i1][j + j1];
}
}
mult(p7, t0, t1, i0, jp0, nBy2, offs + nBy2);
// combine
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
// C11 = P1 + P4 - P5 + P7;
c[i + i0][j + j0] = p1[i + i0][j + jp0] + p4[i + i0][j + jp0] - p5[i + i0][j + jp0] + p7[i + i0][j + jp0];
// C12 = P3 + P5;
c[i + i0][j + j1] = p3[i + i0][j + jp0] + p5[i + i0][j + jp0];
// C21 = P2 + P4;
c[i + i1][j + j0] = p2[i + i0][j + jp0] + p4[i + i0][j + jp0];
// C22 = P1 + P3 - P2 + P6;
c[i + i1][j + j1] = p1[i + i0][j + jp0] + p3[i + i0][j + jp0] - p2[i + i0][j + jp0] + p6[i + i0][j + jp0];
}
}
}
}
void dumpInternal () {
System.out.println("P1");
TestIntMatrixMultiplication.dumpMatrix(p1);
System.out.println("P2");
TestIntMatrixMultiplication.dumpMatrix(p2);
System.out.println("P3");
TestIntMatrixMultiplication.dumpMatrix(p3);
System.out.println("P4");
TestIntMatrixMultiplication.dumpMatrix(p4);
System.out.println("P5");
TestIntMatrixMultiplication.dumpMatrix(p5);
System.out.println("P6");
TestIntMatrixMultiplication.dumpMatrix(p6);
System.out.println("P7");
TestIntMatrixMultiplication.dumpMatrix(p7);
System.out.println("T0");
TestIntMatrixMultiplication.dumpMatrix(t0);
System.out.println("T1");
TestIntMatrixMultiplication.dumpMatrix(t1);
}
}
class Classical{
public String getName () {
return "classic";
}
public static int[][] mult (int[][] c, int[][] a, int[][] b) {
int n = a.length;
for(int i=0; i<n; i++) {
final int[] a_i = a[i];
final int[] c_i = c[i];
for(int j=0; j<n; j++) {
int sum = 0;
for(int k=0; k<n; k++) {
sum += a_i[k] * b[k][j];
}
c_i[j] = sum;
}
}
return c;
}
}
I copied strassen's algorithm from somewhere and then executed it. Here is the output
n = 256
classical took 360ms
strassen 1 took 33609ms
strassen2 took 1172ms
classical took 437ms
strassen 1 took 32891ms
strassen2 took 1156ms
classical took 266ms
strassen 1 took 27234ms
strassen2 took 734ms
where strassen1
is a dynamic approach, strassen2
for cache and classical
is the old matrix multiplication. This means that our old and easy classical one is the best. Is this true or i am wrong somewhere? Here's the code in Java.
import java.util.Random;
class TestIntMatrixMultiplication {
public static void main (String...args) throws Exception {
final int n = args.length > 0 ? Integer.parseInt(args[0]) : 256;
final int seed = args.length > 1 ? Integer.parseInt(args[1]) : 256;
final Random random = new Random(seed);
int[][] a, b, c;
a = new int[n][n];
b = new int[n][n];
c = new int[n][n];
for(int i=0; i<n; i++) {
for(int j=0; j<n; j++) {
a[i][j] = random.nextInt(100);
b[i][j] = random.nextInt(100);
}
}
System.out.println("n = " + n);
if (a.length < 64) {
System.out.println("A");
dumpMatrix(a);
System.out.println("B");
dumpMatrix(b);
System.out.println("classic");
Classical.mult(c, a, b);
dumpMatrix(c);
System.out.println("strassen");
strassen2.mult(c, a, b);
dumpMatrix(c);
return;
}
for (int i = 0; i <3; ++i) {
timeMultiplies1(a, b, c);
if (n <= 256)
timeMultiplies2( a, b, c);
timeMultiplies3( a, b, c);
}
}
static void timeMultiplies1 (int[][] a, int[][] b, int[][] c) {
final long start = System.currentTimeMillis();
Classical.mult(c, a, b);
final long finish = System.currentTimeMillis();
System.out.println("classical took " + (finish - start) + "ms");
}
static void timeMultiplies2(int[][] a, int[][] b, int[][] c) {
final long start = System.currentTimeMillis();
strassen1.mult(c, a, b);
final long finish = System.currentTimeMillis();
System.out.println("strassen 1 took " + (finish - start) + "ms");
}
static void timeMultiplies3 (int[][] a, int[][] b, int[][] c) {
final long start = System.currentTimeMillis();
strassen2.mult(c, a, b);
final long finish = System.currentTimeMillis();
System.out.println("strassen2 took " + (finish - start) + "ms");
}
static void dumpMatrix (int[][] m) {
for (int[] row : m) {
System.out.print("[\t");
for (int val : row) {
System.out.print(val);
System.out.print('\t');
}
System.out.println(']');
}
}
}
class strassen1{
public String getName () {
return "Strassen(dynamic)";
}
public static int[][] mult (int[][] c, int[][] a, int[][] b) {
return strassenMatrixMultiplication(a, b);
}
public static int [][] strassenMatrixMultiplication(int [][] A, int [][] B) {
int n = A.length;
int [][] result = new int[n][n];
if(n == 1) {
result[0][0] = A[0][0] * B[0][0];
} else {
int [][] A11 = new int[n/2][n/2];
int [][] A12 = new int[n/2][n/2];
int [][] A21 = new int[n/2][n/2];
int [][] A22 = new int[n/2][n/2];
int [][] B11 = new int[n/2][n/2];
int [][] B12 = new int[n/2][n/2];
int [][] B21 = new int[n/2][n/2];
int [][] B22 = new int[n/2][n/2];
divideArray(A, A11, 0 , 0);
divideArray(A, A12, 0 , n/2);
divideArray(A, A21, n/2, 0);
divideArray(A, A22, n/2, n/2);
divideArray(B, B11, 0 , 0);
divideArray(B, B12, 0 , n/2);
divideArray(B, B21, n/2, 0);
divideArray(B, B22, n/2, n/2);
int [][] P1 = strassenMatrixMultiplication(addMatrices(A11, A22), addMatrices(B11, B22));
int [][] P2 = strassenMatrixMultiplication(addMatrices(A21, A22), B11);
int [][] P3 = strassenMatrixMultiplication(A11, subtractMatrices(B12, B22));
int [][] P4 = strassenMatrixMultiplication(A22, subtractMatrices(B21, B11));
int [][] P5 = strassenMatrixMultiplication(addMatrices(A11, A12), B22);
int [][] P6 = strassenMatrixMultiplication(subtractMatrices(A21, A11), addMatrices(B11, B12));
int [][] P7 = strassenMatrixMultiplication(subtractMatrices(A12, A22), addMatrices(B21, B22));
int [][] C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
int [][] C12 = addMatrices(P3, P5);
int [][] C21 = addMatrices(P2, P4);
int [][] C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);
copySubArray(C11, result, 0 , 0);
copySubArray(C12, result, 0 , n/2);
copySubArray(C21, result, n/2, 0);
copySubArray(C22, result, n/2, n/2);
}
return result;
}
public static int [][] addMatrices(int [][] A, int [][] B) {
int n = A.length;
int [][] result = new int[n][n];
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
result[i][j] = A[i][j] + B[i][j];
return result;
}
public static int [][] subtractMatrices(int [][] A, int [][] B) {
int n = A.length;
int [][] result = new int[n][n];
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
result[i][j] = A[i][j] - B[i][j];
return result;
}
public static void divideArray(int[][] parent, int[][] child, int iB, int jB) {
for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
child[i1][j1] = parent[i2][j2];
}
public static void copySubArray(int[][] child, int[][] parent, int iB, int jB) {
for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
parent[i2][j2] = child[i1][j1];
}
}
class strassen2{
public String getName () {
return "Strassen(cached)";
}
static int [][] p1;
static int [][] p2;
static int [][] p3;
static int [][] p4;
static int [][] p5;
static int [][] p6;
static int [][] p7;
static int [][] t0;
static int [][] t1;
public static int[][] mult (int[][] c, int[][] a, int[][] b) {
final int n = c.length;
if (p1 == null || p1.length < n) {
p1 = new int[n/2][n-1];
p2 = new int[n/2][n-1];
p3 = new int[n/2][n-1];
p4 = new int[n/2][n-1];
p5 = new int[n/2][n-1];
p6 = new int[n/2][n-1];
p7 = new int[n/2][n-1];
t0 = new int[n/2][n-1];
t1 = new int[n/2][n-1];
}
mult(c, a, b, 0, 0, n, 0);
return c;
}
public static void mult (int[][] c, int[][] a, int[][] b, int i0, int j0, int n, int offs) {
if(n == 1) {
c[i0][j0] = a[i0][j0] * b[i0][j0];
} else {
final int nBy2 = n/2;
final int i1 = i0 + nBy2;
final int j1 = j0 + nBy2;
// offset applied to 'p' j index so recursive calls don't overwrite data
final int jp0 = offs;
final int jp1 = nBy2 + offs;
// P1 <- (A11 + A22)(B11 + B22)
// T0 <- (A11 + A22), T1 <- (B11 + B22), P1 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i0][j + j0] + b[i + i1][j + j1];
}
}
mult(p1, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P2 <- (A21 + A22)B11
// T0 <- (A21 + A22), T1 <- B11, P2 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i1][j + j0] + a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i0][j + j0];
}
}
mult(p2, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P3 <- A11(B12 - B22)
// T0 <- A11, T1 <- (B12 - B22), P3 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j0];
t1[i + i0][j + jp0] = b[i + i0][j + j1] - b[i + i1][j + j1];
}
}
mult(p3, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P4 <- A22(B21 - B11)
// T0 <- A22, T1 <- (B21 - B11), P4 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i1][j + j0] - b[i + i0][j + j0];
}
}
mult(p4, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P5 <- (A11 + A12) B22
// T0 <- (A11 + A12), T1 <- B22, P5 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i0][j + j1];
t1[i + i0][j + jp0] = b[i + i1][j + j1];
}
}
mult(p5, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P6 <- (A21 - A11)(B11 - B12)
// T0 <- (A21 - A11), T1 <- (B11 - B12), P6 <- T0 * T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i1][j + j0] - a[i + i0][j + j0];
t1[i + i0][j + jp0] = b[i + i0][j + j0] - b[i + i0][j + j1];
}
}
mult(p6, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P7 <- (A12 - A22)(B21 + B22)
// T0 <- (A12 - A22), T1 <- (B21 + B22), P7 <- T0 * T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j1] - a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i1][j + j0] + b[i + i1][j + j1];
}
}
mult(p7, t0, t1, i0, jp0, nBy2, offs + nBy2);
// combine
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
// C11 = P1 + P4 - P5 + P7;
c[i + i0][j + j0] = p1[i + i0][j + jp0] + p4[i + i0][j + jp0] - p5[i + i0][j + jp0] + p7[i + i0][j + jp0];
// C12 = P3 + P5;
c[i + i0][j + j1] = p3[i + i0][j + jp0] + p5[i + i0][j + jp0];
// C21 = P2 + P4;
c[i + i1][j + j0] = p2[i + i0][j + jp0] + p4[i + i0][j + jp0];
// C22 = P1 + P3 - P2 + P6;
c[i + i1][j + j1] = p1[i + i0][j + jp0] + p3[i + i0][j + jp0] - p2[i + i0][j + jp0] + p6[i + i0][j + jp0];
}
}
}
}
void dumpInternal () {
System.out.println("P1");
TestIntMatrixMultiplication.dumpMatrix(p1);
System.out.println("P2");
TestIntMatrixMultiplication.dumpMatrix(p2);
System.out.println("P3");
TestIntMatrixMultiplication.dumpMatrix(p3);
System.out.println("P4");
TestIntMatrixMultiplication.dumpMatrix(p4);
System.out.println("P5");
TestIntMatrixMultiplication.dumpMatrix(p5);
System.out.println("P6");
TestIntMatrixMultiplication.dumpMatrix(p6);
System.out.println("P7");
TestIntMatrixMultiplication.dumpMatrix(p7);
System.out.println("T0");
TestIntMatrixMultiplication.dumpMatrix(t0);
System.out.println("T1");
TestIntMatrixMultiplication.dumpMatrix(t1);
}
}
class Classical{
public String getName () {
return "classic";
}
public static int[][] mult (int[][] c, int[][] a, int[][] b) {
int n = a.length;
for(int i=0; i<n; i++) {
final int[] a_i = a[i];
final int[] c_i = c[i];
for(int j=0; j<n; j++) {
int sum = 0;
for(int k=0; k<n; k++) {
sum += a_i[k] * b[k][j];
}
c_i[j] = sum;
}
}
return c;
}
}
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
data:image/s3,"s3://crabby-images/d5906/d59060df4059a6cc364216c4d63ceec29ef7fe66" alt="扫码二维码加入Web技术交流群"
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(3)
我看到的问题:
1)你的施特拉森乘法始终动态分配内存。这会影响性能。
2)您的施特拉森乘法应该切换到小尺寸的传统乘法,而不是一直递归(尽管这种优化会使您的测试无效)。
3)您的矩阵大小可能太小而无法看到差异。
您应该与几种不同尺寸进行比较。也许是 256、512、1024、2048、4096、8192...然后绘制时间并查看趋势。如果矩阵大小是 2 的所有幂,您可能需要对数尺度的矩阵大小。Strassen
仅对于较大的 N 才更快。大小很大程度上取决于实现。您对经典所做的只是一个基本实现,在现代机器上也不是最佳的。
Issues I see:
1)Your Strassen multiply is dynamically allocating memory all the time. This is going to kill performance.
2)Your Strassen multiply should switch over to conventional multiply for small sizes rather than being recursive all the way down (though this optimization sort of invalidates your test).
3)You're matrix size may simply be too small to see the difference.
You should do comparisons with several different sizes. Perhaps 256, 512, 1024, 2048, 4096, 8192... Then plot the times and look at the trends. You will probably want matrix size on a log scale if it's all powers of 2.
Strassen is only faster for large N. How large will depend a lot on the implementation. What you have done for classical is only a basic implementation and is not optimal on a modern machine either.
除了实现问题之外,我认为您误解了算法的性能。正如 Phkahler 所说,您对算法性能的期望有点偏差。分而治之算法适用于大型输入,因为它们递归地将问题分解为可以更快解决的子问题。
然而,对于小型甚至中等大小的输入,与此拆分操作相关的开销可能会导致算法运行速度(有时慢得多)。通常,像 Strassen 这样的算法的理论分析将包括所谓的“断点”计算。这是分割开销比简单技术更可取的输入大小。
您的代码需要包含对在断点处切换到简单技术的输入大小的检查。
Implementation questions aside, I think you're misunderstanding the algorithm's performance. Like phkahler said, your expectations are a little off for the performance of the algorithm. Divide-and-conquer algorithms work well for large inputs because they recursively break the problem into sub-problems which can be solved more quickly.
However, the overhead associated with this splitting action can cause the algorithm to run (sometimes much) slower for small or even medium-sized inputs. Typically, the theoretical analysis of an algorithm like Strassen will include a so-called "breakpoint" calculation. This is the input size where the overhead of splitting becomes preferable to a naive technique.
Your code needs to include a check on the size of the input that switches to the naive technique at the breakpoint.
写下 Strassen 算法对 2 x 2 矩阵的作用。计算操作数。这个数字绝对是荒谬的。对 2x2 矩阵使用 Strassen 方法是愚蠢的。对于 3 x 3 或 4 x 4 矩阵来说也是如此,而且可能要高得多。
Write down what the Strassen algorithm does for a 2 x 2 matrix. Count the operations. The number is absolutely ridiculous. It's stupid to use Strassen's method for a 2x2 matrix. Same for a 3 x 3, or 4 x 4, matrix and probably quite a way up.