From c374ed5a48ed90ae56315754309f8d120e2c54f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sat, 30 Nov 2024 05:49:50 +0700 Subject: [PATCH 1/2] Add user_chat capacity to proactive notification --- backend/routers/apps.py | 6 ++++-- backend/utils/llm.py | 5 ++++- backend/utils/plugins.py | 10 ++++++++-- plugins/example/basic/mentor.py | 8 +++++++- 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/backend/routers/apps.py b/backend/routers/apps.py index a5b1b518d..0cd52f1dc 100644 --- a/backend/routers/apps.py +++ b/backend/routers/apps.py @@ -227,7 +227,8 @@ def get_notification_scopes(): return [ {'title': 'User Name', 'id': 'user_name'}, {'title': 'User Facts', 'id': 'user_facts'}, - {'title': 'User Memories', 'id': 'user_context'} + {'title': 'User Memories', 'id': 'user_context'}, + {'title': 'User Chat', 'id': 'user_chat'} ] @@ -243,7 +244,8 @@ def get_plugin_capabilities(): {'title': 'Notification', 'id': 'proactive_notification', 'scopes': [ {'title': 'User Name', 'id': 'user_name'}, {'title': 'User Facts', 'id': 'user_facts'}, - {'title': 'User Memories', 'id': 'user_context'} + {'title': 'User Memories', 'id': 'user_context'}, + {'title': 'User Chat', 'id': 'user_chat'} ]} ] diff --git a/backend/utils/llm.py b/backend/utils/llm.py index 0bb50bdd5..a71ce3a7a 100644 --- a/backend/utils/llm.py +++ b/backend/utils/llm.py @@ -857,7 +857,7 @@ def provide_advice_message(uid: str, segments: List[TranscriptSegment], context: # ************* PROACTIVE NOTIFICATION PLUGIN ************** # ************************************************** -def get_proactive_message(uid: str, plugin_prompt: str, params: [str], context: str) -> str: +def get_proactive_message(uid: str, plugin_prompt: str, params: [str], context: str, chat_messages: List[Message]) -> str: user_name, facts_str = get_prompt_facts(uid) prompt = plugin_prompt @@ -871,6 +871,9 @@ def get_proactive_message(uid: str, plugin_prompt: str, params: [str], context: if param == "user_context": prompt = prompt.replace("{{user_context}}", context if context else "") continue + if param == "user_chat": + prompt = prompt.replace("{{user_chat}}", Message.get_messages_as_string(chat_messages) if chat_messages else "") + continue prompt = prompt.replace(' ', '').strip() # print(prompt) diff --git a/backend/utils/plugins.py b/backend/utils/plugins.py index 7ee560c54..b91340d1e 100644 --- a/backend/utils/plugins.py +++ b/backend/utils/plugins.py @@ -8,11 +8,12 @@ from database import mem_db from database import redis_db from database.apps import get_private_apps_db, get_public_apps_db -from database.chat import add_plugin_message +from database.chat import add_plugin_message, get_messages from database.plugins import record_plugin_usage from database.redis_db import get_enabled_plugins, get_plugin_reviews, get_plugin_installs_count, get_generic_cache, \ set_generic_cache, get_plugins_reviews, get_plugins_installs_count from models.app import App +from models.chat import Message from models.memory import Memory, MemorySource from models.notification_message import NotificationMessage from models.plugin import Plugin, UsageHistoryType @@ -303,10 +304,15 @@ def _process_proactive_notification(uid: str, token: str, plugin: App, data): if len(memories) > 0: context = Memory.memories_to_string(memories, True) + # messages + messages = [] + if 'user_chat' in filter_scopes: + messages = list(reversed([Message(**msg) for msg in get_messages(uid, limit=10)])) + # print(f'_process_proactive_notification context {context[:100] if context else "empty"}') # retrive message - message = get_proactive_message(uid, prompt, filter_scopes, context) + message = get_proactive_message(uid, prompt, filter_scopes, context, messages) if not message or len(message) < min_message_char_limit: print(f"Plugins {plugin.id}, message too short", uid) return None diff --git a/plugins/example/basic/mentor.py b/plugins/example/basic/mentor.py index 0ebd1509d..c5e5ec429 100644 --- a/plugins/example/basic/mentor.py +++ b/plugins/example/basic/mentor.py @@ -44,6 +44,7 @@ def normalize(text): user_name = "{{user_name}}" user_facts = "{{user_facts}}" user_context = "{{user_context}}" + user_chat = "{{user_chat}}" prompt = f""" You are an experienced mentor, that helps people achieve their goals during the meeting. @@ -74,6 +75,11 @@ def normalize(text): ${transcript} ``` + Converstation History: + ``` + {user_chat} + ``` + Context: ``` {user_context} @@ -86,7 +92,7 @@ def normalize(text): 'session_id': data.session_id, 'notification': { 'prompt': prompt, - 'params': ['user_name', 'user_facts', 'user_context'], + 'params': ['user_name', 'user_facts', 'user_context', 'user_chat'], } } From 02aedcaa60c77943188f6b623e69c4cb587a6583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=E1=BB=8Bnh?= Date: Sat, 30 Nov 2024 06:34:51 +0700 Subject: [PATCH 2/2] Get plugin messages instead of all messages --- backend/database/chat.py | 45 ++++++++++++++++++++++++++++++++++++++++ backend/utils/apps.py | 2 +- backend/utils/plugins.py | 4 ++-- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/backend/database/chat.py b/backend/database/chat.py index 1404ce2eb..48da4daf5 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -3,6 +3,7 @@ from typing import Optional from google.cloud import firestore +from google.cloud.firestore_v1 import FieldFilter from models.chat import Message from utils.other.endpoints import timeit @@ -47,6 +48,50 @@ def add_summary_message(text: str, uid: str) -> Message: return ai_message +def get_plugin_messages(uid: str, plugin_id: str, limit: int = 20, offset: int = 0, include_memories: bool = False): + user_ref = db.collection('users').document(uid) + messages_ref = ( + user_ref.collection('messages') + .where(filter=FieldFilter('plugin_id', '==', plugin_id)) + .order_by('created_at', direction=firestore.Query.DESCENDING) + .limit(limit) + .offset(offset) + ) + messages = [] + memories_id = set() + + # Fetch messages and collect memory IDs + for doc in messages_ref.stream(): + message = doc.to_dict() + + if message.get('deleted') is True: + continue + + messages.append(message) + memories_id.update(message.get('memories_id', [])) + + if not include_memories: + return messages + + # Fetch all memories at once + memories = {} + memories_ref = user_ref.collection('memories') + doc_refs = [memories_ref.document(str(memory_id)) for memory_id in memories_id] + docs = db.get_all(doc_refs) + for doc in docs: + if doc.exists: + memory = doc.to_dict() + memories[memory['id']] = memory + + # Attach memories to messages + for message in messages: + message['memories'] = [ + memories[memory_id] for memory_id in message.get('memories_id', []) if memory_id in memories + ] + + return messages + + @timeit def get_messages(uid: str, limit: int = 20, offset: int = 0, include_memories: bool = False): user_ref = db.collection('users').document(uid) diff --git a/backend/utils/apps.py b/backend/utils/apps.py index cfd5f1d09..7050c060e 100644 --- a/backend/utils/apps.py +++ b/backend/utils/apps.py @@ -6,7 +6,7 @@ get_public_approved_apps_db, get_app_by_id_db, get_app_usage_history_db, set_app_review_in_db from database.redis_db import get_enabled_plugins, get_plugin_installs_count, get_plugin_reviews, get_generic_cache, \ set_generic_cache, set_app_usage_history_cache, get_app_usage_history_cache, get_app_money_made_cache, \ - set_app_money_made_cache, set_plugin_review, get_plugins_installs_count, get_plugins_reviews + set_app_money_made_cache, set_plugin_review, get_plugins_installs_count, get_plugins_reviews, get_app_cache_by_id, set_app_cache_by_id from models.app import App, UsageHistoryItem, UsageHistoryType diff --git a/backend/utils/plugins.py b/backend/utils/plugins.py index b91340d1e..be107e4ad 100644 --- a/backend/utils/plugins.py +++ b/backend/utils/plugins.py @@ -8,7 +8,7 @@ from database import mem_db from database import redis_db from database.apps import get_private_apps_db, get_public_apps_db -from database.chat import add_plugin_message, get_messages +from database.chat import add_plugin_message, get_plugin_messages from database.plugins import record_plugin_usage from database.redis_db import get_enabled_plugins, get_plugin_reviews, get_plugin_installs_count, get_generic_cache, \ set_generic_cache, get_plugins_reviews, get_plugins_installs_count @@ -307,7 +307,7 @@ def _process_proactive_notification(uid: str, token: str, plugin: App, data): # messages messages = [] if 'user_chat' in filter_scopes: - messages = list(reversed([Message(**msg) for msg in get_messages(uid, limit=10)])) + messages = list(reversed([Message(**msg) for msg in get_plugin_messages(uid, plugin.id, limit=10)])) # print(f'_process_proactive_notification context {context[:100] if context else "empty"}')