-
Notifications
You must be signed in to change notification settings - Fork 555
[Torch] add aten.bilinear op decomposing #3931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi @vivekkhandelwal1. I don't have permission to request review yet. Do you mind taking a look at this change? Thanks! |
Hi @vivekkhandelwal1, a gentle reminder to review this pull request when you get a chance. Thank you! |
@dixinzhou Is this PR still active and needs a review? |
Hi @vivekkhandelwal1, thanks for the reply. Yes, this PR still needs to be reviewed and upstreamed. |
int64_t size1 = inputSize1[i]; | ||
int64_t size2 = inputSize2[i]; | ||
if (size1 == kUnknownSize || size2 == kUnknownSize) { | ||
mulShape.push_back(kUnknownSize); | ||
} else { | ||
mulShape.push_back(size1 == 1 ? size2 : size1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please explain this? I would suggest add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment added in the code.
if (isa<Torch::NoneType>(bias.getType())) { | ||
rewriter.replaceOp(op, trilinear); | ||
return success(); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Else is not required since the if ends with a return.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The else
is removed, and feedback resolved.
// generate `aten._trilinear` op | ||
unsigned n = inputType1.getSizes().size() - 1; | ||
Type listOfInt = | ||
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()); | ||
Value zero = | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(n)); | ||
Value one = | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(n + 1)); | ||
Value two = | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(n + 2)); | ||
Value expand1 = rewriter.create<PrimListConstructOp>( | ||
loc, listOfInt, SmallVector<Value>{zero, two}); | ||
Value expand2 = rewriter.create<PrimListConstructOp>( | ||
loc, listOfInt, SmallVector<Value>{zero, one}); | ||
SmallVector<Value> expandWeightValue; | ||
for (unsigned i = 0; i < n; i++) { | ||
Value value = | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)); | ||
expandWeightValue.push_back(value); | ||
} | ||
Value expandw = | ||
rewriter.create<PrimListConstructOp>(loc, listOfInt, expandWeightValue); | ||
Value sumdim = rewriter.create<PrimListConstructOp>( | ||
loc, listOfInt, SmallVector<Value>{one, two}); | ||
Value constOne = | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)); | ||
Value trilinear = rewriter.create<Aten_TrilinearOp>( | ||
loc, op.getType(), input1, weight, input2, expand1, expandw, expand2, | ||
sumdim, constOne); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not understand this part. Can you please add a comment explaining this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment added in the code.
Hi @vivekkhandelwal1. I have updated the pull request per your review feedback. Do you mind taking another look at this change? Thanks! |
This PR adds support for
aten.bilinear
op. Theaten.bilinear
op is decomposed toaten._trilinear
andaten.add
according to https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Linear.cpp#L712.Additionally, this PR fixes a bug in
aten._trilinear
op decomposition about tensor shape mismatch.