Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine tuning issues #65

Open
VRichardJP opened this issue Dec 27, 2024 · 0 comments
Open

Fine tuning issues #65

VRichardJP opened this issue Dec 27, 2024 · 0 comments

Comments

@VRichardJP
Copy link

Hi,

I am trying to use fine tune one of the pretrained HTS-AT model for binary classification on a custom dataset. I have already managed to do the exact same thing with a pretrained BEATs model, but somehow I can't make it work with HTS-AT.

Here is a summary of what I do:

  • I create the HTSAT_Swin_Transformer model with the same config than in your ESC-50 fine tuning example, the only difference being num_classes=1 and loss_type = "clip_bce" since I do binary classification
  • I load one of the pretrained checkpoint (e.g. HTSAT_AudioSet_Saved_1.ckpt) and update all the model weights but sed_model.tscam_conv.weights and sed_model.tscam_conv.bias (I have verified all weights are correctly loaded)
  • I freeze all the parameters but tscam_conv ones (4.6K trainable params left)
  • I feed the model batches of raw audio frames (sampled at 32000Hz and zero-padded to fit longest audio clip in the batch) and compute the loss against its 0-1 targets with nn.BCELoss

I follow the exact same process with BEATs, the only difference being the layers names and the input data sample rate (16000Hz). Yet I can't get the HTS-AT model to learn anything. For example here is the val_loss after a few epochs over a few tries (blue is BEATs fine tuning for reference):

image

I have tried with different learning rates, pretrained weights and optimizers but it does not seem to have any effect.

My dataset being composed of roughly 10% of positives, the val_loss of a dummy model outputing a constant value of 0.10 would have an approximate val_loss of 0.27, which is what all my attempts seem to converge toward. Basically, the model is not learning anything from the input here.

The input data looks "normal". For example, here is what the sound of an ambulance looks like after HTSAT preprocessing:

image

    def forward(
        self, x: torch.Tensor, mixup_lambda=None, infer_mode=False
    ):  # out_feat_keys: List[str] = None):
        x = self.spectrogram_extractor(x)  # (batch_size, 1, time_steps, freq_bins)

        fig, axs = plt.subplots(2)
        img = librosa.display.specshow(
            x[0][0].detach().cpu().numpy().T, x_axis="time", y_axis="log", ax=axs[0]
        )
        fig.colorbar(img, ax=axs[0], format="%+2.f dB")
        axs[0].set(title="spectogram")

        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        img = librosa.display.specshow(
            x[0][0].detach().cpu().numpy().T, x_axis="time", y_axis="mel", ax=axs[1]
        )
        fig.colorbar(img, ax=axs[1], format="%+2.f dB")
        axs[1].set(title="logmel")
        plt.show()

        # ...

Am I missing a key detail?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant