|
1 | 1 | """Tests for text-generation module |
2 | 2 | """ |
| 3 | + |
3 | 4 | # Standard |
4 | 5 | import os |
5 | 6 | import platform |
|
10 | 11 | import torch |
11 | 12 |
|
12 | 13 | # First Party |
13 | | -from caikit.interfaces.nlp.data_model import GeneratedTextResult |
| 14 | +from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults |
14 | 15 | import caikit |
15 | 16 |
|
16 | 17 | # Local |
@@ -211,7 +212,26 @@ def test_zero_epoch_case(disable_wip): |
211 | 212 | assert isinstance(model.model, HFAutoSeq2SeqLM) |
212 | 213 |
|
213 | 214 |
|
214 | | -def test_run_tokenizer_not_implemented(): |
215 | | - with pytest.raises(NotImplementedError): |
216 | | - model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) |
217 | | - model.run_tokenizer("This text doesn't matter") |
| 215 | +# ############################## Run Tokenizer ################################ |
| 216 | + |
| 217 | + |
| 218 | +def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device): |
| 219 | + """Test tokenizer on edge cases like empty strings and long input.""" |
| 220 | + model = TextGeneration.bootstrap(CAUSAL_LM_MODEL) |
| 221 | + |
| 222 | + # Edge case: Empty string |
| 223 | + empty_result = model.run_tokenizer("") |
| 224 | + assert isinstance(empty_result, TokenizationResults) |
| 225 | + assert empty_result.token_count == 0 |
| 226 | + |
| 227 | + # Normal case: short sentence |
| 228 | + short_text = "This is a test sentence." |
| 229 | + short_result = model.run_tokenizer(short_text) |
| 230 | + assert isinstance(short_result, TokenizationResults) |
| 231 | + assert short_result.token_count == len(model.model.tokenizer.encode(short_text)) |
| 232 | + |
| 233 | + # Edge case: Long input |
| 234 | + long_text = "This is a test sentence. " * 1000 |
| 235 | + long_result = model.run_tokenizer(long_text) |
| 236 | + assert isinstance(long_result, TokenizationResults) |
| 237 | + assert long_result.token_count == len(model.model.tokenizer.encode(long_text)) |
0 commit comments