Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nherment committed Jan 13, 2025
1 parent f8c330c commit 310f5c2
Show file tree
Hide file tree
Showing 19 changed files with 341 additions and 88 deletions.
1 change: 0 additions & 1 deletion examples/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def completion(self, messages: List[Dict[str, Any]], tools: Optional[List[Tool]]


def ask_holmes():
console = Console()

prompt = "what issues do I have in my cluster"

Expand Down
44 changes: 29 additions & 15 deletions holmes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from pydantic import FilePath, SecretStr, Field
from pydash.arrays import concat
from rich.console import Console


from holmes.core.runbooks import RunbookManager
Expand All @@ -35,6 +34,7 @@
from holmes.core.tools import YAMLToolset
from holmes.common.env_vars import ROBUSTA_CONFIG_PATH
from holmes.utils.definitions import RobustaConfig
from holmes.core.perf_timing import PerfTiming

DEFAULT_CONFIG_LOCATION = os.path.expanduser("~/.holmes/config.yaml")

Expand Down Expand Up @@ -133,20 +133,18 @@ def __get_cluster_name() -> Optional[str]:
return None

def create_console_tool_executor(
self, console: Console, allowed_toolsets: ToolsetPattern, dal:Optional[SupabaseDal]
self, allowed_toolsets: ToolsetPattern, dal:Optional[SupabaseDal]
) -> ToolExecutor:
"""
Creates ToolExecutor for the cli
"""
default_toolsets = [toolset for toolset in load_builtin_toolsets(dal, grafana_config=self.grafana) if any(tag in (ToolsetTag.CORE, ToolsetTag.CLI) for tag in toolset.tags)]

if allowed_toolsets == "*":
matching_toolsets = default_toolsets
else:
matching_toolsets = get_matching_toolsets(
default_toolsets, allowed_toolsets.split(",")
)

# Enable all matching toolsets that have CORE or CLI tag
for toolset in matching_toolsets:
toolset.enabled = True
Expand Down Expand Up @@ -184,17 +182,20 @@ def create_console_tool_executor(
return ToolExecutor(enabled_toolsets)

def create_tool_executor(
self, console: Console, dal:Optional[SupabaseDal]
self, dal:Optional[SupabaseDal]
) -> ToolExecutor:
t = PerfTiming("create_tool_executor")
"""
Creates ToolExecutor for the server endpoints
"""

all_toolsets = load_builtin_toolsets(dal=dal, grafana_config=self.grafana)
t.measure("load_builtin_toolsets")

if os.path.isfile(CUSTOM_TOOLSET_LOCATION):
try:
all_toolsets.extend(load_toolsets_from_file(CUSTOM_TOOLSET_LOCATION, silent_fail=True))
t.measure(f"load_toolsets_from_file {CUSTOM_TOOLSET_LOCATION}")
except Exception as error:
logging.error(f"An error happened while trying to use custom toolset: {error}")

Expand All @@ -203,22 +204,29 @@ def create_tool_executor(
logging.debug(
f"Starting AI session with tools: {[t.name for t in enabled_tools]}"
)
return ToolExecutor(enabled_toolsets)
t.measure("merge toolsets")
tool_executor = ToolExecutor(enabled_toolsets)
t.measure("instantiate ToolExecutor")
t.end()
return tool_executor

def create_console_toolcalling_llm(
self, console: Console, allowed_toolsets: ToolsetPattern, dal:Optional[SupabaseDal] = None
self, allowed_toolsets: ToolsetPattern, dal:Optional[SupabaseDal] = None
) -> ToolCallingLLM:
tool_executor = self.create_console_tool_executor(console, allowed_toolsets, dal)
tool_executor = self.create_console_tool_executor(allowed_toolsets, dal)
return ToolCallingLLM(
tool_executor,
self.max_steps,
self._get_llm()
)

def create_toolcalling_llm(
self, console: Console, dal:Optional[SupabaseDal] = None
self, dal:Optional[SupabaseDal] = None
) -> ToolCallingLLM:
tool_executor = self.create_tool_executor(console, dal)
t = PerfTiming("create_toolcalling_llm")
tool_executor = self.create_tool_executor(dal)
t.measure("create_tool_executor")
t.end()
return ToolCallingLLM(
tool_executor,
self.max_steps,
Expand All @@ -227,25 +235,31 @@ def create_toolcalling_llm(

def create_issue_investigator(
self,
console: Console,
dal: Optional[SupabaseDal] = None
) -> IssueInvestigator:
t = PerfTiming("create_issue_investigator")
all_runbooks = load_builtin_runbooks()
t.measure("load_builtin_runbooks")
for runbook_path in self.custom_runbooks:
all_runbooks.extend(load_runbooks_from_file(runbook_path))
t.measure("custom_runbooks -> load_runbooks_from_file")

runbook_manager = RunbookManager(all_runbooks)
tool_executor = self.create_tool_executor(console, dal)
return IssueInvestigator(
t.measure("RunbookManager()")
tool_executor = self.create_tool_executor(dal)
t.measure("create_tool_executor")
issue_investigator = IssueInvestigator(
tool_executor,
runbook_manager,
self.max_steps,
self._get_llm()
)
t.measure("IssueInvestigator()")
t.end()
return issue_investigator

def create_console_issue_investigator(
self,
console: Console,
allowed_toolsets: ToolsetPattern,
dal: Optional[SupabaseDal] = None
) -> IssueInvestigator:
Expand All @@ -254,7 +268,7 @@ def create_console_issue_investigator(
all_runbooks.extend(load_runbooks_from_file(runbook_path))

runbook_manager = RunbookManager(all_runbooks)
tool_executor = self.create_console_tool_executor(console, allowed_toolsets, dal)
tool_executor = self.create_console_tool_executor(allowed_toolsets, dal)
return IssueInvestigator(
tool_executor,
runbook_manager,
Expand Down
15 changes: 12 additions & 3 deletions holmes/core/investigation.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,39 @@

from typing import Optional
from rich.console import Console
from holmes.common.env_vars import HOLMES_POST_PROCESSING_PROMPT
from holmes.config import Config
from holmes.core.issue import Issue
from holmes.core.models import InvestigateRequest, InvestigationResult
from holmes.core.supabase_dal import SupabaseDal
from holmes.utils.robusta import load_robusta_api_key
from holmes.core.perf_timing import PerfTiming


def investigate_issues(investigate_request: InvestigateRequest, dal: SupabaseDal, config: Config, console:Console):
def investigate_issues(investigate_request: InvestigateRequest, dal: SupabaseDal, config: Config, console:Optional[Console] = None):
t = PerfTiming("investigate_issues")
load_robusta_api_key(dal=dal, config=config)
context = dal.get_issue_data(
investigate_request.context.get("robusta_issue_id")
)
t.measure("get_issue_data")

resource_instructions = dal.get_resource_instructions(
"alert", investigate_request.context.get("issue_type")
)
t.measure("dal.get_resource_instructions")
global_instructions = dal.get_global_instructions_for_account()
t.measure("dal.get_global_instructions_for_account")

raw_data = investigate_request.model_dump()
t.measure("investigate_request.model_dump")
if context:
raw_data["extra_context"] = context

ai = config.create_issue_investigator(
console, dal=dal
dal=dal
)
t.measure("config.create_issue_investigator")
issue = Issue(
id=context["id"] if context else "",
name=investigate_request.title,
Expand All @@ -42,7 +50,8 @@ def investigate_issues(investigate_request: InvestigateRequest, dal: SupabaseDal
instructions=resource_instructions,
global_instructions=global_instructions
)

t.measure("ai.investigate")
t.end()
return InvestigationResult(
analysis=investigation.result,
tool_calls=investigation.tool_calls or [],
Expand Down
63 changes: 60 additions & 3 deletions holmes/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,17 @@
from pydantic import BaseModel
import litellm
import os
import sys
import json
import sys
from types import ModuleType, FunctionType
from gc import get_referents
from holmes.common.env_vars import ROBUSTA_AI, ROBUSTA_API_ENDPOINT

from types import ModuleType, FunctionType
from gc import get_referents
from holmes.core.perf_timing import PerfTiming, log_function_timing


def environ_get_safe_int(env_var, default="0"):
try:
Expand All @@ -22,6 +31,10 @@ def environ_get_safe_int(env_var, default="0"):
OVERRIDE_MAX_OUTPUT_TOKEN = environ_get_safe_int("OVERRIDE_MAX_OUTPUT_TOKEN")
OVERRIDE_MAX_CONTENT_SIZE = environ_get_safe_int("OVERRIDE_MAX_CONTENT_SIZE")

cache = dict()
cache_hit = 0
cache_miss = 0

class LLM:

@abstractmethod
Expand All @@ -40,6 +53,32 @@ def count_tokens_for_message(self, messages: list[dict]) -> int:
def completion(self, messages: List[Dict[str, Any]], tools: Optional[List[Tool]] = [], tool_choice: Optional[Union[str, dict]] = None, response_format: Optional[Union[dict, Type[BaseModel]]] = None, temperature:Optional[float] = None, drop_params: Optional[bool] = None) -> ModelResponse:
pass

def hash_messages(messages:Any) -> int:
return hash(json.dumps(messages, sort_keys=True))


# Custom objects know their class.
# Function objects seem to know way too much, including modules.
# Exclude modules as well.
BLACKLIST = type, ModuleType, FunctionType


def getsize(obj):
"""sum size of object & members."""
if isinstance(obj, BLACKLIST):
raise TypeError('getsize() does not take argument of type: '+ str(type(obj)))
seen_ids = set()
size = 0
objects = [obj]
while objects:
need_referents = []
for obj in objects:
if not isinstance(obj, BLACKLIST) and id(obj) not in seen_ids:
seen_ids.add(id(obj))
size += sys.getsizeof(obj)
need_referents.append(obj)
objects = get_referents(*need_referents)
return size

class DefaultLLM(LLM):

Expand Down Expand Up @@ -100,12 +139,12 @@ def check_llm(self, model:str, api_key:Optional[str]):
"https://docs.litellm.ai/docs/providers/watsonx#usage---models-in-deployment-spaces"
)
else:
#
#
api_key_env_var = f"{provider.upper()}_API_KEY"
if api_key:
os.environ[api_key_env_var] = api_key
model_requirements = litellm.validate_environment(model=model)

if not model_requirements["keys_in_environment"]:
raise Exception(f"model {model} requires the following environment variables: {model_requirements['missing_keys']}")

Expand Down Expand Up @@ -146,7 +185,24 @@ def count_tokens_for_message(self, messages: list[dict]) -> int:
return litellm.token_counter(model=self.model,
messages=messages)

@log_function_timing
def completion(self, messages: List[Dict[str, Any]], tools: Optional[List[Tool]] = [], tool_choice: Optional[Union[str, dict]] = None, response_format: Optional[Union[dict, Type[BaseModel]]] = None, temperature:Optional[float] = None, drop_params: Optional[bool] = None) -> ModelResponse:
# hash_val = hash_messages(messages)
# global cache
# global cache_hit
# global cache_miss
# cache_value = None
# if hash_val in cache:
# cache_hit = cache_hit + 1
# cache_value = cache.get(hash_val)
# else:
# cache_miss = cache_miss + 1

# print(f"(*)(*) cache hit rate = {round(cache_hit/(cache_hit+cache_miss)*100)}%. cvache size = {round(getsize(cache)/1024)}MB")

# if cache_value:
# return cache_value
t = PerfTiming("llm.completion")
result = litellm.completion(
model=self.model,
api_key=self.api_key,
Expand All @@ -158,7 +214,8 @@ def completion(self, messages: List[Dict[str, Any]], tools: Optional[List[Tool]]
response_format=response_format,
drop_params=drop_params
)

t.end()
# cache[hash_val] = result
if isinstance(result, ModelResponse):
return result
else:
Expand Down
48 changes: 48 additions & 0 deletions holmes/core/perf_timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import time
import logging
from contextlib import contextmanager

from functools import wraps

class PerfTiming:
def __init__(self, name):
self.ended = False

self.name = name
self.start_time = time.time()
self.last_measure_time = self.start_time
self.last_measure_label = "Start"
self.timings = []

def measure(self, label):
if self.ended:
raise Exception("cannot measure a perf timing that is already ended")
current_time = time.time()

time_since_start = int((current_time - self.start_time) * 1000)
time_since_last = int((current_time - self.last_measure_time) * 1000)

self.timings.append((label, time_since_last, time_since_start))

self.last_measure_time = current_time
self.last_measure_label = label

def end(self):
self.ended = True
current_time = time.time()
time_since_start = int((current_time - self.start_time) * 1000)
message = f'{self.name}(TOTAL) {time_since_start}ms'
logging.info(message)
for label, time_since_last, time_since_start in self.timings:
logging.info(f' {self.name}({label}) +{time_since_last}ms {time_since_start}ms')

def log_function_timing(func):
@wraps(func)
def function_timing_wrapper(*args, **kwargs):
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
total_time = int((end_time - start_time) * 1000)
logging.info(f'Function "{func.__name__}()" took {total_time}ms')
return result
return function_timing_wrapper
Loading

0 comments on commit 310f5c2

Please sign in to comment.