Skip to content

Commit 39f5305

Browse files
committed
ensure numpy
1 parent 96578f1 commit 39f5305

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

test/test_loggers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def test_log_video(self, trackio_logger):
488488
# C - number of image channels (e.g. 3 for RGB), H, W - image dimensions.
489489
# the first 64 frames are black and the next 64 are white
490490
video = torch.cat(
491-
(torch.zeros(128, 1, 32, 32), torch.full((128, 1, 32, 32), 255))
491+
(torch.zeros(128, 3, 32, 32), torch.full((128, 3, 32, 32), 255))
492492
)
493493
video = video[None, :]
494494
trackio_logger.log_video(

torchrl/record/loggers/trackio.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from collections.abc import Sequence
1010

11+
import numpy as np
12+
1113
from torch import Tensor
1214

1315
from .common import Logger
@@ -96,7 +98,11 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
9698
fps = kwargs.pop("fps", self.video_fps)
9799
format = kwargs.pop("format", "mp4")
98100
self.experiment.log(
99-
{name: trackio.Video(video, fps=fps, format=format)},
101+
{
102+
name: trackio.Video(
103+
video.numpy().astype(np.uint8), fps=fps, format=format
104+
)
105+
},
100106
**kwargs,
101107
)
102108

0 commit comments

Comments
 (0)