Hi, this is an interesting project!
However, there are some small issues in the code. While the CausCell class allows passing the parameter device=‘cpu’ during initialization, subsequent sections of the CausCell class and the Trainer class code all use .cuda(), which causes errors when running in a non-CUDA environment.
Replacing this with .to(self.device) resolves the issue.
🙂