Skip to content

Commit 048459c

Browse files
committed
Test + Use *WITH_LOCATION macro for propagating the error.
1 parent 08c5ecd commit 048459c

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

test/test_operations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,6 +2445,12 @@ def test_isneginf_no_fallback(self):
24452445
t = t.to(torch.float16)
24462446
self._test_no_fallback(torch.isneginf, (t,))
24472447

2448+
def test_construct_large_tensor_raises_error(self):
2449+
a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device())
2450+
2451+
with self.assertRaisesRegex(RuntimeError, r"Out of memory allocating \d* bytes"):
2452+
a.cpu()
2453+
24482454

24492455
class MNISTComparator(nn.Module):
24502456

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -515,17 +515,14 @@ PjRtComputationClient::TransferFromDevice(absl::Span<const DataPtr> handles) {
515515
ABSL_CHECK(pjrt_data->buffer != nullptr)
516516
<< "PjRt buffer is null in " << __FUNCTION__;
517517

518-
xla::Literal& literal =
519-
literals.emplace_back(host_output_shape(pjrt_data->buffer.get()));
518+
xla::Literal& literal = literals.emplace_back(
519+
xla::Literal(host_output_shape(pjrt_data->buffer.get()),
520+
/* allocate_arrays= */ false));
520521
futures.push_back(pjrt_data->buffer->ToLiteral(&literal));
521522

522523
total_size += literal.size_bytes();
523524
}
524-
auto joined = xla::JoinFutures(futures);
525-
XLA_RETURN_IF_ERROR(
526-
joined.Await(),
527-
absl::StrCat(__FUNCTION__,
528-
": failed to await future from buffer to literal."));
525+
XLA_RETURN_IF_ERROR_WITH_LOCATION(xla::JoinFutures(futures).Await());
529526
InboundDataMetric()->AddSample(total_size);
530527

531528
return literals;

0 commit comments

Comments
 (0)