Skip to content

Commit 30e3c47

Browse files
committed
Improve audio detection
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 01aace3 commit 30e3c47

File tree

1 file changed

+135
-105
lines changed

1 file changed

+135
-105
lines changed

core/http/endpoints/openai/realtime.go

Lines changed: 135 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"github.com/gofiber/fiber/v2"
1414
"github.com/gofiber/websocket/v2"
1515
"github.com/mudler/LocalAI/core/application"
16-
"github.com/mudler/LocalAI/core/backend"
1716
"github.com/mudler/LocalAI/core/config"
1817
"github.com/mudler/LocalAI/core/schema"
1918
"github.com/mudler/LocalAI/pkg/functions"
@@ -138,6 +137,8 @@ func registerRealtime(application *application.Application) func(c *websocket.Co
138137
model = "gpt-4o"
139138
}
140139

140+
log.Info().Msgf("New session with model: %s", model)
141+
141142
sessionID := generateSessionID()
142143
session := &Session{
143144
ID: sessionID,
@@ -487,9 +488,16 @@ func updateSession(session *Session, update *Session, cl *config.BackendConfigLo
487488
}
488489

489490
const (
490-
minMicVolume = 450
491-
sendToVADDelay = time.Second
492-
maxWhisperSegmentDuration = time.Second * 15
491+
minMicVolume = 450
492+
sendToVADDelay = time.Second
493+
)
494+
495+
type VADState int
496+
497+
const (
498+
StateSilence VADState = iota
499+
StateSpeaking
500+
StateTrailingSilence
493501
)
494502

495503
// handle VAD (Voice Activity Detection)
@@ -503,7 +511,8 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
503511
cancel()
504512
}()
505513

506-
audioDetected := false
514+
vadState := VADState(StateSilence)
515+
segments := []*proto.VADSegment{}
507516
timeListening := time.Now()
508517

509518
// Implement VAD logic here
@@ -520,15 +529,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
520529

521530
if len(session.InputAudioBuffer) > 0 {
522531

523-
if audioDetected && time.Since(timeListening) < maxWhisperSegmentDuration {
524-
log.Debug().Msgf("VAD detected speech, but still listening")
525-
// audioDetected = false
526-
// keep listening
527-
session.AudioBufferLock.Unlock()
528-
continue
529-
}
530-
531-
if audioDetected {
532+
if vadState == StateTrailingSilence {
532533
log.Debug().Msgf("VAD detected speech that we can process")
533534

534535
// Commit the audio buffer as a conversation item
@@ -561,7 +562,8 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
561562
Item: item,
562563
})
563564

564-
audioDetected = false
565+
vadState = StateSilence
566+
segments = []*proto.VADSegment{}
565567
// Generate a response
566568
generateResponse(cfg, evaluator, session, conversation, ResponseCreate{}, c, websocket.TextMessage)
567569
continue
@@ -570,7 +572,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
570572
adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
571573

572574
// Resample from 24kHz to 16kHz
573-
adata = sound.ResampleInt16(adata, 24000, 16000)
575+
// adata = sound.ResampleInt16(adata, 24000, 16000)
574576

575577
soundIntBuffer := &audio.IntBuffer{
576578
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
@@ -582,9 +584,20 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
582584
session.AudioBufferLock.Unlock()
583585
continue
584586
} */
585-
586587
float32Data := soundIntBuffer.AsFloat32Buffer().Data
587588

589+
// TODO: testing wav decoding
590+
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
591+
// buf, err := dec.FullPCMBuffer()
592+
// if err != nil {
593+
// //log.Error().Msgf("failed to process audio: %s", err.Error())
594+
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
595+
// session.AudioBufferLock.Unlock()
596+
// continue
597+
// }
598+
599+
//float32Data = buf.AsFloat32Buffer().Data
600+
588601
resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
589602
Audio: float32Data,
590603
})
@@ -598,20 +611,34 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
598611
if len(resp.Segments) == 0 {
599612
log.Debug().Msg("VAD detected no speech activity")
600613
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
601-
602-
if !audioDetected {
614+
if len(session.InputAudioBuffer) > 16000 {
603615
session.InputAudioBuffer = nil
616+
segments = []*proto.VADSegment{}
604617
}
618+
605619
log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer))
620+
} else if (len(resp.Segments) != len(segments)) && vadState == StateSpeaking {
621+
// We have new segments, but we are still speaking
622+
// We need to wait for the trailing silence
606623

607-
session.AudioBufferLock.Unlock()
608-
continue
609-
}
624+
segments = resp.Segments
625+
626+
} else if (len(resp.Segments) == len(segments)) && vadState == StateSpeaking {
627+
// We have the same number of segments, but we are still speaking
628+
// We need to check if we are in this state for long enough, update the timer
610629

611-
if !audioDetected {
612-
timeListening = time.Now()
630+
// Check if we have been listening for too long
631+
if time.Since(timeListening) > sendToVADDelay {
632+
vadState = StateTrailingSilence
633+
} else {
634+
635+
timeListening = timeListening.Add(time.Since(timeListening))
636+
}
637+
} else {
638+
log.Debug().Msg("VAD detected speech activity")
639+
vadState = StateSpeaking
640+
segments = resp.Segments
613641
}
614-
audioDetected = true
615642

616643
session.AudioBufferLock.Unlock()
617644
} else {
@@ -843,101 +870,104 @@ func processTextResponse(config *config.BackendConfig, session *Session, prompt
843870
// Replace this with actual model inference logic using session.Model and prompt
844871
// For example, the model might return a special token or JSON indicating a function call
845872

846-
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
873+
/*
874+
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
847875
848-
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
849-
if !shouldUseFn {
850-
// no function is called, just reply and use stop as finish reason
851-
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
852-
return
853-
}
854-
855-
textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
856-
s = functions.CleanupLLMResult(s, config.FunctionsConfig)
857-
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
858-
log.Debug().Msgf("Text content to return: %s", textContentToReturn)
859-
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
860-
861-
switch {
862-
case noActionsToRun:
863-
result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
864-
if err != nil {
865-
log.Error().Err(err).Msg("error handling question")
876+
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
877+
if !shouldUseFn {
878+
// no function is called, just reply and use stop as finish reason
879+
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
866880
return
867881
}
868-
*c = append(*c, schema.Choice{
869-
Message: &schema.Message{Role: "assistant", Content: &result}})
870-
default:
871-
toolChoice := schema.Choice{
872-
Message: &schema.Message{
873-
Role: "assistant",
874-
},
875-
}
876882
877-
if len(input.Tools) > 0 {
878-
toolChoice.FinishReason = "tool_calls"
879-
}
883+
textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
884+
s = functions.CleanupLLMResult(s, config.FunctionsConfig)
885+
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
886+
log.Debug().Msgf("Text content to return: %s", textContentToReturn)
887+
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
888+
889+
switch {
890+
case noActionsToRun:
891+
result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
892+
if err != nil {
893+
log.Error().Err(err).Msg("error handling question")
894+
return
895+
}
896+
*c = append(*c, schema.Choice{
897+
Message: &schema.Message{Role: "assistant", Content: &result}})
898+
default:
899+
toolChoice := schema.Choice{
900+
Message: &schema.Message{
901+
Role: "assistant",
902+
},
903+
}
880904
881-
for _, ss := range results {
882-
name, args := ss.Name, ss.Arguments
883905
if len(input.Tools) > 0 {
884-
// If we are using tools, we condense the function calls into
885-
// a single response choice with all the tools
886-
toolChoice.Message.Content = textContentToReturn
887-
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
888-
schema.ToolCall{
889-
ID: id,
890-
Type: "function",
891-
FunctionCall: schema.FunctionCall{
892-
Name: name,
893-
Arguments: args,
906+
toolChoice.FinishReason = "tool_calls"
907+
}
908+
909+
for _, ss := range results {
910+
name, args := ss.Name, ss.Arguments
911+
if len(input.Tools) > 0 {
912+
// If we are using tools, we condense the function calls into
913+
// a single response choice with all the tools
914+
toolChoice.Message.Content = textContentToReturn
915+
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
916+
schema.ToolCall{
917+
ID: id,
918+
Type: "function",
919+
FunctionCall: schema.FunctionCall{
920+
Name: name,
921+
Arguments: args,
922+
},
894923
},
895-
},
896-
)
897-
} else {
898-
// otherwise we return more choices directly
899-
*c = append(*c, schema.Choice{
900-
FinishReason: "function_call",
901-
Message: &schema.Message{
902-
Role: "assistant",
903-
Content: &textContentToReturn,
904-
FunctionCall: map[string]interface{}{
905-
"name": name,
906-
"arguments": args,
924+
)
925+
} else {
926+
// otherwise we return more choices directly
927+
*c = append(*c, schema.Choice{
928+
FinishReason: "function_call",
929+
Message: &schema.Message{
930+
Role: "assistant",
931+
Content: &textContentToReturn,
932+
FunctionCall: map[string]interface{}{
933+
"name": name,
934+
"arguments": args,
935+
},
907936
},
908-
},
909-
})
937+
})
938+
}
910939
}
911-
}
912940
913-
if len(input.Tools) > 0 {
914-
// we need to append our result if we are using tools
915-
*c = append(*c, toolChoice)
941+
if len(input.Tools) > 0 {
942+
// we need to append our result if we are using tools
943+
*c = append(*c, toolChoice)
944+
}
916945
}
946+
947+
}, nil)
948+
if err != nil {
949+
return err
917950
}
918951
919-
}, nil)
920-
if err != nil {
921-
return err
922-
}
952+
resp := &schema.OpenAIResponse{
953+
ID: id,
954+
Created: created,
955+
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
956+
Choices: result,
957+
Object: "chat.completion",
958+
Usage: schema.OpenAIUsage{
959+
PromptTokens: tokenUsage.Prompt,
960+
CompletionTokens: tokenUsage.Completion,
961+
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
962+
},
963+
}
964+
respData, _ := json.Marshal(resp)
965+
log.Debug().Msgf("Response: %s", respData)
923966
924-
resp := &schema.OpenAIResponse{
925-
ID: id,
926-
Created: created,
927-
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
928-
Choices: result,
929-
Object: "chat.completion",
930-
Usage: schema.OpenAIUsage{
931-
PromptTokens: tokenUsage.Prompt,
932-
CompletionTokens: tokenUsage.Completion,
933-
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
934-
},
935-
}
936-
respData, _ := json.Marshal(resp)
937-
log.Debug().Msgf("Response: %s", respData)
967+
// Return the prediction in the response body
968+
return c.JSON(resp)
938969
939-
// Return the prediction in the response body
940-
return c.JSON(resp)
970+
*/
941971

942972
// TODO: use session.ModelInterface...
943973
// Simulate a function call

0 commit comments

Comments
 (0)