|
12 | 12 | ModelInput, |
13 | 13 | Provider, |
14 | 14 | ProviderSpec, |
15 | | - ShieldInput, |
16 | 15 | ToolGroupInput, |
17 | 16 | ) |
18 | 17 | from llama_stack.distribution.utils.dynamic import instantiate_class_type |
|
32 | 31 | from llama_stack.providers.remote.inference.anthropic.models import ( |
33 | 32 | MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, |
34 | 33 | ) |
35 | | -from llama_stack.providers.remote.inference.anthropic.models import ( |
36 | | - SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES, |
37 | | -) |
38 | 34 | from llama_stack.providers.remote.inference.bedrock.models import ( |
39 | 35 | MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES, |
40 | 36 | ) |
41 | | -from llama_stack.providers.remote.inference.bedrock.models import ( |
42 | | - SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES, |
43 | | -) |
44 | 37 | from llama_stack.providers.remote.inference.cerebras.models import ( |
45 | 38 | MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES, |
46 | 39 | ) |
47 | | -from llama_stack.providers.remote.inference.cerebras.models import ( |
48 | | - SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES, |
49 | | -) |
50 | 40 | from llama_stack.providers.remote.inference.databricks.databricks import ( |
51 | 41 | MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES, |
52 | 42 | ) |
53 | | -from llama_stack.providers.remote.inference.databricks.databricks import ( |
54 | | - SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES, |
55 | | -) |
56 | 43 | from llama_stack.providers.remote.inference.fireworks.models import ( |
57 | 44 | MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, |
58 | 45 | ) |
59 | | -from llama_stack.providers.remote.inference.fireworks.models import ( |
60 | | - SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES, |
61 | | -) |
62 | 46 | from llama_stack.providers.remote.inference.gemini.models import ( |
63 | 47 | MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, |
64 | 48 | ) |
65 | | -from llama_stack.providers.remote.inference.gemini.models import ( |
66 | | - SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES, |
67 | | -) |
68 | 49 | from llama_stack.providers.remote.inference.groq.models import ( |
69 | 50 | MODEL_ENTRIES as GROQ_MODEL_ENTRIES, |
70 | 51 | ) |
71 | | -from llama_stack.providers.remote.inference.groq.models import ( |
72 | | - SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES, |
73 | | -) |
74 | 52 | from llama_stack.providers.remote.inference.nvidia.models import ( |
75 | 53 | MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES, |
76 | 54 | ) |
77 | | -from llama_stack.providers.remote.inference.nvidia.models import ( |
78 | | - SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES, |
79 | | -) |
80 | 55 | from llama_stack.providers.remote.inference.openai.models import ( |
81 | 56 | MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, |
82 | 57 | ) |
83 | | -from llama_stack.providers.remote.inference.openai.models import ( |
84 | | - SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES, |
85 | | -) |
86 | 58 | from llama_stack.providers.remote.inference.runpod.runpod import ( |
87 | 59 | MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES, |
88 | 60 | ) |
89 | | -from llama_stack.providers.remote.inference.runpod.runpod import ( |
90 | | - SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES, |
91 | | -) |
92 | 61 | from llama_stack.providers.remote.inference.sambanova.models import ( |
93 | 62 | MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, |
94 | 63 | ) |
95 | | -from llama_stack.providers.remote.inference.sambanova.models import ( |
96 | | - SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES, |
97 | | -) |
98 | 64 | from llama_stack.providers.remote.inference.together.models import ( |
99 | 65 | MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, |
100 | 66 | ) |
101 | | -from llama_stack.providers.remote.inference.together.models import ( |
102 | | - SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES, |
103 | | -) |
104 | 67 | from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig |
105 | 68 | from llama_stack.providers.remote.vector_io.pgvector.config import ( |
106 | 69 | PGVectorVectorIOConfig, |
|
111 | 74 | DistributionTemplate, |
112 | 75 | RunConfigSettings, |
113 | 76 | get_model_registry, |
| 77 | + get_shield_registry, |
114 | 78 | ) |
115 | 79 |
|
116 | 80 |
|
@@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt |
164 | 128 | def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]: |
165 | 129 | """Get model entries for a specific provider type.""" |
166 | 130 | safety_model_entries_map = { |
167 | | - "openai": OPENAI_SAFETY_MODELS_ENTRIES, |
168 | | - "fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES, |
169 | | - "together": TOGETHER_SAFETY_MODELS_ENTRIES, |
170 | | - "anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES, |
171 | | - "gemini": GEMINI_SAFETY_MODELS_ENTRIES, |
172 | | - "groq": GROQ_SAFETY_MODELS_ENTRIES, |
173 | | - "sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES, |
174 | | - "cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES, |
175 | | - "bedrock": BEDROCK_SAFETY_MODELS_ENTRIES, |
176 | | - "databricks": DATABRICKS_SAFETY_MODELS_ENTRIES, |
177 | | - "nvidia": NVIDIA_SAFETY_MODELS_ENTRIES, |
178 | | - "runpod": RUNPOD_SAFETY_MODELS_ENTRIES, |
179 | | - } |
180 | | - |
181 | | - # Special handling for providers with dynamic model entries |
182 | | - if provider_type == "ollama": |
183 | | - return [ |
| 131 | + "ollama": [ |
184 | 132 | ProviderModelEntry( |
185 | | - provider_model_id="llama-guard3:1b", |
| 133 | + provider_model_id="${env.SAFETY_MODEL:=__disabled__}", |
186 | 134 | model_type=ModelType.llm, |
187 | 135 | ), |
188 | | - ] |
| 136 | + ], |
| 137 | + } |
189 | 138 |
|
190 | 139 | return safety_model_entries_map.get(provider_type, []) |
191 | 140 |
|
@@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro |
246 | 195 |
|
247 | 196 |
|
248 | 197 | # build a list of shields for all possible providers |
249 | | -def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]: |
250 | | - shields = [] |
| 198 | +def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]: |
| 199 | + available_models = {} |
251 | 200 | for provider in providers: |
252 | 201 | provider_type = provider.provider_type.split("::")[1] |
253 | 202 | safety_model_entries = _get_model_safety_entries_for_provider(provider_type) |
254 | 203 | if len(safety_model_entries) == 0: |
255 | 204 | continue |
256 | | - if provider.provider_id: |
257 | | - shield_id = provider.provider_id |
258 | | - else: |
259 | | - raise ValueError(f"Provider {provider.provider_type} has no provider_id") |
260 | | - for safety_model_entry in safety_model_entries: |
261 | | - print(f"provider.provider_id: {provider.provider_id}") |
262 | | - print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}") |
263 | | - shields.append( |
264 | | - ShieldInput( |
265 | | - provider_id="llama-guard", |
266 | | - shield_id=shield_id, |
267 | | - provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}", |
268 | | - ) |
269 | | - ) |
270 | | - return shields |
| 205 | + |
| 206 | + env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}" |
| 207 | + provider_id = f"${{env.{env_var}:=__disabled__}}" |
| 208 | + |
| 209 | + available_models[provider_id] = safety_model_entries |
| 210 | + |
| 211 | + return available_models |
271 | 212 |
|
272 | 213 |
|
273 | 214 | def get_distribution_template() -> DistributionTemplate: |
@@ -307,8 +248,6 @@ def get_distribution_template() -> DistributionTemplate: |
307 | 248 | ), |
308 | 249 | ] |
309 | 250 |
|
310 | | - shields = get_shields_for_providers(remote_inference_providers) |
311 | | - |
312 | 251 | providers = { |
313 | 252 | "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), |
314 | 253 | "vector_io": ([p.provider_type for p in vector_io_providers]), |
@@ -361,7 +300,10 @@ def get_distribution_template() -> DistributionTemplate: |
361 | 300 | }, |
362 | 301 | ) |
363 | 302 |
|
364 | | - default_models = get_model_registry(available_models) |
| 303 | + default_models, ids_conflict_in_models = get_model_registry(available_models) |
| 304 | + |
| 305 | + available_safety_models = get_safety_models_for_providers(remote_inference_providers) |
| 306 | + shields = get_shield_registry(available_safety_models, ids_conflict_in_models) |
365 | 307 |
|
366 | 308 | return DistributionTemplate( |
367 | 309 | name=name, |
|
0 commit comments