Skip to content

Commit 115eb8d

Browse files
authored
feat(dipu): add schema: ones/zeros/zero_ (#742)
* add ones/zeros/zero_ * add tests * add schema * update * opt * update diopi main * for python black format
1 parent c2d8a3d commit 115eb8d

File tree

5 files changed

+69
-1
lines changed

5 files changed

+69
-1
lines changed

dipu/SupportedDiopiFunctions.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ diopiNormalInp
192192
diopiNormalScalarTensor
193193
diopiNormalTensor
194194
diopiNormalTensorScalar
195+
diopiOnes
195196
diopiPolar
196197
diopiPow
197198
diopiPowInp
@@ -250,3 +251,5 @@ diopiUpsampleLinearBackward
250251
diopiUpsampleNearest
251252
diopiUpsampleNearestBackward
252253
diopiWhere
254+
diopiZeroInp
255+
diopiZeros

dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,6 +2279,29 @@
22792279
- schema: "atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
22802280
interface: diopiAtan(ctx, out, self)
22812281

2282+
- schema: "ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)"
2283+
interface: diopiOnes(ctx, out, size)
2284+
2285+
- schema: "ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"
2286+
custom_code_at_the_beginning: |
2287+
c10::TensorOptions option;
2288+
auto shape = c10::asIntArrayRefUnchecked(size);
2289+
auto out = nodispatch::empty(shape, option.dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory));
2290+
interface: diopiOnes(ctx, out, size)
2291+
2292+
- schema: "zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"
2293+
custom_code_at_the_beginning: |
2294+
c10::TensorOptions option;
2295+
auto shape = c10::asIntArrayRefUnchecked(size);
2296+
auto out = nodispatch::empty(shape, option.dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory));
2297+
interface: diopiZeroInp(ctx, out)
2298+
2299+
- schema: "zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)"
2300+
interface: diopiZeros(ctx, out, size)
2301+
2302+
- schema: "zero_(Tensor(a!) self) -> Tensor(a!)"
2303+
interface: diopiZeroInp(ctx, self)
2304+
22822305
- schema: "im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor"
22832306
size_attr: [kernel_size, stride, padding, dilation]
22842307
custom_code_at_the_beginning: |
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024, DeepLink.
2+
import torch
3+
import torch_dipu
4+
from torch_dipu.testing._internal.common_utils import TestCase, run_tests
5+
6+
7+
class TestOnes(TestCase):
8+
def test_ones(self):
9+
device = torch.device("dipu")
10+
size = [5, 6]
11+
x = torch.ones(size=size)
12+
y = torch.ones(size=size, device=device)
13+
self.assertEqual(x, y.cpu(), exact_dtype=True)
14+
15+
16+
if __name__ == "__main__":
17+
run_tests()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) 2024, DeepLink.
2+
import torch
3+
import torch_dipu
4+
from torch_dipu.testing._internal.common_utils import TestCase, run_tests
5+
6+
7+
class TestZeros(TestCase):
8+
def test_zeros(self):
9+
device = torch.device("dipu")
10+
size = [5, 6]
11+
x = torch.zeros(size=size)
12+
y = torch.zeros(size=size, device=device)
13+
self.assertEqual(x, y.cpu(), exact_dtype=True)
14+
15+
def test_zero_(self):
16+
size = [3, 5]
17+
y = torch.randn(size=size).cuda()
18+
x = y.cpu()
19+
x.zero_()
20+
y.zero_()
21+
self.assertEqual(x, y.cpu(), exact_dtype=True)
22+
23+
24+
if __name__ == "__main__":
25+
run_tests()

0 commit comments

Comments
 (0)