Skip to content

Commit 8d0e12e

Browse files
committed
Using transforms' new API.
1 parent c451cdd commit 8d0e12e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

nyaggle/feature/nlp/bert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def _process_text(self, text: str) -> np.ndarray:
8080

8181
self.model.eval()
8282
with torch.no_grad():
83-
all_encoder_layers, _ = self.model(tokens_tensor)
83+
outputs = self.model(tokens_tensor)
8484

85-
embedding = all_encoder_layers.cpu().numpy()[0]
85+
embedding = outputs.last_hidden_state.cpu().numpy()[0]
8686
if self.pooling_strategy == 'reduce_mean':
8787
return np.mean(embedding, axis=0)
8888
elif self.pooling_strategy == 'reduce_max':

tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)