diff --git a/src/main/java/com/thealgorithms/matrix/MatrixMultiplication.java b/src/main/java/com/thealgorithms/matrix/MatrixMultiplication.java new file mode 100644 index 000000000000..6467a438577b --- /dev/null +++ b/src/main/java/com/thealgorithms/matrix/MatrixMultiplication.java @@ -0,0 +1,69 @@ +package com.thealgorithms.matrix; + +/** + * This class provides a method to perform matrix multiplication. + * + *

Matrix multiplication takes two 2D arrays (matrices) as input and + * produces their product, following the mathematical definition of + * matrix multiplication. + * + *

For more details: + * https://www.geeksforgeeks.org/java/java-program-to-multiply-two-matrices-of-any-size/ + * https://en.wikipedia.org/wiki/Matrix_multiplication + * + *

Time Complexity: O(n^3) – where n is the dimension of the matrices + * (assuming square matrices for simplicity). + * + *

Space Complexity: O(n^2) – for storing the result matrix. + * + * + * @author Nishitha Wihala Pitigala + * + */ + +public final class MatrixMultiplication { + private MatrixMultiplication() { + } + + /** + * Multiplies two matrices. + * + * @param matrixA the first matrix rowsA x colsA + * @param matrixB the second matrix rowsB x colsB + * @return the product of the two matrices rowsA x colsB + * @throws IllegalArgumentException if the matrices cannot be multiplied + */ + public static double[][] multiply(double[][] matrixA, double[][] matrixB) { + // Check the input matrices are not null + if (matrixA == null || matrixB == null) { + throw new IllegalArgumentException("Input matrices cannot be null"); + } + + // Check for empty matrices + if (matrixA.length == 0 || matrixB.length == 0 || matrixA[0].length == 0 || matrixB[0].length == 0) { + throw new IllegalArgumentException("Input matrices must not be empty"); + } + + // Validate the matrix dimensions + if (matrixA[0].length != matrixB.length) { + throw new IllegalArgumentException("Matrices cannot be multiplied: incompatible dimensions."); + } + + int rowsA = matrixA.length; + int colsA = matrixA[0].length; + int colsB = matrixB[0].length; + + // Initialize the result matrix with zeros + double[][] result = new double[rowsA][colsB]; + + // Perform matrix multiplication + for (int i = 0; i < rowsA; i++) { + for (int j = 0; j < colsB; j++) { + for (int k = 0; k < colsA; k++) { + result[i][j] += matrixA[i][k] * matrixB[k][j]; + } + } + } + return result; + } +} diff --git a/src/test/java/com/thealgorithms/matrix/MatrixMultiplicationTest.java b/src/test/java/com/thealgorithms/matrix/MatrixMultiplicationTest.java new file mode 100644 index 000000000000..9463d33a18cb --- /dev/null +++ b/src/test/java/com/thealgorithms/matrix/MatrixMultiplicationTest.java @@ -0,0 +1,101 @@ +package com.thealgorithms.matrix; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class MatrixMultiplicationTest { + + private static final double EPSILON = 1e-9; // for floating point comparison + + @Test + void testMultiply1by1() { + double[][] matrixA = {{1.0}}; + double[][] matrixB = {{2.0}}; + double[][] expected = {{2.0}}; + + double[][] result = MatrixMultiplication.multiply(matrixA, matrixB); + assertMatrixEquals(expected, result); + } + + @Test + void testMultiply2by2() { + double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}}; + double[][] matrixB = {{5.0, 6.0}, {7.0, 8.0}}; + double[][] expected = {{19.0, 22.0}, {43.0, 50.0}}; + + double[][] result = MatrixMultiplication.multiply(matrixA, matrixB); + assertMatrixEquals(expected, result); // Use custom method due to floating point issues + } + + @Test + void testMultiply3by2and2by1() { + double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}; + double[][] matrixB = {{7.0}, {8.0}}; + double[][] expected = {{23.0}, {53.0}, {83.0}}; + + double[][] result = MatrixMultiplication.multiply(matrixA, matrixB); + assertMatrixEquals(expected, result); + } + + @Test + void testMultiplyNonRectangularMatrices() { + double[][] matrixA = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}; + double[][] matrixB = {{7.0, 8.0}, {9.0, 10.0}, {11.0, 12.0}}; + double[][] expected = {{58.0, 64.0}, {139.0, 154.0}}; + + double[][] result = MatrixMultiplication.multiply(matrixA, matrixB); + assertMatrixEquals(expected, result); + } + + @Test + void testNullMatrixA() { + double[][] b = {{1, 2}, {3, 4}}; + assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(null, b)); + } + + @Test + void testNullMatrixB() { + double[][] a = {{1, 2}, {3, 4}}; + assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, null)); + } + + @Test + void testMultiplyNull() { + double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}}; + double[][] matrixB = null; + + Exception exception = assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(matrixA, matrixB)); + + String expectedMessage = "Input matrices cannot be null"; + String actualMessage = exception.getMessage(); + + assertTrue(actualMessage.contains(expectedMessage)); + } + + @Test + void testIncompatibleDimensions() { + double[][] a = {{1.0, 2.0}}; + double[][] b = {{1.0, 2.0}}; + assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b)); + } + + @Test + void testEmptyMatrices() { + double[][] a = new double[0][0]; + double[][] b = new double[0][0]; + assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b)); + } + + private void assertMatrixEquals(double[][] expected, double[][] actual) { + assertEquals(expected.length, actual.length, "Row count mismatch"); + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i].length, actual[i].length, "Column count mismatch at row " + i); + for (int j = 0; j < expected[i].length; j++) { + assertEquals(expected[i][j], actual[i][j], EPSILON, "Mismatch at (" + i + "," + j + ")"); + } + } + } +}