Skip to content

Commit cd44077

Browse files
authored
Merge pull request #402 from m-misiura/run_tokenier_in_text_generation_local
Added `run_tokenizer` method in ` caikit_nlp/modules/text_generation/text_generation_local.py`
2 parents 56b7e18 + f629248 commit cd44077

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

caikit_nlp/modules/text_generation/text_generation_local.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555

5656
TRAINING_LOSS_LOG_FILENAME = "training_logs.jsonl"
5757

58+
5859
# pylint: disable=too-many-lines,too-many-instance-attributes
5960
@module(
6061
id="f9181353-4ccf-4572-bd1e-f12bcda26792",
@@ -590,7 +591,11 @@ def run_tokenizer(
590591
TokenizationResults
591592
The token count
592593
"""
593-
raise NotImplementedError("Tokenization not implemented for local")
594+
error.type_check("<NLP48137045E>", str, text=text)
595+
tokenized_output = self.model.tokenizer(text, return_attention_mask=True)
596+
return TokenizationResults(
597+
token_count=len(tokenized_output["input_ids"]),
598+
)
594599

595600
################################## Private Functions ######################################
596601

tests/modules/text_generation/test_text_generation_local.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for text-generation module
22
"""
3+
34
# Standard
45
import os
56
import platform
@@ -10,7 +11,7 @@
1011
import torch
1112

1213
# First Party
13-
from caikit.interfaces.nlp.data_model import GeneratedTextResult
14+
from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults
1415
import caikit
1516

1617
# Local
@@ -211,7 +212,26 @@ def test_zero_epoch_case(disable_wip):
211212
assert isinstance(model.model, HFAutoSeq2SeqLM)
212213

213214

214-
def test_run_tokenizer_not_implemented():
215-
with pytest.raises(NotImplementedError):
216-
model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL)
217-
model.run_tokenizer("This text doesn't matter")
215+
# ############################## Run Tokenizer ################################
216+
217+
218+
def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device):
219+
"""Test tokenizer on edge cases like empty strings and long input."""
220+
model = TextGeneration.bootstrap(CAUSAL_LM_MODEL)
221+
222+
# Edge case: Empty string
223+
empty_result = model.run_tokenizer("")
224+
assert isinstance(empty_result, TokenizationResults)
225+
assert empty_result.token_count == 0
226+
227+
# Normal case: short sentence
228+
short_text = "This is a test sentence."
229+
short_result = model.run_tokenizer(short_text)
230+
assert isinstance(short_result, TokenizationResults)
231+
assert short_result.token_count == len(model.model.tokenizer.encode(short_text))
232+
233+
# Edge case: Long input
234+
long_text = "This is a test sentence. " * 1000
235+
long_result = model.run_tokenizer(long_text)
236+
assert isinstance(long_result, TokenizationResults)
237+
assert long_result.token_count == len(model.model.tokenizer.encode(long_text))

0 commit comments

Comments
 (0)