From d73ea7c8086875a31c27abf0ece1c4aa080be154 Mon Sep 17 00:00:00 2001 From: Sourcery AI Date: Tue, 31 Oct 2023 19:49:43 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- data/dataset.py | 6 +++--- data/preprocess.py | 4 +--- models.py | 8 +++----- train.py | 2 +- utils/routes.py | 6 +++--- utils/trainer.py | 42 ++++++++++++++++++++++-------------------- utils/vis.py | 12 +++++++++--- 7 files changed, 42 insertions(+), 38 deletions(-) diff --git a/data/dataset.py b/data/dataset.py index 3480c91..0187ef6 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -57,9 +57,9 @@ def __getitem__(self, idx): def split_dataset(phase: str = 'train', train_ratio: float = 0.8, **kwargs): full_dataset = TrackingDataset(**kwargs) - if phase == 'train': + if phase == 'test': + return full_dataset + elif phase == 'train': train_size = int(len(full_dataset) * train_ratio) val_size = len(full_dataset) - train_size return random_split(full_dataset, [train_size, val_size]) - elif phase == 'test': - return full_dataset diff --git a/data/preprocess.py b/data/preprocess.py index 82382bb..bd2c655 100644 --- a/data/preprocess.py +++ b/data/preprocess.py @@ -4,9 +4,7 @@ def sub_mean(frames: Tensor) -> Tensor: mean_frame = frames.mean(axis=0, keepdim=True) - frames_sub_mean = frames.sub(mean_frame) - - return frames_sub_mean + return frames.sub(mean_frame) def diff(frames: Tensor) -> Tensor: diff --git a/models.py b/models.py index 4e9284f..2d1973e 100644 --- a/models.py +++ b/models.py @@ -12,7 +12,7 @@ def __init__(self, model_name: str, pretrained: bool = True, rnn_hdim: int = 128): super(PAC_Cell, self).__init__() - assert model_name in ['PAC_Net', 'P_Net', 'C_Net', 'baseline'] + assert model_name in {'PAC_Net', 'P_Net', 'C_Net', 'baseline'} self.rnn_hdim = rnn_hdim self.backbone_builder = { @@ -66,7 +66,7 @@ def __init__(self, model_name: str, pretrained: bool, self.rnn_hdim = rnn_hdim self.v_loss = v_loss - assert model_name in ['PAC_Net', 'P_Net', 'C_Net', 'baseline'] + assert model_name in {'PAC_Net', 'P_Net', 'C_Net', 'baseline'} # CNN self.backbone_builder = { 'PAC_Net': tvmodels.resnet18, @@ -366,9 +366,7 @@ def warm_up(self, I: Tensor): fx = self.warmup_encoder(rearrange(I, 'b c t h w -> (b t) c h w')) fx = rearrange(fx, '(b t) d -> b t d', b=B) - hx = self.warmup_rnn(fx)[1] # (2, B, D) - - return hx + return self.warmup_rnn(fx)[1] class NLOS_baseline(PAC_Net_Base): diff --git a/train.py b/train.py index 20555ee..376e164 100644 --- a/train.py +++ b/train.py @@ -14,7 +14,7 @@ def main(cfg): # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' world_size = len(dist_cfgs['device_ids'].split(',')) - dist_cfgs['distributed'] = True if world_size > 1 else False + dist_cfgs['distributed'] = world_size > 1 dist_cfgs['world_size'] = world_size cfg['loader_kwargs']['batch_size'] = cfg['train_configs']['batch_size'] // world_size diff --git a/utils/routes.py b/utils/routes.py index 4ab30ce..7e49226 100644 --- a/utils/routes.py +++ b/utils/routes.py @@ -48,7 +48,7 @@ def generate_route(self, self.route_length = route_length self._init_pv() - for step in range(route_length): + for _ in range(route_length): # print(self.velocity) self.next_step(turn_rate=turn_rate) self.e_route.append(self.e_position.copy()) @@ -116,8 +116,8 @@ def load_route(self, mat_path = os.path.join(save_dir, mat_name) save_dict = loadmat(mat_path) - self.e_route = [p for p in save_dict['route']] - self.velocities = [v for v in save_dict['velocities']] + self.e_route = list(save_dict['route']) + self.velocities = list(save_dict['velocities']) print(f'Load data from {mat_path} successfully!') diff --git a/utils/trainer.py b/utils/trainer.py index 75415e5..215583a 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -192,20 +192,20 @@ class Trainer_tracking(Trainer_Base): def __init__(self, cfg): super(Trainer_tracking, self).__init__(cfg=cfg) - log_train_cfg = { - "model_name": self.model_name, - **self.model_cfgs, - "batch_size": self.train_cfgs['batch_size'], - "v_loss_alpha": self.train_cfgs['v_loss_alpha'], - "loss_total_alpha": self.train_cfgs['loss_total_alpha'], - "resume": self.train_cfgs['resume'], - "route_len": self.dataset_cfgs['route_len'], - "noise_factor": self.dataset_cfgs['noise_factor'], - **self.optim_kwargs, - "epochs": self.schedule_cfgs['max_epoch'], - } - if self.dist_cfgs['local_rank'] == 0: + log_train_cfg = { + "model_name": self.model_name, + **self.model_cfgs, + "batch_size": self.train_cfgs['batch_size'], + "v_loss_alpha": self.train_cfgs['v_loss_alpha'], + "loss_total_alpha": self.train_cfgs['loss_total_alpha'], + "resume": self.train_cfgs['resume'], + "route_len": self.dataset_cfgs['route_len'], + "noise_factor": self.dataset_cfgs['noise_factor'], + **self.optim_kwargs, + "epochs": self.schedule_cfgs['max_epoch'], + } + self._init_recorder(log_train_cfg) self.val_metrics = {'x_loss': 0.0, @@ -278,14 +278,14 @@ def run(self): 'Metric/val/dtw': val_metric[2], }, step=epoch + 1) if self.epoch % 5 == 0: - logger.info(f'Logging images...') + logger.info('Logging images...') self.test_plot(epoch=self.epoch, phase='train') self.test_plot(epoch=self.epoch, phase='val') self.scheduler.step() if ((epoch + 1) % self.log_cfgs['save_epoch_interval'] == 0) \ - or (epoch + 1) == self.schedule_cfgs['max_epoch']: + or (epoch + 1) == self.schedule_cfgs['max_epoch']: checkpoint_path = os.path.join(self.ckpt_dir, f"epoch_{(epoch + 1)}.pth") self.save_checkpoint(checkpoint_path) @@ -315,7 +315,7 @@ def train(self, epoch): dynamic_ncols=True, ascii=(platform.version() == 'Windows')) - for step in range(len_loader): + for _ in range(len_loader): try: inputs, labels, map_sizes = next(iter_loader) except Exception as e: @@ -387,7 +387,7 @@ def train(self, epoch): pbar.close() return (x_loss_recorder.avg, v_loss_recorder.avg), \ - (pcm_recorder.avg, area_recorder.avg, dtw_recorder.avg) + (pcm_recorder.avg, area_recorder.avg, dtw_recorder.avg) def val(self, epoch): self.model.eval() @@ -482,13 +482,15 @@ def val(self, epoch): metrics = [self.val_metrics[name] for name in names] res_table.add_row([f"{m:.4}" if type(m) is float else m for m in metrics[:-1]] + [metrics[-1]]) - logger.info(f'Performance on validation set at epoch: {epoch + 1}\n' + res_table.get_string()) + logger.info( + f'Performance on validation set at epoch: {epoch + 1}\n{res_table.get_string()}' + ) return (self.val_metrics['x_loss'], self.val_metrics['v_loss']), \ - (self.val_metrics['pcm'], self.val_metrics['area'], self.val_metrics['dtw']) + (self.val_metrics['pcm'], self.val_metrics['area'], self.val_metrics['dtw']) def test_plot(self, epoch, phase: str): - assert phase in ['train', 'val'] + assert phase in {'train', 'val'} self.model.eval() iter_loader = iter(self.val_loader) if phase == 'val' else iter(self.train_loader) frames, gt_routes, map_sizes = next(iter_loader) diff --git a/utils/vis.py b/utils/vis.py index 8291784..5fde9b5 100644 --- a/utils/vis.py +++ b/utils/vis.py @@ -23,7 +23,7 @@ def draw_route(map_size: ndarray, route: ndarray, lc.set_array(idxs) lc.set_linewidth(3) line = ax.add_collection(lc) - fig.colorbar(line, ax=ax, ticks=idxs[::int(len(idxs) / 10)], label='step') + fig.colorbar(line, ax=ax, ticks=idxs[::len(idxs) // 10], label='step') ax.set_xlim(0, map_size[0]) ax.set_xlabel('x') @@ -40,7 +40,7 @@ def draw_route(map_size: ndarray, route: ndarray, def draw_routes(routes: tuple[ndarray, ndarray], return_mode: str = None): - assert return_mode in ['plt_fig', 'fig_array', None] + assert return_mode in {'plt_fig', 'fig_array', None} titles = ('GT', 'pred') cmaps = ('viridis', 'plasma') @@ -59,7 +59,13 @@ def draw_routes(routes: tuple[ndarray, ndarray], return_mode: str = None): lc.set_array(idxs) lc.set_linewidth(3) line = axes[i].add_collection(lc) - fig.colorbar(line, ax=axes[i], ticks=idxs[::int(len(idxs) / 10)], label='step', fraction=0.05) + fig.colorbar( + line, + ax=axes[i], + ticks=idxs[:: len(idxs) // 10], + label='step', + fraction=0.05, + ) axes[i].set_title(titles[i]) axes[i].set_xlim(0, 1)