diff --git a/.envrc b/.envrc index 7aafde3..453b8a2 100644 --- a/.envrc +++ b/.envrc @@ -3,5 +3,6 @@ strict_env layout node +dotenv_if_exists use nix diff --git a/.github/workflows/changeset-check.yml b/.github/workflows/changeset-check.yml index 1004400..96bfe6a 100644 --- a/.github/workflows/changeset-check.yml +++ b/.github/workflows/changeset-check.yml @@ -51,14 +51,24 @@ jobs: typescript: - 'typescript/**' + - name: Prepare Python package to use Changesets + if: steps.filter.outputs.python == 'true' + working-directory: python + run: | + echo "Initializing pnpm project..." + touch pnpm-workspace.yaml + pnpm init + echo "Checking .changeset:" + ls -la .changeset + - name: Check for Python changesets # Only run this step if python files were changed if: steps.filter.outputs.python == 'true' id: python-check - uses: changesets/action@v1 + uses: changesets/action@v1.4.9 continue-on-error: true with: - cwd: ./python + cwd: '${{ github.workspace }}/python' version: false publish: false setupGitUser: true @@ -70,10 +80,10 @@ jobs: # Only run this step if typescript files were changed if: steps.filter.outputs.typescript == 'true' id: typescript-check - uses: changesets/action@v1 + uses: changesets/action@v1.4.9 continue-on-error: true with: - cwd: ./typescript + cwd: '${{ github.workspace }}/typescript' version: false publish: false setupGitUser: true diff --git a/.gitignore b/.gitignore index e858225..f6689dc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +python/README.md + # OS generated files .DS_Store .DS_Store? diff --git a/README.md b/README.md index 9fc8bd8..07b71dc 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,12 @@ Check out [Agentic Coding Guidance](https://brainy.gitbook.io/flow/guides/agenti We would like to extend our deepest gratitude to the creators and contributors of the PocketFlow framework, from which brainyFlow originated as a fork. +## Contributors Wanted! + +We're looking for contributors for all aspects of the project. Whether you're interested in documentation, testing, or implementing features, we'd love your help! + +Get involved by joining our [Discord server](https://discord.gg/N9mVvxRXyH). + ## Liability Disclaimer BrainyFlow is provided "as is" without any warranties or guarantees. diff --git a/cookbook/typescript-agent/package.json b/cookbook/typescript-agent/package.json index 41d46d4..91f3804 100644 --- a/cookbook/typescript-agent/package.json +++ b/cookbook/typescript-agent/package.json @@ -9,14 +9,14 @@ "license": "ISC", "description": "", "devDependencies": { - "@types/node": "^22.14.0", - "tsx": "^4.19.3", + "@types/node": "^22.15.21", + "tsx": "^4.19.4", "typescript": "^5.8.3" }, "dependencies": { "@phukon/duckduckgo-search": "^1.1.0", - "brainyflow": "^0.0.5", - "openai": "^4.91.1", - "yaml": "^2.7.1" + "brainyflow": "^1.0.0", + "openai": "^4.103.0", + "yaml": "^2.8.0" } } diff --git a/cookbook/typescript-chat/package.json b/cookbook/typescript-chat/package.json index f29e116..7d84255 100644 --- a/cookbook/typescript-chat/package.json +++ b/cookbook/typescript-chat/package.json @@ -9,11 +9,11 @@ "license": "ISC", "description": "", "devDependencies": { - "@types/node": "^22.13.17", - "tsx": "^4.19.3", - "typescript": "^5.8.2" + "@types/node": "^22.15.21", + "tsx": "^4.19.4", + "typescript": "^5.8.3" }, "dependencies": { - "openai": "^4.91.1" + "openai": "^4.103.0" } } diff --git a/docs/core_abstraction/flow.md b/docs/core_abstraction/flow.md index d995710..a59fe3a 100644 --- a/docs/core_abstraction/flow.md +++ b/docs/core_abstraction/flow.md @@ -302,7 +302,7 @@ flowchart LR ### Cycle Detection -Loops are created by connecting a node back to a previously executed node. To prevent infinite loops, `Flow` includes cycle detection controlled by the `maxVisits` option in its constructor (default is 5). If a node is visited more times than `maxVisits` during a single run, an error is thrown. +Loops are created by connecting a node back to a previously executed node. To prevent infinite loops, `Flow` includes cycle detection controlled by the `maxVisits` option in its constructor (default is 15). If a node is visited more times than `maxVisits` during a single run, an error is thrown. {% hint style="success" %} This ensures that the flow does not get stuck in an infinite loop. @@ -313,7 +313,7 @@ This ensures that the flow does not get stuck in an infinite loop. const flow = new Flow(startNode, { maxVisits: 10 }) ``` -- The default value for `maxVisits` is `5`. +- The default value for `maxVisits` is `15`. - Set `maxVisits` to `Infinity` or a very large number for effectively no limit (use with caution!). ## Flow Parallelism diff --git a/docs/guides/migrating_from_pocketflow.md b/docs/guides/migrating_from_pocketflow.md index 0515ffe..978e26d 100644 --- a/docs/guides/migrating_from_pocketflow.md +++ b/docs/guides/migrating_from_pocketflow.md @@ -53,6 +53,7 @@ from brainyflow import Node, Flow, SequentialBatchNode # ... etc ### Step 2: Add `async` / `await`: - Add `async` before `def` for your `prep`, `exec`, `post`, and `exec_fallback` methods in Nodes and Flows. +- Remove any `_async` suffix from the method names. - Add `await` before any calls to these methods, `run()` methods, `asyncio.sleep()`, or other async library functions. #### Node Example (Before): @@ -106,32 +107,49 @@ class MyNode(Node): _(Flow methods follow the same pattern)_ -### Step 3: Update Batch Processing Implementation (`*BatchNode` / `*BatchFlow` Removal) +### Step 3: Use `.trigger()` for next actions + +Check all `.post()` methods and replace any `return action` with a call to `self.trigger(action)`. _`return "default"` can be either replaced or removed._ + +### Step 4: Replace memory access methods: + +Replace `shared.get(` by `getattr(shared, ` if `shared` is a `Memory` instance. + +### Step 5: Update Batch Processing Implementation (`*BatchNode` / `*BatchFlow` Removal) PocketFlow had dedicated classes like `BatchNode`, `ParallelBatchNode`, `BatchFlow`, and `ParallelBatchFlow`. BrainyFlow v0.3+ **removes these specialized classes**. -The functionality is now achieved using standard `Node`s and `Flow`s combined with a specific pattern: +The batch functionality is now achieved using standard `Node`s and `Flow`s combined with a specific pattern: -1. **Rename Classes**: +1. **Adopt the Fan-Out Trigger Pattern**: - - Replace `BatchNode`, `AsyncBatchNode`, `ParallelBatchNode`, `AsyncParallelBatchNode` with the standard `brainyflow.Node`. - - Replace `BatchFlow`, `AsyncBatchFlow` with `brainyflow.Flow`. - - Replace `AsyncParallelBatchFlow` with `brainyflow.ParallelFlow`. - - Remember to make `prep`, `exec`, `post` methods `async` as per Step 2. + All `BatchNode` need to be split into two, a **Trigger Node** and a **Processor Node**. -2. **Adopt the Fan-Out Trigger Pattern**: + The **Trigger Node**: + - Use the `prep` method to fetch the list of items to process, as usual. + - Use the `post` method to iterate through these items. For **each item**, calls `self.trigger(action, forkingData={"item": current_item, "index": i, ...})`. The `forkingData` dictionary passes item-specific data into the **local memory** of the triggered successor. (the `action` name can be any of your choice as long as you connect the nodes in the flow; e.g. `process_one`, `default`) - - The node that previously acted as the `BatchNode` (or the starting node of a `BatchFlow`) needs to be refactored into a **Trigger Node**. - - Its `prep` method usually fetches the list of items to process. - - Its `post` method iterates through these items. For **each item**, it calls `self.trigger("process_one", forkingData={"item": current_item, "index": i, ...})`. The `forkingData` dictionary passes item-specific data into the **local memory** of the triggered successor. - - The logic previously in the `exec_one` method of the `BatchNode` must be moved into the `exec` method of a new **Processor Node**. - - This `ProcessorNode` is connected to the `TriggerNode` via the `"process_one"` action (e.g., `trigger_node.on("process_one", processor_node)`). - - The `ProcessorNode`'s `prep` method reads the specific item data (e.g., `memory.item`, `memory.index`) from its **local memory**, which was populated by the `forkingData`. + The **Processor Node**: + - The `ProcessorNode`'s `prep` method reads the specific item data (e.g., `memory.item`, `memory.index`) from its **local memory**, which was populated by the `forkingData` in the trigger node. + - The logic previously in the `exec_one` method of the `BatchNode` must be renamed to `exec`. - Its `post` method typically writes the result back to the **global memory**, often using the index to place it correctly in a shared list or dictionary. -3. **Choose the Right Flow**: + Similarly, `BatchFlow` need to be split into a `Node` and a regular `Flow`: + - Replace the return value of the `prep` method with a `post` method containing trigger calls. + - Instead of `self.params["property"]`, use the usual `memory.property`. + +2. **Choose the Right Flow**: - Wrap the `TriggerNode` and `ProcessorNode` in a standard `brainyflow.Flow` if you need items processed **sequentially**. - Wrap them in a `brainyflow.ParallelFlow` if items can be processed **concurrently**. + - Connect the nodes: `trigger_node >> processor_node` or `trigger_node - action >> processor_node` + +3. **Rename All Classes**: + + - Replace `AsyncParallelBatchFlow` with `brainyflow.ParallelFlow`. + - Replace `AsyncParallelBatchNode`, `ParallelBatchNode`, `AsyncBatchNode`, `BatchNode` with the standard `brainyflow.Node`. + - Replace `AsyncBatchFlow`, `BatchFlow` with `brainyflow.Flow`. + - Remember to make `prep`, `exec`, `post` methods `async` as per Step 2. + #### Example: Translating Text into Multiple Languages @@ -330,7 +348,7 @@ triggerNode.next(processorNode) _(See the [MapReduce design pattern](../design_pattern/mapreduce.md) for more detailed examples of fan-out/aggregate patterns)._ -### Step 4: Run with `asyncio`: +### Step 6: Run with `asyncio`: BrainyFlow code must be run within an async event loop. The standard way is using `asyncio.run()`: @@ -353,9 +371,11 @@ if __name__ == "__main__": Migrating from PocketFlow to BrainyFlow primarily involves: 1. Updating imports to `brainyflow` and adding `import asyncio`. -2. Adding `async` to your Node/Flow method definitions (`prep`, `exec`, `post`, `exec_fallback`). +2. Adding `async` to your Node/Flow method definitions (`prep`, `exec`, `post`, `exec_fallback`) and removing any `_async` suffix from the method names. +3. Replacing any `return action` in `post()` with a call to `self.trigger(action)`. 3. Using `await` when calling `run()` methods and any other asynchronous operations within your methods. 4. Replacing `BatchNode`/`BatchFlow` with the appropriate `Sequential*` or `Parallel*` BrainyFlow classes. 5. Running your main execution logic within an `async def main()` function called by `asyncio.run()`. +6. Replacing all `return action` by `self.trigger(action)` in your Node methods. This transition enables you to leverage the performance and concurrency benefits of asynchronous programming in your workflows. diff --git a/docs/guides/testing.md b/docs/guides/testing.md index 74da52c..3d99ba1 100644 --- a/docs/guides/testing.md +++ b/docs/guides/testing.md @@ -298,7 +298,7 @@ async def mock_llm_logic(prompt: str) -> str: async def test_node_with_mocked_llm(): # Assume MyLlmNode calls utils.call_llm internally # node = MyLlmNode() - # memory = Memory.create({"input": "some text to summarize"}) + # memory = Memory({"input": "some text to summarize"}) # Use patch to replace the actual call_llm with patch('utils.call_llm', new=AsyncMock(side_effect=mock_llm_logic)) as mock_call: @@ -394,7 +394,7 @@ async def test_retry_logic(): global call_count_retry call_count_retry = 0 # Reset counter for test # node = NodeWithRetry() - # memory = Memory.create({}) + # memory = Memory({}) # Patch the external call made within node.exec # Also patch asyncio.sleep to avoid actual waiting diff --git a/docs/installation.md b/docs/installation.md index 8439812..40f3933 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,3 +1,7 @@ +--- +machine-display: false +--- + # Installation BrainyFlow is currently available for both Python and TypeScript. diff --git a/package.json b/package.json index 762ab79..7a42465 100644 --- a/package.json +++ b/package.json @@ -4,10 +4,10 @@ "@ianvs/prettier-plugin-sort-imports": "4.4.1", "@total-typescript/ts-reset": "0.6.1", "@tsconfig/node-lts": "22.0.1", - "@types/node": "22.13.14", + "@types/node": "22.15.21", "prettier": "3.5.3", - "tsx": "4.19.3", - "typescript": "^5.8.2" + "tsx": "4.19.4", + "typescript": "5.8.3" }, - "packageManager": "pnpm@10.8.1" + "packageManager": "pnpm@10.11.0" } diff --git a/prettier.config.cjs b/prettier.config.mjs similarity index 88% rename from prettier.config.cjs rename to prettier.config.mjs index 23400d3..732661c 100755 --- a/prettier.config.cjs +++ b/prettier.config.mjs @@ -1,16 +1,15 @@ /** - * @see https://prettier.io/docs/en/configuration.html + * @see https://prettier.io/docs/configuration * @type {import("prettier").Config} */ - -module.exports = { +const config = { editorconfig: true, semi: false, useTabs: false, singleQuote: true, arrowParens: 'always', tabWidth: 2, - printWidth: 100, + printWidth: 130, trailingComma: 'all', plugins: ['@ianvs/prettier-plugin-sort-imports'], importOrderTypeScriptVersion: '5.4.5', @@ -40,3 +39,5 @@ module.exports = { '^./types$', ], } + +export default config diff --git a/python/.changeset/may.md b/python/.changeset/may.md new file mode 100644 index 0000000..dbe39e8 --- /dev/null +++ b/python/.changeset/may.md @@ -0,0 +1,60 @@ +--- +"brainyflow": major +--- + +# Major refactor and enhancement release + +## Breaking Changes + +* **Memory creation**: If you were creating a new Memory object using the `Memory.create(data)` static class method, you will need to replace it by simply `Memory(data)`. +* **Flow.run()**: If you were traversing the nodes representation given by the `.run()` method, you will need to update your code to use the new `ExecutionTree` class. + +### Core Library Changes +- **Memory class**: Considerable refactor with new deletion methods (`__delattr__`, `__delitem__`) and improved proxy behavior for local memory access +- **Flow class**: Default `maxVisits` increased from 5 to 15 for cycle detection +- **NodeError**: Changed from Exception class to Protocol interface +- **ExecutionTree**: Updated structure for better result aggregation and tracking +- **Type annotations**: Improved throughout with better Generic constraints and Protocol usage, fixing all type errors and inconsistencies + +### API Changes +- Memory deletion operations now support both attribute and item-style deletion +- Error message format updated for cycle detection: now shows "Maximum cycle count (N) reached for ClassName#nodeId" +- Node execution warnings removed for nodes with successors + +## New Features + +### Memory Management +- Added comprehensive deletion support for Memory objects +- New local proxy with isolated deletion operations +- Better memory cloning with forking data support +- Enhanced store management with helper functions + +### Developer Experience +- Improved migration documentation with detailed examples +- Added "Contributors Wanted" section to encourage community participation +- Better test isolation and predictable node ID management +- Enhanced error messages and debugging information + +## Infrastructure Improvements + +### CI/CD Pipeline +- Updated changesets action to v1.4.9 (workaround for github.com/changesets/action/issues/501) + +### Testing +- Comprehensive test suite updates for new Memory functionality +- Added deletion operation tests +- Improved test reliability with BaseNode ID resets +- Better assertion patterns matching new error formats + +### Documentation +- Updated migration guide from PocketFlow with clearer examples +- Enhanced core abstraction documentation +- Improved installation and setup instructions + +## Bug Fixes +- Fixed memory proxy behavior edge cases +- Improved error handling in node execution +- Better cycle detection error reporting +- Enhanced type safety throughout the codebase + +This release represents a significant improvement in memory management, developer experience, and overall library robustness while maintaining the core workflow orchestration capabilities. diff --git a/python/brainyflow.py b/python/brainyflow.py index 18893d6..3c4d493 100644 --- a/python/brainyflow.py +++ b/python/brainyflow.py @@ -1,97 +1,82 @@ +from __future__ import annotations import asyncio import copy import warnings from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, TypeVar, Generic, Callable, Union, cast, TypedDict, Literal, overload, Awaitable, Sequence +from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, TypeAlias, TypeVar, Generic, Callable, Union, cast, TypedDict, Literal, overload, Awaitable, Sequence, runtime_checkable DEFAULT_ACTION = 'default' - Action = str SharedStore = Dict[str, Any] - G = TypeVar('G', bound=SharedStore) L = TypeVar('L', bound=SharedStore) T = TypeVar('T') PrepResultT = TypeVar('PrepResultT') ExecResultT = TypeVar('ExecResultT') ActionT = TypeVar('ActionT', bound=str) - +AnyNode: TypeAlias = 'BaseNode[G, Any, Any, Any, Any]' +class ExecutionTree(TypedDict): + order: str + type: str + triggered: Optional[Dict[Action, List["ExecutionTree"]]] class Trigger(TypedDict): - """Represents a triggered action with forking data.""" action: Action forking_data: SharedStore +def _get_from_stores(key: str, primary: SharedStore, secondary: SharedStore | None = None, Error: Type[Exception] = KeyError) -> Any: + if key in primary: return primary[key] + if secondary is not None and key in secondary: return secondary[key] + raise Error(f"Key '{key}' not found in store{'s' if secondary else ''}") + +def _delete_from_stores(key: str, primary: SharedStore, secondary: SharedStore | None = None) -> None: + if key not in primary and (secondary is None or key not in secondary): raise KeyError(key) + if key in primary: del primary[key] + if secondary is not None and key in secondary: del secondary[key] + class Memory(Generic[G, L]): """ - Memory class for managing global and local state. - Memory provides a dual-scope approach to state management: + Manager of global and local state. Provides a dual-scope approach to state management: - Global store: Shared across the entire flow - Local store: Specific to a particular execution path """ - def __init__(self, _global: G, _local: Optional[L] = None): - """Initialize a Memory instance with global and optional local stores.""" - # Directly set attributes in __dict__ to avoid __setattr__ object.__setattr__(self, '_global', _global) - object.__setattr__(self, '_local', _local if _local is not None else cast(L, {})) - - def __getattr__(self, name: str) -> Any: - """Access properties, checking local store first, then global.""" - if name in self._local: - return self._local[name] - if name in self._global: - return self._global[name] - raise AttributeError(f"'Memory' object has no attribute {name!r}") - - def __setattr__(self, name: str, value: Any) -> None: - """Write properties, handling reserved names and local/global interaction.""" - # Reserved property handling - if name in ['global', 'local', '_global', '_local', 'clone', 'create']: - raise ValueError(f"Reserved property '{name}' cannot be set") - - # Remove from local if exists, then set in global - if name in self._local: - del self._local[name] - - # Set in global store - self._global[name] = value - - def __getitem__(self, key: str) -> Any: - """Support dictionary-style access (memory['key']).""" - if key in self._local: - return self._local[key] - return self._global.get(key) - - def __setitem__(self, key: str, value: Any) -> None: - """Support dictionary-style assignment (memory['key'] = value).""" - # Remove from local if exists, then set in global + object.__setattr__(self, '_local', _local if _local else cast(L, {})) + def __getattr__(self, key: str) -> Any: return _get_from_stores(key, self._local, self._global, Error=AttributeError) + def __getitem__(self, key: str) -> Any: return _get_from_stores(key, self._local, self._global) + def _set_value(self, key: str, value: Any) -> None: + assert key not in ['global', 'local', '_global', '_local', 'clone', 'create'], f"Reserved property '{key}' cannot be set" if key in self._local: del self._local[key] self._global[key] = value - - def __contains__(self, key: str) -> bool: - """Support 'in' operator (key in memory).""" - return key in self._local or key in self._global - + def __setattr__(self, name: str, value: Any) -> None: self._set_value(name, value) + def __setitem__(self, key: str, value: Any) -> None: self._set_value(key, value) + def __delattr__(self, key: str) -> None: _delete_from_stores(key, self._global, self._local) + def __delitem__(self, key: str) -> None: _delete_from_stores(key, self._global, self._local) + def __contains__(self, key: str) -> bool: return key in self._local or key in self._global + def clone(self, forking_data: Optional[SharedStore] = None) -> Memory[G, L]: + new_local = copy.deepcopy(self._local) + new_local.update(copy.deepcopy(forking_data or {})) + return Memory[G, L](self._global, cast(L, new_local)) @property def local(self) -> L: - """Access the local store directly.""" - return self._local - - def clone(self, forking_data: Optional[SharedStore] = None) -> 'Memory[G, L]': - """Create a new Memory with shared global store but deep-copied local store.""" - forking_data = forking_data or {} - new_local_store = copy.deepcopy(self._local) - new_local_store.update(copy.deepcopy(forking_data)) - return Memory.create(self._global, cast(L, new_local_store)) - - @staticmethod - def create(global_store: G, local_store: Optional[L] = None) -> 'Memory[G, L]': - """Factory method to create a Memory instance.""" - return Memory(global_store, local_store if local_store is not None else cast(L, {})) + class LocalProxy: + def __init__(self, store: L) -> None: self.store = store + def __getattr__(self, key: str) -> Any: return _get_from_stores(key, self.store, Error=AttributeError) + def __getitem__(self, key: str) -> Any: return _get_from_stores(key, self.store) + def __setattr__(self, key: str, value: Any) -> None: self.store[key] = value + def __setitem__(self, key: str, value: Any) -> None: self.store[key] = value + def __delattr__(self, key: str) -> None: _delete_from_stores(key, self.store) + def __delitem__(self, key: str) -> None: _delete_from_stores(key, self.store) + def __contains__(self, key: str) -> bool: return key in self.store + def __eq__(self, other: object) -> bool: + if isinstance(other, LocalProxy): return self.store == other.store + return self.store == other + def __repr__(self) -> str: return self.store.__repr__() + return cast(L, LocalProxy(self._local)) -class NodeError(Exception): - """Error raised during node execution with retry count information.""" +@runtime_checkable +class NodeError(Protocol): retry_count: int = 0 class BaseNode(Generic[G, L, ActionT, PrepResultT, ExecResultT], ABC): @@ -106,20 +91,18 @@ class BaseNode(Generic[G, L, ActionT, PrepResultT, ExecResultT], ABC): - PrepResultT: Return type of prep method - ExecResultT: Return type of exec method """ - _next_id = 0 def __init__(self) -> None: - """Initialize a BaseNode instance.""" - self.successors: Dict[Action, List['BaseNode']] = {} # dict of action -> list of nodes + self.successors: Dict[Action, List[AnyNode[G]]] = {} # dict of action -> list of nodes self._triggers: List[Trigger] = [] # list of dicts with action and forking_data + self._permeated: List[Action] = [] self._locked: bool = True # Prevent trigger calls outside post() self._node_order: int = BaseNode._next_id BaseNode._next_id += 1 - def clone(self, seen: Optional[Dict['BaseNode', 'BaseNode']] = None) -> 'BaseNode[G, L, ActionT, PrepResultT, ExecResultT]': + def clone(self, seen: Optional[Dict[AnyNode[G], AnyNode[G]]] = None) -> BaseNode[G, L, ActionT, PrepResultT, ExecResultT]: """Create a deep copy of the node including its successors.""" - # Create a deep copy with cycle detection seen = seen or {} if self in seen: return seen[self] @@ -133,6 +116,7 @@ def clone(self, seen: Optional[Dict['BaseNode', 'BaseNode']] = None) -> 'BaseNod if key != 'successors': # Shallow-copy by default; deep-copy lists/dicts/sets to prevent sharing setattr(cloned, key, copy.deepcopy(value) if isinstance(value, (list, dict, set)) else value) + # Clone successors with cycle detection cloned.successors = {} for action, nodes in self.successors.items(): @@ -141,38 +125,31 @@ def clone(self, seen: Optional[Dict['BaseNode', 'BaseNode']] = None) -> 'BaseNod ] return cloned - - def on(self, action: Action, node: 'BaseNode') -> 'BaseNode': + def on(self, action: Action, node: AnyNode[G]) -> AnyNode[G]: """Add a successor node for a specific action.""" if action not in self.successors: self.successors[action] = [] self.successors[action].append(node) return node - def next(self, node: 'BaseNode', action: Action = DEFAULT_ACTION) -> 'BaseNode': + def next(self, node: AnyNode[G], action: Action = DEFAULT_ACTION) -> AnyNode[G]: """Convenience method equivalent to on().""" return self.on(action, node) - # Python-specific syntax sugar - def __rshift__(self, other: 'BaseNode') -> 'BaseNode': + def __rshift__(self, other: AnyNode[G]) -> AnyNode[G]: """Implement node_a >> node_b syntax for default action""" return self.next(other) - def __sub__(self, action: Action) -> 'ActionLinker': + def __sub__(self, action: Action): """Implement node_a - "action" syntax for action selection""" - return self.ActionLinker(self, action) - - class ActionLinker: - """Helper class for action-specific transitions""" - def __init__(self, node: 'BaseNode', action: Action): - self.node = node - self.action = action - - def __rshift__(self, other: 'BaseNode') -> 'BaseNode': - """Implement - "action" >> node_b syntax""" - return self.node.on(self.action, other) - - def get_next_nodes(self, action: Action = DEFAULT_ACTION) -> List['BaseNode']: + that = self + class ActionLinker: + def __rshift__(self, other: AnyNode[G]) -> AnyNode[G]: + """Implement - "action" >> node_b syntax""" + return that.on(action, other) + return ActionLinker() + + def get_next_nodes(self, action: Action = DEFAULT_ACTION) -> List[AnyNode[G]]: """Get successor nodes for a specific action.""" next_nodes = self.successors.get(action, []) if not next_nodes and action != DEFAULT_ACTION and self.successors: @@ -193,8 +170,7 @@ async def post(self, memory: Memory[G, L], prep_res: PrepResultT, exec_res: Exec def trigger(self, action: ActionT, forking_data: Optional[SharedStore] = None) -> None: """Trigger a successor action with optional forking data.""" - if self._locked: - raise RuntimeError("An action can only be triggered inside post()") + assert not self._locked, "An action can only be triggered inside post()" self._triggers.append({ "action": action, @@ -203,7 +179,7 @@ def trigger(self, action: ActionT, forking_data: Optional[SharedStore] = None) - def list_triggers(self, memory: Memory[G, L]) -> List[Tuple[Action, Memory[G, L]]]: """Process triggers or return default.""" - if not self._triggers: + if not self._triggers and not DEFAULT_ACTION in self._permeated: return [(DEFAULT_ACTION, memory.clone())] return [(t["action"], memory.clone(t["forking_data"])) for t in self._triggers] @@ -215,19 +191,15 @@ async def exec_runner(self, memory: Memory[G, L], prep_res: PrepResultT) -> Exec @overload async def run(self, memory: Union[Memory[G, L], G], propagate: Literal[True]) -> List[Tuple[Action, Memory[G, L]]]: ... - @overload async def run(self, memory: Union[Memory[G, L], G], propagate: Literal[False] = False) -> ExecResultT: ... - async def run(self, memory: Union[Memory[G, L], G], propagate: bool = False) -> Union[List[Tuple[Action, Memory[G, L]]], ExecResultT]: - """Run the node's full lifecycle.""" - if self.successors: - warnings.warn("Node won't run successors. Use Flow!", stacklevel=2) - + """Run the node's full lifecycle (prep → exec → post).""" if not isinstance(memory, Memory): - memory = Memory.create(memory) + memory = Memory[G, L](memory) self._triggers = [] + self._permeated = [] prep_res = await self.prep(memory) exec_res = await self.exec_runner(memory, prep_res) @@ -248,7 +220,6 @@ class Node(BaseNode[G, L, ActionT, PrepResultT, ExecResultT]): wait: Seconds to wait between retry attempts cur_retry: Current retry attempt (0-indexed) """ - def __init__(self, max_retries: int = 1, wait: float = 0) -> None: """Initialize a Node with retry configuration.""" super().__init__() @@ -256,7 +227,7 @@ def __init__(self, max_retries: int = 1, wait: float = 0) -> None: self.wait = wait self.cur_retry = 0 - async def exec_fallback(self, prep_res: PrepResultT, error: NodeError) -> ExecResultT: + async def exec_fallback(self, prep_res: PrepResultT, error: Exception) -> ExecResultT: """Called when all retry attempts fail.""" raise error @@ -267,17 +238,16 @@ async def exec_runner(self, memory: Memory[G, L], prep_res: PrepResultT) -> Exec try: return await self.exec(prep_res) except Exception as error: + if not hasattr(error, 'retry_count'): + error.retry_count = attempt + 1 # type: ignore if attempt < self.max_retries - 1: if self.wait > 0: await asyncio.sleep(self.wait) continue - - wrapped = error if isinstance(error, NodeError) else NodeError(str(error)).with_traceback(error.__traceback__) - wrapped.retry_count = attempt + 1 - return await self.exec_fallback(prep_res, wrapped) - raise RuntimeError("Unreachable: exec_runner should have returned or raised in the loop") + return await self.exec_fallback(prep_res, error) + raise RuntimeError("Unreachable: exec_runner should have returned or raised in the loop") # This should never happen if max_retries > 0 -class Flow(BaseNode[G, L, ActionT, PrepResultT, Dict[str, Any]]): +class Flow(BaseNode[G, L, ActionT, PrepResultT, ExecutionTree]): """ Orchestrates the execution of a graph of nodes sequentially. @@ -286,18 +256,17 @@ class Flow(BaseNode[G, L, ActionT, PrepResultT, Dict[str, Any]]): options: Configuration options like max_visits visit_counts: Tracks node visits for cycle detection """ - - def __init__(self, start: BaseNode, options: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, start: AnyNode[G], options: Optional[Dict[str, Any]] = None) -> None: """Initialize a Flow with a start node and options.""" super().__init__() self.start = start - self.options = options or {"max_visits": 5} + self.options = options or {"max_visits": 15} self.visit_counts: Dict[str, int] = {} - async def exec(self, prep_res: PrepResultT) -> Dict[str, Any]: + async def exec(self, prep_res: PrepResultT) -> ExecutionTree: raise RuntimeError("This method should never be called in a Flow") - async def exec_runner(self, memory: Memory[G, L], prep_res: PrepResultT) -> Dict[str, Any]: + async def exec_runner(self, memory: Memory[G, L], prep_res: PrepResultT) -> ExecutionTree: """Run the flow starting from the start node.""" self.visit_counts = {} # Reset visit counts return await self.run_node(self.start, memory) @@ -309,59 +278,52 @@ async def run_tasks(self, tasks: Sequence[Callable[[], Awaitable[T]]]) -> List[T results.append(await task()) return results - async def run_nodes(self, nodes: List[BaseNode], memory: Memory[G, L]) -> List[Any]: + async def run_nodes(self, nodes: List[AnyNode[G]], memory: Memory[G, L]) -> List[ExecutionTree]: """Run a list of nodes with the given memory.""" - tasks: List[Callable[[], Awaitable[Any]]] = [ - lambda n=node, m=memory: self.run_node(n, m) for node in nodes + tasks: List[Callable[[], Awaitable[ExecutionTree]]] = [ + (lambda n=node, m=memory: lambda: self.run_node(n, m))() for node in nodes ] return await self.run_tasks(tasks) - async def run_node(self, node: BaseNode, memory: Memory[G, L]) -> Dict[str, Any]: - """Run a node with cycle detection.""" - node_id = str(node._node_order) - + async def run_node(self, node: AnyNode[G], memory: Memory[G, L]) -> ExecutionTree: + """Run a node with cycle detection and return its execution log.""" + node_order = str(node._node_order) # Check for cycles - current_visit_count = self.visit_counts.get(node_id, 0) + 1 - if current_visit_count > self.options["max_visits"]: - raise RuntimeError( - f"Maximum cycle count reached ({self.options['max_visits']}) for " - f"{node_id}.{node.__class__.__name__}" - ) + current_visit_count = self.visit_counts.get(node_order, 0) + 1 + assert current_visit_count <= self.options["max_visits"], f"Maximum cycle count ({self.options['max_visits']}) reached for {node.__class__.__name__}#{node_order}" + self.visit_counts[node_order] = current_visit_count - self.visit_counts[node_id] = current_visit_count - - # Clone node and run with propagate=True cloned_node = node.clone() triggers = await cloned_node.run(memory.clone(), True) - # Process each trigger and collect results - tasks: List[Callable[[], Awaitable[Tuple[Action, List[Any]]]]] = [] + triggered: Dict[Action, List[ExecutionTree]] = {} + tasks: List[Callable[[], Awaitable[Tuple[Action, List[ExecutionTree]]]]] = [] for action, node_memory in triggers: next_nodes = cloned_node.get_next_nodes(action) - tasks.append( - lambda a=action, nn=next_nodes, nm=node_memory: self._process_trigger(a, nn, nm) - ) - - # Run all trigger tasks and build result tree - tree = await self.run_tasks(tasks) - return {action: results for action, results in tree} - - async def _process_trigger(self, action: Action, next_nodes: List[BaseNode], node_memory: Memory[G, L]) -> Tuple[Action, List[Any]]: - """Process a single trigger.""" - if not next_nodes: - return (action, []) - - results = await self.run_nodes(next_nodes, node_memory) - return (action, results) + if not next_nodes: + # If the sub-node triggered an action that has no successors *within its own definition*, + # that action becomes a terminal trigger for this Flow itself (if Flow is nested). + next_nodes = [n.clone() for n in self.get_next_nodes(action)] + # if next_nodes: + self._permeated.append(action) + + if next_nodes: + tasks.append((lambda act=action, nn_list=next_nodes, nm_mem=node_memory: \ + lambda: self._process_trigger(act, nn_list, nm_mem))()) + + tree: List[Tuple[Action, List[ExecutionTree]]] = await self.run_tasks(tasks) + for action, resulting_node_logs in tree: + triggered[action] = resulting_node_logs + + return { 'order': node_order, 'type': node.__class__.__name__, 'triggered': triggered if triggered else None } + + async def _process_trigger(self, action: Action, next_nodes: List[AnyNode[G]], node_memory: Memory[G, L]) -> Tuple[Action, List[ExecutionTree]]: + """Process a single trigger by running its next_nodes.""" + return (action, await self.run_nodes(next_nodes, node_memory)) class ParallelFlow(Flow[G, L, ActionT, PrepResultT]): - """ - Orchestrates execution of a graph of nodes with parallel branching. - Overrides run_tasks to execute tasks concurrently using asyncio.gather. - """ - + """Orchestrates execution of a graph of nodes with parallel branching.""" async def run_tasks(self, tasks: Sequence[Callable[[], Awaitable[T]]]) -> List[T]: - """Run tasks concurrently using asyncio.gather.""" if not tasks: return [] return await asyncio.gather(*(task() for task in tasks)) diff --git a/python/design.md b/python/design.md index a504302..63dc70d 100644 --- a/python/design.md +++ b/python/design.md @@ -24,7 +24,6 @@ Manages the state accessible to nodes during execution. - **Membership Testing:** Supports the `in` operator via `__contains__` to check for key existence in either local or global scope. - **Local Property:** The `local` property provides direct read access to the `_local` dictionary. - **Cloning:** The `clone(forking_data=None)` method creates a new `Memory` instance. The global store is shared by reference, while the local store is deep-copied. Optional `forking_data` can be provided to initialize or update the new local store. This is crucial for branching and parallel execution to ensure state isolation where needed. -- **Factory Method:** `Memory.create(global_store, local_store=None)` provides a static method for instantiation. ### 2. BaseNode (`brainyflow.BaseNode`) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 59a1eb2..0dc88cd 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -145,7 +145,7 @@ async def exec(self, prep_res): def memory(): """Create a test memory instance.""" global_store = {"initial": "global"} - return Memory.create(global_store) + return Memory(global_store) @pytest.fixture def test_nodes(): diff --git a/python/tests/design.md b/python/tests/design.md index 4c4cc5c..2f28ef3 100644 --- a/python/tests/design.md +++ b/python/tests/design.md @@ -24,7 +24,7 @@ This document outlines the testing strategy for the Python port of the `brainyfl ### 3.1. `Memory` Class - **Initialization:** - - `Memory.create()` correctly initializes global and optional local stores + - `Memory()` correctly initializes global and optional local stores - Global and local stores are properly accessible - **Proxy Behavior (Reading):** - Reads property from local store if present diff --git a/python/tests/test_flow.py b/python/tests/test_flow.py index daed6d2..9e7c1f4 100644 --- a/python/tests/test_flow.py +++ b/python/tests/test_flow.py @@ -1,6 +1,6 @@ import pytest -from unittest.mock import Mock, AsyncMock -from brainyflow import Memory, Node, Flow, DEFAULT_ACTION +from unittest.mock import AsyncMock +from brainyflow import Memory, Node, Flow, DEFAULT_ACTION, BaseNode # --- Helper Node Implementations --- class BaseTestNode(Node): @@ -58,11 +58,19 @@ class TestFlow: def memory(self): """Create a test memory instance.""" global_store = {"initial": "global"} - return Memory.create(global_store) + return Memory(global_store) @pytest.fixture def nodes(self): - """Create test nodes.""" + """Create test nodes. + IMPORTANT: Reset BaseNode._next_id to ensure predictable orders for tests. + This should ideally be handled by a session-scoped or test-scoped fixture + if node creation order varies significantly across test files or setup. + For now, we assume it's reset or nodes are created fresh with predictable IDs. + """ + # Resetting for test predictability. If BaseNode is imported by multiple test files, + # this might need a more robust solution (e.g. pytest_runtest_setup fixture). + BaseNode._next_id = 0 return { "A": BaseTestNode("A"), "B": BaseTestNode("B"), @@ -70,6 +78,13 @@ def nodes(self): "D": BaseTestNode("D") } + @pytest.fixture + def branching_node_fixture(self): + # Separate fixture for branching node to control its _node_order independently if needed + BaseNode._next_id = 0 # Example of resetting if it's the first node in a test + return BranchingNode("Branch") + + class TestInitialization: """Tests for Flow initialization.""" @@ -77,7 +92,7 @@ def test_store_start_node_and_default_options(self, nodes): """Should store the start node and default options.""" flow = Flow(nodes["A"]) assert flow.start == nodes["A"] - assert getattr(flow, "options", {}).get("max_visits") == 5 + assert getattr(flow, "options", {}).get("max_visits") == 15 def test_accept_custom_options(self, nodes): """Should accept custom options.""" @@ -131,43 +146,59 @@ class TestConditionalBranching: async def test_follow_correct_path_based_on_triggered_action(self, nodes, memory): """Should follow the correct path based on triggered action.""" - branching_node = BranchingNode("Branch") - branching_node.on("path_B", nodes["B"]) - branching_node.on("path_C", nodes["C"]) + # Reset BaseNode ID for predictable IDs in this test + BaseNode._next_id = 0 + branching_node = BranchingNode("Branch") # id_Branch = 0 + # nodes A, B, C, D are created by fixture, their IDs will be 0,1,2,3 if nodes fixture is used first + # If branching_node is created first, its ID will be 0. + # Let's use fresh nodes for clarity here if IDs matter deeply for the test logic. + # For this test, memory state check is primary. + + node_b_local = BaseTestNode("B_local") # id_B_local = 1 + node_c_local = BaseTestNode("C_local") # id_C_local = 2 + + branching_node.on("path_B", node_b_local) + branching_node.on("path_C", node_c_local) # Test path B branching_node.set_trigger("path_B") flow_b = Flow(branching_node) - memory_b = Memory.create({}) + memory_b = Memory({}) await flow_b.run(memory_b) assert memory_b.post_Branch is True - assert memory_b.post_B is True - assert getattr(memory_b, "post_C", None) is None + assert memory_b.post_B_local is True + assert getattr(memory_b, "post_C_local", None) is None # Test path C - branching_node.set_trigger("path_C") # Reset trigger - flow_c = Flow(branching_node) # New flow to reset visit counts - memory_c = Memory.create({}) + # Re-create branching_node or use a new Flow instance to reset visit counts + # and ensure post_mock counts are fresh for this path. + BaseNode._next_id = 0 # Reset again if we want same ID for branching_node + branching_node_for_c = BranchingNode("Branch") # id = 0 + node_b_for_c = BaseTestNode("B_for_C") # id = 1 + node_c_for_c = BaseTestNode("C_for_C") # id = 2 + branching_node_for_c.on("path_B", node_b_for_c) + branching_node_for_c.on("path_C", node_c_for_c) + + branching_node_for_c.set_trigger("path_C") + flow_c = Flow(branching_node_for_c) + memory_c = Memory({}) await flow_c.run(memory_c) assert memory_c.post_Branch is True - assert getattr(memory_c, "post_B", None) is None - assert memory_c.post_C is True + assert getattr(memory_c, "post_B_for_C", None) is None + assert memory_c.post_C_for_C is True class TestMemoryHandling: """Tests for memory handling.""" async def test_propagate_global_memory_changes(self, nodes, memory): """Should propagate global memory changes.""" - # Setup mock to modify memory - # Accept extra args passed by post_mock async def modify_memory(mem, prep_res, exec_res): mem.global_A = "set_by_A" nodes["A"].post_mock.side_effect = modify_memory - # Setup mock to verify memory async def verify_memory(mem): assert mem.global_A == "set_by_A" @@ -178,15 +209,18 @@ async def verify_memory(mem): await flow.run(memory) assert memory.global_A == "set_by_A" - assert nodes["B"].prep_mock.call_count == 1 # Ensure B ran + assert nodes["B"].prep_mock.call_count == 1 async def test_isolate_local_memory_using_forking_data(self, nodes, memory): """Should isolate local memory using forkingData.""" - branching_node = BranchingNode("Branch") - branching_node.on("path_B", nodes["B"]) - branching_node.on("path_C", nodes["C"]) + BaseNode._next_id = 0 + branching_node = BranchingNode("Branch") # id 0 + node_b_local = BaseTestNode("B_local") # id 1 + node_c_local = BaseTestNode("C_local") # id 2 + + branching_node.on("path_B", node_b_local) + branching_node.on("path_C", node_c_local) - # Setup mocks to check local memory async def check_b_memory(mem): assert mem.local_data == "for_B" assert mem.common_local == "common" @@ -197,28 +231,37 @@ async def check_c_memory(mem): assert mem.common_local == "common" assert mem.local["local_data"] == "for_C" - nodes["B"].prep_mock.side_effect = check_b_memory - nodes["C"].prep_mock.side_effect = check_c_memory + node_b_local.prep_mock.side_effect = check_b_memory + node_c_local.prep_mock.side_effect = check_c_memory - # Trigger B with specific local data branching_node.set_trigger("path_B", {"local_data": "for_B", "common_local": "common"}) flow_b = Flow(branching_node) - memory_b = Memory.create({"global_val": 1}) + memory_b = Memory({"global_val": 1}) await flow_b.run(memory_b) - assert nodes["B"].prep_mock.call_count == 1 - assert nodes["C"].prep_mock.call_count == 0 - assert getattr(memory_b, "local_data", None) is None # Forked data shouldn't leak to global + assert node_b_local.prep_mock.call_count == 1 + assert node_c_local.prep_mock.call_count == 0 + assert getattr(memory_b, "local_data", None) is None assert getattr(memory_b, "common_local", None) is None - # Trigger C with different local data - branching_node.set_trigger("path_C", {"local_data": "for_C", "common_local": "common"}) - flow_c = Flow(branching_node) # New flow to reset visits - memory_c = Memory.create({"global_val": 1}) + # For path C, use a new branching_node instance or reset mocks for clarity + BaseNode._next_id = 0 + branching_node_for_c = BranchingNode("BranchC") # id 0 + # node_b_local and node_c_local are not reused here to avoid mock call count confusion + node_b_for_c_path = BaseTestNode("B_for_C_Path") # id 1 + node_c_for_c_path = BaseTestNode("C_for_C_Path") # id 2 + node_c_for_c_path.prep_mock.side_effect = check_c_memory + + + branching_node_for_c.on("path_B", node_b_for_c_path) + branching_node_for_c.on("path_C", node_c_for_c_path) + branching_node_for_c.set_trigger("path_C", {"local_data": "for_C", "common_local": "common"}) + + flow_c = Flow(branching_node_for_c) + memory_c = Memory({"global_val": 1}) await flow_c.run(memory_c) - assert nodes["B"].prep_mock.call_count == 1 # Still called once from previous run - assert nodes["C"].prep_mock.call_count == 1 + assert node_c_for_c_path.prep_mock.call_count == 1 assert getattr(memory_c, "local_data", None) is None assert getattr(memory_c, "common_local", None) is None @@ -227,40 +270,34 @@ class TestCycleDetection: async def test_execute_loop_maxvisits_times_before_error(self, nodes): """Should execute a loop exactly maxVisits times before error.""" - loop_count = [0] # Using a list for mutable closure + loop_count = [0] - # Setup mock to increment count async def increment_count(mem): loop_count[0] += 1 mem.count = loop_count[0] nodes["A"].prep_mock.side_effect = increment_count - nodes["A"].next(nodes["A"]) # A -> A loop + nodes["A"].next(nodes["A"]) max_visits = 3 flow = Flow(nodes["A"], {"max_visits": max_visits}) + loop_memory = Memory({}) - # Use a fresh memory for this test - loop_memory = Memory.create({}) - - # Should raise exception when max_visits is exceeded - with pytest.raises(Exception, match=f"Maximum cycle count reached.*{max_visits}"): + with pytest.raises(AssertionError, match=f"Maximum cycle count \\({max_visits}\\) reached for {nodes['A'].__class__.__name__}#{nodes['A']._node_order}"): await flow.run(loop_memory) - # Verify counts assert loop_count[0] == max_visits assert loop_memory.count == max_visits async def test_error_immediately_if_loop_exceeds_maxvisits(self, nodes): """Should throw error immediately if loop exceeds max_visits (e.g. max_visits=2).""" - nodes["A"].next(nodes["A"]) # A -> A loop + nodes["A"].next(nodes["A"]) max_visits = 2 flow = Flow(nodes["A"], {"max_visits": max_visits}) + loop_memory = Memory({}) - loop_memory = Memory.create({}) - - with pytest.raises(Exception, match=f"Maximum cycle count reached.*{max_visits}"): + with pytest.raises(AssertionError, match=f"Maximum cycle count \\({max_visits}\\) reached for {nodes['A'].__class__.__name__}#{nodes['A']._node_order}"): await flow.run(loop_memory) class TestFlowAsNode: @@ -268,24 +305,20 @@ class TestFlowAsNode: async def test_execute_nested_flow_as_single_node_step(self, nodes, memory): """Should execute a nested flow as a single node step.""" - # Sub-flow: B -> C nodes["B"].next(nodes["C"]) sub_flow = Flow(nodes["B"]) - # Main flow: A -> subFlow -> D nodes["A"].next(sub_flow) - sub_flow.next(nodes["D"]) # Connect subFlow's exit to D + sub_flow.next(nodes["D"]) main_flow = Flow(nodes["A"]) await main_flow.run(memory) - # Check execution order assert nodes["A"].post_mock.call_count == 1 - assert nodes["B"].post_mock.call_count == 1 # B ran inside subFlow - assert nodes["C"].post_mock.call_count == 1 # C ran inside subFlow - assert nodes["D"].post_mock.call_count == 1 # D ran after subFlow + assert nodes["B"].post_mock.call_count == 1 + assert nodes["C"].post_mock.call_count == 1 + assert nodes["D"].post_mock.call_count == 1 - # Check memory state assert memory.post_A is True assert memory.post_B is True assert memory.post_C is True @@ -296,18 +329,17 @@ async def test_nested_flow_prep_post_wrap_subflow_execution(self, nodes, memory) nodes["B"].next(nodes["C"]) sub_flow = Flow(nodes["B"]) - # Add prep/post to the subFlow sub_flow.prep = AsyncMock() sub_flow.post = AsyncMock() - # Setup mock side effects to modify memory async def subflow_prep(mem): mem.subflow_prep = True - return None + return None # PrepResultT for Flow is not used by its exec_runner async def subflow_post(mem, prep_res, exec_res): mem.subflow_post = True - return None + # exec_res here is the ExecutionTree from the sub_flow's execution + return None sub_flow.prep.side_effect = subflow_prep sub_flow.post.side_effect = subflow_post @@ -318,105 +350,184 @@ async def subflow_post(mem, prep_res, exec_res): await main_flow.run(memory) assert memory.subflow_prep is True - assert memory.post_B is True # Inner nodes ran + assert memory.post_B is True assert memory.post_C is True assert memory.subflow_post is True - assert memory.post_D is True # D ran after subflow post + assert memory.post_D is True assert sub_flow.prep.call_count == 1 - assert sub_flow.post.call_count == 1 + sub_flow.post.assert_called_once() # Check it was called + # To check args: sub_flow.post.assert_called_with(memory, None, ANY) # prep_res is None, exec_res is the log + + async def test_nested_flow_propagates_terminal_action_to_parent_flow(self, memory): + """Should propagate a terminal action from a sub-flow to the parent flow.""" + BaseNode._next_id = 0 + main_start_node = BaseTestNode("MainStart") # id 0 + sub_node_a = BaseTestNode("SubA") # id 1 + + sub_node_b = BranchingNode("SubB") # id 2 + sub_node_b.set_trigger("sub_flow_completed") + + main_end_node = BaseTestNode("MainEnd") # id 3 + + sub_node_a.next(sub_node_b) + sub_flow = Flow(start=sub_node_a) # id 4 (Flow itself is a BaseNode) + + main_start_node.next(sub_flow) + sub_flow.on("sub_flow_completed", main_end_node) + + main_flow = Flow(start=main_start_node) # id 5 + await main_flow.run(memory) + + assert memory["post_MainStart"] is True + assert memory["post_SubA"] is True + assert memory["post_SubB"] is True + assert memory["post_MainEnd"] is True + + main_start_node.post_mock.assert_called_once() + sub_node_a.post_mock.assert_called_once() + sub_node_b.post_mock.assert_called_once() + main_end_node.post_mock.assert_called_once() class TestResultAggregation: - """Tests for result aggregation.""" + """Tests for result aggregation using ExecutionTree.""" async def test_return_correct_nested_actions_structure_for_simple_flow(self, nodes, memory): - """Should return correct structure for a simple flow.""" - nodes["A"].next(nodes["B"]) # A -> B + """Should return correct structure for a simple flow A -> B.""" + node_a = nodes["A"] + node_b = nodes["B"] + node_a.next(node_b) - flow = Flow(nodes["A"]) + flow = Flow(node_a) result = await flow.run(memory) expected = { - # Results from node A triggering default - DEFAULT_ACTION: [ - { # Results from node B triggering default - DEFAULT_ACTION: [] # Terminal node - } - ] + 'order': str(node_a._node_order), + 'type': node_a.__class__.__name__, + 'triggered': { + DEFAULT_ACTION: [ + { + 'order': str(node_b._node_order), + 'type': node_b.__class__.__name__, + 'triggered': {DEFAULT_ACTION: []} # Node B is terminal + } + ] + } } - assert result == expected async def test_return_correct_structure_for_branching_flow(self, nodes): """Should return correct structure for branching flow.""" - branching_node = BranchingNode("Branch") - branching_node.on("path_B", nodes["B"]) # Branch -> B on path_B - branching_node.on("path_C", nodes["C"]) # Branch -> C on path_C - nodes["B"].next(nodes["D"]) # B -> D + # Reset BaseNode ID for predictable IDs + BaseNode._next_id = 0 + branching_node = BranchingNode("Branch") # id 0 + node_b = BaseTestNode("B_local_branch") # id 1 + node_c = BaseTestNode("C_local_branch") # id 2 + node_d = BaseTestNode("D_local_branch") # id 3 + + branching_node.on("path_B", node_b) + branching_node.on("path_C", node_c) + node_b.next(node_d) - # Trigger path B + # Test path B: Branch -> B -> D branching_node.set_trigger("path_B") flow_b = Flow(branching_node) - result_b = await flow_b.run(Memory.create({})) + result_b = await flow_b.run(Memory({})) expected_b = { - "path_B": [ - { # Results from Branch triggering path_B - DEFAULT_ACTION: [ - { # Results from B triggering default - DEFAULT_ACTION: [] # Terminal node D + 'order': str(branching_node._node_order), + 'type': branching_node.__class__.__name__, + 'triggered': { + "path_B": [ + { + 'order': str(node_b._node_order), + 'type': node_b.__class__.__name__, + 'triggered': { + DEFAULT_ACTION: [ + { + 'order': str(node_d._node_order), + 'type': node_d.__class__.__name__, + 'triggered': {DEFAULT_ACTION: []} # Node D is terminal + } + ] } - ] - } - ] + } + ] + } } - assert result_b == expected_b - # Trigger path C - branching_node.set_trigger("path_C") - flow_c = Flow(branching_node) # Reset flow visits - result_c = await flow_c.run(Memory.create({})) + # Test path C: Branch -> C + # Need to use a new branching_node or flow to reset visit counts + BaseNode._next_id = 0 + branching_node_c_path = BranchingNode("BranchCPath") # id 0 + node_b_c_path = BaseTestNode("B_CPath") # id 1 (unused for path C trigger) + node_c_c_path = BaseTestNode("C_CPath") # id 2 + node_d_c_path = BaseTestNode("D_CPath") # id 3 (unused for path C trigger) + + branching_node_c_path.on("path_B", node_b_c_path) + branching_node_c_path.on("path_C", node_c_c_path) + # node_b_c_path.next(node_d_c_path) # Not relevant for path C test + + branching_node_c_path.set_trigger("path_C") + flow_c = Flow(branching_node_c_path) + result_c = await flow_c.run(Memory({})) expected_c = { - "path_C": [ - { # Results from Branch triggering path_C - DEFAULT_ACTION: [] # Terminal node C - } - ] + 'order': str(branching_node_c_path._node_order), + 'type': branching_node_c_path.__class__.__name__, + 'triggered': { + "path_C": [ + { + 'order': str(node_c_c_path._node_order), + 'type': node_c_c_path.__class__.__name__, + 'triggered': {DEFAULT_ACTION: []} # Node C is terminal + } + ] + } } - assert result_c == expected_c async def test_return_correct_structure_for_multi_trigger(self, nodes, memory): """Should return correct structure for multi-trigger (fan-out).""" - class MultiTrigger(BaseTestNode): + BaseNode._next_id = 0 + class MultiTrigger(BaseTestNode): # Inherits BaseTestNode, so uses its _node_order async def post(self, memory, prep_res, exec_res): await super().post(memory, prep_res, exec_res) self.trigger("out1") self.trigger("out2") - multi_node = MultiTrigger("Multi") - multi_node.on("out1", nodes["B"]) # Multi -> B on out1 - multi_node.on("out2", nodes["C"]) # Multi -> C on out2 + multi_node = MultiTrigger("Multi") # id 0 + node_b = BaseTestNode("B_multi") # id 1 + node_c = BaseTestNode("C_multi") # id 2 + + multi_node.on("out1", node_b) + multi_node.on("out2", node_c) flow = Flow(multi_node) result = await flow.run(memory) - expected = { + expected_triggered = { "out1": [ - { # Results from Multi triggering out1 - DEFAULT_ACTION: [] # Terminal node B + { + 'order': str(node_b._node_order), + 'type': node_b.__class__.__name__, + 'triggered': {DEFAULT_ACTION: []} # Node B is terminal } ], "out2": [ - { # Results from Multi triggering out2 - DEFAULT_ACTION: [] # Terminal node C + { + 'order': str(node_c._node_order), + 'type': node_c.__class__.__name__, + 'triggered': {DEFAULT_ACTION: []} # Node C is terminal } ] } - # Verify structure without being sensitive to key order - assert set(result.keys()) == {"out1", "out2"} - assert result["out1"] == expected["out1"] - assert result["out2"] == expected["out2"] + assert result['order'] == str(multi_node._node_order) + assert result['type'] == multi_node.__class__.__name__ + assert result['triggered'] is not None + assert set(result['triggered'].keys()) == {"out1", "out2"} + assert result['triggered']["out1"] == expected_triggered["out1"] + assert result['triggered']["out2"] == expected_triggered["out2"] + diff --git a/python/tests/test_memory.py b/python/tests/test_memory.py index ce09cdd..c4a71c1 100644 --- a/python/tests/test_memory.py +++ b/python/tests/test_memory.py @@ -3,26 +3,25 @@ class TestMemory: """Tests for the Memory class.""" - class TestInitialization: """Tests for Memory initialization.""" def test_initialize_with_global_store_only(self): """Should initialize with global store only.""" global_store = {"g1": "global1"} - memory = Memory.create(global_store) + memory = Memory(global_store) assert memory.g1 == "global1", "Should access global property" - assert memory.local == {}, "Local store should be empty" + assert memory.local == memory._local == {}, "Local store should be empty" def test_initialize_with_global_and_local_stores(self): """Should initialize with global and local stores.""" global_store = {"g1": "global1", "common": "global_common"} local_store = {"l1": "local1", "common": "local_common"} - memory = Memory.create(global_store, local_store) + memory = Memory(global_store, local_store) assert memory.g1 == "global1", "Should access global property" assert memory.l1 == "local1", "Should access local property" assert memory.common == "local_common", "Local should shadow global" - assert memory.local == {"l1": "local1", "common": "local_common"}, "Local store should contain initial local data" + assert memory.local == memory._local == {"l1": "local1", "common": "local_common"}, "Local store should contain initial local data" class TestProxyBehaviorReading: """Tests for Memory proxy reading behavior.""" @@ -32,7 +31,7 @@ def memory(self): """Create a memory instance with both global and local stores.""" global_store = {"g1": "global1", "common": "global_common"} local_store = {"l1": "local1", "common": "local_common"} - return Memory.create(global_store, local_store) + return Memory(global_store, local_store) def test_read_from_local_store_first(self, memory): """Should read from local store first.""" @@ -43,14 +42,17 @@ def test_fall_back_to_global_store_if_not_in_local(self, memory): """Should fall back to global store if property not in local.""" assert memory.g1 == "global1" - def test_return_none_if_property_exists_in_neither_store(self, memory): - """Should raise AttributeError if property exists in neither store.""" - with pytest.raises(AttributeError, match="'Memory' object has no attribute 'non_existent'"): + def test_return_appropriate_error_if_property_exists_in_neither_store(self, memory): + """Should raise AttributeError for attribute access and KeyError for item access if property exists in neither store.""" + with pytest.raises(AttributeError, match="Key 'non_existent' not found in stores"): _ = memory.non_existent + with pytest.raises(KeyError, match="Key 'non_existent_item' not found in stores"): + _ = memory["non_existent_item"] + def test_correctly_access_the_local_property(self, memory): """Should correctly access the local property.""" - assert memory.local == {"l1": "local1", "common": "local_common"} + assert memory.local == memory._local == {"l1": "local1", "common": "local_common"} class TestProxyBehaviorWriting: """Tests for Memory proxy writing behavior.""" @@ -60,7 +62,7 @@ def memory(self): """Create a memory instance with both global and local stores.""" self.global_store = {"g1": "global1", "common": "global_common"} self.local_store = {"l1": "local1", "common": "local_common"} - return Memory.create(self.global_store, self.local_store) + return Memory(self.global_store, self.local_store) def test_write_property_to_global_store_by_default(self, memory): """Should write property to global store by default.""" @@ -83,14 +85,13 @@ def test_remove_property_from_local_store_when_writing_globally(self, memory): assert self.global_store["common"] == "updated_common_globally", "Global store should be updated" assert "common" not in self.local_store, "Property should be removed from local store" assert "common" not in memory.local, "Accessing via memory.local should also show removal" + assert "common" not in memory._local, "Accessing via memory._local should also show removal" def test_throw_error_when_attempting_to_set_reserved_properties(self, memory): """Should throw error when attempting to set reserved properties.""" with pytest.raises(Exception, match="Reserved property 'global' cannot be set"): - # Use the exact reserved name 'global' setattr(memory, 'global', {}) with pytest.raises(Exception, match="Reserved property 'local' cannot be set"): - # Use the exact reserved name 'local' setattr(memory, 'local', {}) with pytest.raises(Exception, match="Reserved property '_global' cannot be set"): memory._global = {} @@ -113,7 +114,7 @@ def memory_setup(self): "common": "local_common", "nested_l": {"val": 2} } - self.memory = Memory.create(self.global_store, self.local_store) + self.memory = Memory(self.global_store, self.local_store) return self.memory def test_create_new_memory_instance_with_shared_global_store(self, memory_setup): @@ -135,24 +136,20 @@ def test_create_deep_clone_of_local_store(self, memory_setup): cloned_memory = memory_setup.clone() # Verify local store is not shared by reference - assert cloned_memory.local is not memory_setup.local, "Local store reference should NOT be shared" - assert cloned_memory.local == self.local_store, "Cloned local store should have same values initially" - - # Modify local via original, check clone - memory_setup.local["l1"] = "modified_local_original" # Modify original's internal local store + assert (cloned_memory.local is not memory_setup.local) and (cloned_memory._local is not memory_setup._local), "Local store reference should NOT be shared" + assert cloned_memory.local == cloned_memory._local == self.local_store, "Cloned local store should have same values initially" - # Read from the clone. Since its local store is independent, it should still find 'l1' locally. + memory_setup.local["l1"] = "modified_local_original" assert cloned_memory.l1 == "local1", "Clone local property should be unaffected by original local changes" - assert cloned_memory.local["l1"] == "local1", "Clone local store internal value should be unchanged" + assert cloned_memory.local["l1"] == cloned_memory._local["l1"] == "local1", "Clone local store internal value should be unchanged" # Modify local via clone, check original cloned_memory.local["l2"] = "added_via_clone_local" # Accessing l2 on the original should raise AttributeError as it wasn't set globally or locally there - with pytest.raises(AttributeError, match="'Memory' object has no attribute 'l2'"): + with pytest.raises(AttributeError, match="Key 'l2' not found in stores"): _ = memory_setup.l2 assert "l2" not in memory_setup.local, "Original local store internal value should be unchanged" - # Test nested objects assert cloned_memory.nested_l == {"val": 2} memory_setup.local["nested_l"]["val"] = 99 assert cloned_memory.nested_l == {"val": 2}, "Nested local object in clone should be unaffected" @@ -168,25 +165,180 @@ def test_correctly_merge_forking_data_into_new_local_store(self, memory_setup): assert cloned_memory.g1 == "global1", "Should still access global property" assert cloned_memory.nested_f == {"val": 3} - # Check internal local store state assert cloned_memory.local == { "l1": "local1", - "common": "forked_common", # Overwritten + "common": "forked_common", "nested_l": {"val": 2}, - "f1": "forked1", # Added - "nested_f": {"val": 3}, # Added + "f1": "forked1", + "nested_f": {"val": 3}, } - # Ensure forkingData was deep cloned forking_data["nested_f"]["val"] = 99 assert cloned_memory.nested_f == {"val": 3}, "Nested object in forked data should have been deep cloned" def test_handle_empty_forking_data(self, memory_setup): """Should handle empty forkingData.""" cloned_memory = memory_setup.clone({}) - assert cloned_memory.local == self.local_store + assert cloned_memory.local == cloned_memory._local == self.local_store def test_handle_cloning_without_forking_data(self, memory_setup): """Should handle cloning without forkingData.""" cloned_memory = memory_setup.clone() - assert cloned_memory.local == self.local_store + assert cloned_memory.local == cloned_memory._local == self.local_store + +class TestMemoryDeletion: + """Tests for the new Memory deletion functionalities.""" + + @pytest.fixture + def memory_for_deletion(self): + """Fixture to create a Memory instance with global and local values for deletion tests.""" + global_store = {"g_only": "global_val", "common_gl": "global_common", "g_shadowed": "global_shadow"} + local_store = {"l_only": "local_val", "common_gl": "local_common", "g_shadowed": "local_shadow_val"} + return Memory(global_store, local_store) + + # Tests for del memory.attr and del memory[key] + def test_delattr_on_memory_deletes_global_only_key(self, memory_for_deletion): + """del memory.attr should delete a key present only in the global store.""" + assert "g_only" in memory_for_deletion + del memory_for_deletion.g_only + + assert "g_only" not in memory_for_deletion + assert "g_only" not in memory_for_deletion._global + assert "g_only" not in memory_for_deletion._local + with pytest.raises(AttributeError): + _ = memory_for_deletion.g_only + + def test_delitem_on_memory_deletes_global_only_key(self, memory_for_deletion): + """del memory[key] should delete a key present only in the global store.""" + assert "g_only" in memory_for_deletion + del memory_for_deletion["g_only"] + + assert "g_only" not in memory_for_deletion + assert "g_only" not in memory_for_deletion._global + assert "g_only" not in memory_for_deletion._local + with pytest.raises(KeyError): + _ = memory_for_deletion["g_only"] + + def test_delattr_on_memory_deletes_local_only_key(self, memory_for_deletion): + """del memory.attr should delete a key present only in the local store.""" + # Note: current _delete_from_stores(key, self._global, self._local) means + # it tries global first, then local. So "l_only" is only in local. + assert "l_only" in memory_for_deletion + del memory_for_deletion.l_only + + assert "l_only" not in memory_for_deletion + assert "l_only" not in memory_for_deletion._global + assert "l_only" not in memory_for_deletion._local + with pytest.raises(AttributeError): + _ = memory_for_deletion.l_only + + def test_delattr_on_memory_deletes_key_from_both_stores_if_present_in_both(self, memory_for_deletion): + """del memory.attr should delete a key from both global and local if it exists in both.""" + assert memory_for_deletion.common_gl == "local_common" + assert "common_gl" in memory_for_deletion._global + assert "common_gl" in memory_for_deletion._local + + del memory_for_deletion.common_gl + + assert "common_gl" not in memory_for_deletion + assert "common_gl" not in memory_for_deletion._global + assert "common_gl" not in memory_for_deletion._local + with pytest.raises(AttributeError): + _ = memory_for_deletion.common_gl + + def test_delattr_on_memory_raises_keyerror_for_non_existent_key(self, memory_for_deletion): + """del memory.attr should raise KeyError for a non-existent key.""" + # __delattr__ uses _delete_from_stores which raises KeyError + with pytest.raises(KeyError, match="'non_existent_attr'"): + del memory_for_deletion.non_existent_attr + + def test_delitem_on_memory_raises_keyerror_for_non_existent_key(self, memory_for_deletion): + """del memory[key] should raise KeyError for a non-existent key.""" + with pytest.raises(KeyError, match="'non_existent_key'"): + del memory_for_deletion["non_existent_key"] + + # Tests for del memory.local.attr and del memory.local[key] + def test_delattr_on_local_proxy_deletes_from_local_store_only(self, memory_for_deletion): + """del memory.local.attr should delete a key only from the local store.""" + assert memory_for_deletion.local.l_only == "local_val" + + del memory_for_deletion.local.l_only + + assert "l_only" not in memory_for_deletion.local + assert "l_only" not in memory_for_deletion._local + with pytest.raises(AttributeError): + _ = memory_for_deletion.local.l_only + + assert "l_only" not in memory_for_deletion._global + + def test_delitem_on_local_proxy_deletes_from_local_store_only(self, memory_for_deletion): + """del memory.local[key] should delete a key only from the local store.""" + assert memory_for_deletion.local["l_only"] == "local_val" + + del memory_for_deletion.local["l_only"] + + assert "l_only" not in memory_for_deletion.local + assert "l_only" not in memory_for_deletion._local + with pytest.raises(KeyError): + _ = memory_for_deletion.local["l_only"] + + def test_delattr_on_local_proxy_does_not_affect_global_store(self, memory_for_deletion): + """del memory.local.attr should not affect the global store, even if key has same name.""" + # 'g_shadowed' is in both: global_store{"g_shadowed": "global_shadow"}, local_store{"g_shadowed": "local_shadow_val"} + assert memory_for_deletion.local.g_shadowed == "local_shadow_val" + assert memory_for_deletion._global["g_shadowed"] == "global_shadow" + + del memory_for_deletion.local.g_shadowed + + assert "g_shadowed" not in memory_for_deletion.local # Deleted from local + assert "g_shadowed" not in memory_for_deletion._local + + # Global store should be untouched + assert memory_for_deletion._global["g_shadowed"] == "global_shadow" + # Main memory object should now see the global value + assert memory_for_deletion.g_shadowed == "global_shadow" + + def test_delattr_on_local_proxy_unshadows_global_key(self, memory_for_deletion): + """del memory.local.attr on a shadowing key should make the global key visible via memory.attr.""" + # 'g_shadowed' is in both, local value shadows global + assert memory_for_deletion.g_shadowed == "local_shadow_val" # Accesses local via memory object + + del memory_for_deletion.local.g_shadowed # Delete from local proxy + + # Now, memory.g_shadowed should access the global value + assert "g_shadowed" not in memory_for_deletion.local + assert memory_for_deletion.g_shadowed == "global_shadow" + assert memory_for_deletion._global["g_shadowed"] == "global_shadow" + + def test_delattr_on_local_proxy_raises_keyerror_for_non_existent_key(self, memory_for_deletion): + """del memory.local.attr should raise KeyError for a non-existent key in local store.""" + # __delattr__ on LocalProxy uses _delete_from_stores which raises KeyError + with pytest.raises(KeyError, match="'non_existent_local_attr'"): + del memory_for_deletion.local.non_existent_local_attr + + def test_delitem_on_local_proxy_raises_keyerror_for_non_existent_key(self, memory_for_deletion): + """del memory.local[key] should raise KeyError for a non-existent key in local store.""" + with pytest.raises(KeyError, match="'non_existent_local_key'"): + del memory_for_deletion.local["non_existent_local_key"] + + def test_contains_check_after_deletions_on_memory(self, memory_for_deletion): + """Verify `in` operator behaves correctly after deletions on Memory instance.""" + memory_for_deletion._global["g_delete_in_test"] = 1 + memory_for_deletion._local["l_delete_in_test"] = 2 + + assert "g_delete_in_test" in memory_for_deletion + assert "l_delete_in_test" in memory_for_deletion + + del memory_for_deletion.g_delete_in_test + assert "g_delete_in_test" not in memory_for_deletion + + del memory_for_deletion.l_delete_in_test + assert "l_delete_in_test" not in memory_for_deletion + + def test_contains_check_after_deletions_on_local_proxy(self, memory_for_deletion): + """Verify `in` operator behaves correctly for LocalProxy after deletions.""" + memory_for_deletion.local["proxy_del_test"] = 3 + assert "proxy_del_test" in memory_for_deletion.local + + del memory_for_deletion.local.proxy_del_test + assert "proxy_del_test" not in memory_for_deletion.local diff --git a/python/tests/test_node.py b/python/tests/test_node.py index b8d4312..d8dbbb8 100644 --- a/python/tests/test_node.py +++ b/python/tests/test_node.py @@ -58,7 +58,7 @@ class TestBaseNodeAndNode: def memory(self): """Create a test memory instance.""" global_store = {"initial": "global"} - return Memory.create(global_store) + return Memory(global_store) class TestLifecycleMethods: """Tests for node lifecycle methods (prep, exec, post).""" @@ -235,7 +235,7 @@ async def test_store_triggers_internally_via_trigger(self, memory): assert triggered_memory.key == "value" # Check forking_data applied locally assert triggered_memory.local["key"] == "value" # Original memory should not have 'key' - with pytest.raises(AttributeError, match="'Memory' object has no attribute 'key'"): + with pytest.raises(AttributeError, match="Key 'key' not found in stores"): _ = memory.key async def test_trigger_throws_error_if_called_outside_post(self, memory): @@ -337,15 +337,6 @@ async def test_run_with_propagate_true_returns_list_triggers_result(self, memory assert triggers[0][0] == "test_action" assert isinstance(triggers[0][1], Memory) - async def test_run_warns_if_called_on_node_with_successors(self, memory): - """run() should warn if called on a node with successors.""" - node_a = SimpleNode() - node_b = SimpleNode() - - node_a.next(node_b) - with pytest.warns(UserWarning, match="won't run successors"): - await node_a.run(memory) - async def test_run_accepts_global_store_directly(self): """run() should accept global store directly.""" node = SimpleNode() @@ -435,7 +426,7 @@ async def test_retry_exec_on_failure(self): """exec should be retried max_retries-1 times upon failure.""" node = ErrorNode(max_retries=3, succeed_after=2) # Succeed on 3rd attempt (after 2 failures) - result = await node.run(Memory.create({})) + result = await node.run(Memory({})) assert result == "success_after_retry" assert node.fail_count == 2 # Should have failed twice @@ -452,7 +443,7 @@ async def mock_fallback(prep_res, error): node.exec_fallback = mock_fallback - result = await node.run(Memory.create({})) + result = await node.run(Memory({})) assert result == "fallback_called" assert node.fail_count == 2 # Should have failed max_retries times @@ -469,7 +460,7 @@ async def mock_fallback(prep_res, error): node.exec_fallback = mock_fallback - result = await node.run(Memory.create({})) + result = await node.run(Memory({})) end_time = asyncio.get_event_loop().time() elapsed = end_time - start_time @@ -489,4 +480,4 @@ async def mock_fallback(prep_res, error): node.exec_fallback = mock_fallback with pytest.raises(ValueError, match="Fallback error"): - await node.run(Memory.create({})) + await node.run(Memory({})) diff --git a/python/tests/test_parallel_flow.py b/python/tests/test_parallel_flow.py index 064a46b..b975de8 100644 --- a/python/tests/test_parallel_flow.py +++ b/python/tests/test_parallel_flow.py @@ -2,7 +2,7 @@ import asyncio import time from unittest.mock import Mock, AsyncMock -from brainyflow import Memory, Node, Flow, ParallelFlow, DEFAULT_ACTION +from brainyflow import Memory, Node, Flow, ParallelFlow, DEFAULT_ACTION, BaseNode, ExecutionTree # Helper sleep function for async tests async def async_sleep(seconds: float): @@ -20,7 +20,6 @@ def __init__(self, id_str): self.next_node_delay = None async def prep(self, memory): - # Read delay from local memory (passed via forkingData) delay = getattr(memory, 'delay', 0) memory[f"prep_start_{self.id}_{getattr(memory, 'id', 'main')}"] = time.time() await self.prep_mock(memory) @@ -36,10 +35,10 @@ async def post(self, memory, prep_res, exec_res): memory[f"post_{self.id}_{getattr(memory, 'id', 'main')}"] = exec_res memory[f"prep_end_{self.id}_{getattr(memory, 'id', 'main')}"] = time.time() - # Trigger default successor, passing the intended delay for the *next* node if set if self.next_node_delay is not None: self.trigger(DEFAULT_ACTION, {"delay": self.next_node_delay, "id": getattr(memory, 'id', None)}) else: + # Even if no specific forking_data for next node, pass the current branch ID self.trigger(DEFAULT_ACTION, {"id": getattr(memory, 'id', None)}) class MultiTriggerNode(Node): @@ -63,102 +62,116 @@ class TestParallelFlow: @pytest.fixture def setup(self): """Create test nodes and memory.""" + BaseNode._next_id = 0 # Reset for predictable IDs global_store = {"initial": "global"} - memory = Memory.create(global_store) - trigger_node = MultiTriggerNode() - node_b = DelayedNode("B") - node_c = DelayedNode("C") - node_d = DelayedNode("D") # For testing sequential after parallel + memory_instance = Memory(global_store) + trigger_node_instance = MultiTriggerNode() # id 0 + node_b_instance = DelayedNode("B") # id 1 + node_c_instance = DelayedNode("C") # id 2 + node_d_instance = DelayedNode("D") # id 3 return { - "memory": memory, + "memory": memory_instance, "global_store": global_store, - "trigger_node": trigger_node, - "node_b": node_b, - "node_c": node_c, - "node_d": node_d + "trigger_node": trigger_node_instance, + "node_b": node_b_instance, + "node_c": node_c_instance, + "node_d": node_d_instance } @pytest.mark.asyncio async def test_execute_triggered_branches_concurrently(self, setup): """Should execute triggered branches concurrently using run_tasks override.""" - delay_b = 0.05 # 50ms - delay_c = 0.06 # 60ms + delay_b = 0.05 + delay_c = 0.06 - # Setup: TriggerNode fans out to B and C with different delays using distinct actions - setup["trigger_node"].add_trigger("process_b", {"id": "B", "delay": delay_b}) - setup["trigger_node"].add_trigger("process_c", {"id": "C", "delay": delay_c}) - setup["trigger_node"].on("process_b", setup["node_b"]) - setup["trigger_node"].on("process_c", setup["node_c"]) + trigger_node = setup["trigger_node"] + node_b = setup["node_b"] + node_c = setup["node_c"] + + trigger_node.add_trigger("process_b", {"id": "B", "delay": delay_b}) + trigger_node.add_trigger("process_c", {"id": "C", "delay": delay_c}) + trigger_node.on("process_b", node_b) + trigger_node.on("process_c", node_c) - parallel_flow = ParallelFlow(setup["trigger_node"]) + parallel_flow = ParallelFlow(trigger_node) start_time = time.time() result = await parallel_flow.run(setup["memory"]) end_time = time.time() duration = end_time - start_time - # --- Assertions --- - - # 1. Check total duration: Should be closer to max(delay_b, delay_c) than sum(delay_b, delay_c) max_delay = max(delay_b, delay_c) sum_delay = delay_b + delay_c print(f"Execution Time: {duration}s (Max Delay: {max_delay}s, Sum Delay: {sum_delay}s)") - # Allow up to 100 ms of overhead and fuzzy‐match against max delay - assert duration < sum_delay + 0.1, f"Duration ({duration}s) should be less than sum ({sum_delay}s) with overhead" + assert duration < sum_delay + 0.1 assert duration == pytest.approx(max_delay, abs=0.1) - - # 2. Check if both nodes executed (via post-execution memory state) + assert setup["memory"][f"post_B_B"] == f"exec_B_slept_{delay_b}" assert setup["memory"][f"post_C_C"] == f"exec_C_slept_{delay_c}" - # 3. Check the aggregated result structure - assert result and isinstance(result, dict), "Result should be a dictionary" - assert "process_b" in result, "Result should contain 'process_b' key" - assert "process_c" in result, "Result should contain 'process_c' key" + assert result and isinstance(result, dict), "Result should be a dictionary (ExecutionTree)" + assert result['order'] == str(trigger_node._node_order) + assert result['type'] == trigger_node.__class__.__name__ - process_b_results = result["process_b"] - process_c_results = result["process_c"] + triggered = result['triggered'] + assert triggered is not None + assert "process_b" in triggered, "Result should contain 'process_b' key in triggered" + assert "process_c" in triggered, "Result should contain 'process_c' key in triggered" - assert isinstance(process_b_results, list) and len(process_b_results) == 1, "'process_b' should be a list with 1 result" - assert isinstance(process_c_results, list) and len(process_c_results) == 1, "'process_c' should be a list with 1 result" + process_b_results_list = triggered["process_b"] + process_c_results_list = triggered["process_c"] + + assert isinstance(process_b_results_list, list) and len(process_b_results_list) == 1 + assert isinstance(process_c_results_list, list) and len(process_c_results_list) == 1 + + # Check structure of individual node logs + expected_log_b: ExecutionTree = { + 'order': str(node_b._node_order), + 'type': node_b.__class__.__name__, + 'triggered': {DEFAULT_ACTION: []} # DelayedNode is terminal for this action + } + expected_log_c: ExecutionTree = { + 'order': str(node_c._node_order), + 'type': node_c.__class__.__name__, + 'triggered': {DEFAULT_ACTION: []} # DelayedNode is terminal for this action + } - # Check that both branches completed - assert process_b_results[0] == {DEFAULT_ACTION: []} - assert process_c_results[0] == {DEFAULT_ACTION: []} + assert process_b_results_list[0] == expected_log_b + assert process_c_results_list[0] == expected_log_c - # 4. Check total mock calls - assert setup["node_b"].exec_mock.call_count + setup["node_c"].exec_mock.call_count == 2, "Total exec calls across parallel nodes should be 2" + assert node_b.exec_mock.call_count + node_c.exec_mock.call_count == 2 @pytest.mark.asyncio async def test_handle_mix_of_parallel_and_sequential_execution(self, setup): """Should handle mix of parallel and sequential execution.""" - # A (MultiTrigger) -> [B (delay 50ms), C (delay 60ms)] -> D (delay 30ms) - delay_b = 0.05 # 50ms - delay_c = 0.06 # 60ms - delay_d = 0.03 # 30ms - - # Use distinct actions for parallel steps - setup["trigger_node"].add_trigger("parallel_b", {"id": "B", "delay": delay_b}) - setup["trigger_node"].add_trigger("parallel_c", {"id": "C", "delay": delay_c}) + delay_b = 0.05 + delay_c = 0.06 + delay_d = 0.03 + + trigger_node = setup["trigger_node"] + node_b = setup["node_b"] + node_c = setup["node_c"] + node_d = setup["node_d"] + + trigger_node.add_trigger("parallel_b", {"id": "B", "delay": delay_b}) + trigger_node.add_trigger("parallel_c", {"id": "C", "delay": delay_c}) - # Both parallel branches lead to D - setup["trigger_node"].on("parallel_b", setup["node_b"]) - setup["trigger_node"].on("parallel_c", setup["node_c"]) + trigger_node.on("parallel_b", node_b) + trigger_node.on("parallel_c", node_c) - setup["node_b"].next(setup["node_d"]) # B -> D - setup["node_c"].next(setup["node_d"]) # C -> D + node_b.next(node_d) + node_c.next(node_d) - # Set the delay that nodes B and C should pass to node D - setup["node_b"].next_node_delay = delay_d - setup["node_c"].next_node_delay = delay_d + node_b.next_node_delay = delay_d + node_c.next_node_delay = delay_d - parallel_flow = ParallelFlow(setup["trigger_node"]) + parallel_flow = ParallelFlow(trigger_node) start_time = time.time() - await parallel_flow.run(setup["memory"]) + result_mix = await parallel_flow.run(setup["memory"]) # Renamed result to result_mix end_time = time.time() duration = end_time - start_time @@ -166,16 +179,29 @@ async def test_handle_mix_of_parallel_and_sequential_execution(self, setup): print(f"Mixed Execution Time: {duration}s (Expected Min: ~{expected_min_duration}s)") - # Check completion assert setup["memory"][f"post_B_B"] == f"exec_B_slept_{delay_b}" assert setup["memory"][f"post_C_C"] == f"exec_C_slept_{delay_c}" - assert setup["memory"][f"post_D_B"] == f"exec_D_slept_{delay_d}" # D executed after B - assert setup["memory"][f"post_D_C"] == f"exec_D_slept_{delay_d}" # D executed after C + assert setup["memory"][f"post_D_B"] == f"exec_D_slept_{delay_d}" + assert setup["memory"][f"post_D_C"] == f"exec_D_slept_{delay_d}" - # Check timing: D should start only after its respective predecessor (B or C) finishes. - # The whole flow should take roughly max(delay_b, delay_c) + delay_d - assert duration >= expected_min_duration - 0.01, f"Duration ({duration}s) should be >= expected min ({expected_min_duration}s)" - assert duration < expected_min_duration + 0.1, f"Duration ({duration}s) should be reasonably close to expected min ({expected_min_duration}s)" + assert duration >= expected_min_duration - 0.02 # Allow small timing variance + assert duration < expected_min_duration + 0.1 + + assert node_d.exec_mock.call_count == 2 + + # Optionally, assert the structure of result_mix if needed + assert result_mix['order'] == str(trigger_node._node_order) + triggered_mix = result_mix['triggered'] + assert triggered_mix is not None - # Check D was executed twice (once for each incoming path) - assert setup["node_d"].exec_mock.call_count == 2 + path_b_log = triggered_mix['parallel_b'][0] + path_c_log = triggered_mix['parallel_c'][0] + + assert path_b_log['order'] == str(node_b._node_order) + assert path_b_log['triggered'][DEFAULT_ACTION][0]['order'] == str(node_d._node_order) + assert path_b_log['triggered'][DEFAULT_ACTION][0]['triggered'] == {DEFAULT_ACTION: []} + + assert path_c_log['order'] == str(node_c._node_order) + assert path_c_log['triggered'][DEFAULT_ACTION][0]['order'] == str(node_d._node_order) + assert path_c_log['triggered'][DEFAULT_ACTION][0]['triggered'] == {DEFAULT_ACTION: []} + diff --git a/typescript/brainyflow.ts b/typescript/brainyflow.ts index ba65607..ea0a9cc 100644 --- a/typescript/brainyflow.ts +++ b/typescript/brainyflow.ts @@ -1,66 +1,76 @@ export const DEFAULT_ACTION = 'default' as const -export type SharedStore = Record +export type SharedStore = Record type Action = string | typeof DEFAULT_ACTION type NestedActions = Record[]> +export type Memory = GlobalStore & + LocalStore & { + local: LocalStore + clone(forkingData?: T): Memory + _isMemoryObject: true + } + export type NodeError = Error & { retryCount?: number } -export class Memory { - constructor( - private __global: G, - private __local: L = {} as L, - ) {} +interface Trigger { + action: Action + forkingData: L +} - // Allow property access on this object to check local memory first, then global - [key: string]: any +function _get_from_stores(key: string | symbol, closer: SharedStore, further?: SharedStore): unknown { + if (key in closer) return Reflect.get(closer, key) + if (further && key in further) return Reflect.get(further, key) +} - clone(forkingData: T = {} as T): Memory { - return Memory.create(this.__global, { - ...structuredClone(this.__local), - ...structuredClone(forkingData), - }) - } +function _delete_from_stores(key: string | symbol, closer: SharedStore, further?: SharedStore): boolean { + // JS does not usually throw errors when key is not found. Otherwise uncomment: `if (!(key in closer) && (!further || !(key in further))) throw new Error(`Key '${key}' not found in store${further ? "s" : ""}`)` + let removed = false + if (key in closer) removed = Reflect.deleteProperty(closer, key) + if (further && key in further) removed = Reflect.deleteProperty(further, key) || removed + return removed +} - static create( - global: G, - local: L = {} as L, - ): Memory { - return new Proxy(new Memory(global, local), { - get: (target, prop) => { - // if (prop === 'setGlobal') return target.setGlobal.bind(target) - if (prop === 'clone') return target.clone.bind(target) - if (prop === 'local') return target.__local - - // Check local memory first, then fall back to global - if (prop in target.__local) { - return target.__local[prop as string] - } - return target.__global[prop as string] - }, - set: (target, prop, value) => { - if (['global', 'local', '__global', '__local'].includes(prop as string)) - throw new Error(`Reserved property '${String(prop)}' cannot be set`) - - // By default, set in global memory - if (typeof prop === 'string') { - delete target.__local[prop as string] - target.__global[prop as keyof G] = value - return true - } - // For internal properties, set on the target - ;(target as any)[prop] = value - return true - }, - }) +function createProxyHandler(closer: T, further?: SharedStore): ProxyHandler { + return { + get: (target, prop) => { + if (Reflect.has(target, prop)) return Reflect.get(target, prop) + return _get_from_stores(prop, closer, further) + }, + set: (target, prop, value) => { + if (target._isMemoryObject && prop in target) { + throw new Error(`Reserved property '${String(prop)}' cannot be set to ${target}`) + } + _delete_from_stores(prop, closer, further) + if (further) return Reflect.set(further, prop, value) + return Reflect.set(closer, prop, value) + }, + deleteProperty: (target, prop) => _delete_from_stores(prop, closer, further), + has: (target, prop) => Reflect.has(closer, prop) || (further ? Reflect.has(further, prop) : false), } } -interface Trigger { - action: Action - forkingData: L +export function createMemory( + global: GlobalStore, + local: LocalStore = {} as LocalStore, +): Memory { + const localProxy = new Proxy(local, createProxyHandler(local)) + const memory = new Proxy( + { + _isMemoryObject: true, + local: localProxy, + clone: (forkingData: T = {} as T): Memory => + createMemory(global, { + ...structuredClone(local), + ...structuredClone(forkingData), + }), + }, + createProxyHandler(localProxy, global), + ) + + return memory as Memory } export abstract class BaseNode< @@ -91,9 +101,7 @@ export abstract class BaseNode< cloned.successors.set( key, Symbol.iterator in value - ? value.map((node) => - node && typeof node.clone === 'function' ? node.clone(seen) : node, - ) + ? value.map((node) => (node && typeof node.clone === 'function' ? node.clone(seen) : node)) : value, ) } @@ -122,11 +130,7 @@ export abstract class BaseNode< async prep(memory: Memory): Promise {} async exec(prepRes: PrepResult | void): Promise {} - async post( - memory: Memory, - prepRes: PrepResult | void, - execRes: ExecResult | void, - ): Promise {} + async post(memory: Memory, prepRes: PrepResult | void, execRes: ExecResult | void): Promise {} /** * Trigger a child node with optional local memory @@ -144,9 +148,7 @@ export abstract class BaseNode< }) } - private listTriggers( - memory: Memory, - ): [AllowedActions[number], Memory][] { + private listTriggers(memory: Memory): [AllowedActions[number], Memory][] { if (!this.triggers.length) { return [[DEFAULT_ACTION, memory.clone()]] } @@ -154,19 +156,10 @@ export abstract class BaseNode< return this.triggers.map((t) => [t.action, memory.clone(t.forkingData)]) } - protected abstract execRunner( - memory: Memory, - prepRes: PrepResult | void, - ): Promise + protected abstract execRunner(memory: Memory, prepRes: PrepResult | void): Promise - async run( - memory: Memory | GlobalStore, - propagate?: false, - ): Promise> - async run( - memory: Memory | GlobalStore, - propagate: true, - ): Promise> + async run(memory: Memory | GlobalStore, propagate?: false): Promise> + async run(memory: Memory | GlobalStore, propagate: true): Promise> async run( memory: Memory | GlobalStore, propagate?: boolean, @@ -175,8 +168,9 @@ export abstract class BaseNode< console.warn("Node won't run successors. Use Flow!") } - const _memory: Memory = - memory instanceof Memory ? memory : Memory.create(memory) + const _memory: Memory = memory._isMemoryObject + ? (memory as Memory) + : createMemory(memory as GlobalStore) this.triggers = [] const prepRes = await this.prep(_memory) @@ -201,8 +195,8 @@ class RetryNode< ExecResult = any, > extends BaseNode { private curRetry = 0 - private maxRetries: number = 1 - private wait: number = 0 + private maxRetries = 1 + private wait = 0 constructor(options: { maxRetries?: number; wait?: number } = {}) { super() @@ -215,10 +209,7 @@ class RetryNode< throw error } - protected async execRunner( - memory: Memory, - prepRes: PrepResult, - ): Promise { + protected async execRunner(memory: Memory, prepRes: PrepResult): Promise { for (this.curRetry = 0; this.curRetry < this.maxRetries; this.curRetry++) { try { return await this.exec(prepRes) @@ -229,7 +220,6 @@ class RetryNode< } continue } - ;(error as NodeError).retryCount = this.curRetry return await this.execFallback(prepRes, error as NodeError) } @@ -240,15 +230,18 @@ class RetryNode< export const Node = RetryNode -export class Flow< - GlobalStore extends SharedStore = SharedStore, - AllowedActions extends Action[] = Action[], -> extends BaseNode> { +export class Flow extends BaseNode< + GlobalStore, + SharedStore, + AllowedActions, + void, + NestedActions +> { private visitCounts: Map = new Map() constructor( public start: BaseNode, - private options: { maxVisits: number } = { maxVisits: 5 }, + private options: { maxVisits: number } = { maxVisits: 15 }, ) { super() } @@ -257,9 +250,7 @@ export class Flow< throw new Error('This method should never be called in a Flow') } - protected async execRunner( - memory: Memory, - ): Promise> { + protected async execRunner(memory: Memory): Promise> { return await this.runNode(this.start, memory) } @@ -271,23 +262,15 @@ export class Flow< return res } - private async runNodes( - nodes: BaseNode[], - memory: Memory, - ): Promise[]> { + private async runNodes(nodes: BaseNode[], memory: Memory): Promise[]> { return await this.runTasks(nodes.map((node) => () => this.runNode(node, memory))) } - private async runNode( - node: BaseNode, - memory: Memory, - ): Promise> { + private async runNode(node: BaseNode, memory: Memory): Promise> { const nodeId = node.__nodeOrder.toString() const currentVisitCount = (this.visitCounts.get(nodeId) || 0) + 1 if (currentVisitCount > this.options.maxVisits) { - throw new Error( - `Maximum cycle count reached (${this.options.maxVisits}) for ${nodeId}.${node.constructor.name}`, - ) + throw new Error(`Maximum cycle count (${this.options.maxVisits}) reached for ${node.constructor.name}#${nodeId}`) } this.visitCounts.set(nodeId, currentVisitCount) @@ -299,12 +282,7 @@ export class Flow< const tasks = triggers.map(([action, nodeMemory]) => async () => { const nextNodes = clone.getNextNodes(action) - return [ - action, - !nextNodes.length - ? [] - : await this.runNodes(nextNodes, nodeMemory as Memory), - ] + return [action, !nextNodes.length ? [] : await this.runNodes(nextNodes, nodeMemory as Memory)] }) const tree = await this.runTasks(tasks) @@ -312,23 +290,18 @@ export class Flow< } } -export class ParallelFlow< - GlobalStore extends SharedStore = SharedStore, - AllowedActions extends Action[] = Action[], -> extends Flow { +export class ParallelFlow extends Flow< + GlobalStore, + AllowedActions +> { async runTasks(tasks: (() => T)[]): Promise[]> { return await Promise.all(tasks.map((task) => task())) } } -// Make classes available globally in the browser +// Make classes available globally in the browser for UMD bundle // @ts-ignore if (typeof window !== 'undefined' && !globalThis.brainyflow) { // @ts-ignore - globalThis.brainyflow = { - BaseNode, - Node, - Flow, - ParallelFlow, - } + globalThis.brainyflow = { Memory, BaseNode, Node, Flow, ParallelFlow } } diff --git a/typescript/design.md b/typescript/design.md index 5c07d10..93e2b67 100644 --- a/typescript/design.md +++ b/typescript/design.md @@ -33,7 +33,6 @@ The library is built around several key abstractions: - **Get:** Checks the `__local` store first. If the property is not found, it checks the `__global` store. Special properties like `clone` and `local` (accessing `__local`) are handled directly. - **Set:** Writes properties directly to the `__global` store by default after ensuring the property is removed from the `__local` store. Protects reserved property names (`global`, `local`, `__global`, `__local`). - **Cloning (`clone(forkingData?)`):** Creates a _new_ `Memory` instance wrapped in a `Proxy`. The `__global` store reference is shared, but the `__local` store is _deep-cloned_ using `structuredClone`. Optional `forkingData` is merged into the new local store using `structuredClone` as well. This is crucial for state isolation when branching in a `Flow`. -- **Creation (`Memory.create(global, local?)`):** A static factory method to instantiate `Memory` objects wrapped in the necessary `Proxy`. ### 3.2. `BaseNode` Abstract Class @@ -78,7 +77,7 @@ The library is built around several key abstractions: - **Purpose:** Orchestrates the execution of a graph of nodes sequentially, managing state and preventing infinite loops. - **Inheritance:** Extends `BaseNode`. - **Initialization:** - - `constructor(start, options?)`: Requires the starting `BaseNode` of the workflow and accepts optional `options` like `maxVisits` (default 5) for cycle detection. + - `constructor(start, options?)`: Requires the starting `BaseNode` of the workflow and accepts optional `options` like `maxVisits` (default 15) for cycle detection. - **Properties:** - `start`: The entry point node of the flow. - `visitCounts`: A `Map` to track how many times each node (identified by `__nodeOrder`) has been visited during a single `run` execution to detect cycles. diff --git a/typescript/package.json b/typescript/package.json index 7b4b6d0..3e7cf99 100644 --- a/typescript/package.json +++ b/typescript/package.json @@ -42,14 +42,14 @@ "release": "pnpm run build && changeset publish" }, "devDependencies": { - "@changesets/cli": "2.28.1", - "@std/assert": "npm:@jsr/std__assert@1.0.12", - "@types/node": "22.13.14", + "@changesets/cli": "2.29.4", + "@std/assert": "npm:@jsr/std__assert@1.0.13", + "@types/node": "22.15.21", "p-limit": "6.2.0", - "tsup": "8.4.0", - "typescript": "^5.8.2" + "tsup": "8.5.0", + "typescript": "5.8.3" }, - "packageManager": "pnpm@10.7.0", + "packageManager": "pnpm@10.11.0", "publishConfig": { "access": "public" } diff --git a/typescript/tests/design.md b/typescript/tests/design.md index 58ac0c9..68fc5bb 100644 --- a/typescript/tests/design.md +++ b/typescript/tests/design.md @@ -20,8 +20,6 @@ This document outlines the testing strategy for the `brainyflow.ts` library usin ### 3.1. `Memory` Class -- **Initialization:** - - `Memory.create()` correctly initializes global and optional local stores. - **Proxy Behavior (Reading):** - Reads property from local store if present. - Falls back to global store if property not in local store. diff --git a/typescript/tests/flow.test.ts b/typescript/tests/flow.test.ts index 1fc1f4b..6012790 100644 --- a/typescript/tests/flow.test.ts +++ b/typescript/tests/flow.test.ts @@ -1,6 +1,6 @@ import assert from 'node:assert/strict' import { beforeEach, describe, it, mock } from 'node:test' -import { DEFAULT_ACTION, Flow, Memory, Node } from '../brainyflow' +import { createMemory, DEFAULT_ACTION, Flow, Memory, Node } from '../brainyflow' // --- Helper Nodes --- class TestNode extends Node { @@ -54,7 +54,7 @@ describe('Flow Class', () => { beforeEach(() => { globalStore = { initial: 'global' } - memory = Memory.create(globalStore) + memory = createMemory(globalStore) nodeA = new TestNode('A') nodeB = new TestNode('B') nodeC = new TestNode('C') @@ -67,7 +67,7 @@ describe('Flow Class', () => { it('should store the start node and default options', () => { const flow = new Flow(nodeA) assert.strictEqual(flow.start, nodeA) - assert.deepStrictEqual((flow as any).options, { maxVisits: 5 }) + assert.deepStrictEqual((flow as any).options, { maxVisits: 15 }) }) it('should accept custom options', () => { @@ -126,7 +126,7 @@ describe('Flow Class', () => { // Test path B branchingNode.setTrigger('path_B') let flowB = new Flow(branchingNode) - let memoryB = Memory.create({}) + let memoryB = createMemory({}) await flowB.run(memoryB) assert.equal(memoryB.post_Branch, true) assert.equal(memoryB.post_B, true) @@ -135,7 +135,7 @@ describe('Flow Class', () => { // Test path C branchingNode.setTrigger('path_C') // Reset trigger let flowC = new Flow(branchingNode) // Recreate flow to reset visits - let memoryC = Memory.create({}) + let memoryC = createMemory({}) await flowC.run(memoryC) assert.equal(memoryC.post_Branch, true) assert.strictEqual(memoryC.post_B, undefined) @@ -176,9 +176,12 @@ describe('Flow Class', () => { }) // Trigger B with specific local data - branchingNode.setTrigger('path_B', { local_data: 'for_B', common_local: 'common' }) + branchingNode.setTrigger('path_B', { + local_data: 'for_B', + common_local: 'common', + }) let flowB = new Flow(branchingNode) - let memoryB = Memory.create({ global_val: 1 }) + let memoryB = createMemory({ global_val: 1 }) await flowB.run(memoryB) assert.equal(nodeB.prepMock.mock.calls.length, 1) assert.equal(nodeC.prepMock.mock.calls.length, 0) @@ -186,9 +189,12 @@ describe('Flow Class', () => { assert.strictEqual(memoryB.common_local, undefined) // Trigger C with different local data - branchingNode.setTrigger('path_C', { local_data: 'for_C', common_local: 'common' }) + branchingNode.setTrigger('path_C', { + local_data: 'for_C', + common_local: 'common', + }) let flowC = new Flow(branchingNode) // Recreate flow - let memoryC = Memory.create({ global_val: 1 }) + let memoryC = createMemory({ global_val: 1 }) await flowC.run(memoryC) assert.equal(nodeB.prepMock.mock.calls.length, 1) // Called once from previous run assert.equal(nodeC.prepMock.mock.calls.length, 1) // Called once now @@ -210,7 +216,7 @@ describe('Flow Class', () => { const flow = new Flow(nodeA, { maxVisits: maxVisitsAllowed }) // Use a fresh memory object for this specific test's state - const loopMemory = Memory.create<{ count?: number }>({}) + const loopMemory = createMemory<{ count?: number }>({}) // Expect rejection when the (maxVisits + 1)th execution is attempted await assert.rejects( async () => { @@ -218,42 +224,30 @@ describe('Flow Class', () => { await flow.run(loopMemory) // Run with the dedicated memory } catch (e) { // Assert state *inside* the catch block before re-throwing - assert.equal( - loopCount, - maxVisitsAllowed, - `Node should have executed exactly ${maxVisitsAllowed} times before error`, - ) - assert.equal( - loopMemory.count, - maxVisitsAllowed, - `Memory count should be ${maxVisitsAllowed} before error`, - ) + assert.equal(loopCount, maxVisitsAllowed, `Node should have executed exactly ${maxVisitsAllowed} times before error`) + assert.equal(loopMemory.count, maxVisitsAllowed, `Memory count should be ${maxVisitsAllowed} before error`) throw e // Re-throw for assert.rejects to catch } // If it doesn't throw (which it should), fail the test explicitly assert.fail('Flow should have rejected due to cycle limit, but did not.') }, - new RegExp(`Maximum cycle count reached \\(${maxVisitsAllowed}\\)`), + new RegExp(`Maximum cycle count \\(${maxVisitsAllowed}\\) reached`), 'Flow should reject when loop count exceeds maxVisits', ) // Final check on loopCount after rejection is confirmed - assert.equal( - loopCount, - maxVisitsAllowed, - `Node should have executed exactly ${maxVisitsAllowed} times (final check)`, - ) + assert.equal(loopCount, maxVisitsAllowed, `Node should have executed exactly ${maxVisitsAllowed} times (final check)`) }) it('should throw error immediately if loop exceeds maxVisits (e.g., maxVisits=2)', async () => { nodeA.next(nodeA) // A -> A loop const maxVisitsAllowed = 2 const flow = new Flow(nodeA, { maxVisits: maxVisitsAllowed }) - const loopMemory = Memory.create<{ count?: number }>({}) // Fresh memory + const loopMemory = createMemory<{ count?: number }>({}) // Fresh memory await assert.rejects( flow.run(loopMemory), - new RegExp(`Maximum cycle count reached \\(${maxVisitsAllowed}\\)`), // Check error message + new RegExp(`Maximum cycle count \\(${maxVisitsAllowed}\\) reached`), // Check error message 'Flow should reject when loop count exceeds maxVisits (maxVisits=2)', ) }) @@ -337,7 +331,7 @@ describe('Flow Class', () => { // Trigger path B branchingNode.setTrigger('path_B') let flowB = new Flow(branchingNode) - let resultB = await flowB.run(Memory.create({})) + let resultB = await flowB.run(createMemory({})) const expectedB = { path_B: [ @@ -357,7 +351,7 @@ describe('Flow Class', () => { // Trigger path C branchingNode.setTrigger('path_C') let flowC = new Flow(branchingNode) // Reset flow/visits - let resultC = await flowC.run(Memory.create({})) + let resultC = await flowC.run(createMemory({})) const expectedC = { path_C: [ diff --git a/typescript/tests/memory.test.ts b/typescript/tests/memory.test.ts index efa4295..1cb3736 100644 --- a/typescript/tests/memory.test.ts +++ b/typescript/tests/memory.test.ts @@ -1,12 +1,12 @@ import assert from 'node:assert/strict' import { beforeEach, describe, it } from 'node:test' // Import beforeEach -import { Memory } from '../brainyflow' +import { createMemory, Memory } from '../brainyflow' describe('Memory Class', () => { describe('Initialization', () => { it('should initialize with global store only', () => { const global = { g1: 'global1' } - const memory = Memory.create(global) + const memory = createMemory(global) assert.equal(memory.g1, 'global1', 'Should access global property') assert.deepStrictEqual(memory.local, {}, 'Local store should be empty') }) @@ -14,23 +14,19 @@ describe('Memory Class', () => { it('should initialize with global and local stores', () => { const global = { g1: 'global1', common: 'global_common' } const local = { l1: 'local1', common: 'local_common' } - const memory = Memory.create(global, local) + const memory = createMemory(global, local) assert.equal(memory.g1, 'global1', 'Should access global property') assert.equal(memory.l1, 'local1', 'Should access local property') assert.equal(memory.common, 'local_common', 'Local should shadow global') - assert.deepStrictEqual( - memory.local, - { l1: 'local1', common: 'local_common' }, - 'Local store should contain initial local data', - ) + assert.deepStrictEqual(memory.local, { l1: 'local1', common: 'local_common' }, 'Local store should contain initial local data') }) }) describe('Proxy Behavior (Reading)', () => { const global = { g1: 'global1', common: 'global_common' } const local = { l1: 'local1', common: 'local_common' } - const memory = Memory.create(global, local) + const memory = createMemory(global, local) it('should read from local store first', () => { assert.equal(memory.l1, 'local1') @@ -46,7 +42,10 @@ describe('Memory Class', () => { }) it('should correctly access the local property', () => { - assert.deepStrictEqual(memory.local, { l1: 'local1', common: 'local_common' }) + assert.deepStrictEqual(memory.local, { + l1: 'local1', + common: 'local_common', + }) }) }) @@ -58,7 +57,7 @@ describe('Memory Class', () => { beforeEach(() => { global = { g1: 'global1', common: 'global_common' } local = { l1: 'local1', common: 'local_common' } - memory = Memory.create(global, local) + memory = createMemory(global, local) }) it('should write property to global store by default', () => { @@ -80,17 +79,13 @@ describe('Memory Class', () => { assert.equal(memory.common, 'updated_common_globally', 'Should read the new global value') assert.equal(global.common, 'updated_common_globally', 'Global store should be updated') assert.strictEqual(local.common, undefined, 'Property should be removed from local store') - assert.strictEqual( - memory.local.common, - undefined, - 'Accessing via memory.local should also show removal', - ) + assert.strictEqual(memory.local.common, undefined, 'Accessing via memory.local should also show removal') }) it('should throw error when attempting to set reserved properties', () => { - assert.throws(() => (memory.global = {}), /Reserved property 'global' cannot be set/) + assert.throws(() => (memory._isMemoryObject = {}), /Reserved property '_isMemoryObject' cannot be set/) assert.throws(() => (memory.local = {}), /Reserved property 'local' cannot be set/) - // Cannot test setting __global and __local directly as they are private + assert.throws(() => (memory.clone = {}), /Reserved property 'clone' cannot be set/) }) }) @@ -103,10 +98,10 @@ describe('Memory Class', () => { beforeEach(() => { global = { g1: 'global1', common: 'global_common', nestedG: { val: 1 } } local = { l1: 'local1', common: 'local_common', nestedL: { val: 2 } } - memory = Memory.create(global, local) + memory = createMemory(global, local) }) - it('should create a new Memory instance with shared global store reference', () => { + it('should create a createMemory instance with shared global store reference', () => { clonedMemory = memory.clone() assert.notStrictEqual(clonedMemory, memory, 'Cloned memory should be a new instance') @@ -123,11 +118,7 @@ describe('Memory Class', () => { it('should create a deep clone of the local store', () => { clonedMemory = memory.clone() // Verify local store is not shared by reference - assert.notStrictEqual( - clonedMemory.local, - memory.local, - 'Local store reference should NOT be shared', - ) + assert.notStrictEqual(clonedMemory.local, memory.local, 'Local store reference should NOT be shared') assert.deepStrictEqual( clonedMemory.local, local, // Check initial values are the same @@ -150,37 +141,25 @@ describe('Memory Class', () => { // Modify local via clone, check original clonedMemory.local.l2 = 'added_via_clone_local' - assert.strictEqual( - memory.l2, - undefined, - 'Original should not see local changes from clone (reads undefined)', - ) - assert.strictEqual( - memory.local.l2, - undefined, - 'Original local store internal value should be unchanged', - ) + assert.strictEqual(memory.l2, undefined, 'Original should not see local changes from clone (reads undefined)') + assert.strictEqual(memory.local.l2, undefined, 'Original local store internal value should be unchanged') // Test nested objects assert.deepStrictEqual(clonedMemory.nestedL, { val: 2 }) memory.local.nestedL.val = 99 - assert.deepStrictEqual( - clonedMemory.nestedL, - { val: 2 }, - 'Nested local object in clone should be unaffected', - ) + assert.deepStrictEqual(clonedMemory.nestedL, { val: 2 }, 'Nested local object in clone should be unaffected') }) it('should correctly merge forkingData into the new local store', () => { - const forkingData = { f1: 'forked1', common: 'forked_common', nestedF: { val: 3 } } + const forkingData = { + f1: 'forked1', + common: 'forked_common', + nestedF: { val: 3 }, + } clonedMemory = memory.clone(forkingData) assert.equal(clonedMemory.f1, 'forked1', 'Should access forked property') - assert.equal( - clonedMemory.common, - 'forked_common', - 'Forked data should shadow original local and global', - ) + assert.equal(clonedMemory.common, 'forked_common', 'Forked data should shadow original local and global') assert.equal(clonedMemory.l1, 'local1', 'Should still access original local property') assert.equal(clonedMemory.g1, 'global1', 'Should still access global property') assert.deepStrictEqual(clonedMemory.nestedF, { val: 3 }) @@ -196,11 +175,7 @@ describe('Memory Class', () => { // Ensure forkingData was deep cloned forkingData.nestedF.val = 99 - assert.deepStrictEqual( - clonedMemory.nestedF, - { val: 3 }, - 'Nested object in forked data should have been deep cloned', - ) + assert.deepStrictEqual(clonedMemory.nestedF, { val: 3 }, 'Nested object in forked data should have been deep cloned') }) it('should handle empty forkingData', () => { diff --git a/typescript/tests/node.test.ts b/typescript/tests/node.test.ts index b2fe910..9a32dc1 100644 --- a/typescript/tests/node.test.ts +++ b/typescript/tests/node.test.ts @@ -1,6 +1,6 @@ import assert from 'node:assert/strict' import { afterEach, beforeEach, describe, it, mock } from 'node:test' -import { BaseNode, DEFAULT_ACTION, Memory, Node, NodeError } from '../brainyflow' +import { BaseNode, createMemory, DEFAULT_ACTION, Memory, Node, NodeError } from '../brainyflow' // Helper sleep function const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)) @@ -71,7 +71,7 @@ describe('BaseNode & Node', () => { beforeEach(() => { globalStore = { initial: 'global' } - memory = Memory.create(globalStore) + memory = createMemory(globalStore) // Reset mocks for SimpleNode if necessary (though node:test often isolates) mock.reset() // Reset all mocks globally for safety }) @@ -191,7 +191,7 @@ describe('BaseNode & Node', () => { assert.equal(triggers.length, 1) const [action, triggeredMemory] = triggers[0] assert.equal(action, 'my_action') - assert.ok(triggeredMemory instanceof Memory) + assert.ok(triggeredMemory._isMemoryObject) assert.equal(triggeredMemory.key, 'value') // Check forkingData applied locally assert.equal(triggeredMemory.local.key, 'value') assert.strictEqual(memory.key, undefined) // Original memory unaffected @@ -211,7 +211,7 @@ describe('BaseNode & Node', () => { assert.equal(triggers.length, 1) const [action, triggeredMemory] = triggers[0] assert.equal(action, DEFAULT_ACTION) - assert.ok(triggeredMemory instanceof Memory) + assert.ok(triggeredMemory._isMemoryObject) assert.notStrictEqual(triggeredMemory, memory) // Should be a clone assert.deepStrictEqual(triggeredMemory.local, {}) // No forking data }) @@ -259,7 +259,7 @@ describe('BaseNode & Node', () => { const triggers = await node.run(memory, true) assert.equal(triggers.length, 1) assert.equal(triggers[0][0], 'test_action') - assert.ok(triggers[0][1] instanceof Memory) + assert.ok(triggers[0][1]._isMemoryObject) }) it('run() should warn if called on a node with successors', async () => { @@ -268,11 +268,7 @@ describe('BaseNode & Node', () => { nodeA.next(nodeB) const warnMock = mock.method(console, 'warn', () => {}) await nodeA.run(memory) - assert.equal( - warnMock.mock.calls.length, - 1, - 'Expected a warning when running a node that has successors', - ) + assert.equal(warnMock.mock.calls.length, 1, 'Expected a warning when running a node that has successors') warnMock.mock.restore() // assert.match(warnMock.mock.calls[0].arguments[0], /Node won't run successors. Use Flow!/); // warnMock.mock.restore(); @@ -287,7 +283,7 @@ describe('BaseNode & Node', () => { assert.equal(node.prep.mock.calls.length, 1) const memoryArg = node.prep.mock.calls[0].arguments[0] as Memory - assert.ok(memoryArg instanceof Memory) + assert.ok(memoryArg._isMemoryObject) assert.equal(memoryArg.initial, 'global_val') // Check property from the passed global assert.equal(memoryArg.count, 5) assert.deepStrictEqual(memoryArg.local, {}) @@ -395,7 +391,11 @@ describe('BaseNode & Node', () => { }) it('should wait between retries if wait > 0', async () => { - const node = new ErrorNode({ maxRetries: 3, wait: 0.05, succeedAfter: 1 }) // Succeed on 2nd attempt + const node = new ErrorNode({ + maxRetries: 3, + wait: 0.05, + succeedAfter: 1, + }) // Succeed on 2nd attempt const startTime = Date.now() const result = await node.run(memory) const endTime = Date.now() diff --git a/typescript/tests/parallelFlow.test.ts b/typescript/tests/parallelFlow.test.ts index 5db9513..67856ab 100644 --- a/typescript/tests/parallelFlow.test.ts +++ b/typescript/tests/parallelFlow.test.ts @@ -1,6 +1,6 @@ import assert from 'node:assert/strict' import { beforeEach, describe, it, mock } from 'node:test' -import { DEFAULT_ACTION, Memory, Node, ParallelFlow } from '../brainyflow' +import { createMemory, DEFAULT_ACTION, Memory, Node, ParallelFlow } from '../brainyflow' // Helper sleep function const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)) @@ -68,7 +68,7 @@ describe('ParallelFlow Class', () => { beforeEach(() => { globalStore = { initial: 'global' } - memory = Memory.create(globalStore) + memory = createMemory(globalStore) triggerNode = new MultiTriggerNode() nodeB = new DelayedNode('B') nodeC = new DelayedNode('C') @@ -98,13 +98,8 @@ describe('ParallelFlow Class', () => { // 1. Check total duration: Should be closer to max(delayB, delayC) than sum(delayB, delayC) const maxDelay = Math.max(delayB, delayC) const sumDelay = delayB + delayC - console.log( - `Execution Time: ${duration}ms (Max Delay: ${maxDelay}ms, Sum Delay: ${sumDelay}ms)`, - ) - assert.ok( - duration < sumDelay - 10, - `Duration (${duration}ms) should be significantly less than sum (${sumDelay}ms)`, - ) + console.log(`Execution Time: ${duration}ms (Max Delay: ${maxDelay}ms, Sum Delay: ${sumDelay}ms)`) + assert.ok(duration < sumDelay - 10, `Duration (${duration}ms) should be significantly less than sum (${sumDelay}ms)`) assert.ok( duration >= maxDelay - 5 && duration < maxDelay + 50, // Allow buffer for overhead `Duration (${duration}ms) should be close to max delay (${maxDelay}ms)`, @@ -120,25 +115,15 @@ describe('ParallelFlow Class', () => { assert.ok('process_c' in result, "Result should contain 'process_c' key") const processB_Results = result.process_b const processC_Results = result.process_c - assert.ok( - Array.isArray(processB_Results) && processB_Results.length === 1, - "'process_b' should be an array with 1 result", - ) - assert.ok( - Array.isArray(processC_Results) && processC_Results.length === 1, - "'process_c' should be an array with 1 result", - ) + assert.ok(Array.isArray(processB_Results) && processB_Results.length === 1, "'process_b' should be an array with 1 result") + assert.ok(Array.isArray(processC_Results) && processC_Results.length === 1, "'process_c' should be an array with 1 result") // Check that both branches completed (results are empty objects as DelayedNode has no successors) assert.deepStrictEqual(processB_Results[0], { [DEFAULT_ACTION]: [] }) assert.deepStrictEqual(processC_Results[0], { [DEFAULT_ACTION]: [] }) // 4. Check total mock calls - assert.equal( - nodeB.execMock.mock.calls.length + nodeC.execMock.mock.calls.length, - 2, - 'Total exec calls across parallel nodes should be 2', - ) + assert.equal(nodeB.execMock.mock.calls.length + nodeC.execMock.mock.calls.length, 2, 'Total exec calls across parallel nodes should be 2') }) it('should handle mix of parallel and sequential execution', async () => { @@ -181,10 +166,7 @@ describe('ParallelFlow Class', () => { // Check timing: D should start only after its respective predecessor (B or C) finishes. // The whole flow should take roughly max(delayB, delayC) + delayD - assert.ok( - duration >= expectedMinDuration - 10, - `Duration (${duration}ms) should be >= expected min (${expectedMinDuration}ms)`, - ) + assert.ok(duration >= expectedMinDuration - 10, `Duration (${duration}ms) should be >= expected min (${expectedMinDuration}ms)`) assert.ok( duration < expectedMinDuration + 100, // Allow generous buffer for overhead `Duration (${duration}ms) should be reasonably close to expected min (${expectedMinDuration}ms)`,