1. 개요
행렬의 곱연산을 수행함에 있어 행렬이 커지면 커질수록 연산에 속도는 기하급수적으로 증가할 수밖에 없는 구조이다. 행렬의 곱은 각 원소를 곱한 후에 나온 결과를 더해 최종 행렬이 생성이 되며 행렬의 크기가 커질수록 곱하기 연산은 증가 할 수밖에 없다. cpu에 구조상 더하기 연산이 빠르기 때문에 스트라센 알고리즘에서는 곱하기 연산을 더하기 연산으로 치환하여 계산하도록 알고리즘을 보안했다.
2. 알고리즘 방법
아래 A, B 행렬과 두 행의 곱의 결과 C가 있다고 했을 때
일반적인 행렬의 곱은 다음과 같으며, 총 8번의 곱셉과 네번의 덧셈으로 연산된다.
스트라센에서 행렬의 곱셉을 더하기 연산으로 풀어 각 원소를 구할 수 있는 M이라는 연산 행렬로 표현한다.
이러한 M행렬은 일곱번의 곱셈과 10번의 덧셈으로 연산으로 나타낼 수 있으며 아래 와 같이 표현한다.
3. 결론
스트라센의 경우 행렬의 곱셉을 하기 위해서는 정사각행렬에 대해서만 처리가 가능하다. 그렇지 않을 경우에는 행렬을 정사각 행렬로 변경하는 작업이 필요하다. 또한, 특정 단계에서는 행렬의 곱이 더 빠른 구간이 있으며 스트라센 행렬에서는 최적의 행렬 크기에서는 일반곱으로 행렬을 풀어나가는 방법이 있다. 스트라센 알고리즘에서 또 눈여겨 볼 부분은 연산으로 사용하는 M행렬을 구하는 부분에서도 행렬의 곱이 쓰인다는 점이다. 행렬의 곱은 스트라센으로 풀어나가는 알고리즘이기 때문에 M1을 예로 들면 M1 := (A + A) strassen (A + B) 이런식으로 풀어 쓸수 있다. 결국에는 재귀적인 호출을 통해 스트라센을 구해 나가는 방식을 이용하는 알고리즘인것 이다. 분할 정복알고리즘과 동일하며, M에서는 각 행렬을 작은 단위로 분할하며 최종 C행렬을 구하기 위해서는 분할된 원소를 재조립하는 과정으로 최종 행렬을 얻어낼 수 있다.
4. 소스코드
public class Strassen {
public static void main(String[] args) {
int n = 1024;
int[][] x = initMetrix(n); int[][] y = initMetrix(n);
int[][] nomalResult = Strassen.metrixMul(n, x, y);
Strassen strassen = new Strassen();
int[][] strassenReslut = strassen.excuteStrassen(x, y);
boolean checkMetrix = true; for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (nomalResult[i][j] != strassenReslut[i][j]) { checkMetrix = false; } }
}
System.out.println("결과 : " + checkMetrix);
}
public static int[][] initMetrix(int n) {
Random r = new Random();
int[][] resultMetrix = new int [n][n];
for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { resultMetrix[i][j] = r.nextInt(30); } } return resultMetrix; }
public int[][] excuteStrassen(int[][] metrixX, int[][] metrixY) {
// 스트라센의 경우 n*n 행렬로 연산 int n = metrixX.length;
// 임계 차원 보다 작을 경우 기존 메트릭스 곱으로 풀이 if (n <= 2) { return metrixMul(n, metrixX, metrixY); }
// 4 등분 int rank = n / 2;
// 배열 분해 int[][] a11 = subMetrix(rank, 0, 0, metrixX); int[][] a12 = subMetrix(rank, 0, rank, metrixX); int[][] a21 = subMetrix(rank, rank, 0, metrixX); int[][] a22 = subMetrix(rank, rank, rank, metrixX); int[][] b11 = subMetrix(rank, 0, 0, metrixY); int[][] b12 = subMetrix(rank, 0, rank, metrixY); int[][] b21 = subMetrix(rank, rank, 0, metrixY); int[][] b22 = subMetrix(rank, rank, rank, metrixY);
int[][] m1 = excuteStrassen(metrixSum(a11, a22), metrixSum(b11, b22)); // m1=(a11+a11)(b11+b22) int[][] m2 = excuteStrassen(metrixSum(a21, a22), b11); // m2=(a21+a22)b11 int[][] m3 = excuteStrassen(a11, metrixSub(b12, b22)); // m3=a11(b12-b22) int[][] m4 = excuteStrassen(a22, metrixSub(b21, b11)); // m4=a22(b21-b11) int[][] m5 = excuteStrassen(metrixSum(a11, a12), b22); // m5=(a11+a12)b22 int[][] m6 = excuteStrassen(metrixSub(a21, a11), metrixSum(b11, b12)); // m6=(a21-a11)(b11+b12) int[][] m7 = excuteStrassen(metrixSub(a12, a22), metrixSum(b21, b22)); // m7=(a12-a22)(a21+b22)
// 결과 생성 int[][] c11 = metrixSum(metrixSub(metrixSum(m1, m4), m5), m7); // c11 = m1 + m4 - m5 + m7 int[][] c12 = metrixSum(m3, m5); // c12 = m3 + m5 int[][] c21 = metrixSum(m2, m4); // c21 = m2 + m4 int[][] c22 = metrixSum(metrixSub(metrixSum(m1, m3), m2), m6); // c22 = m1 + m3 - m2 + m6
// 결합 return combin(c11, c12, c21, c22); }
private int[][] combin(int[][] c11, int[][] c12, int[][] c21, int[][] c22) { int n = c11.length;
int[][] resultMetrix = new int [n*2][n*2];
for (int i = 0; i < n; i ++) { for (int j = 0; j < n; j++) { resultMetrix[i][j] = c11[i][j]; // 11 resultMetrix[i][j + n] = c12[i][j]; // 12 resultMetrix[i + n][j] = c21[i][j]; // 21 resultMetrix[i + n][j + n] = c22[i][j]; // 22 } } return resultMetrix; }
private int[][] subMetrix(int n, int startX, int startY, int[][] metrix) {
int[][] subMetirx = new int[n][n];
for (int i = 0, x = startX; i < n; i++, x++) { for (int j = 0, y = startY; j < n; j++, y++) { subMetirx[i][j] = metrix[x][y]; } } return subMetirx; }
private int[][] metrixSum(int[][] metrixX, int[][] metrixY) { int n = metrixX.length; int[][] metrixResult = new int[n][n];
for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { metrixResult[i][j] = metrixX[i][j] + metrixY[i][j]; } }
return metrixResult; }
private int[][] metrixSub(int[][] metrixX, int[][] metrixY) { int n = metrixX.length; int[][] metrixResult = new int[n][n];
for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { metrixResult[i][j] = metrixX[i][j] - metrixY[i][j]; } } return metrixResult; }
public static int[][] metrixMul(int n, int[][] metrixX, int[][] metrixY) {
int [][] result = new int[n][n];
for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { for (int k = 0; k < n; k++) { result[i][j] += metrixX[i][k] * metrixY[k][j]; } }
}
return result; } } |
5. 최종 결론
아마 제공된 코드를 수행하더라도 유익한 수행 시간을 얻어 낼 수 는 없을 것이다. 결국엔 자바에서나 C에서도 행렬을 구하는 과정에 산술연산이외에 추가로 필요한 작업들이 들어갔기 때문이다. 좀더 복잡한 과정을 통해 메모리 초기화 과정을 제거 할 수는 있으나 소스코드가 복잡해지기 때문에 이 쯤에서 멈추도록 하겠다.
'ALGORITHM' 카테고리의 다른 글
최장 증가 부분 수열(LIS) (0) | 2021.07.11 |
---|---|
[정렬] 머지정렬 (merge sort) (0) | 2016.11.02 |
[정렬] 퀵정렬 (quick sort) (0) | 2016.11.02 |
[정렬] 버블정렬 (Bubble sort) (0) | 2016.11.01 |
[정렬] 선택정렬 (selection Sort) (0) | 2016.11.01 |