forked from GMvandeVen/continual-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreplayer.py
32 lines (23 loc) · 1.01 KB
/
replayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import abc
from torch import nn
class Replayer(nn.Module, metaclass=abc.ABCMeta):
'''Abstract module for a classifier/generator that can be trained with replay.
Initiates ability to reset state of optimizer between tasks.'''
def __init__(self):
super().__init__()
# Optimizer (and whether it needs to be reset)
self.optimizer = None
self.optim_type = "adam"
#--> self.[optim_type] <str> name of optimizer, relevant if optimizer should be reset for every task
self.optim_list = []
#--> self.[optim_list] <list>, if optimizer should be reset after each task, provide list of required <dicts>
# Replay: temperature for distillation loss (and whether it should be used)
self.replay_targets = "hard" # hard|soft
self.KD_temp = 2.
def _device(self):
return next(self.parameters()).device
def _is_on_cuda(self):
return next(self.parameters()).is_cuda
@abc.abstractmethod
def forward(self, x):
pass