Skip to content

Commit

Permalink
fix for spmd classes that rely on batch functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Nov 22, 2024
1 parent 912f559 commit 60cbc71
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
19 changes: 19 additions & 0 deletions onedal/common/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# ==============================================================================

import logging
from contextlib import contextmanager
from types import MethodType
from typing import Any, Callable, Literal, Optional

from onedal import Backend, _default_backend, _spmd_backend
Expand Down Expand Up @@ -59,6 +61,23 @@ def _get_policy(self, queue: Any, *data: Any) -> Any:
return _get_policy


@contextmanager
def DefaultPolicyOverride(instance: Any):
original_method = getattr(instance, "_get_policy", None)
try:
# Inject the new _get_policy method from _default_backend
new_policy_method = inject_policy_manager(_default_backend)
bound_method = MethodType(new_policy_method, instance)
setattr(instance, "_get_policy", bound_method)
yield
finally:
# Restore the original _get_policy method
if original_method is not None:
setattr(instance, "_get_policy", original_method)
else:
delattr(instance, "_get_policy")


def bind_default_backend(module_name: str, lookup_name: Optional[str] = None):
def decorator(method: Callable[..., Any]):
# grab the lookup_name from outer scope
Expand Down
7 changes: 6 additions & 1 deletion onedal/spmd/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# ==============================================================================


from ...common._backend import bind_spmd_backend
from ...common._backend import DefaultPolicyOverride, bind_spmd_backend
from ...covariance import (
IncrementalEmpiricalCovariance as base_IncrementalEmpiricalCovariance,
)
Expand All @@ -27,3 +27,8 @@ def _get_policy(self, queue, *data): ...

@bind_spmd_backend("covariance")
def finalize_compute(self, policy, params, partial_result): ...

def partial_fit(self, X, y=None, queue=None):
# partial fit performed by parent backend, therefore default policy required
with DefaultPolicyOverride(self):
return super().partial_fit(X, y, queue)
13 changes: 11 additions & 2 deletions onedal/spmd/decomposition/incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
# limitations under the License.
# ==============================================================================

from ...common._backend import bind_spmd_backend
from ...common._backend import (
DefaultPolicyOverride,
bind_default_backend,
bind_spmd_backend,
)
from ...decomposition import IncrementalPCA as base_IncrementalPCA


Expand All @@ -26,8 +30,13 @@ class IncrementalPCA(base_IncrementalPCA):
API is the same as for `onedal.decomposition.IncrementalPCA`
"""

@bind_spmd_backend("decomposition")
@bind_spmd_backend("decomposition", lookup_name="_get_policy")
def _get_policy(self, queue, *data): ...

@bind_spmd_backend("decomposition.dim_reduction")
def finalize_train(self, policy, params, partial_result): ...

def partial_fit(self, X, queue):
# partial fit performed by parent backend, therefore default policy required
with DefaultPolicyOverride(self):
return super().partial_fit(X, queue)
7 changes: 6 additions & 1 deletion onedal/spmd/linear_model/incremental_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# ==============================================================================


from ...common._backend import bind_spmd_backend
from ...common._backend import DefaultPolicyOverride, bind_spmd_backend
from ...linear_model import (
IncrementalLinearRegression as base_IncrementalLinearRegression,
)
Expand All @@ -33,3 +33,8 @@ def _get_policy(self): ...

@bind_spmd_backend("linear_model.regression")
def finalize_train(self, *args, **kwargs): ...

def partial_fit(self, X, y, queue):
# partial fit performed by parent backend, therefore default policy required
with DefaultPolicyOverride(self):
return super().partial_fit(X, y, queue)

0 comments on commit 60cbc71

Please sign in to comment.