Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions fiftyone/factory/repos/delegated_operation_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ExecutionRunState,
ExecutionProgress,
)
from fiftyone.operators.types import Pipeline
from fiftyone.operators.types import Pipeline, PipelineRunInfo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,8 +63,8 @@ def __init__(

# distributed task fields
self.parent_id = None # Only on children
self.pipeline_index = 0
self.pipeline = None # Only on pipeline parent
self.pipeline_run_info = None # Only on pipeline parent

@property
def num_distributed_tasks(self):
Expand Down Expand Up @@ -130,8 +130,10 @@ def from_pymongo(self, doc: dict):
if "updated_at" in doc["status"]:
self.status.updated_at = doc["status"]["updated_at"]

if "pipeline" in doc and doc["pipeline"] is not None:
self.pipeline = Pipeline.from_json(doc["pipeline"])
self.pipeline = Pipeline.from_json(doc.get("pipeline"))
self.pipeline_run_info = PipelineRunInfo.from_json(
doc.get("pipeline_run_info")
)

return self

Expand All @@ -154,5 +156,7 @@ def to_pymongo(self) -> dict:
}
if self.pipeline:
d["pipeline"] = self.pipeline.to_json()
if self.pipeline_run_info:
d["pipeline_run_info"] = self.pipeline_run_info.to_json()

return d
54 changes: 50 additions & 4 deletions fiftyone/operators/_types/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
|
"""
import dataclasses
from typing import Any, Mapping, Optional
from typing import Any, List, Mapping, Optional


@dataclasses.dataclass
Expand All @@ -26,16 +26,33 @@ class PipelineStage:
operator_uri: str

# Optional
always_run: bool = False
name: Optional[str] = None
num_distributed_tasks: Optional[int] = None
params: Optional[Mapping[str, Any]] = None

# ADD A CUSTOM __init__ METHOD TO ACCEPT AND DISCARD UNUSED KWARGS
def __init__(
self,
operator_uri: str,
always_run: bool = False,
name: Optional[str] = None,
num_distributed_tasks: Optional[int] = None,
params: Optional[Mapping[str, Any]] = None,
**kwargs, # Accepts and ignores unused kwargs
):
# Call the default dataclass initialization for the defined fields
self.operator_uri = operator_uri
self.always_run = always_run
self.name = name
self.num_distributed_tasks = num_distributed_tasks
self.params = params
self.__post_init__()

def __post_init__(self):
if not self.operator_uri:
raise ValueError("operator_uri must be a non-empty string")

if self.num_distributed_tasks is not None:
self.num_distributed_tasks = int(self.num_distributed_tasks)
if (
self.num_distributed_tasks is not None
and self.num_distributed_tasks < 1
Expand Down Expand Up @@ -63,14 +80,20 @@ class Pipeline:

stages: list[PipelineStage] = dataclasses.field(default_factory=list)

# ADD A CUSTOM __init__ METHOD TO ACCEPT AND DISCARD UNUSED KWARGS
def __init__(self, stages: list[PipelineStage] = None, **kwargs):
# Call the default dataclass initialization for the defined fields
self.stages = stages if stages is not None else []
# kwargs are implicitly discarded

def stage(
self,
operator_uri,
name=None,
num_distributed_tasks=None,
params=None,
# kwargs accepted for forward compatibility
**kwargs # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Adds a stage to the end of the pipeline.

Expand Down Expand Up @@ -109,6 +132,8 @@ def from_json(cls, json_dict):
Args:
json_dict: a JSON / python dict representation of the pipeline
"""
if isinstance(json_dict, list):
json_dict = {"stages": json_dict}
stages = [
PipelineStage(**stage) for stage in json_dict.get("stages") or []
]
Expand All @@ -128,3 +153,24 @@ def to_json(self):
JSON / python dict representation of the pipeline
"""
return dataclasses.asdict(self)


@dataclasses.dataclass
class PipelineRunInfo:
"""Information about a pipeline run.

Unlike the pipeline definition, this class is considered mutable.
"""

active: bool = True
expected_children: Optional[List[int]] = None
stage_index: int = 0

@classmethod
def from_json(cls, doc: dict):
if doc is None:
return None
return cls(**doc)

def to_json(self):
return dataclasses.asdict(self)
2 changes: 1 addition & 1 deletion fiftyone/operators/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
|
"""

from ._types.pipeline import Pipeline, PipelineStage
from ._types.pipeline import Pipeline, PipelineRunInfo, PipelineStage
from ._types.types import *
20 changes: 17 additions & 3 deletions tests/unittests/factory/delegated_operation_doc_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,27 @@ def test_serialize_pipeline(self):
op_doc = repos.DelegatedOperationDocument()

op_doc.pipeline = Pipeline(
[
PipelineStage(name="one", operator_uri="@test/op1"),
PipelineStage(name="two", operator_uri="@test/op2"),
stages=[
PipelineStage(
name="one",
operator_uri="@test/op1",
num_distributed_tasks=5,
params={"foo": "bar"},
),
PipelineStage(
name="two", operator_uri="@test/op2", always_run=True
),
]
)
op_doc.pipeline_run_info = (
repos.delegated_operation_doc.PipelineRunInfo(
stage_index=5, active=False, expected_children=[1, 2, 3, 4]
)
)
out = op_doc.to_pymongo()
assert out["pipeline"] == op_doc.pipeline.to_json()
assert out["pipeline_run_info"] == op_doc.pipeline_run_info.to_json()
op_doc2 = repos.DelegatedOperationDocument()
op_doc2.from_pymongo(out)
assert op_doc2.pipeline == op_doc.pipeline
assert op_doc2.pipeline_run_info == op_doc.pipeline_run_info
11 changes: 9 additions & 2 deletions tests/unittests/operators/delegated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,15 @@ def test_delegate_operation(self, mock_load_dataset, mock_get_operator):

pipeline = Pipeline(
[
PipelineStage(name="one", operator_uri="@test/op1"),
PipelineStage(name="two", operator_uri="@test/op2"),
PipelineStage(
name="one",
operator_uri="@test/op1",
num_distributed_tasks=5,
params={"foo": "bar"},
),
PipelineStage(
name="two", operator_uri="@test/op2", always_run=True
),
]
)
doc = self.svc.queue_operation(
Expand Down
19 changes: 19 additions & 0 deletions tests/unittests/operators/types_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_serialize(self):
name="stage2",
num_distributed_tasks=5,
params={"foo": "bar"},
always_run=True,
),
]
)
Expand All @@ -65,12 +66,14 @@ def test_serialize(self):
"name": None,
"num_distributed_tasks": None,
"params": None,
"always_run": False,
},
{
"operator_uri": "my/uri2",
"name": "stage2",
"num_distributed_tasks": 5,
"params": {"foo": "bar"},
"always_run": True,
},
],
},
Expand All @@ -93,3 +96,19 @@ def test_validation(self):
pipe = types.Pipeline()
with self.assertRaises(ValueError):
pipe.stage("my/uri", num_distributed_tasks=-5)

def test_pipeline_run_info(self):
run_info = types.PipelineRunInfo(
active=False, stage_index=2, expected_children=[1, 2]
)
dict_rep = run_info.to_json()
self.assertEqual(
dict_rep,
{
"active": False,
"stage_index": 2,
"expected_children": [1, 2],
},
)
new_obj = types.PipelineRunInfo.from_json(dict_rep)
self.assertEqual(new_obj, run_info)
Loading