Skip to content

Commit 1e08ef0

Browse files
author
Chen188
committed
* adapting code with TEN
* pass transcribe and polly init param when invoking start api; * update transcribe_asr graph to display chat in playground; * other code improvements.
1 parent a1090a3 commit 1e08ef0

File tree

9 files changed

+237
-69
lines changed

9 files changed

+237
-69
lines changed

agents/property.json

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@
441441
"region": "us-east-1",
442442
"access_key": "$AWS_ACCESS_KEY_ID",
443443
"secret_key": "$AWS_SECRET_ACCESS_KEY",
444-
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
444+
"model": "$AWS_BEDROCK_MODEL",
445445
"max_tokens": 512,
446446
"prompt": "",
447447
"greeting": "ASTRA agent connected. How can i help you today?",
@@ -1008,7 +1008,7 @@
10081008
"region": "us-east-1",
10091009
"access_key": "$AWS_ACCESS_KEY_ID",
10101010
"secret_key": "$AWS_SECRET_ACCESS_KEY",
1011-
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
1011+
"model": "$AWS_BEDROCK_MODEL",
10121012
"max_tokens": 512,
10131013
"prompt": "",
10141014
"greeting": "ASTRA agent connected. How can i help you today?",
@@ -1036,6 +1036,12 @@
10361036
"addon": "interrupt_detector_python",
10371037
"name": "interrupt_detector"
10381038
},
1039+
{
1040+
"type": "extension",
1041+
"extension_group": "transcriber",
1042+
"addon": "message_collector",
1043+
"name": "message_collector"
1044+
},
10391045
{
10401046
"type": "extension_group",
10411047
"addon": "default_extension_group",
@@ -1067,6 +1073,10 @@
10671073
{
10681074
"extension_group": "bedrock",
10691075
"extension": "bedrock_llm"
1076+
},
1077+
{
1078+
"extension_group": "transcriber",
1079+
"extension": "message_collector"
10701080
}
10711081
]
10721082
}
@@ -1082,6 +1092,30 @@
10821092
{
10831093
"extension_group": "tts",
10841094
"extension": "polly_tts"
1095+
},
1096+
{
1097+
"extension_group": "transcriber",
1098+
"extension": "message_collector",
1099+
"cmd_conversions": [
1100+
{
1101+
"cmd": {
1102+
"type": "per_property",
1103+
"keep_original": true,
1104+
"rules": [
1105+
{
1106+
"path": "is_final",
1107+
"type": "fixed_value",
1108+
"value": "bool(true)"
1109+
},
1110+
{
1111+
"path": "stream_id",
1112+
"type": "fixed_value",
1113+
"value": "uint32(999)"
1114+
}
1115+
]
1116+
}
1117+
}
1118+
]
10851119
}
10861120
]
10871121
}
@@ -1124,6 +1158,21 @@
11241158
}
11251159
]
11261160
},
1161+
{
1162+
"extension_group": "transcriber",
1163+
"extension": "message_collector",
1164+
"data": [
1165+
{
1166+
"name": "data",
1167+
"dest": [
1168+
{
1169+
"extension_group": "default",
1170+
"extension": "agora_rtc"
1171+
}
1172+
]
1173+
}
1174+
]
1175+
},
11271176
{
11281177
"extension_group": "default",
11291178
"extension": "interrupt_detector",
@@ -1158,7 +1207,7 @@
11581207
"remote_stream_id": 123,
11591208
"subscribe_audio": true,
11601209
"publish_audio": true,
1161-
"publish_data": false,
1210+
"publish_data": true,
11621211
"enable_agora_asr": false,
11631212
"agora_asr_vendor_name": "microsoft",
11641213
"agora_asr_language": "en-US",
@@ -1189,7 +1238,7 @@
11891238
"region": "us-east-1",
11901239
"access_key": "$AWS_ACCESS_KEY_ID",
11911240
"secret_key": "$AWS_SECRET_ACCESS_KEY",
1192-
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
1241+
"model": "$AWS_BEDROCK_MODEL",
11931242
"max_tokens": 512,
11941243
"prompt": "",
11951244
"greeting": "ASTRA agent connected. How can i help you today?",
@@ -1217,6 +1266,12 @@
12171266
"addon": "interrupt_detector_python",
12181267
"name": "interrupt_detector"
12191268
},
1269+
{
1270+
"type": "extension",
1271+
"extension_group": "transcriber",
1272+
"addon": "message_collector",
1273+
"name": "message_collector"
1274+
},
12201275
{
12211276
"type": "extension_group",
12221277
"addon": "default_extension_group",
@@ -1297,6 +1352,10 @@
12971352
{
12981353
"extension_group": "bedrock",
12991354
"extension": "bedrock_llm"
1355+
},
1356+
{
1357+
"extension_group": "transcriber",
1358+
"extension": "message_collector"
13001359
}
13011360
]
13021361
}
@@ -1312,6 +1371,30 @@
13121371
{
13131372
"extension_group": "tts",
13141373
"extension": "polly_tts"
1374+
},
1375+
{
1376+
"extension_group": "transcriber",
1377+
"extension": "message_collector",
1378+
"cmd_conversions": [
1379+
{
1380+
"cmd": {
1381+
"type": "per_property",
1382+
"keep_original": true,
1383+
"rules": [
1384+
{
1385+
"path": "is_final",
1386+
"type": "fixed_value",
1387+
"value": "bool(true)"
1388+
},
1389+
{
1390+
"path": "stream_id",
1391+
"type": "fixed_value",
1392+
"value": "uint32(999)"
1393+
}
1394+
]
1395+
}
1396+
}
1397+
]
13151398
}
13161399
]
13171400
}
@@ -1354,6 +1437,21 @@
13541437
}
13551438
]
13561439
},
1440+
{
1441+
"extension_group": "transcriber",
1442+
"extension": "message_collector",
1443+
"data": [
1444+
{
1445+
"name": "data",
1446+
"dest": [
1447+
{
1448+
"extension_group": "default",
1449+
"extension": "agora_rtc"
1450+
}
1451+
]
1452+
}
1453+
]
1454+
},
13571455
{
13581456
"extension_group": "default",
13591457
"extension": "interrupt_detector",
@@ -2161,4 +2259,4 @@
21612259
}
21622260
]
21632261
}
2164-
}
2262+
}

agents/ten_packages/extension/bedrock_llm_python/bedrock_llm_extension.py

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -136,20 +136,27 @@ def on_start(self, ten: TenEnv) -> None:
136136

137137
# Send greeting if available
138138
if greeting:
139-
try:
140-
output_data = Data.create("text_data")
141-
output_data.set_property_string(
142-
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting
143-
)
144-
output_data.set_property_bool(
145-
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True
146-
)
147-
ten.send_data(output_data)
148-
logger.info(f"greeting [{greeting}] sent")
149-
except Exception as err:
150-
logger.info(f"greeting [{greeting}] send failed, err: {err}")
139+
logger.info(f'sending greeting: [{greeting}]')
140+
self.send_data(ten=ten, sentence=greeting, end_of_segment=True, input_text='greeting')
141+
151142
ten.on_start_done()
152143

144+
def send_data(self, ten, sentence, end_of_segment, input_text):
145+
try:
146+
output_data = Data.create("text_data")
147+
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
148+
output_data.set_property_bool(
149+
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, end_of_segment
150+
)
151+
ten.send_data(output_data)
152+
logger.info(
153+
f"for input text: [{input_text}] {'end of segment ' if end_of_segment else ''}sent sentence [{sentence}]"
154+
)
155+
except Exception as err:
156+
logger.exception(
157+
f"for input text: [{input_text}] send sentence [{sentence}] failed, err: {err}"
158+
)
159+
153160
def on_stop(self, ten: TenEnv) -> None:
154161
logger.info("BedrockLLMExtension on_stop")
155162
ten.on_stop_done()
@@ -294,24 +301,12 @@ def converse_stream_worker(start_time, input_text, memory):
294301
)
295302

296303
# send sentence
297-
try:
298-
output_data = Data.create("text_data")
299-
output_data.set_property_string(
300-
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence
301-
)
302-
output_data.set_property_bool(
303-
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, False
304-
)
305-
ten.send_data(output_data)
306-
logger.info(
307-
f"GetConverseStream recv for input text: [{input_text}] sent sentence [{sentence}]"
308-
)
309-
except Exception as err:
310-
logger.info(
311-
f"GetConverseStream recv for input text: [{input_text}] send sentence [{sentence}] failed, err: {err}"
312-
)
313-
break
314-
304+
self.send_data(
305+
ten=ten,
306+
sentence=sentence,
307+
end_of_segment=False,
308+
input_text=input_text,
309+
)
315310
sentence = ""
316311
if not first_sentence_sent:
317312
first_sentence_sent = True
@@ -335,22 +330,7 @@ def converse_stream_worker(start_time, input_text, memory):
335330
return
336331

337332
# send end of segment
338-
try:
339-
output_data = Data.create("text_data")
340-
output_data.set_property_string(
341-
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence
342-
)
343-
output_data.set_property_bool(
344-
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True
345-
)
346-
ten.send_data(output_data)
347-
logger.info(
348-
f"GetConverseStream for input text: [{input_text}] end of segment with sentence [{sentence}] sent"
349-
)
350-
except Exception as err:
351-
logger.info(
352-
f"GetConverseStream for input text: [{input_text}] end of segment with sentence [{sentence}] send failed, err: {err}"
353-
)
333+
self.send_data(ten=ten, sentence=sentence, end_of_segment=True, input_text=input_text)
354334

355335
except Exception as e:
356336
logger.info(

agents/ten_packages/extension/polly_tts/polly_tts_extension.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def on_start(self, ten: TenEnv) -> None:
6363
f"GetProperty optional {optional_param} failed, err: {err}. Using default value: {polly_config.__getattribute__(optional_param)}"
6464
)
6565

66+
polly_config.validate()
67+
6668
self.polly = PollyWrapper(polly_config)
6769
self.frame_size = int(
6870
int(polly_config.sample_rate)

agents/ten_packages/extension/polly_tts/polly_wrapper.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,29 @@
77

88
from .log import logger
99

10+
ENGINE_STANDARD = 'standard'
11+
ENGINE_NEURAL = 'neural'
12+
ENGINE_GENERATIVE = 'generative'
13+
ENGINE_LONG_FORM = 'long-form'
14+
15+
VOICE_ENGINE_MAP = {
16+
"Zhiyu": [ENGINE_STANDARD, ENGINE_NEURAL],
17+
"Matthew": [ENGINE_GENERATIVE, ENGINE_NEURAL],
18+
"Ruth": [ENGINE_GENERATIVE, ENGINE_NEURAL, ENGINE_LONG_FORM]
19+
}
20+
21+
VOICE_LANG_MAP = {
22+
"Zhiyu": ['cmn-CN'],
23+
"Matthew": ['en-US'],
24+
"Ruth": ['en-US']
25+
}
26+
27+
LANGCODE_MAP = {
28+
'cmn-CN': 'cmn-CN',
29+
'zh-CN': 'cmn-CN',
30+
'en-US': 'en-US'
31+
}
32+
1033
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/polly/client/synthesize_speech.html
1134
class PollyConfig:
1235
def __init__(self,
@@ -30,6 +53,23 @@ def __init__(self,
3053
self.audio_format = 'pcm' # 'json'|'mp3'|'ogg_vorbis'|'pcm'
3154
self.include_visemes = False
3255

56+
def validate(self):
57+
if not self.voice in set(VOICE_ENGINE_MAP.keys()):
58+
err_msg = f"Invalid voice '{self.voice}'. Must be one of {list(VOICE_ENGINE_MAP.keys())}."
59+
logger.error(err_msg)
60+
raise ValueError(err_msg)
61+
62+
if not self.engine in VOICE_ENGINE_MAP[self.voice]:
63+
logger.warn(f"Invalid engine '{self.engine}' for voice '{self.voice}'. Must be one of {VOICE_ENGINE_MAP[self.voice]}. Fallback to {VOICE_ENGINE_MAP[self.voice][0]}")
64+
self.engine = VOICE_ENGINE_MAP[self.voice][0]
65+
66+
if self.lang_code:
67+
self.lang_code = LANGCODE_MAP.get(self.lang_code, self.lang_code)
68+
69+
if not self.lang_code in VOICE_LANG_MAP[self.voice]:
70+
logger.warn(f"Invalid language code '{self.lang_code}' for voice '{self.voice}'. Must be one of {VOICE_LANG_MAP[self.voice]}. Fallback to {VOICE_LANG_MAP[self.voice][0]}")
71+
self.lang_code = VOICE_LANG_MAP[self.voice][0]
72+
3373
@classmethod
3474
def default_config(cls):
3575
return cls(
@@ -172,4 +212,4 @@ def get_voices(self, engine, language_code):
172212
vo["Name"]: vo["Id"]
173213
for vo in self.voice_metadata
174214
if engine in vo["SupportedEngines"] and language_code == vo["LanguageCode"]
175-
}
215+
}

0 commit comments

Comments
 (0)