diff --git a/autogpt_platform/backend/backend/blocks/ai_image_customizer.py b/autogpt_platform/backend/backend/blocks/ai_image_customizer.py index 83178e924d26..850046317a91 100644 --- a/autogpt_platform/backend/backend/blocks/ai_image_customizer.py +++ b/autogpt_platform/backend/backend/blocks/ai_image_customizer.py @@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient from replicate.helpers import FileOutput +from backend.blocks.replicate._helper import run_replicate_with_retry from backend.data.block import ( Block, BlockCategory, @@ -183,9 +184,10 @@ async def run_model( if images: input_params["image_input"] = [str(img) for img in images] - output: FileOutput | str = await client.async_run( # type: ignore + output: FileOutput | str = await run_replicate_with_retry( # type: ignore + client, model_name, - input=input_params, + input_params, wait=False, ) diff --git a/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py b/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py index 8c7b6e610219..c14357826acb 100644 --- a/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py +++ b/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py @@ -5,6 +5,7 @@ from replicate.client import Client as ReplicateClient from replicate.helpers import FileOutput +from backend.blocks.replicate._helper import run_replicate_with_retry from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput from backend.data.model import ( APIKeyCredentials, @@ -181,7 +182,9 @@ async def _run_client( client = ReplicateClient(api_token=credentials.api_key.get_secret_value()) # Run the model with input parameters - output = await client.async_run(model_name, input=input_params, wait=False) + output = await run_replicate_with_retry( + client, model_name, input_params, wait=False + ) # Process output if isinstance(output, list) and len(output) > 0: diff --git a/autogpt_platform/backend/backend/blocks/ai_music_generator.py b/autogpt_platform/backend/backend/blocks/ai_music_generator.py index 1ecb78f95e78..f44ef2e46cd3 100644 --- a/autogpt_platform/backend/backend/blocks/ai_music_generator.py +++ b/autogpt_platform/backend/backend/blocks/ai_music_generator.py @@ -1,11 +1,12 @@ -import asyncio import logging from enum import Enum from typing import Literal from pydantic import SecretStr from replicate.client import Client as ReplicateClient +from replicate.helpers import FileOutput +from backend.blocks.replicate._helper import run_replicate_with_retry from backend.data.block import ( Block, BlockCategory, @@ -43,12 +44,14 @@ class MusicGenModelVersion(str, Enum): STEREO_LARGE = "stereo-large" MELODY_LARGE = "melody-large" LARGE = "large" + MINIMAX_MUSIC_1_5 = "minimax/music-1.5" # Audio format enum class AudioFormat(str, Enum): WAV = "wav" MP3 = "mp3" + PCM = "pcm" # Normalization strategy enum @@ -72,6 +75,14 @@ class Input(BlockSchemaInput): placeholder="e.g., 'An upbeat electronic dance track with heavy bass'", title="Prompt", ) + lyrics: str | None = SchemaField( + description=( + "Lyrics for the song (required for Minimax Music 1.5). " + "Use \\n to separate lines. Supports tags like [intro], [verse], [chorus], etc." + ), + default=None, + title="Lyrics", + ) music_gen_model_version: MusicGenModelVersion = SchemaField( description="Model to use for generation", default=MusicGenModelVersion.STEREO_LARGE, @@ -126,6 +137,7 @@ def __init__(self): test_input={ "credentials": TEST_CREDENTIALS_INPUT, "prompt": "An upbeat electronic dance track with heavy bass", + "lyrics": None, "music_gen_model_version": MusicGenModelVersion.STEREO_LARGE, "duration": 8, "temperature": 1.0, @@ -142,7 +154,7 @@ def __init__(self): ), ], test_mock={ - "run_model": lambda api_key, music_gen_model_version, prompt, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav", + "run_model": lambda api_key, music_gen_model_version, prompt, lyrics, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav", }, test_credentials=TEST_CREDENTIALS, ) @@ -150,48 +162,35 @@ def __init__(self): async def run( self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs ) -> BlockOutput: - max_retries = 3 - retry_delay = 5 # seconds - last_error = None - - for attempt in range(max_retries): - try: - logger.debug( - f"[AIMusicGeneratorBlock] - Running model (attempt {attempt + 1})" - ) - result = await self.run_model( - api_key=credentials.api_key, - music_gen_model_version=input_data.music_gen_model_version, - prompt=input_data.prompt, - duration=input_data.duration, - temperature=input_data.temperature, - top_k=input_data.top_k, - top_p=input_data.top_p, - classifier_free_guidance=input_data.classifier_free_guidance, - output_format=input_data.output_format, - normalization_strategy=input_data.normalization_strategy, - ) - if result and isinstance(result, str) and result.startswith("http"): - yield "result", result - return - else: - last_error = "Model returned empty or invalid response" - raise ValueError(last_error) - except Exception as e: - last_error = f"Unexpected error: {str(e)}" - logger.error(f"[AIMusicGeneratorBlock] - Error: {last_error}") - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - continue - - # If we've exhausted all retries, yield the error - yield "error", f"Failed after {max_retries} attempts. Last error: {last_error}" + try: + result = await self.run_model( + api_key=credentials.api_key, + music_gen_model_version=input_data.music_gen_model_version, + prompt=input_data.prompt, + lyrics=input_data.lyrics, + duration=input_data.duration, + temperature=input_data.temperature, + top_k=input_data.top_k, + top_p=input_data.top_p, + classifier_free_guidance=input_data.classifier_free_guidance, + output_format=input_data.output_format, + normalization_strategy=input_data.normalization_strategy, + ) + if result and isinstance(result, str) and result.startswith("http"): + yield "result", result + else: + yield "error", "Model returned empty or invalid response" + + except Exception as e: + logger.error(f"[AIMusicGeneratorBlock] - Error: {str(e)}") + yield "error", f"Failed to generate music: {str(e)}" async def run_model( self, api_key: SecretStr, music_gen_model_version: MusicGenModelVersion, prompt: str, + lyrics: str | None, duration: int, temperature: float, top_k: int, @@ -203,10 +202,24 @@ async def run_model( # Initialize Replicate client with the API key client = ReplicateClient(api_token=api_key.get_secret_value()) - # Run the model with parameters - output = await client.async_run( - "meta/musicgen:671ac645ce5e552cc63a54a2bbff63fcf798043055d2dac5fc9e36a837eedcfb", - input={ + if music_gen_model_version == MusicGenModelVersion.MINIMAX_MUSIC_1_5: + if not lyrics: + raise ValueError("Lyrics are required for Minimax Music 1.5 model") + + # Validate prompt length (10-300 chars) + if len(prompt) < 10: + prompt = prompt.ljust(10, ".") + elif len(prompt) > 300: + prompt = prompt[:300] + + input_params = { + "prompt": prompt, + "lyrics": lyrics, + "audio_format": output_format.value, + } + model_name = "minimax/music-1.5" + else: + input_params = { "prompt": prompt, "music_gen_model_version": music_gen_model_version, "duration": duration, @@ -216,7 +229,15 @@ async def run_model( "classifier_free_guidance": classifier_free_guidance, "output_format": output_format, "normalization_strategy": normalization_strategy, - }, + } + model_name = "meta/musicgen:671ac645ce5e552cc63a54a2bbff63fcf798043055d2dac5fc9e36a837eedcfb" + + # Run the model with parameters + output = await run_replicate_with_retry( + client, + model_name, + input_params, + wait=True, ) # Handle the output @@ -224,6 +245,8 @@ async def run_model( result_url = output[0] # If output is a list, get the first element elif isinstance(output, str): result_url = output # If output is a string, use it directly + elif isinstance(output, FileOutput): + result_url = output.url else: result_url = ( "No output received" # Fallback message if output is not as expected diff --git a/autogpt_platform/backend/backend/blocks/flux_kontext.py b/autogpt_platform/backend/backend/blocks/flux_kontext.py index 908d0962ed95..3fa4e4e0b3d2 100644 --- a/autogpt_platform/backend/backend/blocks/flux_kontext.py +++ b/autogpt_platform/backend/backend/blocks/flux_kontext.py @@ -5,6 +5,7 @@ from replicate.client import Client as ReplicateClient from replicate.helpers import FileOutput +from backend.blocks.replicate._helper import run_replicate_with_retry from backend.data.block import ( Block, BlockCategory, @@ -173,9 +174,10 @@ async def run_model( **({"seed": seed} if seed is not None else {}), } - output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore + output: FileOutput | list[FileOutput] = await run_replicate_with_retry( # type: ignore + client, model_name, - input=input_params, + input_params=input_params, wait=False, ) diff --git a/autogpt_platform/backend/backend/blocks/replicate/_helper.py b/autogpt_platform/backend/backend/blocks/replicate/_helper.py index 25fcb4fdcfcd..24c17e085bc2 100644 --- a/autogpt_platform/backend/backend/blocks/replicate/_helper.py +++ b/autogpt_platform/backend/backend/blocks/replicate/_helper.py @@ -1,5 +1,8 @@ +import asyncio import logging +from typing import Any +from replicate.client import Client as ReplicateClient from replicate.helpers import FileOutput logger = logging.getLogger(__name__) @@ -37,3 +40,56 @@ def extract_result(output: ReplicateOutputs) -> str: ) return result + + +async def run_replicate_with_retry( + client: ReplicateClient, + model: str, + input_params: dict[str, Any], + wait: bool = False, + max_retries: int = 3, + **kwargs: Any, +) -> Any: + last_error = None + retry_delay = 2 # seconds + + for attempt in range(max_retries): + try: + output = await client.async_run( + model, input=input_params, wait=wait, **kwargs + ) + + # Check for failed status in response + is_failed = False + if isinstance(output, dict) and output.get("status") == "failed": + is_failed = True + elif hasattr(output, "status") and getattr(output, "status") == "failed": + is_failed = True + + if is_failed: + # Try to get error message + error_msg = "Replicate prediction failed" + if isinstance(output, dict): + error = output.get("error") + if error: + error_msg = f"{error_msg}: {error}" + elif hasattr(output, "error"): + error = getattr(output, "error") + if error: + error_msg = f"{error_msg}: {error}" + + raise RuntimeError(error_msg) + + return output + + except Exception as e: + last_error = e + if attempt < max_retries - 1: + wait_time = retry_delay * (2**attempt) + logger.warning( + f"Replicate attempt {attempt + 1} failed: {str(e)}. Retrying in {wait_time}s..." + ) + await asyncio.sleep(wait_time) + else: + logger.error(f"Replicate failed after {max_retries} attempts: {str(e)}") + raise last_error diff --git a/autogpt_platform/backend/backend/blocks/replicate/flux_advanced.py b/autogpt_platform/backend/backend/blocks/replicate/flux_advanced.py index c112ce75c49c..25e33c56a4a4 100644 --- a/autogpt_platform/backend/backend/blocks/replicate/flux_advanced.py +++ b/autogpt_platform/backend/backend/blocks/replicate/flux_advanced.py @@ -9,7 +9,11 @@ TEST_CREDENTIALS_INPUT, ReplicateCredentialsInput, ) -from backend.blocks.replicate._helper import ReplicateOutputs, extract_result +from backend.blocks.replicate._helper import ( + ReplicateOutputs, + extract_result, + run_replicate_with_retry, +) from backend.data.block import ( Block, BlockCategory, @@ -188,9 +192,10 @@ async def run_model( client = ReplicateClient(api_token=api_key.get_secret_value()) # Run the model with additional parameters - output: ReplicateOutputs = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]` + output: ReplicateOutputs = await run_replicate_with_retry( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]` + client, f"{model_name}", - input={ + input_params={ "prompt": prompt, "seed": seed, "steps": steps, diff --git a/autogpt_platform/backend/backend/blocks/replicate/replicate_block.py b/autogpt_platform/backend/backend/blocks/replicate/replicate_block.py index 8cf104edc272..2adbd2f34272 100644 --- a/autogpt_platform/backend/backend/blocks/replicate/replicate_block.py +++ b/autogpt_platform/backend/backend/blocks/replicate/replicate_block.py @@ -9,7 +9,11 @@ TEST_CREDENTIALS_INPUT, ReplicateCredentialsInput, ) -from backend.blocks.replicate._helper import ReplicateOutputs, extract_result +from backend.blocks.replicate._helper import ( + ReplicateOutputs, + extract_result, + run_replicate_with_retry, +) from backend.data.block import ( Block, BlockCategory, @@ -129,8 +133,8 @@ async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr """ api_key_str = api_key.get_secret_value() client = ReplicateClient(api_token=api_key_str) - output: ReplicateOutputs = await client.async_run( - model_ref, input=model_inputs, wait=False + output: ReplicateOutputs = await run_replicate_with_retry( + client, model_ref, input_params=model_inputs, wait=False ) # type: ignore they suck at typing result = extract_result(output)