Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix automatic conversion of constants to float32 and add support for index constant representation #1074

Merged
merged 2 commits into from
May 19, 2024

Conversation

lapid92
Copy link
Contributor

@lapid92 lapid92 commented May 19, 2024

Fix automatic conversion of constants to float32 and add support for index constant representation

Pull Request Description:

Checklist before requesting a review:

  • I set the appropriate labels on the pull request.
  • I have added/updated the release note draft (if necessary).
  • I have updated the documentation to reflect my changes (if necessary).
  • All function and files are well documented.
  • All function and classes have type hints.
  • There is a licenses in all file.
  • The function and variable names are informative.
  • I have checked for code duplications.
  • I have added new unittest (if necessary).

@@ -38,11 +38,13 @@ def set_model(model: torch.nn.Module, train_mode: bool = False):
model.to(device)


def to_torch_tensor(tensor):
def to_torch_tensor(tensor,
torch_dtype=np.float32):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it's numpy_type, no torch_dtype. Anyway, tou can just call it dtype

@lapid92 lapid92 merged commit 1f3ff26 into sony:main May 19, 2024
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants