MSE Temporal Loss Bugfix + Optimization#361
MSE Temporal Loss Bugfix + Optimization#361erenaydoslu wants to merge 5 commits intojeshraghian:masterfrom
Conversation
|
Thanks for working through this. Training worked when specifying In your update, it raised the following error:
whereas in the current release of snnTorch, I'd receive this error:
I'll attempt to debug this and will add a few tests. |
|
Hi @jeshraghian, Thank you for taking the time to review. I've pushed two new commits to address your feedback.
I've run the new versions, and the new loss now handles both scalars and multi-spike targets without errors. Please let me know if there's anything else I can adjust. |
Training a model with
mse_temporal_losswould not work at all. Turns out the functionspikegen.targets_convertused for converting target indices to spike times generates a tensor full of zeros, which would end up to be the target in MSE calculations. As a result, the model learns to spike all the outputs all the time.Furthermore, training with
mse_temporal_lossis also very slow. I changed the Python loops underFirstSpiketo PyTorch functions without using loops. On my device, this leads to ~30x improvement in speed.