Skip to content

Commit

Permalink
Merge pull request #76 from decile-team/krishnatejakk-patch-1
Browse files Browse the repository at this point in the history
Bug Fix in Non Adaptive CRAIG DataLoader
  • Loading branch information
krishnatejakk authored Jun 9, 2022
2 parents b84c334 + 0f59b1d commit 844f897
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions cords/utils/data/dataloader/SL/nonadaptive/craigdataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ def __init__(self, train_loader, val_loader, dss_args, logger, *args, **kwargs):
assert "optimizer" in dss_args.keys(), "'optimizer' is a compulsory argument for CRAIG. Include it as a key in dss_args"
assert "if_convex" in dss_args.keys(), "'if_convex' is a compulsory argument for CRAIG. Include it as a key in dss_args"

super(CRAIGDataLoader, self).__init__(train_loader, val_loader, dss_args,
logger, *args, **kwargs)

self.strategy = CRAIGStrategy(train_loader, val_loader, copy.deepcopy(dss_args.model), dss_args.num_classes,
dss_args.linear_layer, dss_args.loss, dss_args.device,
False, dss_args.selection_type, logger, dss_args.optimizer)

super(CRAIGDataLoader, self).__init__(train_loader, val_loader, dss_args,
logger, *args, **kwargs)


self.train_model = dss_args.model
self.eta = dss_args.eta
self.num_cls = dss_args.num_classes
Expand Down Expand Up @@ -67,4 +69,4 @@ def _init_subset_indices(self):
self.train_model.load_state_dict(cached_state_dict)
end = time.time()
self.logger.info('Epoch: {0:d}, CRAIG subset selection finished, takes {1:.4f}. '.format(self.cur_epoch, (end - start)))
return subset_indices, subset_weights
return subset_indices, subset_weights

0 comments on commit 844f897

Please sign in to comment.