-
Notifications
You must be signed in to change notification settings - Fork 226
/
Copy pathwatsonx.py
116 lines (94 loc) · 3.91 KB
/
watsonx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import asyncio
import json
import sys
import traceback
from pydantic import BaseModel, Field
from beeai_framework import ToolMessage
from beeai_framework.adapters.watsonx.backend.chat import WatsonxChatModel
from beeai_framework.backend.chat import ChatModel
from beeai_framework.backend.message import MessageToolResultContent, UserMessage
from beeai_framework.cancellation import AbortSignal
from beeai_framework.errors import AbortError, FrameworkError
from beeai_framework.tools.weather.openmeteo import OpenMeteoTool
# Setting can be passed here during initiation or pre-configured via environment variables
llm = WatsonxChatModel(
"ibm/granite-3-8b-instruct",
# settings={
# "project_id": "WATSONX_PROJECT_ID",
# "api_key": "WATSONX_API_KEY",
# "api_base": "WATSONX_API_URL",
# },
)
async def watsonx_from_name() -> None:
watsonx_llm = ChatModel.from_name(
"watsonx:ibm/granite-3-8b-instruct",
# {
# "project_id": "WATSONX_PROJECT_ID",
# "api_key": "WATSONX_API_KEY",
# "api_base": "WATSONX_API_URL",
# },
)
user_message = UserMessage("what states are part of New England?")
response = await watsonx_llm.create(messages=[user_message])
print(response.get_text_content())
async def watsonx_sync() -> None:
user_message = UserMessage("what is the capital of Massachusetts?")
response = await llm.create(messages=[user_message])
print(response.get_text_content())
async def watsonx_stream() -> None:
user_message = UserMessage("How many islands make up the country of Cape Verde?")
response = await llm.create(messages=[user_message], stream=True)
print(response.get_text_content())
async def watsonx_stream_abort() -> None:
user_message = UserMessage("What is the smallest of the Cape Verde islands?")
try:
response = await llm.create(messages=[user_message], stream=True, abort_signal=AbortSignal.timeout(0.5))
if response is not None:
print(response.get_text_content())
else:
print("No response returned.")
except AbortError as err:
print(f"Aborted: {err}")
async def watson_structure() -> None:
class TestSchema(BaseModel):
answer: str = Field(description="your final answer")
user_message = UserMessage("How many islands make up the country of Cape Verde?")
response = await llm.create_structure(schema=TestSchema, messages=[user_message])
print(response.object)
async def watson_tool_calling() -> None:
watsonx_llm = ChatModel.from_name(
"watsonx:ibm/granite-3-8b-instruct",
)
user_message = UserMessage("What is the current weather in Boston?")
weather_tool = OpenMeteoTool()
response = await watsonx_llm.create(messages=[user_message], tools=[weather_tool])
tool_call_msg = response.get_tool_calls()[0]
print(tool_call_msg.model_dump())
tool_response = await weather_tool.run(json.loads(tool_call_msg.args))
tool_response_msg = ToolMessage(
MessageToolResultContent(
result=tool_response.get_text_content(), tool_name=tool_call_msg.tool_name, tool_call_id=tool_call_msg.id
)
)
print(tool_response_msg.to_plain())
final_response = await watsonx_llm.create(messages=[user_message, tool_response_msg], tools=[])
print(final_response.get_text_content())
async def main() -> None:
print("*" * 10, "watsonx_from_name")
await watsonx_from_name()
print("*" * 10, "watsonx_sync")
await watsonx_sync()
print("*" * 10, "watsonx_stream")
await watsonx_stream()
print("*" * 10, "watsonx_stream_abort")
await watsonx_stream_abort()
print("*" * 10, "watson_structure")
await watson_structure()
print("*" * 10, "watson_tool_calling")
await watson_tool_calling()
if __name__ == "__main__":
try:
asyncio.run(main())
except FrameworkError as e:
traceback.print_exc()
sys.exit(e.explain())