Skip to content
This repository has been archived by the owner on Mar 1, 2022. It is now read-only.

Fix PyTorch device error when loading custom model #18

Merged
merged 2 commits into from
May 21, 2020
Merged

Fix PyTorch device error when loading custom model #18

merged 2 commits into from
May 21, 2020

Conversation

JasonObeid
Copy link
Contributor

Using Python 3.8, fitbert 0.7.0, transformers 2.9.1, torch 1.5.0

when loading a custom Transformers model as described in the readme using:
BertForMaskedLM.from_pretrained('path to pretrained')

A runtime error occurs:
Expected object of device type cuda but got device type cpu for argument #1 'self' in call to th_index_select

The issue occurs at line 151: tens = tens.to(self.device)

but adding self.bert.to(self.device) to line 41 fixes this issue

Copy link
Contributor

@sam-writer sam-writer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for contributing!

@sam-writer sam-writer merged commit a2e751e into writer:master May 21, 2020
@sam-writer
Copy link
Contributor

@JasonObeid I released this change, it should be available in version 0.9.0

@JasonObeid
Copy link
Contributor Author

Looks great! Thanks for contributing!

No problem Sam, thanks for this great package!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants