diff --git a/kittentts/get_model.py b/kittentts/get_model.py index f91c28c..48a6ca1 100644 --- a/kittentts/get_model.py +++ b/kittentts/get_model.py @@ -75,19 +75,28 @@ def download_from_huggingface(repo_id="KittenML/kitten-tts-nano-0.1", cache_dir= with open(config_path, 'r') as f: config = json.load(f) + # Validate config if config.get("type") != "ONNX1": - raise ValueError("Unsupported model type.") + raise ValueError(f"Unsupported model type in config.json: {config.get('type')}. Only 'ONNX1' is supported.") + + model_filename = config.get("model_file") + if not model_filename: + raise ValueError("The 'model_file' key is missing from config.json.") + + voices_filename = config.get("voices") + if not voices_filename: + raise ValueError("The 'voices' key is missing from config.json.") # Download model and voices files based on config model_path = hf_hub_download( repo_id=repo_id, - filename=config["model_file"], + filename=model_filename, cache_dir=cache_dir ) voices_path = hf_hub_download( repo_id=repo_id, - filename=config["voices"], + filename=voices_filename, cache_dir=cache_dir )