Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scope refactor #2111

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions python/cog/server/scope.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import warnings
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Callable, Generator, Optional, Union
from typing import Any, Callable, Generator, Optional, Union

from attrs import evolve, frozen

from ..types import ExperimentalFeatureWarning


@frozen
class Scope:
def __init__(
self,
*,
record_metric: Callable[[str, Union[float, int]], None],
) -> None:
self.record_metric = record_metric
record_metric: Callable[[str, Union[float, int]], None]
_tag: Optional[str] = None


_current_scope: ContextVar[Optional[Scope]] = ContextVar("scope", default=None)
Expand All @@ -24,6 +23,10 @@ def current_scope() -> Scope:
category=ExperimentalFeatureWarning,
stacklevel=1,
)
return _get_current_scope()


def _get_current_scope() -> Scope:
s = _current_scope.get()
if s is None:
raise RuntimeError("No scope available")
Expand All @@ -39,6 +42,16 @@ def scope(sc: Scope) -> Generator[None, None, None]:
_current_scope.reset(s)


@contextmanager
def evolve_scope(**kwargs: Any) -> Generator[None, None, None]:
new_scope = evolve(_get_current_scope(), **kwargs)
s = _current_scope.set(new_scope)
try:
yield
finally:
_current_scope.reset(s)


def emit_metric(name: str, value: Union[float, int]) -> None:
"""
DEPRECATED: This function will be removed in a future version of cog.
Expand Down
18 changes: 5 additions & 13 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import contextlib
import contextvars
import inspect
import multiprocessing
import os
Expand Down Expand Up @@ -58,15 +57,12 @@
InvalidStateException,
)
from .helpers import SimpleStreamRedirector, StreamRedirector
from .scope import Scope, scope
from .scope import Scope, _get_current_scope, evolve_scope, scope

if PYDANTIC_V2:
from .helpers import unwrap_pydantic_serialization_iterators

_spawn = multiprocessing.get_context("spawn")
_tag_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"tag", default=None
)

_PublicEventType = Union[Done, Log, PredictionOutput, PredictionOutputType]

Expand Down Expand Up @@ -407,7 +403,7 @@ def __init__(
self._cancelable = False
self._max_concurrency = max_concurrency

# for synchronous predictors only! async predictors use _tag_var instead
# for synchronous predictors only! async predictors use current_scope()._tag instead
self._sync_tag: Optional[str] = None
self._has_async_predictor = is_async

Expand Down Expand Up @@ -483,10 +479,8 @@ def record_metric(self, name: str, value: Union[float, int]) -> None:

@property
def _current_tag(self) -> Optional[str]:
# if _tag_var is set, use that (only applies within _apredict())
tag = _tag_var.get()
if tag:
return tag
if self._has_async_predictor:
return _get_current_scope()._tag
return self._sync_tag

def _load_predictor(self) -> Optional[BasePredictor]:
Expand Down Expand Up @@ -687,9 +681,7 @@ async def _apredict(
predict: Callable[..., Any],
redirector: SimpleStreamRedirector,
) -> None:
_tag_var.set(tag)

with self._handle_predict_error(redirector, tag=tag):
with evolve_scope(tag=tag), self._handle_predict_error(redirector, tag=tag):
future_result = predict(**payload)

if future_result:
Expand Down
48 changes: 24 additions & 24 deletions test-integration/test_integration/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def assert_versions_match(semver_version: str, pep440_version: str):
)

# Check base release version
assert (
semver_release == pep440_groups["release"]
), f"Release versions do not match: {semver_release} != {pep440_groups['release']}"
assert semver_release == pep440_groups["release"], (
f"Release versions do not match: {semver_release} != {pep440_groups['release']}"
)

# Check prerelease status
semver_pre = semver_groups["prerelease"]
Expand All @@ -67,35 +67,35 @@ def assert_versions_match(semver_version: str, pep440_version: str):

if semver_pre:
if semver_pre.startswith("alpha"):
assert (
pep440_groups["pre_l"] == "a"
), "Alpha pre-release status does not match"
assert not pep440_groups[
"dev"
], "Semver pre-release cannot also be a PEP440 dev build"
assert pep440_groups["pre_l"] == "a", (
"Alpha pre-release status does not match"
)
assert not pep440_groups["dev"], (
"Semver pre-release cannot also be a PEP440 dev build"
)

if semver_pre.startswith("beta"):
assert (
pep440_groups["pre_l"] == "b"
), "Beta pre-release status does not match"
assert not pep440_groups[
"dev"
], "Semver pre-release cannot also be a PEP440 dev build"
assert pep440_groups["pre_l"] == "b", (
"Beta pre-release status does not match"
)
assert not pep440_groups["dev"], (
"Semver pre-release cannot also be a PEP440 dev build"
)

if semver_pre.startswith("rc"):
assert (
pep440_groups["pre_l"] == "rc"
), "Release candidate pre-release status does not match"
assert not pep440_groups[
"dev"
], "Semver pre-release cannot also be a PEP440 dev build"
assert pep440_groups["pre_l"] == "rc", (
"Release candidate pre-release status does not match"
)
assert not pep440_groups["dev"], (
"Semver pre-release cannot also be a PEP440 dev build"
)

if semver_pre.startswith("dev"):
assert pep440_groups["dev_l"] == "dev", "Dev build status does not match"

assert (
semver_groups["buildmetadata"] == pep440_groups["local"]
), f"Local/build metadata component does not match: {semver_groups['buildmetadata']} != {pep440_groups['local']}"
assert semver_groups["buildmetadata"] == pep440_groups["local"], (
f"Local/build metadata component does not match: {semver_groups['buildmetadata']} != {pep440_groups['local']}"
)


def random_string(length):
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ commands =
[testenv:lint]
base_python = python3.12
skip_install = true
deps = ruff
deps = ruff==0.9.1
commands =
ruff check python/cog
ruff format --check python
Expand Down
Loading