Skip to content

Commit d892fe3

Browse files
Xmaster6yvmoens
authored andcommitted
trackio histograms and str
1 parent 51969d2 commit d892fe3

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

test/test_loggers.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,15 @@ def test_log_scalar(self, steps, trackio_logger):
481481
step=steps[i] if steps else None,
482482
)
483483

484+
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
485+
def test_log_str(self, steps, trackio_logger):
486+
for i in range(3):
487+
trackio_logger.log_str(
488+
name="foo",
489+
value="bar",
490+
step=steps[i] if steps else None,
491+
)
492+
484493
def test_log_video(self, trackio_logger):
485494
torch.manual_seed(0)
486495

@@ -509,10 +518,14 @@ def test_log_hparams(self, trackio_logger, config):
509518
for key, value in config.items():
510519
assert trackio_logger.experiment.config[key] == value
511520

512-
def test_log_histogram(self, trackio_logger):
513-
with pytest.raises(NotImplementedError):
514-
data = torch.randn(10)
515-
trackio_logger.log_histogram("hist", data, step=0, bins=2)
521+
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
522+
def test_log_histogram(self, steps, trackio_logger):
523+
torch.manual_seed(0)
524+
for i in range(3):
525+
data = torch.randn(100)
526+
trackio_logger.log_histogram(
527+
"hist", data, step=steps[i] if steps else None, bins=10
528+
)
516529

517530

518531
if __name__ == "__main__":

torchrl/record/loggers/trackio.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,24 @@ def __repr__(self) -> str:
128128
return f"TrackioLogger(experiment={self.experiment.__repr__()})"
129129

130130
def log_histogram(self, name: str, data: Sequence, **kwargs):
131-
raise NotImplementedError("Logging histograms in trackio is not permitted.")
131+
"""Add histogram to log.
132+
133+
Args:
134+
name (str): Data identifier
135+
data (torch.Tensor, numpy.ndarray): Values to build histogram
136+
137+
Keyword Args:
138+
step (int): Global step value to record
139+
bins (int): Number of bins to use for the histogram
140+
141+
"""
142+
import trackio
143+
144+
num_bins = kwargs.pop("bins", None)
145+
step = kwargs.pop("step", None)
146+
self.experiment.log(
147+
{name: trackio.Histogram(data, num_bins=num_bins)}, step=step
148+
)
132149

133150
def log_str(self, name: str, value: str, step: int | None = None) -> None:
134151
"""Logs a string value to trackio using a table format for better visualization.
@@ -143,8 +160,4 @@ def log_str(self, name: str, value: str, step: int | None = None) -> None:
143160

144161
# Create a table with a single row
145162
table = trackio.Table(columns=["text"], data=[[value]])
146-
147-
if step is not None:
148-
self.experiment.log({name: value}, step=step)
149-
else:
150-
self.experiment.log({name: table})
163+
self.experiment.log({name: table}, step=step)

0 commit comments

Comments
 (0)