|
15 | 15 |
|
16 | 16 | from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url |
17 | 17 |
|
18 | | - |
19 | | -def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]: |
20 | | - """ |
21 | | - Truncate texts to token limit using tiktoken or conservative character truncation. |
22 | | -
|
23 | | - Args: |
24 | | - texts: List of texts to truncate |
25 | | - max_tokens: Maximum tokens allowed per text |
26 | | -
|
27 | | - Returns: |
28 | | - List of truncated texts that should fit within token limit |
29 | | - """ |
30 | | - try: |
31 | | - import tiktoken |
32 | | - |
33 | | - encoder = tiktoken.get_encoding("cl100k_base") |
34 | | - truncated = [] |
35 | | - |
36 | | - for text in texts: |
37 | | - tokens = encoder.encode(text) |
38 | | - if len(tokens) > max_tokens: |
39 | | - # Truncate to max_tokens and decode back to text |
40 | | - truncated_tokens = tokens[:max_tokens] |
41 | | - truncated_text = encoder.decode(truncated_tokens) |
42 | | - truncated.append(truncated_text) |
43 | | - logger.warning( |
44 | | - f"Truncated text from {len(tokens)} to {max_tokens} tokens " |
45 | | - f"(from {len(text)} to {len(truncated_text)} characters)" |
46 | | - ) |
47 | | - else: |
48 | | - truncated.append(text) |
49 | | - return truncated |
50 | | - |
51 | | - except ImportError: |
52 | | - # Fallback: Conservative character truncation |
53 | | - # Assume worst case: 1.5 tokens per character for code content |
54 | | - char_limit = int(max_tokens / 1.5) |
55 | | - truncated = [] |
56 | | - |
57 | | - for text in texts: |
58 | | - if len(text) > char_limit: |
59 | | - truncated_text = text[:char_limit] |
60 | | - truncated.append(truncated_text) |
61 | | - logger.warning( |
62 | | - f"Truncated text from {len(text)} to {char_limit} characters " |
63 | | - f"(conservative estimate for {max_tokens} tokens)" |
64 | | - ) |
65 | | - else: |
66 | | - truncated.append(text) |
67 | | - return truncated |
68 | | - |
69 | | - |
70 | | -def get_model_token_limit(model_name: str) -> int: |
71 | | - """ |
72 | | - Get token limit for a given embedding model. |
73 | | -
|
74 | | - Args: |
75 | | - model_name: Name of the embedding model |
76 | | -
|
77 | | - Returns: |
78 | | - Token limit for the model, defaults to 512 if unknown |
79 | | - """ |
80 | | - # Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text") |
81 | | - base_model_name = model_name.split(":")[0] |
82 | | - |
83 | | - # Check exact match first |
84 | | - if model_name in EMBEDDING_MODEL_LIMITS: |
85 | | - return EMBEDDING_MODEL_LIMITS[model_name] |
86 | | - |
87 | | - # Check base name match |
88 | | - if base_model_name in EMBEDDING_MODEL_LIMITS: |
89 | | - return EMBEDDING_MODEL_LIMITS[base_model_name] |
90 | | - |
91 | | - # Check partial matches for common patterns |
92 | | - for known_model, limit in EMBEDDING_MODEL_LIMITS.items(): |
93 | | - if known_model in base_model_name or base_model_name in known_model: |
94 | | - return limit |
95 | | - |
96 | | - # Default to conservative 512 token limit |
97 | | - logger.warning(f"Unknown model '{model_name}', using default 512 token limit") |
98 | | - return 512 |
99 | | - |
100 | | - |
101 | 18 | # Set up logger with proper level |
102 | 19 | logger = logging.getLogger(__name__) |
103 | 20 | LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() |
@@ -149,8 +66,26 @@ def get_model_token_limit( |
149 | 66 | if limit: |
150 | 67 | return limit |
151 | 68 |
|
152 | | - # Fallback to known model registry |
153 | | - return EMBEDDING_MODEL_LIMITS.get(model_name, default) |
| 69 | + # Fallback to known model registry with version handling (from PR #154) |
| 70 | + # Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text") |
| 71 | + base_model_name = model_name.split(":")[0] |
| 72 | + |
| 73 | + # Check exact match first |
| 74 | + if model_name in EMBEDDING_MODEL_LIMITS: |
| 75 | + return EMBEDDING_MODEL_LIMITS[model_name] |
| 76 | + |
| 77 | + # Check base name match |
| 78 | + if base_model_name in EMBEDDING_MODEL_LIMITS: |
| 79 | + return EMBEDDING_MODEL_LIMITS[base_model_name] |
| 80 | + |
| 81 | + # Check partial matches for common patterns |
| 82 | + for known_model, limit in EMBEDDING_MODEL_LIMITS.items(): |
| 83 | + if known_model in base_model_name or base_model_name in known_model: |
| 84 | + return limit |
| 85 | + |
| 86 | + # Default fallback |
| 87 | + logger.warning(f"Unknown model '{model_name}', using default {default} token limit") |
| 88 | + return default |
154 | 89 |
|
155 | 90 |
|
156 | 91 | def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]: |
|
0 commit comments