Skip to content

Commit 6e9a9a1

Browse files
committed
[Feature] SAC Trainer
ghstack-source-id: 2db05ee Pull-Request: #3191
1 parent d1edaab commit 6e9a9a1

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

torchrl/trainers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
CountFramesLog,
1010
LogReward,
1111
LogScalar,
12+
LogTiming,
1213
LogValidationReward,
1314
mask_batch,
1415
OptimizerHook,
@@ -29,6 +30,7 @@
2930
"CountFramesLog",
3031
"LogReward",
3132
"LogScalar",
33+
"LogTiming",
3234
"LogValidationReward",
3335
"mask_batch",
3436
"OptimizerHook",

torchrl/trainers/algorithms/sac.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
log_observations: bool = False,
126126
target_net_updater: TargetNetUpdater | None = None,
127127
async_collection: bool = False,
128+
log_timings: bool = False,
128129
) -> None:
129130
warnings.warn(
130131
"SACTrainer is an experimental/prototype feature. The API may change in future versions. "
@@ -151,6 +152,7 @@ def __init__(
151152
log_interval=log_interval,
152153
save_trainer_file=save_trainer_file,
153154
async_collection=async_collection,
155+
log_timings=log_timings,
154156
)
155157
self.replay_buffer = replay_buffer
156158
self.async_collection = async_collection

0 commit comments

Comments
 (0)