-
Notifications
You must be signed in to change notification settings - Fork 18
/
setting.py
997 lines (846 loc) · 42.3 KB
/
setting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
import itertools
from dataclasses import dataclass
from pathlib import Path
from typing import ClassVar, Dict, List, Optional, Type, TypeVar, Union
import gym
import numpy as np
import torch
from continuum.datasets import (
CIFAR10,
CIFAR100,
EMNIST,
KMNIST,
MNIST,
QMNIST,
CIFARFellowship,
FashionMNIST,
ImageNet100,
ImageNet1000,
MNISTFellowship,
Synbols,
_ContinuumDataset,
)
from continuum.scenarios import ClassIncremental, _BaseScenario
from continuum.tasks import TaskSet, concat, split_train_val
from gym import spaces
from simple_parsing import choice, field, list_field
from torch import Tensor
from torch.utils.data import ConcatDataset, Dataset, Subset
import wandb
from sequoia.common.config import Config
from sequoia.common.gym_wrappers import RenderEnvWrapper, TransformObservation
from sequoia.common.gym_wrappers.convert_tensors import add_tensor_support
from sequoia.common.spaces import Sparse
from sequoia.common.transforms import Compose, Transforms
from sequoia.settings.assumptions.continual import ContinualAssumption
from sequoia.settings.base import Method
from sequoia.settings.sl.setting import SLSetting
from sequoia.settings.sl.wrappers import MeasureSLPerformanceWrapper
from sequoia.utils.generic_functions import concatenate
from sequoia.utils.logging_utils import get_logger
from sequoia.utils.utils import flag
from .environment import ContinualSLEnvironment, ContinualSLTestEnvironment
from .envs import (
CTRL_INSTALLED,
CTRL_STREAMS,
base_action_spaces,
base_observation_spaces,
base_reward_spaces,
get_action_space,
get_observation_space,
get_reward_space,
)
from .objects import Actions, ActionSpace, Observations, ObservationSpace, Rewards, RewardSpace
from .results import ContinualSLResults
from .wrappers import relabel
logger = get_logger(__name__)
EnvironmentType = TypeVar("EnvironmentType", bound=ContinualSLEnvironment)
available_datasets = {
c.__name__.lower(): c
for c in [
CIFARFellowship,
MNISTFellowship,
ImageNet100,
ImageNet1000,
CIFAR10,
CIFAR100,
EMNIST,
KMNIST,
MNIST,
QMNIST,
FashionMNIST,
Synbols,
]
# "synbols": Synbols,
# "synbols_font": partial(Synbols, task="fonts"),
}
if CTRL_INSTALLED:
available_datasets.update(dict(zip(CTRL_STREAMS, CTRL_STREAMS)))
@dataclass
class ContinualSLSetting(SLSetting, ContinualAssumption):
"""Continuous, Task-Agnostic, Continual Supervised Learning.
This is *currently* the most "general" Supervised Continual Learning setting in
Sequoia.
- Data distribution changes smoothly over time.
- Smooth transitions between "tasks"
- No information about task boundaries or task identity (no task IDs)
- Maximum of one 'epoch' through the environment.
"""
# Class variables that hold the 'base' observation/action/reward spaces for the
# available datasets.
base_observation_spaces: ClassVar[Dict[str, gym.Space]] = base_observation_spaces
base_action_spaces: ClassVar[Dict[str, gym.Space]] = base_action_spaces
base_reward_spaces: ClassVar[Dict[str, gym.Space]] = base_reward_spaces
# (NOTE: commenting out SLSetting.Observations as it is the same class
# as Setting.Observations, and we want a consistent method resolution order.
Observations: ClassVar[Type[Observations]] = Observations
Actions: ClassVar[Type[Actions]] = Actions
Rewards: ClassVar[Type[Rewards]] = Rewards
ObservationSpace: ClassVar[Type[ObservationSpace]] = ObservationSpace
Environment: ClassVar[Type[SLSetting.Environment]] = ContinualSLEnvironment[
Observations, Actions, Rewards
]
Results: ClassVar[Type[ContinualSLResults]] = ContinualSLResults
# Class variable holding a dict of the names and types of all available
# datasets.
# TODO: Issue #43: Support other datasets than just classification
available_datasets: ClassVar[Dict[str, Type[_ContinuumDataset]]] = available_datasets
# A continual dataset to use. (Should be taken from the continuum package).
dataset: str = choice(available_datasets.keys(), default="mnist")
# Transformations to use. See the Transforms enum for the available values.
transforms: List[Transforms] = list_field(
Transforms.to_tensor,
# BUG: The input_shape given to the Model doesn't have the right number
# of channels, even if we 'fixed' them here. However the images are fine
# after.
Transforms.three_channels,
Transforms.channels_first_if_needed,
)
# Either number of classes per task, or a list specifying for
# every task the amount of new classes.
increment: Union[int, List[int]] = list_field(
2, type=int, nargs="*", alias="n_classes_per_task"
)
# The scenario number of tasks.
# If zero, defaults to the number of classes divied by the increment.
nb_tasks: int = 0
# A different task size applied only for the first task.
# Desactivated if `increment` is a list.
initial_increment: int = 0
# An optional custom class order, used for NC.
class_order: Optional[List[int]] = None
# Either number of classes per task, or a list specifying for
# every task the amount of new classes (defaults to the value of
# `increment`).
test_increment: Optional[Union[List[int], int]] = None
# A different task size applied only for the first test task.
# Desactivated if `test_increment` is a list. Defaults to the
# value of `initial_increment`.
test_initial_increment: Optional[int] = None
# An optional custom class order for testing, used for NC.
# Defaults to the value of `class_order`.
test_class_order: Optional[List[int]] = None
# Wether task boundaries are smooth or not.
smooth_task_boundaries: bool = flag(True)
# Wether the context (task) variable is stationary or not.
stationary_context: bool = flag(False)
# Wether tasks share the same action space or not.
# TODO: This will probably be moved into a different assumption.
shared_action_space: Optional[bool] = None
# TODO: Need to put num_workers in only one place.
batch_size: int = field(default=32, cmd=False)
num_workers: int = field(default=4, cmd=False)
# When True, a Monitor-like wrapper will be applied to the training environment
# and monitor the 'online' performance during training. Note that in SL, this will
# also cause the Rewards (y) to be withheld until actions are passed to the `send`
# method of the Environment.
monitor_training_performance: bool = flag(False)
train_datasets: List[Dataset] = field(
default_factory=list, cmd=False, repr=False, to_dict=False
)
val_datasets: List[Dataset] = field(default_factory=list, cmd=False, repr=False, to_dict=False)
test_datasets: List[Dataset] = field(default_factory=list, cmd=False, repr=False, to_dict=False)
def __post_init__(self):
super().__post_init__()
# assert not self.has_setup_fit
# Test values default to the same as train.
self.test_increment = self.test_increment or self.increment
self.test_initial_increment = self.test_initial_increment or self.initial_increment
self.test_class_order = self.test_class_order or self.class_order
# TODO: For now we assume a fixed, equal number of classes per task, for
# sake of simplicity. We could take out this assumption, but it might
# make things a bit more complicated.
if isinstance(self.increment, list) and len(self.increment) == 1:
self.increment = self.increment[0]
if isinstance(self.test_increment, list) and len(self.test_increment) == 1:
self.test_increment = self.test_increment[0]
assert isinstance(self.increment, int)
assert isinstance(self.test_increment, int)
# The 'scenarios' for train and test from continuum. (ClassIncremental for now).
self.train_cl_loader: Optional[_BaseScenario] = None
self.test_cl_loader: Optional[_BaseScenario] = None
self.train_cl_dataset: Optional[_ContinuumDataset] = None
self.test_cl_dataset: Optional[_ContinuumDataset] = None
# This will be set by the Experiment, or passed to the `apply` method.
# TODO: This could be a bit cleaner.
self.config: Config
# Default path to which the datasets will be downloaded.
self.data_dir: Optional[Path] = None
self.train_env: ContinualSLEnvironment = None # type: ignore
self.val_env: ContinualSLEnvironment = None # type: ignore
self.test_env: ContinualSLEnvironment = None # type: ignore
# BUG: These `has_setup_fit`, `has_setup_test`, `has_prepared_data` properties
# aren't working correctly: they get set before the call to the function has
# been executed, making it impossible to check those values from inside those
# functions.
self._has_prepared_data = False
self._has_setup_fit = False
self._has_setup_test = False
if CTRL_INSTALLED and self.dataset in CTRL_STREAMS:
import ctrl
from ctrl.tasks.task_generator import TaskGenerator
from .envs import CTRL_NB_TASKS
self.nb_tasks = self.nb_tasks or CTRL_NB_TASKS[self.dataset]
if self.dataset == "s_long" and not self.nb_tasks:
warnings.warn(
RuntimeWarning(
f"Limiting the scenario to 100 tasks for now when using 's_long' stream."
)
)
self.nb_tasks = 100
task_generator: TaskGenerator = ctrl.get_stream(self.dataset, seed=42)
# Get the train/val/test splits from the tasks.
for task_dataset in itertools.islice(task_generator, self.nb_tasks):
train_dataset = task_dataset.datasets[task_dataset.split_names.index("Train")]
val_dataset = task_dataset.datasets[task_dataset.split_names.index("Val")]
test_dataset = task_dataset.datasets[task_dataset.split_names.index("Test")]
self.train_datasets.append(train_dataset)
self.val_datasets.append(val_dataset)
self.test_datasets.append(test_dataset)
## NOTE: Not sure this is a good idea, because we might easily mix the train/val
## and test splits between different runs! Actually, now that I think about it,
## I need to make sure that this isn't happening already with Avalanche!
# if self.datasets:
# if any(self.train_datasets, self.val_datasets, self.test_datasets):
# raise RuntimeError(
# f"When passing your own datasets to the setting, you have to pass "
# f"either `datasets` or all three of `train_datasets`, "
# f"`val_datasets` and `test_datasets`."
# )
# self.train_datasets = []
# self.val_datasets = []
# self.test_datasets = []
# rng = np.random.default_rng(self.config.seed if self.config else 123)
# for dataset in datasets:
# n = len(dataset)
# n_train_val = int(n * 0.8)
# n_test = n - n_train_val
# n_train = int(n_train_val * 0.8)
# n_valid = n_train_val - n_train
# train_val_dataset, test_dataset = random_split(
# dataset, [n_train_val, n_test], generator=rng,
# )
# train_dataset, val_dataset = random_split(
# train_val_dataset, [n_train, n_valid], generator=rng,
# )
# self.train_datasets.append(train_dataset)
# self.val_datasets.append(val_dataset)
# self.test_datasets.append(test_dataset)
if any([self.train_datasets, self.val_datasets, self.test_datasets]):
if not all([self.train_datasets, self.val_datasets, self.test_datasets]):
raise RuntimeError(
f"When passing your own datasets to the setting, you have to pass "
f"`train_datasets`, `val_datasets` and `test_datasets`."
)
self.nb_tasks = len(self.train_datasets)
if not (len(self.val_datasets) == len(self.test_datasets) == self.nb_tasks):
raise RuntimeError(
f"When passing your own datasets to the setting, you need to pass "
f"The same number of train/valid and test datasets for now."
)
# FIXME: For now, setting `self.dataset` to None, because it has a default
# of 'mnist'. Should probably make it a required argument instead.
self.dataset = None
# x_shape = self.train_datasets[0][0][0].shape
# self.observation_space.x.shape = x_shape
# assert False, (x_shape, self.observation_space)
# Note: Using the same name as in the RL Setting for now, since that's where
# this feature of passing the "envs" for each task was first added.
self._using_custom_envs_foreach_task: bool = bool(self.train_datasets)
# TODO: Remove this
if self.dataset in self.base_action_spaces:
if isinstance(self.action_space, spaces.Discrete):
base_action_space = self.base_action_spaces[self.dataset]
n_classes = base_action_space.n
self.class_order = self.class_order or list(range(n_classes))
if self.nb_tasks:
self.increment = n_classes // self.nb_tasks
if not self.nb_tasks:
base_action_space = self.base_action_spaces[self.dataset]
if isinstance(base_action_space, spaces.Discrete):
self.nb_tasks = base_action_space.n // self.increment
assert self.nb_tasks != 0, self.nb_tasks
def apply(
self, method: Method["ContinualSLSetting"], config: Config = None
) -> ContinualSLResults:
"""Apply the given method on this setting to producing some results."""
# TODO: It still isn't super clear what should be in charge of creating
# the config, and how to create it, when it isn't passed explicitly.
self.config = config or self._setup_config(method)
assert self.config is not None
method.configure(setting=self)
# Run the main loop (defined in ContinualAssumption).
# Basically does the following:
# 1. Call method.fit(train_env, valid_env)
# 2. Test the method on test_env.
# Return the results, as reported by the test environment.
results: ContinualSLResults = super().main_loop(method)
method.receive_results(self, results=results)
return results
def train_dataloader(
self, batch_size: int = 32, num_workers: Optional[int] = 4
) -> EnvironmentType:
if not self.has_prepared_data:
self.prepare_data()
if not self.has_setup_fit:
self.setup("fit")
if self.train_env:
self.train_env.close()
batch_size = batch_size if batch_size is not None else self.batch_size
num_workers = num_workers if num_workers is not None else self.num_workers
# NOTE: ATM the dataset here doesn't have any transforms. We add the transforms after the
# dataloader below using the TransformObservations wrapper. This isn't ideal.
dataset = self._make_train_dataset()
# TODO: Add some kind of Wrapper around the dataset to make it
# semi-supervised?
env = self.Environment(
dataset,
hide_task_labels=(not self.task_labels_at_train_time),
observation_space=self.observation_space,
action_space=self.action_space,
reward_space=self.reward_space,
Observations=self.Observations,
Actions=self.Actions,
Rewards=self.Rewards,
pin_memory=True,
batch_size=batch_size,
num_workers=num_workers,
drop_last=self.drop_last,
shuffle=False,
one_epoch_only=(not self.known_task_boundaries_at_train_time),
)
if self.config.render:
# Add a wrapper that calls 'env.render' at each step?
env = RenderEnvWrapper(env)
train_transforms = Compose(self.transforms + self.train_transforms)
if train_transforms:
env = TransformObservation(env, f=train_transforms)
if self.config.device:
# TODO: Put this before or after the image transforms?
from sequoia.common.gym_wrappers.convert_tensors import ConvertToFromTensors
env = ConvertToFromTensors(env, device=self.config.device)
# env = TransformObservation(env, f=partial(move, device=self.config.device))
# env = TransformReward(env, f=partial(move, device=self.config.device))
if self.monitor_training_performance:
env = MeasureSLPerformanceWrapper(
env,
first_epoch_only=True,
wandb_prefix=f"Train/",
)
# NOTE: Quickfix for the 'dtype' of the TypedDictSpace perhaps getting lost
# when transforms don't propagate the 'dtype' field.
env.observation_space.dtype = self.Observations
self.train_env = env
return self.train_env
def val_dataloader(
self, batch_size: int = 32, num_workers: Optional[int] = 4
) -> EnvironmentType:
if not self.has_prepared_data:
self.prepare_data()
if not self.has_setup_validate:
self.setup("validate")
if self.val_env:
self.val_env.close()
batch_size = batch_size if batch_size is not None else self.batch_size
num_workers = num_workers if num_workers is not None else self.num_workers
dataset = self._make_val_dataset()
# TODO: Add some kind of Wrapper around the dataset to make it
# semi-supervised?
# TODO: Change the reward and action spaces to also use objects.
env = self.Environment(
dataset,
hide_task_labels=(not self.task_labels_at_train_time),
observation_space=self.observation_space,
action_space=self.action_space,
reward_space=self.reward_space,
Observations=self.Observations,
Actions=self.Actions,
Rewards=self.Rewards,
pin_memory=True,
drop_last=self.drop_last,
batch_size=batch_size,
num_workers=num_workers,
one_epoch_only=(not self.known_task_boundaries_at_train_time),
)
# TODO: If wandb is enabled, then add customized Monitor wrapper (with
# IterableWrapper as an additional subclass). There would then be a lot of
# overlap between such a Monitor and the current TestEnvironment.
if self.config.render:
# Add a wrapper that calls 'env.render' at each step?
env = RenderEnvWrapper(env)
# NOTE: The transforms from `self.transforms` (the 'base' transforms) were
# already added when creating the datasets and the CL scenario.
val_transforms = self.transforms + self.val_transforms
if val_transforms:
env = TransformObservation(env, f=val_transforms)
if self.config.device:
# TODO: Put this before or after the image transforms?
from sequoia.common.gym_wrappers.convert_tensors import ConvertToFromTensors
env = ConvertToFromTensors(env, device=self.config.device)
# env = TransformObservation(env, f=partial(move, device=self.config.device))
# env = TransformReward(env, f=partial(move, device=self.config.device))
# NOTE: We don't measure online performance on the validation set.
# if self.monitor_training_performance:
# env = MeasureSLPerformanceWrapper(
# env,
# first_epoch_only=True,
# wandb_prefix=f"Train/Task {self.current_task_id}",
# )
# NOTE: Quickfix for the 'dtype' of the TypedDictSpace perhaps getting lost
# when transforms don't propagate the 'dtype' field.
env.observation_space.dtype = self.Observations
self.val_env = env
return self.val_env
def test_dataloader(
self, batch_size: int = None, num_workers: int = None
) -> ContinualSLEnvironment[Observations, Actions, Rewards]:
"""Returns a Continual SL Test environment."""
if not self.has_prepared_data:
self.prepare_data()
if not self.has_setup_test:
self.setup("test")
batch_size = batch_size if batch_size is not None else self.batch_size
num_workers = num_workers if num_workers is not None else self.num_workers
dataset = self._make_test_dataset()
env = self.Environment(
dataset,
batch_size=batch_size,
num_workers=num_workers,
hide_task_labels=(not self.task_labels_at_test_time),
observation_space=self.observation_space,
action_space=self.action_space,
reward_space=self.reward_space,
Observations=self.Observations,
Actions=self.Actions,
Rewards=self.Rewards,
pretend_to_be_active=True,
drop_last=self.drop_last,
shuffle=False,
one_epoch_only=True,
)
# NOTE: The transforms from `self.transforms` (the 'base' transforms) were
# already added when creating the datasets and the CL scenario.
test_transforms = self.transforms + self.test_transforms
if test_transforms:
env = TransformObservation(env, f=test_transforms)
if self.config.device:
# TODO: Put this before or after the image transforms?
from sequoia.common.gym_wrappers.convert_tensors import ConvertToFromTensors
env = ConvertToFromTensors(env, device=self.config.device)
# env = TransformObservation(env, f=partial(move, device=self.config.device))
# env = TransformReward(env, f=partial(move, device=self.config.device))
# FIXME: Instead of trying to create a 'fake' task schedule for the test
# environment, instead let the test environment see the task ids, (and then hide
# them if necessary) so that it can compile the stats for each task based on the
# task IDs of the observations.
# TODO: Configure the 'monitoring' dir properly.
if wandb.run:
test_dir = wandb.run.dir
else:
test_dir = self.config.log_dir
test_loop_max_steps = len(dataset) // (env.batch_size or 1)
test_env = ContinualSLTestEnvironment(
env,
directory=test_dir,
step_limit=test_loop_max_steps,
force=True,
config=self.config,
video_callable=None if (wandb.run or self.config.render) else False,
)
# NOTE: Quickfix for the 'dtype' of the TypedDictSpace perhaps getting lost
# when transforms don't propagate the 'dtype' field.
env.observation_space.dtype = self.Observations
if self.test_env:
self.test_env.close()
self.test_env = test_env
return self.test_env
def prepare_data(self, data_dir: Path = None) -> None:
# TODO: Pass the transformations to the CL scenario, or to the dataset?
if data_dir is None:
if self.config:
data_dir = self.config.data_dir
else:
data_dir = Path("data")
logger.info(f"Downloading datasets to directory {data_dir}")
self._using_custom_envs_foreach_task = bool(self.train_datasets)
if not self._using_custom_envs_foreach_task:
self.train_cl_dataset = self.make_dataset(data_dir, download=True, train=True)
self.test_cl_dataset = self.make_dataset(data_dir, download=True, train=False)
return super().prepare_data()
def setup(self, stage: str = None):
if not self.has_prepared_data:
self.prepare_data()
super().setup(stage=stage)
if stage not in (None, "fit", "test", "validate"):
raise RuntimeError(f"`stage` should be 'fit', 'test', 'validate' or None.")
if stage in (None, "fit", "validate"):
if not self._using_custom_envs_foreach_task:
self.train_cl_dataset = self.train_cl_dataset or self.make_dataset(
self.config.data_dir, download=False, train=True
)
nb_tasks_kwarg = {}
if self.nb_tasks is not None:
nb_tasks_kwarg.update(nb_tasks=self.nb_tasks)
else:
nb_tasks_kwarg.update(increment=self.increment)
if not self._using_custom_envs_foreach_task:
self.train_cl_loader = self.train_cl_loader or ClassIncremental(
cl_dataset=self.train_cl_dataset,
**nb_tasks_kwarg,
initial_increment=self.initial_increment,
transformations=[], # NOTE: Changing this: The transforms will get added after.
class_order=self.class_order,
)
if not self.train_datasets and not self.val_datasets:
for task_id, train_taskset in enumerate(self.train_cl_loader):
train_taskset, valid_taskset = split_train_val(train_taskset, val_split=0.1)
self.train_datasets.append(train_taskset)
self.val_datasets.append(valid_taskset)
# IDEA: We could do the remapping here instead of adding a wrapper later.
if self.shared_action_space and isinstance(self.action_space, spaces.Discrete):
# If we have a shared output space, then they are all mapped to [0, n_per_task]
self.train_datasets = list(map(relabel, self.train_datasets))
self.val_datasets = list(map(relabel, self.val_datasets))
if stage in (None, "test"):
if not self._using_custom_envs_foreach_task:
self.test_cl_dataset = self.test_cl_dataset or self.make_dataset(
self.config.data_dir, download=False, train=False
)
self.test_class_order = self.test_class_order or self.class_order
self.test_cl_loader = self.test_cl_loader or ClassIncremental(
cl_dataset=self.test_cl_dataset,
nb_tasks=self.nb_tasks,
increment=self.test_increment,
initial_increment=self.test_initial_increment,
transformations=[], # note: not passing transforms here, they get added later
class_order=self.test_class_order,
)
if not self.test_datasets:
# TODO: If we decide to 'shuffle' the test tasks, then store the sequence of
# task ids in a new property, probably here.
# self.test_task_order = list(range(len(self.test_datasets)))
self.test_datasets = list(self.test_cl_loader)
# IDEA: We could do the remapping here instead of adding a wrapper later.
if self.shared_action_space and isinstance(self.action_space, spaces.Discrete):
# If we have a shared output space, then they are all mapped to [0, n_per_task]
self.test_datasets = list(map(relabel, self.test_datasets))
def _make_train_dataset(self) -> Union[TaskSet, Dataset]:
# NOTE: Passing the same seed to `train`/`valid`/`test` is fine, because it's
# only used for the shuffling used to make the task boundaries smooth.
if self.smooth_task_boundaries:
return smooth_task_boundaries_concat(
self.train_datasets, seed=self.config.seed if self.config else None
)
if self.stationary_context:
joined_dataset = concat(self.train_datasets)
return shuffle(joined_dataset, seed=self.config.seed)
if self.known_task_boundaries_at_train_time:
return self.train_datasets[self.current_task_id]
else:
return concatenate(self.train_datasets)
def _make_val_dataset(self) -> Dataset:
if self.smooth_task_boundaries:
return smooth_task_boundaries_concat(self.val_datasets, seed=self.config.seed)
if self.stationary_context:
joined_dataset = concat(self.val_datasets)
return shuffle(joined_dataset, seed=self.config.seed)
if self.known_task_boundaries_at_train_time:
return self.val_datasets[self.current_task_id]
return concatenate(self.val_datasets)
def _make_test_dataset(self) -> Dataset:
if self.smooth_task_boundaries:
return smooth_task_boundaries_concat(self.test_datasets, seed=self.config.seed)
else:
return concatenate(self.test_datasets)
def make_dataset(
self, data_dir: Path, download: bool = True, train: bool = True, **kwargs
) -> _ContinuumDataset:
# TODO: #7 Use this method here to fix the errors that happen when
# trying to create every single dataset from continuum.
data_dir = Path(data_dir)
if not data_dir.exists():
data_dir.mkdir(parents=True, exist_ok=True)
if self.dataset in self.available_datasets:
dataset_class = self.available_datasets[self.dataset]
return dataset_class(data_path=data_dir, download=download, train=train, **kwargs)
elif self.dataset in self.available_datasets.values():
dataset_class = self.dataset
return dataset_class(data_path=data_dir, download=download, train=train, **kwargs)
elif isinstance(self.dataset, Dataset):
logger.info(f"Using a custom dataset {self.dataset}")
return self.dataset
else:
raise NotImplementedError(self.dataset)
@property
def observation_space(self) -> ObservationSpace[Observations]:
"""The un-batched observation space, based on the choice of dataset and
the transforms at `self.transforms` (which apply to the train/valid/test
environments).
The returned space is a TypedDictSpace, with the following properties:
- `x`: observation space (e.g. `Image` space)
- `task_labels`: Union[Discrete, Sparse[Discrete]]
The task labels for each sample. When task labels are not available,
the task labels space is Sparse, and entries will be `None`.
"""
# TODO: Need to clean this up a bit:
if self._using_custom_envs_foreach_task:
x_space = get_observation_space(self.train_datasets[0])
else:
x_space = get_observation_space(self.dataset)
if not self.transforms:
# NOTE: When we don't pass any transforms, continuum scenarios still
# at least use 'to_tensor'.
x_space = Transforms.to_tensor(x_space)
# apply the transforms to the observation space.
for transform in self.transforms:
x_space = transform(x_space)
x_space = add_tensor_support(x_space)
task_label_space = spaces.Discrete(self.nb_tasks)
if not self.task_labels_at_train_time:
task_label_space = Sparse(task_label_space, 1.0)
task_label_space = add_tensor_support(task_label_space)
self._observation_space = self.ObservationSpace(
x=x_space,
task_labels=task_label_space,
dtype=self.Observations,
)
return self._observation_space
# TODO: Add a `train_observation_space`, `train_action_space`, `train_reward_space`?
@property
def action_space(self) -> spaces.Discrete:
"""Action space for this setting."""
if self._action_space:
return self._action_space
# Determine the action space using the right dataset.
# (NOTE: same across train/val/test for now.)
dataset = self.dataset
if self._using_custom_envs_foreach_task:
dataset = self.train_datasets[0]
action_space = get_action_space(dataset)
# TODO: Remove this
if isinstance(action_space, spaces.Discrete) and self.dataset in self.base_action_spaces:
if self.shared_action_space:
assert isinstance(self.increment, int), (
"Need to have same number of classes in each task when "
"`shared_action_space` is true."
)
action_space = spaces.Discrete(self.increment)
self._action_space = action_space
return self._action_space
# TODO: IDEA: Have the action space only reflect the number of 'current' classes
# in order to create a "true" class-incremental learning setting.
# n_classes_seen_so_far = 0
# for task_id in range(self.current_task_id):
# n_classes_seen_so_far += self.num_classes_in_task(task_id)
# return spaces.Discrete(n_classes_seen_so_far)
@property
def reward_space(self) -> spaces.Discrete:
if self._reward_space:
return self._reward_space
# Determine the reward space using the right dataset.
# (NOTE: same across train/val/test for now.)
dataset = self.dataset
if self._using_custom_envs_foreach_task:
dataset = self.train_datasets
reward_space = get_reward_space(dataset)
# TODO: Remove this
if isinstance(reward_space, spaces.Discrete) and self.dataset in self.base_reward_spaces:
if self.shared_action_space:
assert isinstance(self.increment, int), (
"Need to have same number of classes in each task when "
"`shared_action_space` is true."
)
reward_space = spaces.Discrete(self.increment)
self._reward_space = reward_space
return self._reward_space
def smooth_task_boundaries_concat(
datasets: List[Dataset], seed: int = None, window_length: float = 0.03
) -> ConcatDataset:
"""TODO: Use a smarter way of mixing from one to the other?"""
lengths = [len(dataset) for dataset in datasets]
total_length = sum(lengths)
n_tasks = len(datasets)
if not isinstance(window_length, int):
window_length = int(total_length * window_length)
assert (
window_length > 1
), f"Window length should be positive or a fraction of the dataset length. ({window_length})"
rng = np.random.default_rng(seed)
def option1():
shuffled_indices = np.arange(total_length)
for start_index in range(0, total_length - window_length + 1, window_length // 2):
rng.shuffle(shuffled_indices[start_index : start_index + window_length])
return shuffled_indices
# Maybe do the same but backwards?
# IDEA #2: Sample based on how close to the 'center' of the task we are.
def option2():
boundaries = np.array(list(itertools.accumulate(lengths, initial=0)))
middles = [(start + end) / 2 for start, end in zip(boundaries[0:], boundaries[1:])]
samples_left: Dict[int, int] = {i: length for i, length in enumerate(lengths)}
indices_left: Dict[int, List[int]] = {
i: list(range(boundaries[i], boundaries[i] + length))
for i, length in enumerate(lengths)
}
out_indices: List[int] = []
last_dataset_index = n_tasks - 1
for step in range(total_length):
if step < middles[0] and samples_left[0]:
# Prevent sampling things from task 1 at the beginning of task 0, and
eligible_dataset_ids = [0]
elif step > middles[-1] and samples_left[last_dataset_index]:
# Prevent sampling things from task N-1 at the emd of task N
eligible_dataset_ids = [last_dataset_index]
else:
# 'smooth', but at the boundaries there are actually two or three datasets,
# from future tasks even!
eligible_dataset_ids = list(k for k, v in samples_left.items() if v > 0)
# if len(eligible_dataset_ids) > 2:
# # Prevent sampling from future tasks (past the next task) when at a
# # boundary.
# left_dataset_index = min(eligible_dataset_ids)
# right_dataset_index = min(
# v for v in eligible_dataset_ids if v > left_dataset_index
# )
# eligible_dataset_ids = [left_dataset_index, right_dataset_index]
options = np.array(eligible_dataset_ids, dtype=int)
# Calculate the 'distance' to the center of the task's dataset.
distances = np.abs([step - middles[dataset_index] for dataset_index in options])
# NOTE: THis exponent is kindof arbitrary, setting it to this value because it
# sortof works for MNIST so far.
probs = 1 / (1 + np.abs(distances) ** 2)
probs /= sum(probs)
chosen_dataset = rng.choice(options, p=probs)
chosen_index = indices_left[chosen_dataset].pop()
samples_left[chosen_dataset] -= 1
out_indices.append(chosen_index)
shuffled_indices = np.array(out_indices)
return shuffled_indices
def option3():
shuffled_indices = np.arange(total_length)
for start_index in range(0, total_length - window_length + 1, window_length // 2):
rng.shuffle(shuffled_indices[start_index : start_index + window_length])
for start_index in reversed(range(0, total_length - window_length + 1, window_length // 2)):
rng.shuffle(shuffled_indices[start_index : start_index + window_length])
return shuffled_indices
shuffled_indices = option3()
if all(isinstance(dataset, TaskSet) for dataset in datasets):
# Use the 'concat' from continuum, just to preserve the field/methods of a
# TaskSet.
joined_taskset = concat(datasets)
return subset(joined_taskset, shuffled_indices)
else:
joined_dataset = ConcatDataset(datasets)
return Subset(joined_dataset, shuffled_indices)
return shuffled_indices
from functools import singledispatch
from typing import Sequence, overload
from .wrappers import replace_taskset_attributes
DatasetType = TypeVar("DatasetType", bound=Dataset)
@overload
def subset(dataset: TaskSet, indices: Sequence[int]) -> TaskSet:
...
@singledispatch
def subset(dataset: DatasetType, indices: Sequence[int]) -> Union[Subset, DatasetType]:
raise NotImplementedError(f"Don't know how to take a subset of dataset {dataset}")
return Subset(dataset, indices)
@subset.register
def taskset_subset(taskset: TaskSet, indices: np.ndarray) -> TaskSet:
# x, y, t = taskset.get_raw_samples(indices)
x, y, t = taskset.get_raw_samples(indices)
# TODO: Not sure if/how to handle the `bounding_boxes` attribute here.
bounding_boxes = taskset.bounding_boxes
if bounding_boxes is not None:
bounding_boxes = bounding_boxes[indices]
return replace_taskset_attributes(taskset, x=x, y=y, t=t, bounding_boxes=bounding_boxes)
def random_subset(
taskset: TaskSet, n_samples: int, seed: int = None, ordered: bool = True
) -> TaskSet:
"""Returns a random (ordered) subset of the given TaskSet."""
rng = np.random.default_rng(seed)
dataset_length = len(taskset)
if n_samples > dataset_length:
raise RuntimeError(f"Dataset has {dataset_length}, asked for {n_samples} samples.")
indices = rng.permutation(range(dataset_length))[:n_samples]
# indices = rng.choice(len(taskset), size=n_samples, replace=False)
if ordered:
indices = sorted(indices)
assert len(indices) == n_samples
return subset(taskset, indices)
DatasetType = TypeVar("DatasetType", bound=Dataset)
def shuffle(dataset: DatasetType, seed: int = None) -> DatasetType:
length = len(dataset)
rng = np.random.default_rng(seed)
perm = rng.permutation(range(length))
return subset(dataset, perm)
import torch
from torch import Tensor
def smart_class_prediction(
logits: Tensor, task_labels: Tensor, setting: SLSetting, train: bool
) -> Tensor:
"""Predicts classes which are available, given the task labels."""
unique_task_ids = set(task_labels.unique().cpu().tolist())
classes_in_each_task = {
task_id: setting.task_classes(task_id, train=train) for task_id in unique_task_ids
}
y_pred = limit_to_available_classes(logits, task_labels, classes_in_each_task)
return y_pred
def limit_to_available_classes(
logits: Tensor, task_labels: Tensor, classes_in_each_present_task: Dict[int, List[int]]
) -> Tensor:
B = logits.shape[0]
C = logits.shape[-1]
assert logits.shape[0] == task_labels.shape[0] == B
y_preds = []
indices = torch.arange(C, dtype=torch.long, device=logits.device)
elligible_masks = {
task_id: sum(
[indices == label for label in labels],
start=torch.zeros([C], dtype=bool, device=logits.device),
)
for task_id, labels in classes_in_each_present_task.items()
}
y_preds = []
# TODO: Also return the logits, so we can get a loss for the selected indices?
# logits = []
for logit, task_label in zip(logits, task_labels):
t = task_label.item()
eligible_classes_list = classes_in_each_present_task[t]
eligible_classes = torch.as_tensor(eligible_classes_list, dtype=int, device=logits.device)
is_eligible = elligible_masks[t]
if not is_eligible.any():
# Return a random prediction from the set of possible classes, since
# the network has fewer outputs than there are classes.
# NOTE: This can occur for instance when testing on future tasks
# when using a MultiTask module.
y_pred = eligible_classes[torch.randint(len(eligible_classes), (1,))]
else:
masked_logit = logit[is_eligible]
y_pred_without_offset = masked_logit.argmax(-1)
y_pred = eligible_classes[y_pred_without_offset]
assert y_pred.item() in eligible_classes_list
y_preds.append(y_pred.reshape(())) # Just to make sure they all have the same shape.
return torch.stack(y_preds)
from sequoia.common.transforms.channels import has_channels_last, has_channels_first
@has_channels_last.register(ContinualSLSetting.Observations)
def _has_channels_last(obs: ContinualSLSetting.Observations) -> bool:
return has_channels_last(obs.x)