Skip to content

Commit 2dfad7e

Browse files
Add matrix multiplication with double[][] and unit tests (#6417)
* MatrixMultiplication.java created and updated. * Add necessary comment to MatrixMultiplication.java * Create MatrixMultiplicationTest.java * method for 2 by 2 matrix multiplication is created * Use assertMatrixEquals(), otherwise there can be error due to floating point arithmetic errors * assertMatrixEquals method created and updated * method created for 3by2 matrix multiply with 2by1 matrix * method created for null matrix multiplication * method for test matrix dimension error * method for test empty matrix input * testMultiply3by2and2by1 test case updated * Check for empty matrices part updated * Updated Unit test coverage * files updated * clean the code * clean the code * Updated files with google-java-format * Updated files * Updated files * Updated files * Updated files * Add reference links and complexities * Add test cases for 1by1 matrix and non-rectangular matrix * Add reference links and complexities --------- Co-authored-by: Deniz Altunkapan <[email protected]>
1 parent 2722b0e commit 2dfad7e

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package com.thealgorithms.matrix;
2+
3+
/**
4+
* This class provides a method to perform matrix multiplication.
5+
*
6+
* <p>Matrix multiplication takes two 2D arrays (matrices) as input and
7+
* produces their product, following the mathematical definition of
8+
* matrix multiplication.
9+
*
10+
* <p>For more details:
11+
* https://www.geeksforgeeks.org/java/java-program-to-multiply-two-matrices-of-any-size/
12+
* https://en.wikipedia.org/wiki/Matrix_multiplication
13+
*
14+
* <p>Time Complexity: O(n^3) – where n is the dimension of the matrices
15+
* (assuming square matrices for simplicity).
16+
*
17+
* <p>Space Complexity: O(n^2) – for storing the result matrix.
18+
*
19+
*
20+
* @author Nishitha Wihala Pitigala
21+
*
22+
*/
23+
24+
public final class MatrixMultiplication {
25+
private MatrixMultiplication() {
26+
}
27+
28+
/**
29+
* Multiplies two matrices.
30+
*
31+
* @param matrixA the first matrix rowsA x colsA
32+
* @param matrixB the second matrix rowsB x colsB
33+
* @return the product of the two matrices rowsA x colsB
34+
* @throws IllegalArgumentException if the matrices cannot be multiplied
35+
*/
36+
public static double[][] multiply(double[][] matrixA, double[][] matrixB) {
37+
// Check the input matrices are not null
38+
if (matrixA == null || matrixB == null) {
39+
throw new IllegalArgumentException("Input matrices cannot be null");
40+
}
41+
42+
// Check for empty matrices
43+
if (matrixA.length == 0 || matrixB.length == 0 || matrixA[0].length == 0 || matrixB[0].length == 0) {
44+
throw new IllegalArgumentException("Input matrices must not be empty");
45+
}
46+
47+
// Validate the matrix dimensions
48+
if (matrixA[0].length != matrixB.length) {
49+
throw new IllegalArgumentException("Matrices cannot be multiplied: incompatible dimensions.");
50+
}
51+
52+
int rowsA = matrixA.length;
53+
int colsA = matrixA[0].length;
54+
int colsB = matrixB[0].length;
55+
56+
// Initialize the result matrix with zeros
57+
double[][] result = new double[rowsA][colsB];
58+
59+
// Perform matrix multiplication
60+
for (int i = 0; i < rowsA; i++) {
61+
for (int j = 0; j < colsB; j++) {
62+
for (int k = 0; k < colsA; k++) {
63+
result[i][j] += matrixA[i][k] * matrixB[k][j];
64+
}
65+
}
66+
}
67+
return result;
68+
}
69+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package com.thealgorithms.matrix;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
import static org.junit.jupiter.api.Assertions.assertTrue;
6+
7+
import org.junit.jupiter.api.Test;
8+
9+
public class MatrixMultiplicationTest {
10+
11+
private static final double EPSILON = 1e-9; // for floating point comparison
12+
13+
@Test
14+
void testMultiply1by1() {
15+
double[][] matrixA = {{1.0}};
16+
double[][] matrixB = {{2.0}};
17+
double[][] expected = {{2.0}};
18+
19+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
20+
assertMatrixEquals(expected, result);
21+
}
22+
23+
@Test
24+
void testMultiply2by2() {
25+
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
26+
double[][] matrixB = {{5.0, 6.0}, {7.0, 8.0}};
27+
double[][] expected = {{19.0, 22.0}, {43.0, 50.0}};
28+
29+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
30+
assertMatrixEquals(expected, result); // Use custom method due to floating point issues
31+
}
32+
33+
@Test
34+
void testMultiply3by2and2by1() {
35+
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}};
36+
double[][] matrixB = {{7.0}, {8.0}};
37+
double[][] expected = {{23.0}, {53.0}, {83.0}};
38+
39+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
40+
assertMatrixEquals(expected, result);
41+
}
42+
43+
@Test
44+
void testMultiplyNonRectangularMatrices() {
45+
double[][] matrixA = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}};
46+
double[][] matrixB = {{7.0, 8.0}, {9.0, 10.0}, {11.0, 12.0}};
47+
double[][] expected = {{58.0, 64.0}, {139.0, 154.0}};
48+
49+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
50+
assertMatrixEquals(expected, result);
51+
}
52+
53+
@Test
54+
void testNullMatrixA() {
55+
double[][] b = {{1, 2}, {3, 4}};
56+
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(null, b));
57+
}
58+
59+
@Test
60+
void testNullMatrixB() {
61+
double[][] a = {{1, 2}, {3, 4}};
62+
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, null));
63+
}
64+
65+
@Test
66+
void testMultiplyNull() {
67+
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
68+
double[][] matrixB = null;
69+
70+
Exception exception = assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(matrixA, matrixB));
71+
72+
String expectedMessage = "Input matrices cannot be null";
73+
String actualMessage = exception.getMessage();
74+
75+
assertTrue(actualMessage.contains(expectedMessage));
76+
}
77+
78+
@Test
79+
void testIncompatibleDimensions() {
80+
double[][] a = {{1.0, 2.0}};
81+
double[][] b = {{1.0, 2.0}};
82+
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
83+
}
84+
85+
@Test
86+
void testEmptyMatrices() {
87+
double[][] a = new double[0][0];
88+
double[][] b = new double[0][0];
89+
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
90+
}
91+
92+
private void assertMatrixEquals(double[][] expected, double[][] actual) {
93+
assertEquals(expected.length, actual.length, "Row count mismatch");
94+
for (int i = 0; i < expected.length; i++) {
95+
assertEquals(expected[i].length, actual[i].length, "Column count mismatch at row " + i);
96+
for (int j = 0; j < expected[i].length; j++) {
97+
assertEquals(expected[i][j], actual[i][j], EPSILON, "Mismatch at (" + i + "," + j + ")");
98+
}
99+
}
100+
}
101+
}

0 commit comments

Comments
 (0)