Skip to content

Commit 432724b

Browse files
ifedansoumith
authored andcommitted
Fix torch.where to accept only tensors with same dtypes(CPU) (pytorch#29078)
* Make zeros argument of torch.where same dtype as other argument * Added check for torch.where on CPU that both arguments have same dtype * Changes based on PR comments * Fix flake8 * Fixed test for CUDA * Changes basen on PR comments * Changes based on PR review
1 parent cc98c93 commit 432724b

File tree

4 files changed

+74
-2
lines changed

4 files changed

+74
-2
lines changed

aten/src/ATen/native/Loss.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const T
3131
auto denom = (mag_square1 * mag_square2).sqrt_();
3232
auto cos = prod_sum / denom;
3333

34-
auto zeros = at::zeros_like(target);
34+
auto zeros = at::zeros_like(cos);
3535
auto pos = 1 - cos;
3636
auto neg = (cos - margin).clamp_min_(0);
3737
auto output_pos = at::where(target == 1, pos, zeros);
@@ -67,8 +67,8 @@ Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Ten
6767
}
6868

6969
Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction) {
70-
auto zeros = at::zeros_like(target);
7170
auto output_pos = target * (at::log(target) - input);
71+
auto zeros = at::zeros_like(output_pos);
7272
auto output = at::where(target > 0, output_pos, zeros);
7373
return apply_loss_reduction(output, reduction);
7474
}

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ std::vector<Tensor> where(const Tensor& condition) {
122122
}
123123

124124
Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
125+
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
125126
Tensor ret = at::empty(self.sizes(), self.options());
126127
AT_DISPATCH_ALL_TYPES(ret.scalar_type(), "where_cpu", [&] {
127128
where_cpu<scalar_t>(ret, condition, self, other);

test/test_nn.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7223,6 +7223,38 @@ def test_cosine_embedding_loss_margin_no_reduce(self):
72237223
loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target,
72247224
margin=0.5, reduction='none'))
72257225

7226+
def test_cosine_embedding_loss_with_diff_type(self):
7227+
for device in device_():
7228+
input1 = torch.tensor([[2, 3, 4], [6, 2, 4]], dtype=torch.double, device=device)
7229+
input2 = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
7230+
target = torch.tensor([1, -1], dtype=torch.int, device=device)
7231+
expected = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
7232+
for dt1 in torch.testing.get_all_math_dtypes(device):
7233+
for dt2 in torch.testing.get_all_math_dtypes(device):
7234+
for dt3 in torch.testing.get_all_math_dtypes(device):
7235+
# dt3 is used as dtype for target = [1, -1], so let's skip unsigned type
7236+
if dt3 == torch.uint8:
7237+
continue
7238+
input1 = input1.to(dt1)
7239+
input2 = input2.to(dt2)
7240+
target = target.to(dt3)
7241+
result = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
7242+
self.assertEqual(result.item(), expected.item(), 0.001)
7243+
7244+
def test_kl_div_with_diff_type(self):
7245+
for device in device_():
7246+
input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
7247+
target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device)
7248+
expected = torch.nn.functional.kl_div(input, target)
7249+
for input_dtype in torch.testing.get_all_math_dtypes(device):
7250+
for target_dtype in [torch.float32, torch.float64, torch.float16]:
7251+
if (torch.device(device).type == 'cpu' and target_dtype == torch.float16):
7252+
continue
7253+
input = input.to(input_dtype)
7254+
target = target.to(target_dtype)
7255+
result = torch.nn.functional.kl_div(input, target)
7256+
self.assertEqual(result.item(), expected.item(), 0.001)
7257+
72267258
def test_margin_ranking_loss_no_reduce(self):
72277259
input1 = torch.randn(15).mul_(10).requires_grad_()
72287260
input2 = torch.randn(15).mul_(10).requires_grad_()

test/test_torch.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,45 @@ def test_where_bool_tensor(self):
951951
res = torch.where(a > 0)
952952
self.assertEqual(1, len(res))
953953

954+
def test_where_tensor(self):
955+
def rand_tensor(size, dtype, device):
956+
if dtype.is_floating_point:
957+
return torch.rand(size=size, dtype=dtype, device=device)
958+
elif dtype == torch.uint8:
959+
return torch.randint(1, 5, size=size, dtype=dtype, device=device)
960+
elif dtype == torch.bool:
961+
return torch.randint(0, 1, size=size, dtype=dtype, device=device).bool()
962+
else:
963+
return torch.randint(-5, 5, size=size, dtype=dtype, device=device)
964+
965+
def get_tensor(size, dtype, device, contiguous):
966+
if not contiguous and len(size) < 2:
967+
raise RuntimeError("Unable to generate non contiguous tensor with size < 2")
968+
t = rand_tensor(size, dtype, device)
969+
if contiguous:
970+
return t
971+
else:
972+
return t.transpose(0, 1)
973+
974+
height = 5
975+
width = 5
976+
for device in torch.testing.get_all_device_types():
977+
for dt1 in torch.testing.get_all_math_dtypes(device):
978+
for dt2 in torch.testing.get_all_math_dtypes(device):
979+
for contiguous in [True, False]:
980+
x1 = get_tensor((height, width), dt1, device, contiguous)
981+
x2 = get_tensor((height, width), dt2, device, contiguous)
982+
if dt1 != dt2:
983+
self.assertRaisesRegex(RuntimeError, "expected scalar type", lambda: torch.where(x1 == 1, x1, x2))
984+
else:
985+
if x1.is_floating_point():
986+
condition = (x1 < 0.5)
987+
else:
988+
condition = (x1 == 1)
989+
expected = condition.to(x1.dtype) * x1 + (~condition).to(x2.dtype) * x2
990+
result = torch.where(condition, x1, x2)
991+
self.assertEqual(expected, result)
992+
954993
def test_all_any_with_dim(self):
955994
def test(x):
956995
r1 = x.prod(dim=0, keepdim=False).byte()

0 commit comments

Comments
 (0)