From 74b9f4abcbf01ed3d8aee20cd97d56617cd1314f Mon Sep 17 00:00:00 2001 From: Julia Gusak Date: Tue, 16 Feb 2021 22:48:19 +0300 Subject: [PATCH] Update svd_layer.py --- musco/pytorch/compressor/decompositions/svd_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/musco/pytorch/compressor/decompositions/svd_layer.py b/musco/pytorch/compressor/decompositions/svd_layer.py index 0b1648a..7c40c88 100644 --- a/musco/pytorch/compressor/decompositions/svd_layer.py +++ b/musco/pytorch/compressor/decompositions/svd_layer.py @@ -163,11 +163,11 @@ def __init__(self, layer, layer_name, self.rank = rank elif rank_selection == 'param_reduction': if isinstance(self.layer, nn.Sequential): - prev_rank = self.layer[0].out_features + prev_rank = self.layer[0].out_channels else: prev_rank = None - self.rank = estimate_rank_for_compression_rate((self.out_features, self.in_features), + self.rank = estimate_rank_for_compression_rate((self.out_channels, self.in_channels), rate = param_reduction_rate, key = 'svd', prev_rank = prev_rank,