Skip to content

Commit

Permalink
Add user_chat capacity to proactive notification (#1441)
Browse files Browse the repository at this point in the history
Issue: #1327

## TODO
- [x] Add user_chat capacity to proactive notification
- [ ] Try with internal logs first
https://us5.datadoghq.com/dashboard/5p3-sb2-q2c/omi-apps-devtool-log?fromUser=false&refresh_mode=sliding&from_ts=1732750755142&to_ts=1732923555142&live=true

## Deploy steps
- [ ] Deploy backend
- [ ] Deploy pusher
- [ ] Deploy plugin
  • Loading branch information
beastoin authored Nov 29, 2024
2 parents b725732 + 02aedca commit 349615a
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 7 deletions.
45 changes: 45 additions & 0 deletions backend/database/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from google.cloud import firestore
from google.cloud.firestore_v1 import FieldFilter

from models.chat import Message
from utils.other.endpoints import timeit
Expand Down Expand Up @@ -47,6 +48,50 @@ def add_summary_message(text: str, uid: str) -> Message:
return ai_message


def get_plugin_messages(uid: str, plugin_id: str, limit: int = 20, offset: int = 0, include_memories: bool = False):
user_ref = db.collection('users').document(uid)
messages_ref = (
user_ref.collection('messages')
.where(filter=FieldFilter('plugin_id', '==', plugin_id))
.order_by('created_at', direction=firestore.Query.DESCENDING)
.limit(limit)
.offset(offset)
)
messages = []
memories_id = set()

# Fetch messages and collect memory IDs
for doc in messages_ref.stream():
message = doc.to_dict()

if message.get('deleted') is True:
continue

messages.append(message)
memories_id.update(message.get('memories_id', []))

if not include_memories:
return messages

# Fetch all memories at once
memories = {}
memories_ref = user_ref.collection('memories')
doc_refs = [memories_ref.document(str(memory_id)) for memory_id in memories_id]
docs = db.get_all(doc_refs)
for doc in docs:
if doc.exists:
memory = doc.to_dict()
memories[memory['id']] = memory

# Attach memories to messages
for message in messages:
message['memories'] = [
memories[memory_id] for memory_id in message.get('memories_id', []) if memory_id in memories
]

return messages


@timeit
def get_messages(uid: str, limit: int = 20, offset: int = 0, include_memories: bool = False):
user_ref = db.collection('users').document(uid)
Expand Down
6 changes: 4 additions & 2 deletions backend/routers/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def get_notification_scopes():
return [
{'title': 'User Name', 'id': 'user_name'},
{'title': 'User Facts', 'id': 'user_facts'},
{'title': 'User Memories', 'id': 'user_context'}
{'title': 'User Memories', 'id': 'user_context'},
{'title': 'User Chat', 'id': 'user_chat'}
]


Expand All @@ -243,7 +244,8 @@ def get_plugin_capabilities():
{'title': 'Notification', 'id': 'proactive_notification', 'scopes': [
{'title': 'User Name', 'id': 'user_name'},
{'title': 'User Facts', 'id': 'user_facts'},
{'title': 'User Memories', 'id': 'user_context'}
{'title': 'User Memories', 'id': 'user_context'},
{'title': 'User Chat', 'id': 'user_chat'}
]}
]

Expand Down
2 changes: 1 addition & 1 deletion backend/utils/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
get_public_approved_apps_db, get_app_by_id_db, get_app_usage_history_db, set_app_review_in_db
from database.redis_db import get_enabled_plugins, get_plugin_installs_count, get_plugin_reviews, get_generic_cache, \
set_generic_cache, set_app_usage_history_cache, get_app_usage_history_cache, get_app_money_made_cache, \
set_app_money_made_cache, set_plugin_review, get_plugins_installs_count, get_plugins_reviews
set_app_money_made_cache, set_plugin_review, get_plugins_installs_count, get_plugins_reviews, get_app_cache_by_id, set_app_cache_by_id
from models.app import App, UsageHistoryItem, UsageHistoryType


Expand Down
5 changes: 4 additions & 1 deletion backend/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def provide_advice_message(uid: str, segments: List[TranscriptSegment], context:
# ************* PROACTIVE NOTIFICATION PLUGIN **************
# **************************************************

def get_proactive_message(uid: str, plugin_prompt: str, params: [str], context: str) -> str:
def get_proactive_message(uid: str, plugin_prompt: str, params: [str], context: str, chat_messages: List[Message]) -> str:
user_name, facts_str = get_prompt_facts(uid)

prompt = plugin_prompt
Expand All @@ -871,6 +871,9 @@ def get_proactive_message(uid: str, plugin_prompt: str, params: [str], context:
if param == "user_context":
prompt = prompt.replace("{{user_context}}", context if context else "")
continue
if param == "user_chat":
prompt = prompt.replace("{{user_chat}}", Message.get_messages_as_string(chat_messages) if chat_messages else "")
continue
prompt = prompt.replace(' ', '').strip()
# print(prompt)

Expand Down
10 changes: 8 additions & 2 deletions backend/utils/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from database import mem_db
from database import redis_db
from database.apps import get_private_apps_db, get_public_apps_db
from database.chat import add_plugin_message
from database.chat import add_plugin_message, get_plugin_messages
from database.plugins import record_plugin_usage
from database.redis_db import get_enabled_plugins, get_plugin_reviews, get_plugin_installs_count, get_generic_cache, \
set_generic_cache, get_plugins_reviews, get_plugins_installs_count
from models.app import App
from models.chat import Message
from models.memory import Memory, MemorySource
from models.notification_message import NotificationMessage
from models.plugin import Plugin, UsageHistoryType
Expand Down Expand Up @@ -303,10 +304,15 @@ def _process_proactive_notification(uid: str, token: str, plugin: App, data):
if len(memories) > 0:
context = Memory.memories_to_string(memories, True)

# messages
messages = []
if 'user_chat' in filter_scopes:
messages = list(reversed([Message(**msg) for msg in get_plugin_messages(uid, plugin.id, limit=10)]))

# print(f'_process_proactive_notification context {context[:100] if context else "empty"}')

# retrive message
message = get_proactive_message(uid, prompt, filter_scopes, context)
message = get_proactive_message(uid, prompt, filter_scopes, context, messages)
if not message or len(message) < min_message_char_limit:
print(f"Plugins {plugin.id}, message too short", uid)
return None
Expand Down
8 changes: 7 additions & 1 deletion plugins/example/basic/mentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def normalize(text):
user_name = "{{user_name}}"
user_facts = "{{user_facts}}"
user_context = "{{user_context}}"
user_chat = "{{user_chat}}"

prompt = f"""
You are an experienced mentor, that helps people achieve their goals during the meeting.
Expand Down Expand Up @@ -74,6 +75,11 @@ def normalize(text):
${transcript}
```
Converstation History:
```
{user_chat}
```
Context:
```
{user_context}
Expand All @@ -86,7 +92,7 @@ def normalize(text):
'session_id': data.session_id,
'notification': {
'prompt': prompt,
'params': ['user_name', 'user_facts', 'user_context'],
'params': ['user_name', 'user_facts', 'user_context', 'user_chat'],
}
}

Expand Down

0 comments on commit 349615a

Please sign in to comment.