Skip to content

Commit eb60722

Browse files
committed
style: apply ruff formatting and restore PR yichuan-w#154 version handling
- Remove duplicate truncate_to_token_limit and get_model_token_limit functions - Restore version handling logic (model:latest -> model) from PR yichuan-w#154 - Restore partial matching fallback for model name variations - Apply ruff formatting to all modified files - All 11 token truncation tests passing
1 parent ef3b889 commit eb60722

File tree

4 files changed

+123
-176
lines changed

4 files changed

+123
-176
lines changed

packages/leann-core/src/leann/chunking_utils.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,7 @@ def create_ast_chunks(
265265
# Merge document metadata + astchunk metadata
266266
combined_metadata = {**doc_metadata, **astchunk_metadata}
267267

268-
all_chunks.append({
269-
"text": chunk_text.strip(),
270-
"metadata": combined_metadata
271-
})
268+
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
272269

273270
logger.info(
274271
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
@@ -320,18 +317,12 @@ def create_traditional_chunks(
320317
nodes = node_parser.get_nodes_from_documents([doc])
321318
if nodes:
322319
for node in nodes:
323-
result.append({
324-
"text": node.get_content(),
325-
"metadata": doc_metadata
326-
})
320+
result.append({"text": node.get_content(), "metadata": doc_metadata})
327321
except Exception as e:
328322
logger.error(f"Traditional chunking failed for document: {e}")
329323
content = doc.get_content()
330324
if content and content.strip():
331-
result.append({
332-
"text": content.strip(),
333-
"metadata": doc_metadata
334-
})
325+
result.append({"text": content.strip(), "metadata": doc_metadata})
335326

336327
return result
337328

packages/leann-core/src/leann/embedding_compute.py

Lines changed: 20 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -15,89 +15,6 @@
1515

1616
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
1717

18-
19-
def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
20-
"""
21-
Truncate texts to token limit using tiktoken or conservative character truncation.
22-
23-
Args:
24-
texts: List of texts to truncate
25-
max_tokens: Maximum tokens allowed per text
26-
27-
Returns:
28-
List of truncated texts that should fit within token limit
29-
"""
30-
try:
31-
import tiktoken
32-
33-
encoder = tiktoken.get_encoding("cl100k_base")
34-
truncated = []
35-
36-
for text in texts:
37-
tokens = encoder.encode(text)
38-
if len(tokens) > max_tokens:
39-
# Truncate to max_tokens and decode back to text
40-
truncated_tokens = tokens[:max_tokens]
41-
truncated_text = encoder.decode(truncated_tokens)
42-
truncated.append(truncated_text)
43-
logger.warning(
44-
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
45-
f"(from {len(text)} to {len(truncated_text)} characters)"
46-
)
47-
else:
48-
truncated.append(text)
49-
return truncated
50-
51-
except ImportError:
52-
# Fallback: Conservative character truncation
53-
# Assume worst case: 1.5 tokens per character for code content
54-
char_limit = int(max_tokens / 1.5)
55-
truncated = []
56-
57-
for text in texts:
58-
if len(text) > char_limit:
59-
truncated_text = text[:char_limit]
60-
truncated.append(truncated_text)
61-
logger.warning(
62-
f"Truncated text from {len(text)} to {char_limit} characters "
63-
f"(conservative estimate for {max_tokens} tokens)"
64-
)
65-
else:
66-
truncated.append(text)
67-
return truncated
68-
69-
70-
def get_model_token_limit(model_name: str) -> int:
71-
"""
72-
Get token limit for a given embedding model.
73-
74-
Args:
75-
model_name: Name of the embedding model
76-
77-
Returns:
78-
Token limit for the model, defaults to 512 if unknown
79-
"""
80-
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
81-
base_model_name = model_name.split(":")[0]
82-
83-
# Check exact match first
84-
if model_name in EMBEDDING_MODEL_LIMITS:
85-
return EMBEDDING_MODEL_LIMITS[model_name]
86-
87-
# Check base name match
88-
if base_model_name in EMBEDDING_MODEL_LIMITS:
89-
return EMBEDDING_MODEL_LIMITS[base_model_name]
90-
91-
# Check partial matches for common patterns
92-
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
93-
if known_model in base_model_name or base_model_name in known_model:
94-
return limit
95-
96-
# Default to conservative 512 token limit
97-
logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
98-
return 512
99-
100-
10118
# Set up logger with proper level
10219
logger = logging.getLogger(__name__)
10320
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -149,8 +66,26 @@ def get_model_token_limit(
14966
if limit:
15067
return limit
15168

152-
# Fallback to known model registry
153-
return EMBEDDING_MODEL_LIMITS.get(model_name, default)
69+
# Fallback to known model registry with version handling (from PR #154)
70+
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
71+
base_model_name = model_name.split(":")[0]
72+
73+
# Check exact match first
74+
if model_name in EMBEDDING_MODEL_LIMITS:
75+
return EMBEDDING_MODEL_LIMITS[model_name]
76+
77+
# Check base name match
78+
if base_model_name in EMBEDDING_MODEL_LIMITS:
79+
return EMBEDDING_MODEL_LIMITS[base_model_name]
80+
81+
# Check partial matches for common patterns
82+
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
83+
if known_model in base_model_name or base_model_name in known_model:
84+
return limit
85+
86+
# Default fallback
87+
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
88+
return default
15489

15590

15691
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:

0 commit comments

Comments
 (0)