Skip to content

Commit 6b8a8c1

Browse files
authored
fix: Safety in starter (llamastack#2731)
- fireworks, together do not support Llama-guard 3 8b model anymore - Need to default to ollama - current safety shields logic was not correct since the shield_id was the provider ( which had duplicates ) - Followed similar logic to models Note: Seems a bit over-engineered but this can now be extended to other providers and fits in the overall mechanism of how env_vars are used to manage starter. ### How to test ``` ENABLE_OLLAMA=ollama ENABLE_FIREWORKS=fireworks SAFETY_MODEL=llama-guard3:1b pytest -s -v tests/integration/ --stack-config starter -k 'not(supervised_fine_tune or builtin_tool_code or safety_with_image or code_interpreter_for or rag_and_code or truncation or register_and_unregister)' --text-model fireworks/meta-llama/Llama-3.3-70B-Instruct --vision-model fireworks/meta-llama/Llama-4-Scout-17B-16E-Instruct --safety-shield llama-guard3:1b --embedding-model all-MiniLM-L6-v2 ``` ### Related but not obvious in this PR In the llama-stack-ops repo, we run tests before publishing packages and docker containers. The actions in that repo were using the fireworks / together distros ( which are non-existent ) So need to update that to run with `starter` and use `ollama` specifically for safety.
1 parent 6ad22c2 commit 6b8a8c1

File tree

9 files changed

+104
-195
lines changed

9 files changed

+104
-195
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ jobs:
8989
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
9090
--text-model="ollama/llama3.2:3b-instruct-fp16" \
9191
--embedding-model=all-MiniLM-L6-v2 \
92-
--safety-shield=ollama \
92+
--safety-shield=$SAFETY_MODEL \
9393
--color=yes \
9494
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
9595

llama_stack/providers/remote/inference/ollama/models.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@
1212
build_model_entry,
1313
)
1414

15+
SAFETY_MODELS_ENTRIES = [
16+
# The Llama Guard models don't have their full fp16 versions
17+
# so we are going to alias their default version to the canonical SKU
18+
build_hf_repo_model_entry(
19+
"llama-guard3:8b",
20+
CoreModelId.llama_guard_3_8b.value,
21+
),
22+
build_hf_repo_model_entry(
23+
"llama-guard3:1b",
24+
CoreModelId.llama_guard_3_1b.value,
25+
),
26+
]
27+
1528
MODEL_ENTRIES = [
1629
build_hf_repo_model_entry(
1730
"llama3.1:8b-instruct-fp16",
@@ -73,16 +86,6 @@
7386
"llama3.3:70b",
7487
CoreModelId.llama3_3_70b_instruct.value,
7588
),
76-
# The Llama Guard models don't have their full fp16 versions
77-
# so we are going to alias their default version to the canonical SKU
78-
build_hf_repo_model_entry(
79-
"llama-guard3:8b",
80-
CoreModelId.llama_guard_3_8b.value,
81-
),
82-
build_hf_repo_model_entry(
83-
"llama-guard3:1b",
84-
CoreModelId.llama_guard_3_1b.value,
85-
),
8689
ProviderModelEntry(
8790
provider_model_id="all-minilm:l6-v2",
8891
aliases=["all-minilm"],
@@ -100,4 +103,4 @@
100103
"context_length": 8192,
101104
},
102105
),
103-
]
106+
] + SAFETY_MODELS_ENTRIES

llama_stack/templates/nvidia/nvidia.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
6868
),
6969
]
7070

71-
default_models = get_model_registry(available_models)
71+
default_models, _ = get_model_registry(available_models)
7272
return DistributionTemplate(
7373
name="nvidia",
7474
distro_type="self_hosted",

llama_stack/templates/open-benchmark/open_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def get_distribution_template() -> DistributionTemplate:
146146
),
147147
]
148148

149-
default_models = get_model_registry(available_models) + [
149+
models, _ = get_model_registry(available_models)
150+
default_models = models + [
150151
ModelInput(
151152
model_id="meta-llama/Llama-3.3-70B-Instruct",
152153
provider_id="groq",

llama_stack/templates/starter/run.yaml

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,24 +1171,8 @@ models:
11711171
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
11721172
model_type: embedding
11731173
shields:
1174-
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__}
1175-
provider_id: llama-guard
1176-
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
1177-
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
1178-
provider_id: llama-guard
1179-
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
1180-
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
1181-
provider_id: llama-guard
1182-
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
1183-
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
1184-
provider_id: llama-guard
1185-
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
1186-
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
1187-
provider_id: llama-guard
1188-
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
1189-
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
1190-
provider_id: llama-guard
1191-
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
1174+
- shield_id: ${env.SAFETY_MODEL:=__disabled__}
1175+
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
11921176
vector_dbs: []
11931177
datasets: []
11941178
scoring_fns: []

llama_stack/templates/starter/starter.py

Lines changed: 18 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
ModelInput,
1313
Provider,
1414
ProviderSpec,
15-
ShieldInput,
1615
ToolGroupInput,
1716
)
1817
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@@ -32,75 +31,39 @@
3231
from llama_stack.providers.remote.inference.anthropic.models import (
3332
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
3433
)
35-
from llama_stack.providers.remote.inference.anthropic.models import (
36-
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
37-
)
3834
from llama_stack.providers.remote.inference.bedrock.models import (
3935
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
4036
)
41-
from llama_stack.providers.remote.inference.bedrock.models import (
42-
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
43-
)
4437
from llama_stack.providers.remote.inference.cerebras.models import (
4538
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
4639
)
47-
from llama_stack.providers.remote.inference.cerebras.models import (
48-
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
49-
)
5040
from llama_stack.providers.remote.inference.databricks.databricks import (
5141
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
5242
)
53-
from llama_stack.providers.remote.inference.databricks.databricks import (
54-
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
55-
)
5643
from llama_stack.providers.remote.inference.fireworks.models import (
5744
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
5845
)
59-
from llama_stack.providers.remote.inference.fireworks.models import (
60-
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
61-
)
6246
from llama_stack.providers.remote.inference.gemini.models import (
6347
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
6448
)
65-
from llama_stack.providers.remote.inference.gemini.models import (
66-
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
67-
)
6849
from llama_stack.providers.remote.inference.groq.models import (
6950
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
7051
)
71-
from llama_stack.providers.remote.inference.groq.models import (
72-
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
73-
)
7452
from llama_stack.providers.remote.inference.nvidia.models import (
7553
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
7654
)
77-
from llama_stack.providers.remote.inference.nvidia.models import (
78-
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
79-
)
8055
from llama_stack.providers.remote.inference.openai.models import (
8156
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
8257
)
83-
from llama_stack.providers.remote.inference.openai.models import (
84-
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
85-
)
8658
from llama_stack.providers.remote.inference.runpod.runpod import (
8759
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
8860
)
89-
from llama_stack.providers.remote.inference.runpod.runpod import (
90-
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
91-
)
9261
from llama_stack.providers.remote.inference.sambanova.models import (
9362
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
9463
)
95-
from llama_stack.providers.remote.inference.sambanova.models import (
96-
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
97-
)
9864
from llama_stack.providers.remote.inference.together.models import (
9965
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
10066
)
101-
from llama_stack.providers.remote.inference.together.models import (
102-
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
103-
)
10467
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
10568
from llama_stack.providers.remote.vector_io.pgvector.config import (
10669
PGVectorVectorIOConfig,
@@ -111,6 +74,7 @@
11174
DistributionTemplate,
11275
RunConfigSettings,
11376
get_model_registry,
77+
get_shield_registry,
11478
)
11579

11680

@@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
164128
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
165129
"""Get model entries for a specific provider type."""
166130
safety_model_entries_map = {
167-
"openai": OPENAI_SAFETY_MODELS_ENTRIES,
168-
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
169-
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
170-
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
171-
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
172-
"groq": GROQ_SAFETY_MODELS_ENTRIES,
173-
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
174-
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
175-
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
176-
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
177-
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
178-
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
179-
}
180-
181-
# Special handling for providers with dynamic model entries
182-
if provider_type == "ollama":
183-
return [
131+
"ollama": [
184132
ProviderModelEntry(
185-
provider_model_id="llama-guard3:1b",
133+
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
186134
model_type=ModelType.llm,
187135
),
188-
]
136+
],
137+
}
189138

190139
return safety_model_entries_map.get(provider_type, [])
191140

@@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
246195

247196

248197
# build a list of shields for all possible providers
249-
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]:
250-
shields = []
198+
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
199+
available_models = {}
251200
for provider in providers:
252201
provider_type = provider.provider_type.split("::")[1]
253202
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
254203
if len(safety_model_entries) == 0:
255204
continue
256-
if provider.provider_id:
257-
shield_id = provider.provider_id
258-
else:
259-
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
260-
for safety_model_entry in safety_model_entries:
261-
print(f"provider.provider_id: {provider.provider_id}")
262-
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}")
263-
shields.append(
264-
ShieldInput(
265-
provider_id="llama-guard",
266-
shield_id=shield_id,
267-
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
268-
)
269-
)
270-
return shields
205+
206+
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
207+
provider_id = f"${{env.{env_var}:=__disabled__}}"
208+
209+
available_models[provider_id] = safety_model_entries
210+
211+
return available_models
271212

272213

273214
def get_distribution_template() -> DistributionTemplate:
@@ -307,8 +248,6 @@ def get_distribution_template() -> DistributionTemplate:
307248
),
308249
]
309250

310-
shields = get_shields_for_providers(remote_inference_providers)
311-
312251
providers = {
313252
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
314253
"vector_io": ([p.provider_type for p in vector_io_providers]),
@@ -361,7 +300,10 @@ def get_distribution_template() -> DistributionTemplate:
361300
},
362301
)
363302

364-
default_models = get_model_registry(available_models)
303+
default_models, ids_conflict_in_models = get_model_registry(available_models)
304+
305+
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
306+
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
365307

366308
return DistributionTemplate(
367309
name=name,

llama_stack/templates/template.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
def get_model_registry(
3939
available_models: dict[str, list[ProviderModelEntry]],
40-
) -> list[ModelInput]:
40+
) -> tuple[list[ModelInput], bool]:
4141
models = []
4242

4343
# check for conflicts in model ids
@@ -74,7 +74,50 @@ def get_model_registry(
7474
metadata=entry.metadata,
7575
)
7676
)
77-
return models
77+
return models, ids_conflict
78+
79+
80+
def get_shield_registry(
81+
available_safety_models: dict[str, list[ProviderModelEntry]],
82+
ids_conflict_in_models: bool,
83+
) -> list[ShieldInput]:
84+
shields = []
85+
86+
# check for conflicts in shield ids
87+
all_ids = set()
88+
ids_conflict = False
89+
90+
for _, entries in available_safety_models.items():
91+
for entry in entries:
92+
ids = [entry.provider_model_id] + entry.aliases
93+
for model_id in ids:
94+
if model_id in all_ids:
95+
ids_conflict = True
96+
rich.print(
97+
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
98+
)
99+
break
100+
all_ids.update(ids)
101+
if ids_conflict:
102+
break
103+
if ids_conflict:
104+
break
105+
106+
for provider_id, entries in available_safety_models.items():
107+
for entry in entries:
108+
ids = [entry.provider_model_id] + entry.aliases
109+
for model_id in ids:
110+
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
111+
shields.append(
112+
ShieldInput(
113+
shield_id=identifier,
114+
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
115+
if ids_conflict_in_models
116+
else entry.provider_model_id,
117+
)
118+
)
119+
120+
return shields
78121

79122

80123
class DefaultModel(BaseModel):

llama_stack/templates/watsonx/watsonx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
6969
},
7070
)
7171

72-
default_models = get_model_registry(available_models)
72+
default_models, _ = get_model_registry(available_models)
7373
return DistributionTemplate(
7474
name="watsonx",
7575
distro_type="remote_hosted",

0 commit comments

Comments
 (0)