Skip to content

Commit

Permalink
backend: Update make first-run to read existing configuration variabl…
Browse files Browse the repository at this point in the history
…es (#866)

* backend: Read existing .env file when running make first-run

* Read values from existing yaml config file if it exists

* backend: Prompt user to overwrite config file in make first-run if it is invalid

* backend: Ensure yaml secrets are read in make first-run

* backend: Refactor reading of existing yaml config in make first-run
  • Loading branch information
scottmx81 authored Dec 5, 2024
1 parent 827e3ec commit b638c9e
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 43 deletions.
36 changes: 28 additions & 8 deletions src/backend/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import argparse

from backend.cli.constants import COMMUNITY_TOOLS, TOOLS
from dotenv import dotenv_values

from backend.cli.constants import (
COMMUNITY_TOOLS,
CONFIG_FILE_PATH,
SECRETS_FILE_PATH,
TOOLS,
)
from backend.cli.prompts import (
PROMPTS,
community_tools_prompt,
deployment_prompt,
overwrite_config_prompt,
overwrite_secrets_prompt,
review_variables_prompt,
select_deployments_prompt,
tool_prompt,
Expand All @@ -16,12 +25,11 @@
write_env_file,
write_template_config_files,
)
from backend.cli.utils import show_examples, show_welcome_message, wrap_up
from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as MANAGED_DEPLOYMENTS_SETUP,
)
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
from backend.cli.utils import (
process_existing_yaml_config,
show_examples,
show_welcome_message,
wrap_up,
)


Expand All @@ -33,7 +41,9 @@ def start():

show_welcome_message()

secrets = {}
secrets = dotenv_values()
process_existing_yaml_config(secrets, CONFIG_FILE_PATH, overwrite_config_prompt)
process_existing_yaml_config(secrets, SECRETS_FILE_PATH, overwrite_secrets_prompt)

# SET UP ENVIRONMENT
for _, prompt in PROMPTS.items():
Expand All @@ -49,6 +59,16 @@ def start():
tool_prompt(secrets, name, configs)

# SET UP ENVIRONMENT FOR DEPLOYMENTS

# These imports run code that uses settings. Local imports are used so that we can
# validate any existing config file above before using settings.
from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as MANAGED_DEPLOYMENTS_SETUP,
)
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)

all_deployments = MANAGED_DEPLOYMENTS_SETUP.copy()
if use_community_features:
all_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
Expand Down
106 changes: 72 additions & 34 deletions src/backend/cli/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,51 @@
from backend.cli.utils import print_styled


def core_env_var_prompt(secrets):
print_styled("💾 Let's set up your database URL.")
database_url = inquirer.text(
"Enter your database URL or press enter for default [recommended]",
default=DATABASE_URL_DEFAULT,
def overwrite_config_prompt():
return inquirer.confirm(
"Your existing configuration file is invalid. Overwrite?"
)

print_styled("💾 Now, let's set up need to set up your Redis URL.")
redis_url = inquirer.text(
"Enter your Redis URL or press enter for default [recommended]",
default=REDIS_URL_DEFAULT,
)

print_styled("💾 Now, let's set up your public backend API hostname.")
next_public_api_hostname = inquirer.text(
"Enter your public API Hostname or press enter for default [recommended]",
default=NEXT_PUBLIC_API_HOSTNAME_DEFAULT,
def overwrite_secrets_prompt():
return inquirer.confirm(
"Your existing secrets file is invalid. Overwrite?"
)

print_styled("💾 Finally, the frontend client hostname.")
frontend_hostname = inquirer.text(
"Enter your frontend hostname or press enter for default [recommended]",
default=FRONTEND_HOSTNAME_DEFAULT,
)

def core_env_var_prompt(secrets):
database_url = secrets.get("DATABASE_URL")
redis_url = secrets.get("REDIS_URL")
next_public_api_hostname = secrets.get("NEXT_PUBLIC_API_HOSTNAME")
frontend_hostname = secrets.get("FRONTEND_HOSTNAME")

if not database_url:
print_styled("💾 Let's set up your database URL.")
database_url = inquirer.text(
"Enter your database URL or press enter for default [recommended]",
default=DATABASE_URL_DEFAULT,
)

if not redis_url:
print_styled("💾 Now, let's set up need to set up your Redis URL.")
redis_url = inquirer.text(
"Enter your Redis URL or press enter for default [recommended]",
default=REDIS_URL_DEFAULT,
)

if not next_public_api_hostname:
print_styled("💾 Now, let's set up your public backend API hostname.")
next_public_api_hostname = inquirer.text(
"Enter your public API Hostname or press enter for default [recommended]",
default=NEXT_PUBLIC_API_HOSTNAME_DEFAULT,
)

if not frontend_hostname:
print_styled("💾 Finally, the frontend client hostname.")
frontend_hostname = inquirer.text(
"Enter your frontend hostname or press enter for default [recommended]",
default=FRONTEND_HOSTNAME_DEFAULT,
)

secrets["DATABASE_URL"] = database_url
secrets["REDIS_URL"] = redis_url
Expand All @@ -45,20 +66,29 @@ def core_env_var_prompt(secrets):

def deployment_prompt(secrets, configs):
for secret in configs.env_vars:
value = inquirer.text(
f"Enter the value for {secret}", validate=lambda _, x: len(x) > 0
)
value = secrets.get(secret)

if not value:
value = inquirer.text(
f"Enter the value for {secret}", validate=lambda _, x: len(x) > 0
)

secrets[secret] = value


def community_tools_prompt(secrets):
print_styled(
"🏘️ We have some community tools that you can set up. These tools are not required for the Cohere Toolkit to run."
)
use_community_features = inquirer.confirm(
"Do you want to set up community features (tools and model deployments)?"
)
use_community_features = secrets.get("USE_COMMUNITY_FEATURES")

if not use_community_features:
print_styled(
"🏘️ We have some community tools that you can set up. These tools are not required for the Cohere Toolkit to run."
)
use_community_features = inquirer.confirm(
"Do you want to set up community features (tools and model deployments)?"
)

secrets["USE_COMMUNITY_FEATURES"] = use_community_features

return use_community_features


Expand All @@ -68,16 +98,24 @@ def tool_prompt(secrets, name, configs):
)

for key, default_value in configs["secrets"].items():
value = inquirer.text(f"Enter the value for {key}", default=default_value)
value = secrets.get(key)

if not value:
value = inquirer.text(f"Enter the value for {key}", default=default_value)

secrets[key] = value


def build_target_prompt(secrets):
build_target = inquirer.list_input(
"Select the build target",
choices=[BuildTarget.DEV, BuildTarget.PROD],
default=BuildTarget.DEV,
)
build_target = secrets.get("BUILD_TARGET")

if not build_target:
build_target = inquirer.list_input(
"Select the build target",
choices=[BuildTarget.DEV, BuildTarget.PROD],
default=BuildTarget.DEV,
)

secrets["BUILD_TARGET"] = build_target


Expand Down
50 changes: 49 additions & 1 deletion src/backend/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from backend.cli.constants import WELCOME_MESSAGE, DeploymentName, bcolors
import sys
from pathlib import Path

import yaml

from backend.cli.constants import (
ENV_YAML_CONFIG_MAPPING,
WELCOME_MESSAGE,
DeploymentName,
bcolors,
)
from backend.cli.setters import read_yaml


def print_styled(text: str, color: str = bcolors.ENDC):
Expand All @@ -12,6 +23,43 @@ def show_welcome_message():
)


def process_existing_yaml_config(secrets, path, prompt):
_path = Path(path)

if _path.is_file():
try:
yaml_config = read_yaml(path)
except yaml.scanner.ScannerError:
if prompt():
yaml_config = {}
_path.unlink()
else:
sys.exit(1)

secrets.update(convert_yaml_to_secrets(yaml_config))


def convert_yaml_to_secrets(yaml_dict: dict):
def get_nested_value(d, path):
keys = path.split(".")
value = d
for key in keys:
if not isinstance(value, dict) or key not in value:
return None # Return None if the path does not exist
value = value[key]
return value

secrets = {}
for env_var, mapping in ENV_YAML_CONFIG_MAPPING.items():
path = mapping.get("path")
if path: # Only process mappings with a defined path
value = get_nested_value(yaml_dict, path)
if value is not None: # Add only if the value exists
secrets[env_var] = value

return secrets


def wrap_up(deployments):
print_styled("✅ Your configuration file has been set up.", bcolors.OKGREEN)

Expand Down

0 comments on commit b638c9e

Please sign in to comment.