Skip to content

Commit f84d3f8

Browse files
authored
fix(dipu,python): print device index for cuda tensor (#755)
* fix(dipu,python): print device index for cuda tensor Example: ```python import torch import torch_dipu torch.cuda.set_device(0) a = torch.tensor([1, 2, 3]).cuda() print(a) ``` Expected output: ``` tensor([1, 2, 3], device='cuda:0') ``` Current output (before this commit): ``` tensor([1, 2, 3], device='cuda') ``` * test(dipu,python): enhance test_python_device.py
1 parent badbe46 commit f84d3f8

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

dipu/.clang-format

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ IncludeCategories:
1313
- Regex: '^("|<)Python\.h'
1414
Priority: 50
1515
CaseSensitive: false
16-
- Regex: '^("|<)(frameobject|structmember)\.h'
16+
- Regex: '^("|<)(descrobject|frameobject|object|structmember)\.h'
1717
Priority: 50
1818
SortPriority: 51
1919
CaseSensitive: false
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2024, DeepLink.
2+
import torch
3+
import torch_dipu
4+
from torch_dipu.testing._internal.common_utils import TestCase, run_tests
5+
6+
7+
class TestPythonDevice(TestCase):
8+
def test_cpu(self):
9+
a = torch.tensor([1, 2, 3])
10+
self.assertEqual(str(a.device), "cpu")
11+
self.assertEqual(repr(a.device), "device(type='cpu')")
12+
self.assertEqual(str(a), "tensor([1, 2, 3])")
13+
self.assertEqual(repr(a), "tensor([1, 2, 3])")
14+
15+
def test_cuda(self):
16+
device_index = 0 # NOTE: maybe 0 is not available, fix me if this happens
17+
torch.cuda.set_device(device_index)
18+
a = torch.tensor([1, 2, 3]).cuda()
19+
self.assertEqual(str(a.device), f"cuda:{device_index}")
20+
self.assertEqual(repr(a.device), f"device(type='cuda', index={device_index})")
21+
self.assertEqual(str(a), f"tensor([1, 2, 3], device='cuda:{device_index}')")
22+
self.assertEqual(repr(a), f"tensor([1, 2, 3], device='cuda:{device_index}')")
23+
24+
25+
if __name__ == "__main__":
26+
run_tests()

dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
// Copyright (c) 2023, DeepLink.
22

33
#include <array>
4-
#include <cstring>
5-
#include <limits>
4+
#include <cstdint>
65
#include <sstream>
76

8-
#include <ATen/Device.h>
9-
#include <c10/util/Exception.h>
7+
#include <c10/core/Device.h>
8+
#include <c10/core/DeviceType.h>
109
#include <c10/util/Optional.h>
1110
#include <torch/csrc/Device.h>
1211
#include <torch/csrc/Exceptions.h>
13-
#include <torch/csrc/Export.h>
14-
#include <torch/csrc/python_headers.h>
15-
#include <torch/csrc/utils/object_ptr.h>
16-
#include <torch/csrc/utils/pybind.h>
17-
#include <torch/csrc/utils/python_arg_parser.h>
1812
#include <torch/csrc/utils/python_numbers.h>
1913
#include <torch/csrc/utils/python_strings.h>
2014

21-
#include <structmember.h>
15+
#include <Python.h>
16+
#include <descrobject.h>
17+
#include <object.h>
18+
#include <pybind11/pybind11.h>
19+
#include <pybind11/pytypes.h>
2220

2321
#include <csrc_dipu/base/basedef.h>
2422

@@ -72,7 +70,7 @@ PyObject* DIPU_THPDevice_repr(THPDevice* self) {
7270

7371
PyObject* DIPU_THPDevice_str(THPDevice* self) {
7472
std::ostringstream oss;
75-
oss << _get_dipu_python_type(self->device);
73+
oss << at::Device(_get_dipu_python_type(self->device), self->device.index());
7674
return THPUtils_packString(oss.str().c_str());
7775
}
7876

0 commit comments

Comments
 (0)