Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable constant values as inputs to linear layers in PyTorch #1076

Merged
merged 3 commits into from
May 19, 2024

Conversation

lapid92
Copy link
Contributor

@lapid92 lapid92 commented May 19, 2024

Enable constant values as inputs to linear layers in PyTorch

Pull Request Description:

Checklist before requesting a review:

  • I set the appropriate labels on the pull request.
  • I have added/updated the release note draft (if necessary).
  • I have updated the documentation to reflect my changes (if necessary).
  • All function and files are well documented.
  • All function and classes have type hints.
  • There is a licenses in all file.
  • The function and variable names are informative.
  • I have checked for code duplications.
  • I have added new unittest (if necessary).

@@ -42,7 +42,8 @@ def __init__(self,
reuse_group: str = None,
quantization_attr: Dict[str, Any] = None,
has_activation: bool = True,
is_custom: bool = False
is_custom: bool = False,
has_positional_weights: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed. See below

@@ -96,6 +99,15 @@ def get_has_activation(self):
"""
return self.has_activation

def get_has_positional_weights(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to has_positional_weights which is a @property, that checks whether there are positional weights (i.e. check for integers in the keys of the weights dictionary).

graph.get_in_stats_collector(n),
fw_impl=fw_impl)
if n.has_positional_weights:
for candidate_qc in n.candidates_quantization_cfg:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment

@@ -67,7 +67,8 @@ def substitute(self,
return graph

# Check if convolution and residual satisfy the collapsing conditions, otherwise skip substitution
if len(graph.get_next_nodes(first_node)) > 1 or len(graph.get_prev_nodes(second_node)) != 2:
if (len(graph.get_next_nodes(first_node)) > 1 or len(graph.get_prev_nodes(first_node)) < 1 or
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it clearer to write (not blabla==1)?

@@ -33,7 +33,7 @@ def node_builder(n: BaseNode) -> Module:

framework_attr = copy.copy(n.framework_attr)
node_instance = n.type(**framework_attr)
node_instance.load_state_dict({k: torch.tensor(v) for k, v in n.weights.items()}, strict=False)
node_instance.load_state_dict({k: torch.tensor(v) for k, v in n.weights.items() if isinstance(k, str)}, strict=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment to explain why only str keys are used

@lapid92 lapid92 merged commit ec5ef87 into sony:main May 19, 2024
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants