Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ jina grep serve stop # stop when done

## Local mode

`jina embed`, `jina rerank`, and `jina dedup` support `--local` to run on Apple Silicon via the jina-grep embedding server instead of the Jina API. No API key needed.
`jina embed`, `jina rerank`, `jina classify`, and `jina dedup` support `--local` to run on Apple Silicon via the jina-grep embedding server instead of the Jina API. No API key needed.

```bash
# Start the local server first
Expand All @@ -217,6 +217,9 @@ cat texts.txt | jina embed --local --json
# Local reranking (cosine similarity on local embeddings)
cat docs.txt | jina rerank --local "machine learning"

# Local classification (cosine similarity on local embeddings)
jina classify --local "this is great" --labels positive,negative

# Local deduplication
cat items.txt | jina dedup --local
```
Expand Down
32 changes: 32 additions & 0 deletions jina_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,38 @@ def local_embed(
return data.get("data", [])


def local_classify(
texts: list[str],
labels: list[str],
model: str = "jina-embeddings-v5-nano",
task: str = "text-matching",
) -> list[dict]:
"""Classify texts into labels using local embeddings and cosine similarity."""
all_texts = texts + labels
embeddings_data = local_embed(all_texts, model=model, task=task)
embeddings = [item["embedding"] for item in embeddings_data]

text_embs = embeddings[:len(texts)]
label_embs = embeddings[len(texts):]

results = []
for i, text_emb in enumerate(text_embs):
best_label = labels[0]
best_score = -1.0
for j, label_emb in enumerate(label_embs):
score = _cosine_similarity(text_emb, label_emb)
if score > best_score:
best_score = score
best_label = labels[j]
results.append({
"index": i,
"prediction": best_label,
"score": best_score,
})

return results


def local_rerank(
query: str,
documents: list[str],
Expand Down
15 changes: 10 additions & 5 deletions jina_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,12 @@ def dedup(ctx, k, local, as_json, api_key):
@click.argument("text", nargs=-1)
@click.option("--labels", required=True, multiple=True,
help="Labels for classification (comma-separated or repeated --labels)")
@click.option("--model", default=None, help="Model name (default: jina-embeddings-v5-text-small)")
@click.option("--model", default=None, help="Model name (default: jina-embeddings-v5-text-small, or v5-nano with --local)")
@click.option("--local", is_flag=True, help="Use local MLX server (requires: jina-grep serve start)")
@click.option("--json", "as_json", is_flag=True, help="Output as JSON")
@click.option("--api-key", default=None, help="Jina API key")
@click.pass_context
def classify(ctx, text, labels, model, as_json, api_key):
def classify(ctx, text, labels, model, local, as_json, api_key):
"""Classify text into labels.

Input from arguments or stdin (one text per line).
Expand All @@ -424,6 +425,7 @@ def classify(ctx, text, labels, model, as_json, api_key):
jina classify "this is great" --labels positive,negative
echo "stock price rose" | jina classify --labels business,sports,tech
jina classify "text1" "text2" --labels cat1 --labels cat2 --labels cat3
jina classify --local "this is great" --labels positive,negative
"""
key = api_key or ctx.obj.get("api_key")

Expand All @@ -450,10 +452,13 @@ def classify(ctx, text, labels, model, as_json, api_key):
"Fix: --labels positive,negative", err=True)
sys.exit(EXIT_USER_ERROR)

_model = model or "jina-embeddings-v5-text-small"

try:
result = api.classify(texts, parsed_labels, api_key=key, model=_model)
if local:
_model = model or "jina-embeddings-v5-nano"
result = api.local_classify(texts, parsed_labels, model=_model)
else:
_model = model or "jina-embeddings-v5-text-small"
result = api.classify(texts, parsed_labels, api_key=key, model=_model)
click.echo(utils.format_classify_results(result, as_json=as_json))
except Exception as e:
utils.handle_http_error(e)
Expand Down
72 changes: 72 additions & 0 deletions tests/test_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Unit tests for local mode functions (no API key or server needed)."""

from unittest.mock import patch

from jina_cli.api import local_classify, _cosine_similarity


class TestCosineSimlarity:
def test_identical_vectors(self):
assert abs(_cosine_similarity([1, 0, 0], [1, 0, 0]) - 1.0) < 1e-6

def test_orthogonal_vectors(self):
assert abs(_cosine_similarity([1, 0], [0, 1])) < 1e-6

def test_zero_vector(self):
assert _cosine_similarity([0, 0], [1, 1]) == 0.0


class TestLocalClassify:
def test_single_text(self):
fake_embeddings = [
{"embedding": [0.9, 0.1, 0.0]}, # "I love this" - text
{"embedding": [0.8, 0.2, 0.0]}, # "positive" - label (close)
{"embedding": [0.0, 0.1, 0.9]}, # "negative" - label (far)
]

with patch("jina_cli.api.local_embed", return_value=fake_embeddings):
result = local_classify(
texts=["I love this"],
labels=["positive", "negative"],
)

assert len(result) == 1
assert result[0]["prediction"] == "positive"
assert result[0]["score"] > 0.5
assert result[0]["index"] == 0

def test_multiple_texts(self):
fake_embeddings = [
{"embedding": [0.9, 0.1]}, # text 1 - closer to label 1
{"embedding": [0.1, 0.9]}, # text 2 - closer to label 2
{"embedding": [0.8, 0.2]}, # label "sports"
{"embedding": [0.2, 0.8]}, # label "politics"
]

with patch("jina_cli.api.local_embed", return_value=fake_embeddings):
result = local_classify(
texts=["goal scored", "election results"],
labels=["sports", "politics"],
)

assert len(result) == 2
assert result[0]["prediction"] == "sports"
assert result[1]["prediction"] == "politics"

def test_result_format(self):
"""Results should have index, prediction, score keys."""
fake_embeddings = [
{"embedding": [1.0, 0.0]},
{"embedding": [0.9, 0.1]},
]

with patch("jina_cli.api.local_embed", return_value=fake_embeddings):
result = local_classify(
texts=["test"],
labels=["label1"],
)

assert "index" in result[0]
assert "prediction" in result[0]
assert "score" in result[0]
assert result[0]["prediction"] == "label1"
Loading