Skip to content

Commit

Permalink
fix hessian service replacement reused nodes mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Sep 16, 2024
1 parent ea7563d commit ecadf9c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,10 @@ def fetch_hessian(self,
f"{hessian_scores_request.target_nodes}.")

# Replace node in reused target nodes with a representing node from the 'reuse group'.
for n in hessian_scores_request.target_nodes:
if n.reuse_group:
rep_node = self._get_representing_of_reuse_group(n)
hessian_scores_request.target_nodes.remove(n)
if rep_node not in hessian_scores_request.target_nodes:
hessian_scores_request.target_nodes.append(rep_node)
hessian_scores_request.target_nodes = [
self._get_representing_of_reuse_group(node) if node.reuse else node
for node in hessian_scores_request.target_nodes
]

# Ensure the saved info has the required number of approximations
self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def __init__(self):

def forward(self, inp):
x = self.conv1(inp)
x1 = self.bn1(x)
x1 = self.relu(x1)
x1 = self.relu(x)
x_split = torch.split(x1, split_size_or_sections=4, dim=-1)
x1 = self.conv1(x_split[0])
x2 = x_split[1]
Expand Down

0 comments on commit ecadf9c

Please sign in to comment.