Skip to content

Commit 32bffe8

Browse files
authored
Merge pull request #88 from kajyuuen/fix-bert-feature
Fix bert features
2 parents c451cdd + 8d0e12e commit 32bffe8

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

nyaggle/feature/nlp/bert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def _process_text(self, text: str) -> np.ndarray:
8080

8181
self.model.eval()
8282
with torch.no_grad():
83-
all_encoder_layers, _ = self.model(tokens_tensor)
83+
outputs = self.model(tokens_tensor)
8484

85-
embedding = all_encoder_layers.cpu().numpy()[0]
85+
embedding = outputs.last_hidden_state.cpu().numpy()[0]
8686
if self.pooling_strategy == 'reduce_mean':
8787
return np.mean(embedding, axis=0)
8888
elif self.pooling_strategy == 'reduce_max':

tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)