-
Notifications
You must be signed in to change notification settings - Fork 474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: added range onnx import #1834
Conversation
62ae485
to
dfd0290
Compare
dfd0290
to
99f0f56
Compare
a8787df
to
d1f883a
Compare
d1f883a
to
9dcc11a
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1834 +/- ##
==========================================
+ Coverage 86.42% 86.44% +0.02%
==========================================
Files 761 762 +1
Lines 87987 88111 +124
==========================================
+ Hits 76041 76166 +125
+ Misses 11946 11945 -1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job! 🙂
I have one minor request over form when checking the types, otherwise looks good to me.
let start = match &self.start { | ||
Type::Scalar(s) => { | ||
let name = s.name.clone(); | ||
quote! { #name } | ||
} | ||
_ => panic!("Start must be a scalar"), | ||
}; | ||
|
||
let end = match &self.end { | ||
Type::Scalar(s) => { | ||
let name = s.name.clone(); | ||
quote! { #name } | ||
} | ||
_ => panic!("End must be a scalar"), | ||
}; | ||
|
||
let step = match &self.step { | ||
Type::Scalar(s) => { | ||
let name = s.name.clone(); | ||
quote! { #name } | ||
} | ||
_ => panic!("Step must be a scalar"), | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the three fields are all supposed to be scalar, I suggest we change the node fields to:
pub struct RangeNode {
pub start: ScalarType,
pub end: ScalarType,
pub step: ScalarType,
pub output: TensorType,
}
and the logic to check the extract the scalar types can go in the range_conversion
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well yes, scalars are simply represented as 0-dim tensors. Basically what I am suggesting is just to move the type handling outside of the forward node method to the range_conversion
function instead, a bit like this:
fn range_conversion(node: Node) -> RangeNode {
let output = node.outputs.first().unwrap().to_tensor_type();
let start = node.inputs.first().unwrap().to_type();
let end = node.inputs.get(1).unwrap().to_type();
let step = node.inputs.get(2).unwrap().to_type();
// Check/convert start, end & step to ScalarType just as it was done in the forward method
// ...
RangeNode::new(start, end, step, output)
}
Should be good |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! 🚀
00f70a9
to
0a2530c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thank you for your contribution!
* feat: added range onnx import * fix: range input types
Checklist
run-checks all
script has been executed.Changes
Add the range import to onnx
Testing
cargo test range
Related Issue
#1714