diff --git a/papyrus/measurements/measurements.py b/papyrus/measurements/measurements.py index 72fee55..b1d5654 100644 --- a/papyrus/measurements/measurements.py +++ b/papyrus/measurements/measurements.py @@ -272,6 +272,16 @@ def apply(self, ntk: np.ndarray) -> np.ndarray: np.ndarray The trace of the NTK """ + if ntk.shape[0] != ntk.shape[1]: + raise ValueError( + "To compute the trace of the NTK, the NTK matrix must" + f" be a square matrix, but got a matrix of shape {ntk.shape}." + ) + if len(ntk.shape) != 2: + raise ValueError( + "To compute the trace of the NTK, the NTK matrix must" + f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}." + ) return compute_trace(ntk, normalize=self.normalise) @@ -339,11 +349,14 @@ def apply(self, ntk: np.ndarray) -> np.ndarray: """ # Assert that the NTK is a square matrix if ntk.shape[0] != ntk.shape[1]: - raise ValueError("The NTK matrix must be a square matrix.") + raise ValueError( + "To compute the entropy of the NTK, the NTK matrix must" + f" be a square matrix, but got a matrix of shape {ntk.shape}." + ) if len(ntk.shape) != 2: raise ValueError( - "The NTK matrix must be a tensor of rank 2, but got a tensor of rank" - f" {len(ntk.shape)}." + "To compute the entropy of the NTK, the NTK matrix must" + f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}." ) # Compute the von Neumann entropy of the NTK return compute_von_neumann_entropy( @@ -408,6 +421,16 @@ def apply(self, ntk: np.ndarray) -> np.ndarray: np.ndarray Self-entropy of the NTK. """ + if ntk.shape[0] != ntk.shape[1]: + raise ValueError( + "To compute the self-entropy of the NTK, the NTK matrix must" + f" be a square matrix, but got a matrix of shape {ntk.shape}." + ) + if len(ntk.shape) != 2: + raise ValueError( + "To compute the self-entropy of the NTK, the NTK matrix must" + f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}." + ) distribution = compute_grammian_diagonal_distribution(gram_matrix=ntk) return compute_shannon_entropy(distribution, effective=self.effective) @@ -470,6 +493,16 @@ def apply(self, ntk: np.ndarray) -> np.ndarray: np.ndarray The magnitude distribution of the NTK """ + if ntk.shape[0] != ntk.shape[1]: + raise ValueError( + "To compute the magnitude distribution of the NTK, the NTK matrix must" + f" be a square matrix, but got a matrix of shape {ntk.shape}." + ) + if len(ntk.shape) != 2: + raise ValueError( + "To compute the magnitude distribution of the NTK, the NTK matrix must" + f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}." + ) return compute_grammian_diagonal_distribution(gram_matrix=ntk) @@ -535,6 +568,16 @@ def apply(self, ntk: np.ndarray) -> np.ndarray: np.ndarray The eigenvalues of the NTK """ + if ntk.shape[0] != ntk.shape[1]: + raise ValueError( + "To compute the eigenvalues of the NTK, the NTK matrix must" + f" be a square matrix, but got a matrix of shape {ntk.shape}." + ) + if len(ntk.shape) != 2: + raise ValueError( + "To compute the eigenvalues of the NTK, the NTK matrix must" + f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}." + ) return compute_hermitian_eigensystem(ntk, normalize=self.normalize)[0]