Skip to content

Commit fa53229

Browse files
authored
Add to_markdown() to InstanceScores to pretty print output (#1846)
* Add wrapped to_markdown() method to InstanceScores to pretty print the results Signed-off-by: Yoav Katz <[email protected]> * Updated examples Signed-off-by: Yoav Katz <[email protected]> * Updated example to run with existing models Signed-off-by: Yoav Katz <[email protected]> * Fixed max token renaming in WatsonSDK Chat enginer Signed-off-by: Yoav Katz <[email protected]> * Improved summary printout of InstanceScores Signed-off-by: Yoav Katz <[email protected]> * USed new summary feature Signed-off-by: Yoav Katz <[email protected]> * Change lite lllm test to use large model for better and more consistent results Signed-off-by: Yoav Katz <[email protected]> * Update InstanceScore summary to ignore columsn that do not exist. Signed-off-by: Yoav Katz <[email protected]> --------- Signed-off-by: Yoav Katz <[email protected]>
1 parent c5acd23 commit fa53229

File tree

6 files changed

+35
-65
lines changed

6 files changed

+35
-65
lines changed

examples/evaluate_batched_multiclass_classification.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from unitxt import get_logger
66
from unitxt.api import evaluate, load_dataset
77
from unitxt.artifact import fetch_artifact
8-
from unitxt.formats import SystemFormat
98
from unitxt.operators import CollateInstances, Copy, FieldOperator, Rename
109
from unitxt.processors import PostProcess
1110
from unitxt.serializers import MultiTypeSerializer, SingleTypeSerializer
@@ -82,13 +81,11 @@ def serialize(self, value: EnumeratedList, instance: Dict[str, Any]) -> str:
8281

8382
for provider in [
8483
"watsonx",
85-
"bam",
8684
]:
8785
for model_name in [
88-
"granite-3-8b-instruct",
89-
"llama-3-8b-instruct",
86+
"granite-3-3-8b-instruct",
9087
]:
91-
batch_sizes = [30, 20, 10, 5, 1]
88+
batch_sizes = [100, 50, 10, 5, 1]
9289

9390
for batch_size in batch_sizes:
9491
card, _ = fetch_artifact("cards.banking77")
@@ -104,21 +101,6 @@ def serialize(self, value: EnumeratedList, instance: Dict[str, Any]) -> str:
104101
card.task = task
105102
card.templates = [template]
106103
format = "formats.chat_api"
107-
if provider == "bam" and model_name.startswith("llama"):
108-
format = "formats.llama3_instruct"
109-
if provider == "bam" and model_name.startswith("granite"):
110-
format = SystemFormat(
111-
demo_format=(
112-
"{instruction}\\N{source}\\N<|end_of_text|>\n"
113-
"<|start_of_role|>assistant<|end_of_role|>{target}\\N<|end_of_text|>\n"
114-
"<|start_of_role|>user<|end_of_role|>"
115-
),
116-
model_input_format=(
117-
"<|start_of_role|>system<|end_of_role|>{system_prompt}<|end_of_text|>\n"
118-
"<|start_of_role|>user<|end_of_role|>{demos}{instruction}\\N{source}\\N<|end_of_text|>\n"
119-
"<|start_of_role|>assistant<|end_of_role|>"
120-
),
121-
)
122104

123105
dataset = load_dataset(
124106
card=card,
@@ -138,7 +120,7 @@ def serialize(self, value: EnumeratedList, instance: Dict[str, Any]) -> str:
138120
)
139121
"""
140122
We are using a CrossProviderInferenceEngine inference engine that supply api access to provider such as:
141-
watsonx, bam, openai, azure, aws and more.
123+
watsonx, openai, azure, aws and more.
142124
143125
For the arguments these inference engines can receive, please refer to the classes documentation or read
144126
about the the open ai api arguments the CrossProviderInferenceEngine follows.
@@ -148,7 +130,7 @@ def serialize(self, value: EnumeratedList, instance: Dict[str, Any]) -> str:
148130
results = evaluate(predictions=predictions, data=test_dataset)
149131

150132
print(
151-
results.instance_scores.to_df(
133+
results.instance_scores.to_markdown(
152134
columns=[
153135
"source",
154136
"prediction",

examples/evaluate_same_datasets_and_models_with_multiple_providers.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pandas as pd
22
from unitxt.api import evaluate, load_dataset
33
from unitxt.artifact import fetch_artifact
4-
from unitxt.formats import SystemFormat
54

65
df = pd.DataFrame(
76
columns=[
@@ -22,29 +21,11 @@
2221
]:
2322
for model_name in [
2423
"granite-3-8b-instruct",
25-
"llama-3-8b-instruct",
24+
"llama-3-3-70b-instruct",
2625
]:
27-
for format_as_chat_api in [True, False]:
28-
if format_as_chat_api and provider == "watsonx-sdk":
29-
continue
26+
for format_as_chat_api in [True]:
3027
if format_as_chat_api:
3128
format = "formats.chat_api"
32-
else:
33-
if model_name.startswith("llama"):
34-
format = "formats.llama3_instruct"
35-
if model_name.startswith("granite"):
36-
format = SystemFormat(
37-
demo_format=(
38-
"{instruction}\\N{source}\\N<|end_of_text|>\n"
39-
"<|start_of_role|>assistant<|end_of_role|>{target}\\N<|end_of_text|>\n"
40-
"<|start_of_role|>user<|end_of_role|>"
41-
),
42-
model_input_format=(
43-
"<|start_of_role|>system<|end_of_role|>{system_prompt}<|end_of_text|>\n"
44-
"<|start_of_role|>user<|end_of_role|>{demos}{instruction}\\N{source}\\N<|end_of_text|>\n"
45-
"<|start_of_role|>assistant<|end_of_role|>"
46-
),
47-
)
4829
card, _ = fetch_artifact("cards.sst2")
4930

5031
dataset = load_dataset(
@@ -71,16 +52,7 @@
7152
# result_df = pd.json_normalize(evaluated_dataset)
7253
# result_df.to_csv(f"output.csv")
7354
# Print results
74-
print(
75-
results.instance_scores.to_df(
76-
columns=[
77-
"source",
78-
"prediction",
79-
"processed_prediction",
80-
"processed_references",
81-
],
82-
)
83-
)
55+
print(results.instance_scores.summary)
8456

8557
global_scores = results.global_scores
8658
df.loc[len(df)] = [

examples/ner_evaluation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
# )
4242
# Change to this to infer with external APIs:
4343

44-
model = CrossProviderInferenceEngine(model="llama-3-8b-instruct", provider="watsonx")
44+
model = CrossProviderInferenceEngine(model="llama-3-3-70b-instruct", provider="watsonx")
4545
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "bam"]
4646

4747

@@ -57,7 +57,7 @@
5757

5858
print("Instance Results:")
5959
print(
60-
results.instance_scores.to_df(
60+
results.instance_scores.to_markdown(
6161
columns=[
6262
"text",
6363
"prediction",
@@ -66,5 +66,5 @@
6666
"score",
6767
"score_name",
6868
]
69-
).to_markdown()
69+
)
7070
)

src/unitxt/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3677,7 +3677,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
36773677

36783678
_provider_param_renaming = {
36793679
"bam": {"max_tokens": "max_new_tokens", "model": "model_name"},
3680-
"watsonx-sdk": {"max_tokens": "max_new_tokens", "model": "model_name"},
3680+
"watsonx-sdk": {"model": "model_name"},
36813681
"rits": {"model": "model_name"},
36823682
}
36833683

src/unitxt/metric_utils.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import re
3+
import textwrap
34
from collections import defaultdict
45
from functools import lru_cache
56
from statistics import mean
@@ -683,21 +684,36 @@ def to_df(self, flatten=True, columns=None):
683684
return df[columns]
684685
return df
685686

687+
def _to_markdown(self, df, max_col_width=30, **kwargs):
688+
def wrap_column(series, max_width=30):
689+
"""Wraps string values in a Pandas Series to a maximum width."""
690+
return series.apply(lambda x: textwrap.fill(str(x), width=max_width))
691+
692+
wrapped_df = df.copy()
693+
for col in wrapped_df.columns:
694+
wrapped_df[col] = wrap_column(wrapped_df[col], max_col_width)
695+
return wrapped_df.to_markdown(**kwargs)
696+
697+
def to_markdown(self, flatten=True, columns=None, max_col_width=30, **kwargs):
698+
return self._to_markdown(self.to_df(flatten, columns), max_col_width, **kwargs)
699+
686700
@property
687701
def summary(self):
688-
return to_pretty_string(
702+
return self._to_markdown(
689703
self.to_df()
690704
.head()
691705
.drop(
692706
columns=[
693707
"metadata",
694708
"media",
695-
"data_classification_policy",
696709
"groups",
697710
"subset",
698-
]
699-
),
700-
float_format=".2g",
711+
"demos",
712+
"metrics",
713+
"postprocessors",
714+
],
715+
errors="ignore",
716+
)
701717
)
702718

703719
def __repr__(self):

tests/inference/test_inference_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def test_watsonx_inference_with_images(self):
369369

370370
def test_lite_llm_inference_engine(self):
371371
model = LiteLLMInferenceEngine(
372-
model="watsonx/meta-llama/llama-3-2-1b-instruct",
372+
model="watsonx/meta-llama/llama-3-3-70b-instruct",
373373
max_tokens=2,
374374
temperature=0,
375375
top_p=1,
@@ -379,11 +379,11 @@ def test_lite_llm_inference_engine(self):
379379
dataset = get_text_dataset(format="formats.chat_api")
380380
predictions = model(dataset)
381381

382-
self.assertListEqual(predictions, ["100", "```\n"])
382+
self.assertListEqual(predictions, ["7", "2"])
383383

384384
def test_lite_llm_inference_engine_without_task_data_not_failing(self):
385385
LiteLLMInferenceEngine(
386-
model="watsonx/meta-llama/llama-3-2-1b-instruct",
386+
model="watsonx/meta-llama/llama-3-3-70b-instruct",
387387
max_tokens=2,
388388
temperature=0,
389389
top_p=1,

0 commit comments

Comments
 (0)