Skip to content

Commit 9340183

Browse files
authored
feat: llama stack run --providers (llamastack#3989)
# What does this PR do? llama stack run --providers takes a list of providers in the format of api1=provider1,api2=provider2 this allows users to run with a simple list of providers. given the architecture of `create_app`, this run config needs to be written to disk. use ~/.llama/distribution/providers-run/run.yaml each time for consistency resolves llamastack#3956 ## Test Plan new unit tests to ensure --providers. Signed-off-by: Charlie Doern <[email protected]>
1 parent b2a5428 commit 9340183

File tree

2 files changed

+143
-1
lines changed

2 files changed

+143
-1
lines changed

src/llama_stack/cli/stack/run.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,28 @@
88
import os
99
import ssl
1010
import subprocess
11+
import sys
1112
from pathlib import Path
1213

1314
import uvicorn
1415
import yaml
16+
from termcolor import cprint
1517

1618
from llama_stack.cli.stack.utils import ImageType
1719
from llama_stack.cli.subcommand import Subcommand
18-
from llama_stack.core.datatypes import StackRunConfig
20+
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
21+
from llama_stack.core.distribution import get_provider_registry
1922
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
23+
from llama_stack.core.storage.datatypes import (
24+
InferenceStoreReference,
25+
KVStoreReference,
26+
ServerStoresConfig,
27+
SqliteKVStoreConfig,
28+
SqliteSqlStoreConfig,
29+
SqlStoreReference,
30+
StorageConfig,
31+
)
32+
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
2033
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
2134
from llama_stack.log import LoggingConfig, get_logger
2235

@@ -68,6 +81,12 @@ def _add_arguments(self):
6881
action="store_true",
6982
help="Start the UI server",
7083
)
84+
self.parser.add_argument(
85+
"--providers",
86+
type=str,
87+
default=None,
88+
help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.",
89+
)
7190

7291
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
7392
import yaml
@@ -93,6 +112,49 @@ def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
93112
config_file = resolve_config_or_distro(args.config, Mode.RUN)
94113
except ValueError as e:
95114
self.parser.error(str(e))
115+
elif args.providers:
116+
provider_list: dict[str, list[Provider]] = dict()
117+
for api_provider in args.providers.split(","):
118+
if "=" not in api_provider:
119+
cprint(
120+
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
121+
color="red",
122+
file=sys.stderr,
123+
)
124+
sys.exit(1)
125+
api, provider_type = api_provider.split("=")
126+
providers_for_api = get_provider_registry().get(Api(api), None)
127+
if providers_for_api is None:
128+
cprint(
129+
f"{api} is not a valid API.",
130+
color="red",
131+
file=sys.stderr,
132+
)
133+
sys.exit(1)
134+
if provider_type in providers_for_api:
135+
provider = Provider(
136+
provider_type=provider_type,
137+
provider_id=provider_type.split("::")[1],
138+
)
139+
provider_list.setdefault(api, []).append(provider)
140+
else:
141+
cprint(
142+
f"{provider} is not a valid provider for the {api} API.",
143+
color="red",
144+
file=sys.stderr,
145+
)
146+
sys.exit(1)
147+
run_config = self._generate_run_config_from_providers(providers=provider_list)
148+
config_dict = run_config.model_dump(mode="json")
149+
150+
# Write config to disk in providers-run directory
151+
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
152+
config_file = distro_dir / "run.yaml"
153+
154+
logger.info(f"Writing generated config to: {config_file}")
155+
with open(config_file, "w") as f:
156+
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
157+
96158
else:
97159
config_file = None
98160

@@ -214,3 +276,44 @@ def _start_ui_development_server(self, stack_server_port: int):
214276
)
215277
except Exception as e:
216278
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
279+
280+
def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]):
281+
apis = list(providers.keys())
282+
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
283+
# need somewhere to put the storage.
284+
os.makedirs(distro_dir, exist_ok=True)
285+
storage = StorageConfig(
286+
backends={
287+
"kv_default": SqliteKVStoreConfig(
288+
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db",
289+
),
290+
"sql_default": SqliteSqlStoreConfig(
291+
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db",
292+
),
293+
},
294+
stores=ServerStoresConfig(
295+
metadata=KVStoreReference(
296+
backend="kv_default",
297+
namespace="registry",
298+
),
299+
inference=InferenceStoreReference(
300+
backend="sql_default",
301+
table_name="inference_store",
302+
),
303+
conversations=SqlStoreReference(
304+
backend="sql_default",
305+
table_name="openai_conversations",
306+
),
307+
prompts=KVStoreReference(
308+
backend="kv_default",
309+
namespace="prompts",
310+
),
311+
),
312+
)
313+
314+
return StackRunConfig(
315+
image_name="providers-run",
316+
apis=apis,
317+
providers=providers,
318+
storage=storage,
319+
)

tests/unit/cli/test_stack_config.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,42 @@ def test_parse_and_maybe_upgrade_config_preserves_custom_external_providers_dir(
229229

230230
# Verify the custom value was preserved
231231
assert str(result.external_providers_dir) == custom_dir
232+
233+
234+
def test_generate_run_config_from_providers():
235+
"""Test that _generate_run_config_from_providers creates a valid config"""
236+
import argparse
237+
238+
from llama_stack.cli.stack.run import StackRun
239+
from llama_stack.core.datatypes import Provider
240+
241+
parser = argparse.ArgumentParser()
242+
subparsers = parser.add_subparsers()
243+
stack_run = StackRun(subparsers)
244+
245+
providers = {
246+
"inference": [
247+
Provider(
248+
provider_type="inline::meta-reference",
249+
provider_id="meta-reference",
250+
)
251+
]
252+
}
253+
254+
config = stack_run._generate_run_config_from_providers(providers=providers)
255+
config_dict = config.model_dump(mode="json")
256+
257+
# Verify basic structure
258+
assert config_dict["image_name"] == "providers-run"
259+
assert "inference" in config_dict["apis"]
260+
assert "inference" in config_dict["providers"]
261+
262+
# Verify storage has all required stores including prompts
263+
assert "storage" in config_dict
264+
stores = config_dict["storage"]["stores"]
265+
assert "prompts" in stores
266+
assert stores["prompts"]["namespace"] == "prompts"
267+
268+
# Verify config can be parsed back
269+
parsed = parse_and_maybe_upgrade_config(config_dict)
270+
assert parsed.image_name == "providers-run"

0 commit comments

Comments
 (0)