Skip to content

[Torch] add aten.bilinear op decomposing #3931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

dixinzhou
Copy link
Contributor

This PR adds support for aten.bilinear op. The aten.bilinear op is decomposed to aten._trilinear and aten.add according to https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Linear.cpp#L712.

Additionally, this PR fixes a bug in aten._trilinear op decomposition about tensor shape mismatch.

@dixinzhou
Copy link
Contributor Author

Hi @vivekkhandelwal1. I don't have permission to request review yet. Do you mind taking a look at this change? Thanks!

@dixinzhou
Copy link
Contributor Author

Hi @vivekkhandelwal1, a gentle reminder to review this pull request when you get a chance. Thank you!

@vivekkhandelwal1
Copy link
Collaborator

Hi @vivekkhandelwal1, a gentle reminder to review this pull request when you get a chance. Thank you!

@dixinzhou Is this PR still active and needs a review?

@dixinzhou
Copy link
Contributor Author

Hi @vivekkhandelwal1, a gentle reminder to review this pull request when you get a chance. Thank you!

@dixinzhou Is this PR still active and needs a review?

Hi @vivekkhandelwal1, thanks for the reply. Yes, this PR still needs to be reviewed and upstreamed.

Comment on lines 2120 to 2126
int64_t size1 = inputSize1[i];
int64_t size2 = inputSize2[i];
if (size1 == kUnknownSize || size2 == kUnknownSize) {
mulShape.push_back(kUnknownSize);
} else {
mulShape.push_back(size1 == 1 ? size2 : size1);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain this? I would suggest add a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added in the code.

if (isa<Torch::NoneType>(bias.getType())) {
rewriter.replaceOp(op, trilinear);
return success();
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Else is not required since the if ends with a return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The else is removed, and feedback resolved.

Comment on lines 7707 to 7735
// generate `aten._trilinear` op
unsigned n = inputType1.getSizes().size() - 1;
Type listOfInt =
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>());
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(n));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(n + 1));
Value two =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(n + 2));
Value expand1 = rewriter.create<PrimListConstructOp>(
loc, listOfInt, SmallVector<Value>{zero, two});
Value expand2 = rewriter.create<PrimListConstructOp>(
loc, listOfInt, SmallVector<Value>{zero, one});
SmallVector<Value> expandWeightValue;
for (unsigned i = 0; i < n; i++) {
Value value =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
expandWeightValue.push_back(value);
}
Value expandw =
rewriter.create<PrimListConstructOp>(loc, listOfInt, expandWeightValue);
Value sumdim = rewriter.create<PrimListConstructOp>(
loc, listOfInt, SmallVector<Value>{one, two});
Value constOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value trilinear = rewriter.create<Aten_TrilinearOp>(
loc, op.getType(), input1, weight, input2, expand1, expandw, expand2,
sumdim, constOne);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not understand this part. Can you please add a comment explaining this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added in the code.

@dixinzhou
Copy link
Contributor Author

Hi @vivekkhandelwal1. I have updated the pull request per your review feedback. Do you mind taking another look at this change? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants