Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions fiftyone/factory/repos/delegated_operation_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ExecutionRunState,
ExecutionProgress,
)
from fiftyone.operators.types import Pipeline
from fiftyone.operators.types import Pipeline, PipelineRunInfo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,8 +63,8 @@ def __init__(

# distributed task fields
self.parent_id = None # Only on children
self.pipeline_index = 0
self.pipeline = None # Only on pipeline parent
self.pipeline_run_info = None # Only on pipeline parent

@property
def num_distributed_tasks(self):
Expand Down Expand Up @@ -130,8 +130,10 @@ def from_pymongo(self, doc: dict):
if "updated_at" in doc["status"]:
self.status.updated_at = doc["status"]["updated_at"]

if "pipeline" in doc and doc["pipeline"] is not None:
self.pipeline = Pipeline.from_json(doc["pipeline"])
self.pipeline = Pipeline.from_json(doc.get("pipeline"))
self.pipeline_run_info = PipelineRunInfo.from_json(
doc.get("pipeline_run_info")
)

return self

Expand All @@ -142,7 +144,13 @@ def to_pymongo(self) -> dict:
# try to copy because it may contain big, complicated, non-serializable
# objects that may cause issues with copying.

ignore_keys = {"_doc", "id", "context", "pipeline"}
ignore_keys = {
"_doc",
"id",
"context",
"pipeline",
"pipeline_run_info",
}
d = {
k: copy.deepcopy(v)
for k, v in self.__dict__.items()
Expand All @@ -154,5 +162,7 @@ def to_pymongo(self) -> dict:
}
if self.pipeline:
d["pipeline"] = self.pipeline.to_json()
if self.pipeline_run_info:
d["pipeline_run_info"] = self.pipeline_run_info.to_json()

return d
68 changes: 64 additions & 4 deletions fiftyone/operators/_types/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
|
"""
import dataclasses
from typing import Any, Mapping, Optional
from typing import Any, List, Mapping, Optional


@dataclasses.dataclass
Expand All @@ -26,16 +26,38 @@ class PipelineStage:
operator_uri: str

# Optional
always_run: bool = False
name: Optional[str] = None
num_distributed_tasks: Optional[int] = None
params: Optional[Mapping[str, Any]] = None

# ADD A CUSTOM __init__ METHOD TO ACCEPT AND DISCARD UNUSED KWARGS
def __init__(
self,
operator_uri: str,
always_run: bool = False,
name: Optional[str] = None,
num_distributed_tasks: Optional[int] = None,
params: Optional[Mapping[str, Any]] = None,
**_,
):
# Call the default dataclass initialization for the defined fields
self.operator_uri = operator_uri
self.always_run = always_run
self.name = name
self.num_distributed_tasks = num_distributed_tasks
self.params = params
self.__post_init__()

def __post_init__(self):
if not self.operator_uri:
raise ValueError("operator_uri must be a non-empty string")

if self.num_distributed_tasks is not None:
self.num_distributed_tasks = int(self.num_distributed_tasks)
self.num_distributed_tasks = (
int(self.num_distributed_tasks)
if self.num_distributed_tasks is not None
else None
)
if (
self.num_distributed_tasks is not None
and self.num_distributed_tasks < 1
Expand Down Expand Up @@ -63,19 +85,29 @@ class Pipeline:

stages: list[PipelineStage] = dataclasses.field(default_factory=list)

# ADD A CUSTOM __init__ METHOD TO ACCEPT AND DISCARD UNUSED KWARGS
def __init__(self, stages: Optional[list[PipelineStage]] = None, **kwargs):
# Call the default dataclass initialization for the defined fields
self.stages = stages if stages is not None else []
# kwargs are implicitly discarded

def stage(
self,
operator_uri,
always_run=False,
name=None,
num_distributed_tasks=None,
params=None,
# kwargs accepted for forward compatibility
**kwargs # pylint: disable=unused-argument
**kwargs,
):
"""Adds a stage to the end of the pipeline.
Args:
operator_uri: the URI of the operator to use for the stage
always_run: if True, this stage runs even when the pipeline
is inactive (e.g., after a failure), enabling
cleanup/finalization stages
name: the name of the stage
num_distributed_tasks: the number of distributed tasks to use
for the stage, optional
Expand All @@ -88,9 +120,11 @@ def stage(
"""
stage = PipelineStage(
operator_uri=operator_uri,
always_run=always_run,
name=name,
num_distributed_tasks=num_distributed_tasks,
params=params,
**kwargs,
)
self.stages.append(stage)
return stage
Expand All @@ -109,6 +143,11 @@ def from_json(cls, json_dict):
Args:
json_dict: a JSON / python dict representation of the pipeline
"""
if json_dict is None:
return None

if isinstance(json_dict, list):
json_dict = {"stages": json_dict}
stages = [
PipelineStage(**stage) for stage in json_dict.get("stages") or []
]
Expand All @@ -128,3 +167,24 @@ def to_json(self):
JSON / python dict representation of the pipeline
"""
return dataclasses.asdict(self)


@dataclasses.dataclass
class PipelineRunInfo:
"""Information about a pipeline run.
Unlike the pipeline definition, this class is considered mutable.
"""

active: bool = True
expected_children: Optional[List[int]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems fine to me just curious when / how is this expected children going to be used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm I shoulda read the description better:

expected_children: either None if active, otherwise it is a list matching length len(pipeline.stages). This list contains the number of children we should expect to see for each stage. Basically, stages that have been skipped are 0.

Still curious where this will be set / read

stage_index: int = 0

@classmethod
def from_json(cls, doc: dict):
if doc is None:
return None
return cls(**doc)

def to_json(self):
return dataclasses.asdict(self)
2 changes: 1 addition & 1 deletion fiftyone/operators/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
|
"""

from ._types.pipeline import Pipeline, PipelineStage
from ._types.pipeline import Pipeline, PipelineRunInfo, PipelineStage
from ._types.types import *
20 changes: 17 additions & 3 deletions tests/unittests/factory/delegated_operation_doc_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,27 @@ def test_serialize_pipeline(self):
op_doc = repos.DelegatedOperationDocument()

op_doc.pipeline = Pipeline(
[
PipelineStage(name="one", operator_uri="@test/op1"),
PipelineStage(name="two", operator_uri="@test/op2"),
stages=[
PipelineStage(
name="one",
operator_uri="@test/op1",
num_distributed_tasks=5,
params={"foo": "bar"},
),
PipelineStage(
name="two", operator_uri="@test/op2", always_run=True
),
]
)
op_doc.pipeline_run_info = (
repos.delegated_operation_doc.PipelineRunInfo(
stage_index=5, active=False, expected_children=[1, 2, 3, 4]
)
)
out = op_doc.to_pymongo()
assert out["pipeline"] == op_doc.pipeline.to_json()
assert out["pipeline_run_info"] == op_doc.pipeline_run_info.to_json()
op_doc2 = repos.DelegatedOperationDocument()
op_doc2.from_pymongo(out)
assert op_doc2.pipeline == op_doc.pipeline
assert op_doc2.pipeline_run_info == op_doc.pipeline_run_info
11 changes: 9 additions & 2 deletions tests/unittests/operators/delegated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,15 @@ def test_delegate_operation(self, mock_load_dataset, mock_get_operator):

pipeline = Pipeline(
[
PipelineStage(name="one", operator_uri="@test/op1"),
PipelineStage(name="two", operator_uri="@test/op2"),
PipelineStage(
name="one",
operator_uri="@test/op1",
num_distributed_tasks=5,
params={"foo": "bar"},
),
PipelineStage(
name="two", operator_uri="@test/op2", always_run=True
),
]
)
doc = self.svc.queue_operation(
Expand Down
31 changes: 28 additions & 3 deletions tests/unittests/operators/types_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_pipeline_type(self):
name="stage2",
num_distributed_tasks=5,
params={"foo": "bar"},
always_run=True,
)
pipeline = types.Pipeline(stages=[stage1, stage2])
self.assertListEqual(pipeline.stages, [stage1, stage2])
Expand All @@ -37,9 +38,10 @@ def test_pipeline_type(self):
pipeline.stage(stage1.operator_uri)
pipeline.stage(
stage2.operator_uri,
stage2.name,
stage2.num_distributed_tasks,
stage2.params,
always_run=stage2.always_run,
name=stage2.name,
num_distributed_tasks=stage2.num_distributed_tasks,
params=stage2.params,
)
self.assertListEqual(pipeline.stages, [stage1, stage2])

Expand All @@ -52,6 +54,7 @@ def test_serialize(self):
name="stage2",
num_distributed_tasks=5,
params={"foo": "bar"},
always_run=True,
),
]
)
Expand All @@ -65,12 +68,14 @@ def test_serialize(self):
"name": None,
"num_distributed_tasks": None,
"params": None,
"always_run": False,
},
{
"operator_uri": "my/uri2",
"name": "stage2",
"num_distributed_tasks": 5,
"params": {"foo": "bar"},
"always_run": True,
},
],
},
Expand All @@ -93,3 +98,23 @@ def test_validation(self):
pipe = types.Pipeline()
with self.assertRaises(ValueError):
pipe.stage("my/uri", num_distributed_tasks=-5)

def test_none(self):
self.assertIsNone(types.Pipeline.from_json(None))
self.assertIsNone(types.PipelineRunInfo.from_json(None))

def test_pipeline_run_info(self):
run_info = types.PipelineRunInfo(
active=False, stage_index=2, expected_children=[1, 2]
)
dict_rep = run_info.to_json()
self.assertEqual(
dict_rep,
{
"active": False,
"stage_index": 2,
"expected_children": [1, 2],
},
)
new_obj = types.PipelineRunInfo.from_json(dict_rep)
self.assertEqual(new_obj, run_info)
Loading