Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] audio integration #324

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions client/src/nv_ingest_client/primitives/tasks/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
"svg": "image",
"tiff": "image",
"xml": "lxml",
"mp3": "audio",
"wav": "audio",
}

_Type_Extract_Method_PDF = Literal[
Expand All @@ -63,6 +65,8 @@

_Type_Extract_Method_Image = Literal["image"]

_Type_Extract_Method_Audio = Literal["audio"]

_Type_Extract_Method_Map = {
"docx": get_args(_Type_Extract_Method_DOCX),
"jpeg": get_args(_Type_Extract_Method_Image),
Expand All @@ -72,6 +76,8 @@
"pptx": get_args(_Type_Extract_Method_PPTX),
"svg": get_args(_Type_Extract_Method_Image),
"tiff": get_args(_Type_Extract_Method_Image),
"mp3": get_args(_Type_Extract_Method_Audio),
"wav": get_args(_Type_Extract_Method_Audio),
}

_Type_Extract_Tables_Method_PDF = Literal["yolox", "pdfium"]
Expand Down
5 changes: 4 additions & 1 deletion client/src/nv_ingest_client/util/file_processing/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class DocumentTypeEnum(str, Enum):
svg = "svg"
tiff = "tiff"
txt = "text"

mp3 = "mp3"
wav = "wav"

# Maps MIME types to DocumentTypeEnum
MIME_TO_DOCUMENT_TYPE = {
Expand Down Expand Up @@ -64,6 +65,8 @@ class DocumentTypeEnum(str, Enum):
"svg": DocumentTypeEnum.svg,
"tiff": DocumentTypeEnum.tiff,
"txt": DocumentTypeEnum.txt,
"mp3": DocumentTypeEnum.mp3,
"wav": DocumentTypeEnum.wav,
# Add more as needed
}

Expand Down
26 changes: 26 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,29 @@ services:
capabilities: [gpu]
runtime: nvidia

audio:
image: nvcr.io/nvidian/audio_retrieval:latest
shm_size: 2gb
ports:
- "8015:8000"
user: root
environment:
- NIM_HTTP_API_PORT=8000
- NIM_TRITON_LOG_VERBOSE=1
- NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}}
- CUDA_VISIBLE_DEVICES=0
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ["1"]
capabilities: [gpu]
runtime: nvidia
working_dir: /app/audio_retrieval/src



nv-ingest-ms-runtime:
image: nvcr.io/ohlfw0olaadg/ea-participants/nv-ingest:24.10.1
build:
Expand All @@ -141,6 +164,9 @@ services:
cap_add:
- sys_nice
environment:
# Self-hosted audio endpoints.
- AUDIO_HTTP_ENDPOINT=http://audio:8000/v1/transcribe
- AUDIO_INFER_PROTOCOL=http
# Self-hosted cached endpoints.
- CACHED_GRPC_ENDPOINT=cached:8001
- CACHED_HTTP_ENDPOINT=http://cached:8000/v1/infer
Expand Down
11 changes: 10 additions & 1 deletion src/nv_ingest/modules/injectors/metadata_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import logging
import traceback

import mrc
import pandas as pd
Expand Down Expand Up @@ -46,6 +47,9 @@ def on_data(message: ControlMessage):
"type": content_type.name.lower(),
},
"error_metadata": None,
"audio_metadata": (
None if content_type != ContentTypeEnum.AUDIO else {"audio_type": row["document_type"]}
),
"image_metadata": (
None if content_type != ContentTypeEnum.IMAGE else {"image_type": row["document_type"]}
),
Expand Down Expand Up @@ -78,7 +82,12 @@ def _metadata_injection(builder: mrc.Builder):
raise_on_failure=validated_config.raise_on_failure,
)
def _on_data(message: ControlMessage):
return on_data(message)
try:
return on_data(message)
except Exception as e:
logger.error(f"Unhandled exception in metadata_injector: {e}")
traceback.print_exc()
raise

node = builder.make_node("metadata_injector", _on_data)

Expand Down
127 changes: 127 additions & 0 deletions src/nv_ingest/schemas/audio_extractor_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import logging
from typing import Optional
from typing import Tuple

from pydantic import BaseModel
from pydantic import root_validator

logger = logging.getLogger(__name__)


class AudioConfigSchema(BaseModel):
"""
Configuration schema for audio extraction endpoints and options.

Parameters
----------
auth_token : Optional[str], default=None
Authentication token required for secure services.

audio_endpoints : Tuple[str, str]
A tuple containing the gRPC and HTTP services for the audio_retriever endpoint.
Either the gRPC or HTTP service can be empty, but not both.

Methods
-------
validate_endpoints(values)
Validates that at least one of the gRPC or HTTP services is provided for each endpoint.

Raises
------
ValueError
If both gRPC and HTTP services are empty for any endpoint.

Config
------
extra : str
Pydantic config option to forbid extra fields.
"""

auth_token: Optional[str] = None
audio_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
audio_infer_protocol: Optional[str] = None

@root_validator(pre=True)
def validate_endpoints(cls, values):
"""
Validates the gRPC and HTTP services for all endpoints.

Parameters
----------
values : dict
Dictionary containing the values of the attributes for the class.

Returns
-------
dict
The validated dictionary of values.

Raises
------
ValueError
If both gRPC and HTTP services are empty for any endpoint.
"""

def clean_service(service):
"""Set service to None if it's an empty string or contains only spaces or quotes."""
if service is None or not service.strip() or service.strip(" \"'") == "":
return None
return service

endpoint_name = "audio_endpoints"
grpc_service, http_service = values.get(endpoint_name)
grpc_service = clean_service(grpc_service)
http_service = clean_service(http_service)

if not grpc_service and not http_service:
raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.")

values[endpoint_name] = (grpc_service, http_service)

protocol_name = "audio_infer_protocol"
protocol_value = values.get(protocol_name)

if not protocol_value:
protocol_value = "http" if http_service else "grpc" if grpc_service else ""

protocol_value = protocol_value.lower()
values[protocol_name] = protocol_value

return values

class Config:
extra = "forbid"


class AudioExtractorSchema(BaseModel):
"""
Configuration schema for the PDF extractor settings.

Parameters
----------
max_queue_size : int, default=1
The maximum number of items allowed in the processing queue.

n_workers : int, default=16
The number of worker threads to use for processing.

raise_on_failure : bool, default=False
A flag indicating whether to raise an exception on processing failure.

audio_extraction_config: Optional[AudioConfigSchema], default=None
Configuration schema for the audio extraction stage.
"""

max_queue_size: int = 1
n_workers: int = 16
raise_on_failure: bool = False

audio_extraction_config: Optional[AudioConfigSchema] = None

class Config:
extra = "forbid"
2 changes: 2 additions & 0 deletions src/nv_ingest/schemas/ingest_job_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class DocumentTypeEnum(str, Enum):
svg = "svg"
tiff = "tiff"
txt = "text"
mp3 = "mp3"
wav = "wav"


class TaskTypeEnum(str, Enum):
Expand Down
1 change: 1 addition & 0 deletions src/nv_ingest/schemas/ingest_pipeline_config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@


class PipelineConfigSchema(BaseModel):
# TODO(Devin): Audio
chart_extractor_module: ChartExtractorSchema = ChartExtractorSchema()
document_splitter_module: DocumentSplitterSchema = DocumentSplitterSchema()
embedding_storage_module: EmbeddingStorageModuleSchema = EmbeddingStorageModuleSchema()
Expand Down
12 changes: 10 additions & 2 deletions src/nv_ingest/schemas/metadata_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ class ChartMetadataSchema(BaseModelNoExt):
uploaded_image_uri: str = ""


class AudioMetadataSchema(BaseModelNoExt):
audio_transcript: str = ""
audio_type: str = ""


# TODO consider deprecating this in favor of info msg...
class ErrorMetadataSchema(BaseModelNoExt):
task: TaskTypeEnum
Expand All @@ -321,6 +326,7 @@ class MetadataSchema(BaseModelNoExt):
embedding: Optional[List[float]] = None
source_metadata: Optional[SourceMetadataSchema] = None
content_metadata: Optional[ContentMetadataSchema] = None
audio_metadata: Optional[AudioMetadataSchema] = None
text_metadata: Optional[TextMetadataSchema] = None
image_metadata: Optional[ImageMetadataSchema] = None
table_metadata: Optional[TableMetadataSchema] = None
Expand All @@ -334,10 +340,12 @@ class MetadataSchema(BaseModelNoExt):
@classmethod
def check_metadata_type(cls, values):
content_type = values.get("content_metadata", {}).get("type", None)
if content_type != ContentTypeEnum.TEXT:
values["text_metadata"] = None
if content_type != ContentTypeEnum.AUDIO:
values["audio_metadata"] = None
if content_type != ContentTypeEnum.IMAGE:
values["image_metadata"] = None
if content_type != ContentTypeEnum.TEXT:
values["text_metadata"] = None
if content_type != ContentTypeEnum.STRUCTURED:
values["table_metadata"] = None
return values
Expand Down
Loading
Loading