Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions examples/other/elevenlab_scribe_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging

from dotenv import load_dotenv

from livekit.agents import Agent, AgentServer, AgentSession, JobContext, JobProcess, cli, stt
from livekit.plugins import elevenlabs, silero

logger = logging.getLogger("realtime-scribe-v2")
logger.setLevel(logging.INFO)

load_dotenv()

server = AgentServer()


@server.rtc_session()
async def entrypoint(ctx: JobContext):
session = AgentSession(
vad=ctx.proc.userdata["vad"],
stt=stt.StreamAdapter(
stt=elevenlabs.STT(
use_realtime=True,
server_vad=None, # disable server-side VAD
language_code="en",
),
vad=ctx.proc.userdata["vad"],
use_streaming=True,
),
llm="openai/gpt-4.1-mini",
tts="elevenlabs",
)
await session.start(
agent=Agent(instructions="You are a somewhat helpful assistant."), room=ctx.room
)

await session.say("Hello, how can I help you?")


def prewarm(proc: JobProcess):
proc.userdata["vad"] = silero.VAD.load()


server.setup_fnc = prewarm


if __name__ == "__main__":
cli.run_app(server)
42 changes: 0 additions & 42 deletions examples/other/realtime_scribe_v2.py

This file was deleted.

55 changes: 55 additions & 0 deletions examples/voice_agents/stream_stt_with_vad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging

from dotenv import load_dotenv

from livekit.agents import (
Agent,
AgentServer,
AgentSession,
JobContext,
JobProcess,
cli,
stt,
)
from livekit.plugins import deepgram, silero

logger = logging.getLogger("stream-stt-with-vad")

# This example shows how to use a streaming STT with a VAD.
# Only the audio frames which are detected as speech by the VAD will be sent to the STT.
# This requires the STT to support streaming and flush, e.g. deepgram, cartesia, etc.,
# check the `STT.capabilities` for more details.

load_dotenv()

server = AgentServer()


@server.rtc_session()
async def entrypoint(ctx: JobContext):
session = AgentSession(
vad=ctx.proc.userdata["vad"],
stt=stt.StreamAdapter(
stt=deepgram.STT(),
vad=ctx.proc.userdata["vad"],
use_streaming=True, # use streaming mode of the wrapped STT with VAD
),
llm="openai/gpt-4.1-mini",
tts="elevenlabs",
)
await session.start(
agent=Agent(instructions="You are a somewhat helpful assistant."), room=ctx.room
)

await session.say("Hello, how can I help you?")


def prewarm(proc: JobProcess):
proc.userdata["vad"] = silero.VAD.load()


server.setup_fnc = prewarm


if __name__ == "__main__":
cli.run_app(server)
157 changes: 118 additions & 39 deletions livekit-agents/livekit/agents/stt/stream_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import asyncio
from collections.abc import AsyncIterable
from dataclasses import dataclass
from typing import Any

from .. import utils
from ..log import logger
from ..types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, APIConnectOptions, NotGivenOr
from ..vad import VAD, VADEventType
from ..vad import VAD, VADEventType, VADStream
from .stt import STT, RecognizeStream, SpeechEvent, SpeechEventType, STTCapabilities

# already a retry mechanism in STT.recognize, don't retry in stream adapter
Expand All @@ -15,17 +17,46 @@
)


@dataclass
class StreamAdapterOptions:
use_streaming: bool = False


class StreamAdapter(STT):
def __init__(self, *, stt: STT, vad: VAD) -> None:
def __init__(
self,
*,
stt: STT,
vad: VAD,
use_streaming: bool = False,
) -> None:
"""
Create a new instance of StreamAdapter.

Args:
stt: The STT to wrap.
vad: The VAD to use.
use_streaming: Whether to use streaming mode of the wrapped STT. Default is False.
"""
super().__init__(
capabilities=STTCapabilities(
streaming=True,
interim_results=False,
diarization=False, # diarization requires streaming STT
interim_results=use_streaming,
diarization=stt.capabilities.diarization and use_streaming,
)
)
self._vad = vad
self._stt = stt
self._opts = StreamAdapterOptions(use_streaming=use_streaming)
if use_streaming and not stt.capabilities.streaming:
raise ValueError(
f"STT {stt.label} does not support streaming while use_streaming is enabled"
)
if use_streaming and not stt.capabilities.flush:
logger.warning(
f"STT {stt.label} does not support flush while use_streaming is enabled, "
"this may cause incomplete transcriptions."
)

# TODO(theomonnom): The segment_id needs to be populated!
self._stt.on("metrics_collected", self._on_metrics_collected)
Expand Down Expand Up @@ -65,6 +96,7 @@ def stream(
wrapped_stt=self._stt,
language=language,
conn_options=conn_options,
opts=self._opts,
)

def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
Expand All @@ -83,12 +115,14 @@ def __init__(
wrapped_stt: STT,
language: NotGivenOr[str],
conn_options: APIConnectOptions,
opts: StreamAdapterOptions,
) -> None:
super().__init__(stt=stt, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
self._vad = vad
self._wrapped_stt = wrapped_stt
self._wrapped_stt_conn_options = conn_options
self._language = language
self._opts = opts

async def _metrics_monitor_task(self, event_aiter: AsyncIterable[SpeechEvent]) -> None:
pass # do nothing
Expand All @@ -106,43 +140,88 @@ async def _forward_input() -> None:

vad_stream.end_input()

async def _recognize() -> None:
"""recognize speech from vad"""
async for event in vad_stream:
if event.type == VADEventType.START_OF_SPEECH:
self._event_ch.send_nowait(SpeechEvent(SpeechEventType.START_OF_SPEECH))
elif event.type == VADEventType.END_OF_SPEECH:
self._event_ch.send_nowait(
SpeechEvent(
type=SpeechEventType.END_OF_SPEECH,
)
)
async def _forward_stream_output(stream: RecognizeStream) -> None:
async for event in stream:
self._event_ch.send_nowait(event)

stt_stream: RecognizeStream | None = None
forward_input_task = asyncio.create_task(_forward_input(), name="forward_input")
tasks = []
if not self._opts.use_streaming:
tasks.append(
asyncio.create_task(
self._recognize_non_streaming(vad_stream), name="recognize_non_streaming"
),
)
else:
stt_stream = self._wrapped_stt.stream(
language=self._language, conn_options=self._wrapped_stt_conn_options
)
tasks += [
asyncio.create_task(
_forward_stream_output(stt_stream), name="forward_stream_output"
),
asyncio.create_task(
self._recognize_streaming(vad_stream, stt_stream),
name="recognize_streaming",
),
]

merged_frames = utils.merge_frames(event.frames)
t_event = await self._wrapped_stt.recognize(
buffer=merged_frames,
language=self._language,
conn_options=self._wrapped_stt_conn_options,
try:
await asyncio.gather(*tasks, forward_input_task)
finally:
await utils.aio.cancel_and_wait(forward_input_task)
await vad_stream.aclose()
if stt_stream is not None:
stt_stream.end_input()
await stt_stream.aclose()
await utils.aio.cancel_and_wait(*tasks)

async def _recognize_streaming(
self, vad_stream: VADStream, stt_stream: RecognizeStream
) -> None:
speaking = False
async for event in vad_stream:
frames = []
if event.type == VADEventType.START_OF_SPEECH:
speaking = True
frames = event.frames
elif event.type == VADEventType.INFERENCE_DONE and speaking:
frames = event.frames
elif event.type == VADEventType.END_OF_SPEECH:
speaking = False
stt_stream.flush()

for f in frames:
stt_stream.push_frame(f)

async def _recognize_non_streaming(self, vad_stream: VADStream) -> None:
"""recognize speech from vad"""
async for event in vad_stream:
if event.type == VADEventType.START_OF_SPEECH:
self._event_ch.send_nowait(SpeechEvent(SpeechEventType.START_OF_SPEECH))
elif event.type == VADEventType.END_OF_SPEECH:
self._event_ch.send_nowait(
SpeechEvent(
type=SpeechEventType.END_OF_SPEECH,
)
)

if len(t_event.alternatives) == 0:
continue
elif not t_event.alternatives[0].text:
continue
merged_frames = utils.merge_frames(event.frames)
t_event = await self._wrapped_stt.recognize(
buffer=merged_frames,
language=self._language,
conn_options=self._wrapped_stt_conn_options,
)

self._event_ch.send_nowait(
SpeechEvent(
type=SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[t_event.alternatives[0]],
)
)
if len(t_event.alternatives) == 0:
continue
elif not t_event.alternatives[0].text:
continue

tasks = [
asyncio.create_task(_forward_input(), name="forward_input"),
asyncio.create_task(_recognize(), name="recognize"),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.cancel_and_wait(*tasks)
await vad_stream.aclose()
self._event_ch.send_nowait(
SpeechEvent(
type=SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[t_event.alternatives[0]],
)
)
1 change: 1 addition & 0 deletions livekit-agents/livekit/agents/stt/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class STTCapabilities:
streaming: bool
interim_results: bool
diarization: bool = False
flush: bool = False


class STTError(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
buffer_size_seconds: float = 0.05,
):
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=False),
capabilities=stt.STTCapabilities(streaming=True, interim_results=False, flush=True),
)
assemblyai_api_key = api_key if is_given(api_key) else os.environ.get("ASSEMBLYAI_API_KEY")
if assemblyai_api_key is None:
Expand Down Expand Up @@ -171,6 +171,7 @@ def update_options(
class SpeechStream(stt.SpeechStream):
# Used to close websocket
_CLOSE_MSG: str = json.dumps({"type": "Terminate"})
_FLUSH_MSG: str = json.dumps({"type": "ForceEndpoint"})

def __init__(
self,
Expand Down Expand Up @@ -241,6 +242,9 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
self._speech_duration += frame.duration
await ws.send_bytes(frame.data.tobytes())

if isinstance(data, self._FlushSentinel):
await ws.send_str(SpeechStream._FLUSH_MSG)

closing_ws = True
await ws.send_str(SpeechStream._CLOSE_MSG)

Expand Down
Loading