Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def generate_image(
sexually_explicit_threshold: str,
dangerous_content_threshold: str,
system_instruction: str,
image1: torch.Tensor,
image1: Optional[torch.Tensor] = None,
image2: Optional[torch.Tensor] = None,
image3: Optional[torch.Tensor] = None,
) -> List[Image.Image]:
Expand All @@ -89,7 +89,7 @@ def generate_image(
content.
dangerous_content_threshold: Safety threshold for dangerous content.
system_instruction: System-level instructions for the model.
image1: The primary input image tensor for image-to-image tasks.
image1: An optional primary input image tensor for image-to-image tasks.
image2: An optional second input image tensor. Defaults to None.
image3: An optional third input image tensor. Defaults to None.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
"default": "A vivid landscape painting of a futuristic city",
},
),
"image1": ("IMAGE",),
"aspect_ratio": (
[
"1:1",
Expand All @@ -87,6 +86,7 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
"top_k": ("INT", {"default": 32, "min": 1, "max": 64}),
},
"optional": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"image3": ("IMAGE",),
# Safety Settings
Expand Down Expand Up @@ -145,12 +145,12 @@ def generate_and_return_image(
temperature: float,
top_p: float,
top_k: int,
image1: torch.Tensor,
hate_speech_threshold: str,
harassment_threshold: str,
sexually_explicit_threshold: str,
dangerous_content_threshold: str,
system_instruction: str,
image1: Optional[torch.Tensor] = None,
image2: Optional[torch.Tensor] = None,
image3: Optional[torch.Tensor] = None,
gcp_project_id: Optional[str] = None,
Expand All @@ -175,7 +175,7 @@ def generate_and_return_image(
content.
dangerous_content_threshold: Safety threshold for dangerous content.
system_instruction: System-level instructions for the model.
image1: The primary input image tensor for image editing tasks.
image1: An optional primary input image tensor for image editing tasks.
image2: An optional second input image tensor. Defaults to None.
image3: An optional third input image tensor. Defaults to None.
gcp_project_id: The GCP project ID.
Expand Down
84 changes: 15 additions & 69 deletions modules/python/src/custom_nodes/google_genmedia/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,13 +483,15 @@ def generate_video_from_image(


@api_error_retry
def generate_video_from_gcs_references(
def generate_video_from_references(
client: genai.Client,
model: str,
prompt: str,
gcs_uris: List[str],
image_format: str,
aspect_ratio: str,
image1: torch.Tensor,
image2: Optional[torch.Tensor],
image3: Optional[torch.Tensor],
output_resolution: Optional[str],
compression_quality: Optional[str],
person_generation: str,
Expand Down Expand Up @@ -530,16 +532,18 @@ def generate_video_from_gcs_references(
"generate_audio": generate_audio if generate_audio is not None else False,
}

if output_gcs_uri:
temp_config["output_gcs_uri"] = output_gcs_uri

reference_images = []
for uri in gcs_uris:
image_part = Image(gcs_uri=uri, mime_type=mime_type)
reference_image = types.VideoGenerationReferenceImage(
image=image_part, reference_type="asset"
)
reference_images.append(reference_image)

for image_tensor in [image1, image2, image3]:
if image_tensor is not None:
image_part = Image(
imageBytes=tensor_to_pil_to_base64(image_tensor, image_format),
mime_type=mime_type,
)
reference_image = types.VideoGenerationReferenceImage(
image=image_part, reference_type="asset"
)
reference_images.append(reference_image)

temp_config["reference_images"] = reference_images

Expand Down Expand Up @@ -966,64 +970,6 @@ def process_video_response(operation: Any) -> List[str]:
return video_paths


def upload_images_to_gcs(
images: List[Optional[torch.Tensor]], bucket_name: str, image_format: str
) -> List[str]:
"""
Uploads a list of image tensors to a GCS bucket.

Args:
images: A list of torch.Tensor images, which can contain None.
bucket_name: The name of the GCS bucket.
image_format: The format of the images (e.g., "PNG", "JPEG").

Returns:
A list of GCS URIs for the uploaded images.
"""
prefix = "gs://"
if bucket_name.startswith(prefix):
bucket_name = bucket_name[len(prefix) :]

gcs_uris = []
storage_client = storage.Client(
client_info=ClientInfo(user_agent=STORAGE_USER_AGENT)
)
bucket = storage_client.bucket(bucket_name)

if not bucket.exists():
raise APIInputError(
f"GCS bucket '{bucket_name}' does not exist or is inaccessible."
)

for i, image_tensor in enumerate(images):
if image_tensor is not None:
try:
timestamp = int(time.time())
unique_id = random.randint(1000, 9999)
object_name = f"temporary-reference-images/ref_{timestamp}_{i+1}_{unique_id}.{image_format.lower()}"
blob = bucket.blob(object_name)

# VEO expects single images, not batches. Take the first from any potential batch.
single_image_tensor = image_tensor[0].unsqueeze(0)
image_bytes = tensor_to_pil_to_bytes(
single_image_tensor, format=image_format.upper()
)

blob.upload_from_string(
image_bytes, content_type=f"image/{image_format.lower()}"
)

gcs_uri = f"gs://{bucket_name}/{object_name}"
gcs_uris.append(gcs_uri)
logger.info(f"Successfully uploaded reference image {i+1} to {gcs_uri}")
except Exception as e:
raise APIExecutionError(
f"Failed to upload image {i+1} to GCS: {e}"
) from e

return gcs_uris


def validate_gcs_uri_and_image(
gcs_uri: str, check_object: bool = True
) -> Tuple[bool, str]:
Expand Down
24 changes: 5 additions & 19 deletions modules/python/src/custom_nodes/google_genmedia/veo3_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def generate_video_from_references(
self,
model: str,
prompt: str,
bucket_name: str,
image1: torch.Tensor,
image_format: str,
aspect_ratio: str,
Expand All @@ -249,12 +248,11 @@ def generate_video_from_references(
seed: Optional[int],
) -> List[str]:
"""
Uploads reference images to GCS and then generates a video.
Generates a video from the references.

Args:
model: Veo3 model.
prompt: The text prompt for video generation.
bucket_name: The GCS bucket to upload reference images to.
image1: The first reference image as a torch.Tensor.
image_format: The format of the input images.
aspect_ratio: The desired aspect ratio of the video.
Expand All @@ -278,27 +276,15 @@ def generate_video_from_references(
raise APIInputError(
"Image1 is required. At least reference image must be provided."
)
if not bucket_name:
raise APIInputError(
"bucket_name is required for uploading reference images."
)

gcs_uris = utils.upload_images_to_gcs(
images=[image1, image2, image3],
bucket_name=bucket_name,
image_format=image_format,
)

if not gcs_uris:
raise APIExecutionError("Failed to upload any reference images to GCS.")

model_enum = Veo3Model[model]

return utils.generate_video_from_gcs_references(
return utils.generate_video_from_references(
client=self.client,
model=model_enum,
prompt=prompt,
gcs_uris=gcs_uris,
image1=image1,
image2=image2,
image3=image3,
image_format=image_format,
aspect_ratio=aspect_ratio,
output_resolution=output_resolution,
Expand Down
9 changes: 0 additions & 9 deletions modules/python/src/custom_nodes/google_genmedia/veo3_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,6 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
[model.name for model in Veo3Model],
{"default": Veo3Model.VEO_3_1_PREVIEW.name},
),
"bucket_name": (
"STRING",
{
"default": "",
"tooltip": "GCS bucket name to temporarily store reference images.",
},
),
"image1": ("IMAGE",),
"image_format": (
["PNG", "JPEG"],
Expand Down Expand Up @@ -591,7 +584,6 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
def generate_from_references(
self,
model: str,
bucket_name: str,
image1: torch.Tensor,
image_format: str,
prompt: str,
Expand Down Expand Up @@ -622,7 +614,6 @@ def generate_from_references(
video_paths = api.generate_video_from_references(
model=model,
prompt=prompt,
bucket_name=bucket_name,
image1=image1,
image2=image2,
image3=image3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,9 @@ resource "google_storage_bucket_object" "workflow_veo3_itv" {
}

resource "google_storage_bucket_object" "workflow_veo3_r2v" {
bucket = google_storage_bucket.comfyui_workflow.name
name = "veo3-reference-to-video.json"
source = "src/comfyui-workflows/veo3-reference-to-video.json"
depends_on = [local_file.workflow_veo2_ttv]
bucket = google_storage_bucket.comfyui_workflow.name
name = "veo3-reference-to-video.json"
source = "src/comfyui-workflows/veo3-reference-to-video.json"
}

resource "google_storage_bucket_object" "workflow_veo3_ttv" {
Expand Down Expand Up @@ -270,13 +269,3 @@ resource "local_file" "workflow_veo3_ttv" {
)
filename = "${path.module}/src/comfyui-workflows/veo3-text-to-video.json"
}

resource "local_file" "workflow_veo3_r2v" {
content = templatefile(
"${path.module}/src/comfyui-workflows/veo3-reference-to-video.tftpl.json",
{
output_bucket_uri = google_storage_bucket.comfyui_output.url
}
)
filename = "${path.module}/src/comfyui-workflows/veo3-reference-to-video.json"
}
Loading
Loading