Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scope refactor #2111

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move _tag_var to Scope
This moves _tag_var to Scope, so that we have a single ContextVar rather than
multiple.
philandstuff committed Jan 16, 2025
commit ebd390327e188367817f2869066ce46beec5df9e
15 changes: 13 additions & 2 deletions python/cog/server/scope.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import warnings
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Callable, Generator, Optional, Union
from typing import Any, Callable, Generator, Optional, Union

from attrs import frozen
from attrs import evolve, frozen

from ..types import ExperimentalFeatureWarning


@frozen
class Scope:
record_metric: Callable[[str, Union[float, int]], None]
_tag: Optional[str] = None


_current_scope: ContextVar[Optional[Scope]] = ContextVar("scope", default=None)
@@ -37,6 +38,16 @@ def scope(sc: Scope) -> Generator[None, None, None]:
_current_scope.reset(s)


@contextmanager
def evolve_scope(**kwargs: Any) -> Generator[None, None, None]:
new_scope = evolve(current_scope(), **kwargs)
s = _current_scope.set(new_scope)
try:
yield
finally:
_current_scope.reset(s)


def emit_metric(name: str, value: Union[float, int]) -> None:
"""
DEPRECATED: This function will be removed in a future version of cog.
18 changes: 5 additions & 13 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import contextlib
import contextvars
import inspect
import multiprocessing
import os
@@ -58,15 +57,12 @@
InvalidStateException,
)
from .helpers import SimpleStreamRedirector, StreamRedirector
from .scope import Scope, scope
from .scope import Scope, current_scope, evolve_scope, scope

if PYDANTIC_V2:
from .helpers import unwrap_pydantic_serialization_iterators

_spawn = multiprocessing.get_context("spawn")
_tag_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"tag", default=None
)

_PublicEventType = Union[Done, Log, PredictionOutput, PredictionOutputType]

@@ -407,7 +403,7 @@ def __init__(
self._cancelable = False
self._max_concurrency = max_concurrency

# for synchronous predictors only! async predictors use _tag_var instead
# for synchronous predictors only! async predictors use current_scope()._tag instead
self._sync_tag: Optional[str] = None
self._has_async_predictor = is_async

@@ -483,10 +479,8 @@ def record_metric(self, name: str, value: Union[float, int]) -> None:

@property
def _current_tag(self) -> Optional[str]:
# if _tag_var is set, use that (only applies within _apredict())
tag = _tag_var.get()
if tag:
return tag
if self._has_async_predictor:
return current_scope()._tag
return self._sync_tag

def _load_predictor(self) -> Optional[BasePredictor]:
@@ -687,9 +681,7 @@ async def _apredict(
predict: Callable[..., Any],
redirector: SimpleStreamRedirector,
) -> None:
_tag_var.set(tag)

with self._handle_predict_error(redirector, tag=tag):
with evolve_scope(tag=tag), self._handle_predict_error(redirector, tag=tag):
future_result = predict(**payload)

if future_result: