Skip to content

Commit

Permalink
Add shape checking to all ntk measurements
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 16, 2024
1 parent e387616 commit 22c6cb6
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions papyrus/measurements/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,16 @@ def apply(self, ntk: np.ndarray) -> np.ndarray:
np.ndarray
The trace of the NTK
"""
if ntk.shape[0] != ntk.shape[1]:
raise ValueError(
"To compute the trace of the NTK, the NTK matrix must"
f" be a square matrix, but got a matrix of shape {ntk.shape}."
)
if len(ntk.shape) != 2:
raise ValueError(
"To compute the trace of the NTK, the NTK matrix must"
f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}."
)
return compute_trace(ntk, normalize=self.normalise)


Expand Down Expand Up @@ -339,11 +349,14 @@ def apply(self, ntk: np.ndarray) -> np.ndarray:
"""
# Assert that the NTK is a square matrix
if ntk.shape[0] != ntk.shape[1]:
raise ValueError("The NTK matrix must be a square matrix.")
raise ValueError(
"To compute the entropy of the NTK, the NTK matrix must"
f" be a square matrix, but got a matrix of shape {ntk.shape}."
)
if len(ntk.shape) != 2:
raise ValueError(
"The NTK matrix must be a tensor of rank 2, but got a tensor of rank"
f" {len(ntk.shape)}."
"To compute the entropy of the NTK, the NTK matrix must"
f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}."
)
# Compute the von Neumann entropy of the NTK
return compute_von_neumann_entropy(
Expand Down Expand Up @@ -408,6 +421,16 @@ def apply(self, ntk: np.ndarray) -> np.ndarray:
np.ndarray
Self-entropy of the NTK.
"""
if ntk.shape[0] != ntk.shape[1]:
raise ValueError(
"To compute the self-entropy of the NTK, the NTK matrix must"
f" be a square matrix, but got a matrix of shape {ntk.shape}."
)
if len(ntk.shape) != 2:
raise ValueError(
"To compute the self-entropy of the NTK, the NTK matrix must"
f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}."
)
distribution = compute_grammian_diagonal_distribution(gram_matrix=ntk)
return compute_shannon_entropy(distribution, effective=self.effective)

Expand Down Expand Up @@ -470,6 +493,16 @@ def apply(self, ntk: np.ndarray) -> np.ndarray:
np.ndarray
The magnitude distribution of the NTK
"""
if ntk.shape[0] != ntk.shape[1]:
raise ValueError(
"To compute the magnitude distribution of the NTK, the NTK matrix must"
f" be a square matrix, but got a matrix of shape {ntk.shape}."
)
if len(ntk.shape) != 2:
raise ValueError(
"To compute the magnitude distribution of the NTK, the NTK matrix must"
f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}."
)
return compute_grammian_diagonal_distribution(gram_matrix=ntk)


Expand Down Expand Up @@ -535,6 +568,16 @@ def apply(self, ntk: np.ndarray) -> np.ndarray:
np.ndarray
The eigenvalues of the NTK
"""
if ntk.shape[0] != ntk.shape[1]:
raise ValueError(
"To compute the eigenvalues of the NTK, the NTK matrix must"
f" be a square matrix, but got a matrix of shape {ntk.shape}."
)
if len(ntk.shape) != 2:
raise ValueError(
"To compute the eigenvalues of the NTK, the NTK matrix must"
f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}."
)
return compute_hermitian_eigensystem(ntk, normalize=self.normalize)[0]


Expand Down

0 comments on commit 22c6cb6

Please sign in to comment.