Skip to content

Commit 9a09820

Browse files
committed
WIP - improve start and end of speech detection
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 30e3c47 commit 9a09820

File tree

1 file changed

+127
-120
lines changed

1 file changed

+127
-120
lines changed

core/http/endpoints/openai/realtime.go

Lines changed: 127 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -497,156 +497,163 @@ type VADState int
497497
const (
498498
StateSilence VADState = iota
499499
StateSpeaking
500-
StateTrailingSilence
501500
)
502501

503-
// handle VAD (Voice Activity Detection)
504-
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) {
502+
const (
503+
// tune these thresholds to taste
504+
SpeechFramesThreshold = 3 // must see X consecutive speech results to confirm "start"
505+
SilenceFramesThreshold = 5 // must see X consecutive silence results to confirm "end"
506+
)
505507

508+
// handleVAD is a goroutine that listens for audio data from the client,
509+
// runs VAD on the audio data, and commits utterances to the conversation
510+
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
506511
vadContext, cancel := context.WithCancel(context.Background())
507-
//var startListening time.Time
508-
509512
go func() {
510513
<-done
511514
cancel()
512515
}()
513516

514-
vadState := VADState(StateSilence)
515-
segments := []*proto.VADSegment{}
516-
timeListening := time.Now()
517+
ticker := time.NewTicker(300 * time.Millisecond)
518+
defer ticker.Stop()
519+
520+
var (
521+
lastSegmentCount int
522+
timeOfLastNewSeg time.Time
523+
speaking bool
524+
)
517525

518-
// Implement VAD logic here
519-
// For brevity, this is a placeholder
520-
// When VAD detects end of speech, generate a response
521-
// TODO: use session.ModelInterface to handle VAD and cut audio and detect when to process that
522526
for {
523527
select {
524528
case <-done:
525529
return
526-
default:
527-
// Check if there's audio data to process
530+
case <-ticker.C:
531+
// 1) Copy the entire buffer
528532
session.AudioBufferLock.Lock()
533+
allAudio := make([]byte, len(session.InputAudioBuffer))
534+
copy(allAudio, session.InputAudioBuffer)
535+
session.AudioBufferLock.Unlock()
529536

530-
if len(session.InputAudioBuffer) > 0 {
531-
532-
if vadState == StateTrailingSilence {
533-
log.Debug().Msgf("VAD detected speech that we can process")
534-
535-
// Commit the audio buffer as a conversation item
536-
item := &Item{
537-
ID: generateItemID(),
538-
Object: "realtime.item",
539-
Type: "message",
540-
Status: "completed",
541-
Role: "user",
542-
Content: []ConversationContent{
543-
{
544-
Type: "input_audio",
545-
Audio: base64.StdEncoding.EncodeToString(session.InputAudioBuffer),
546-
},
547-
},
548-
}
537+
// 2) If there's no audio at all, just continue
538+
if len(allAudio) == 0 {
539+
continue
540+
}
549541

550-
// Add item to conversation
551-
conversation.Lock.Lock()
552-
conversation.Items = append(conversation.Items, item)
553-
conversation.Lock.Unlock()
554-
555-
// Reset InputAudioBuffer
556-
session.InputAudioBuffer = nil
557-
session.AudioBufferLock.Unlock()
558-
559-
// Send item.created event
560-
sendEvent(c, OutgoingMessage{
561-
Type: "conversation.item.created",
562-
Item: item,
563-
})
564-
565-
vadState = StateSilence
566-
segments = []*proto.VADSegment{}
567-
// Generate a response
568-
generateResponse(cfg, evaluator, session, conversation, ResponseCreate{}, c, websocket.TextMessage)
569-
continue
570-
}
542+
// 3) Run VAD on the entire audio so far
543+
segments, err := runVAD(vadContext, session, allAudio)
544+
if err != nil {
545+
log.Error().Msgf("failed to process audio: %s", err.Error())
546+
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
547+
// handle or log error, continue
548+
continue
549+
}
571550

572-
adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
551+
segCount := len(segments)
573552

574-
// Resample from 24kHz to 16kHz
575-
// adata = sound.ResampleInt16(adata, 24000, 16000)
553+
if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
554+
// no speech detected, and we haven't seen a new segment in > 1s
555+
// clean up input
556+
session.AudioBufferLock.Lock()
557+
session.InputAudioBuffer = nil
558+
session.AudioBufferLock.Unlock()
559+
log.Debug().Msgf("Detected silence for a while, clearing audio buffer")
560+
continue
561+
}
576562

577-
soundIntBuffer := &audio.IntBuffer{
578-
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
579-
}
580-
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
563+
// 4) If we see more segments than before => "new speech"
564+
if segCount > lastSegmentCount {
565+
speaking = true
566+
lastSegmentCount = segCount
567+
timeOfLastNewSeg = time.Now()
568+
log.Debug().Msgf("Detected new speech segment")
569+
}
581570

582-
/* if len(adata) < 16000 {
583-
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
584-
session.AudioBufferLock.Unlock()
585-
continue
586-
} */
587-
float32Data := soundIntBuffer.AsFloat32Buffer().Data
588-
589-
// TODO: testing wav decoding
590-
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
591-
// buf, err := dec.FullPCMBuffer()
592-
// if err != nil {
593-
// //log.Error().Msgf("failed to process audio: %s", err.Error())
594-
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
595-
// session.AudioBufferLock.Unlock()
596-
// continue
597-
// }
598-
599-
//float32Data = buf.AsFloat32Buffer().Data
600-
601-
resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
602-
Audio: float32Data,
603-
})
604-
if err != nil {
605-
log.Error().Msgf("failed to process audio: %s", err.Error())
606-
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
607-
session.AudioBufferLock.Unlock()
608-
continue
609-
}
571+
// 5) If speaking, but we haven't seen a new segment in > 1s => finalize
572+
if speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
573+
log.Debug().Msgf("Detected end of speech segment")
574+
// user has presumably stopped talking
575+
commitUtterance(allAudio, cfg, evaluator, session, conv, c)
576+
// reset state
577+
speaking = false
578+
lastSegmentCount = 0
579+
}
580+
}
581+
}
582+
}
610583

611-
if len(resp.Segments) == 0 {
612-
log.Debug().Msg("VAD detected no speech activity")
613-
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
614-
if len(session.InputAudioBuffer) > 16000 {
615-
session.InputAudioBuffer = nil
616-
segments = []*proto.VADSegment{}
617-
}
584+
func commitUtterance(utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
585+
if len(utt) == 0 {
586+
return
587+
}
588+
// Commit logic: create item, broadcast item.created, etc.
589+
item := &Item{
590+
ID: generateItemID(),
591+
Object: "realtime.item",
592+
Type: "message",
593+
Status: "completed",
594+
Role: "user",
595+
Content: []ConversationContent{
596+
{
597+
Type: "input_audio",
598+
Audio: base64.StdEncoding.EncodeToString(utt),
599+
},
600+
},
601+
}
602+
conv.Lock.Lock()
603+
conv.Items = append(conv.Items, item)
604+
conv.Lock.Unlock()
618605

619-
log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer))
620-
} else if (len(resp.Segments) != len(segments)) && vadState == StateSpeaking {
621-
// We have new segments, but we are still speaking
622-
// We need to wait for the trailing silence
606+
sendEvent(c, OutgoingMessage{
607+
Type: "conversation.item.created",
608+
Item: item,
609+
})
623610

624-
segments = resp.Segments
611+
// Optionally trigger the response generation
612+
generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage)
613+
}
625614

626-
} else if (len(resp.Segments) == len(segments)) && vadState == StateSpeaking {
627-
// We have the same number of segments, but we are still speaking
628-
// We need to check if we are in this state for long enough, update the timer
615+
// runVAD is a helper that calls your model's VAD method, returning
616+
// true if it detects speech, false if it detects silence
617+
func runVAD(ctx context.Context, session *Session, chunk []byte) ([]*proto.VADSegment, error) {
629618

630-
// Check if we have been listening for too long
631-
if time.Since(timeListening) > sendToVADDelay {
632-
vadState = StateTrailingSilence
633-
} else {
619+
adata := sound.BytesToInt16sLE(chunk)
634620

635-
timeListening = timeListening.Add(time.Since(timeListening))
636-
}
637-
} else {
638-
log.Debug().Msg("VAD detected speech activity")
639-
vadState = StateSpeaking
640-
segments = resp.Segments
641-
}
621+
// Resample from 24kHz to 16kHz
622+
// adata = sound.ResampleInt16(adata, 24000, 16000)
642623

643-
session.AudioBufferLock.Unlock()
644-
} else {
645-
session.AudioBufferLock.Unlock()
646-
}
624+
soundIntBuffer := &audio.IntBuffer{
625+
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
626+
}
627+
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
647628

648-
}
629+
/* if len(adata) < 16000 {
630+
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
631+
session.AudioBufferLock.Unlock()
632+
continue
633+
} */
634+
float32Data := soundIntBuffer.AsFloat32Buffer().Data
635+
636+
resp, err := session.ModelInterface.VAD(ctx, &proto.VADRequest{
637+
Audio: float32Data,
638+
})
639+
if err != nil {
640+
return nil, err
649641
}
642+
643+
// TODO: testing wav decoding
644+
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
645+
// buf, err := dec.FullPCMBuffer()
646+
// if err != nil {
647+
// //log.Error().Msgf("failed to process audio: %s", err.Error())
648+
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
649+
// session.AudioBufferLock.Unlock()
650+
// continue
651+
// }
652+
653+
//float32Data = buf.AsFloat32Buffer().Data
654+
655+
// If resp.Segments is empty => no speech
656+
return resp.Segments, nil
650657
}
651658

652659
// Function to generate a response based on the conversation

0 commit comments

Comments
 (0)