Skip to content

Commit f629248

Browse files
committed
✅ based on the PR comments, changed test case to check for an expected number instead of checking if length is non-zero; added return_attention_mask=True in the run_tokenizer method
Signed-off-by: m-misiura <[email protected]>
1 parent 261e1a3 commit f629248

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

caikit_nlp/modules/text_generation/text_generation_local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def run_tokenizer(
592592
The token count
593593
"""
594594
error.type_check("<NLP48137045E>", str, text=text)
595-
tokenized_output = self.model.tokenizer(text)
595+
tokenized_output = self.model.tokenizer(text, return_attention_mask=True)
596596
return TokenizationResults(
597597
token_count=len(tokenized_output["input_ids"]),
598598
)

tests/modules/text_generation/test_text_generation_local.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,10 @@ def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device):
228228
short_text = "This is a test sentence."
229229
short_result = model.run_tokenizer(short_text)
230230
assert isinstance(short_result, TokenizationResults)
231-
assert short_result.token_count > 0
231+
assert short_result.token_count == len(model.model.tokenizer.encode(short_text))
232232

233233
# Edge case: Long input
234234
long_text = "This is a test sentence. " * 1000
235235
long_result = model.run_tokenizer(long_text)
236236
assert isinstance(long_result, TokenizationResults)
237-
assert long_result.token_count > 0
237+
assert long_result.token_count == len(model.model.tokenizer.encode(long_text))

0 commit comments

Comments
 (0)