diff --git a/src/rules_engine/__init__.py b/src/rules_engine/__init__.py index aa9936e..7ac154e 100644 --- a/src/rules_engine/__init__.py +++ b/src/rules_engine/__init__.py @@ -1,8 +1,15 @@ from typing import Any, Callable, TypeVar, Optional +from dataclasses import dataclass T = TypeVar('T') +@dataclass +class Result: + value: Any + message: Optional[str] + + class Rule: def __init__( self, @@ -29,14 +36,21 @@ class RulesEngine: def __init__(self, *rules: Rule) -> None: self.rules = rules + def _get_message(self, rule) -> Optional[str]: + if rule.message: + return rule.message + return rule.condition.__name__ if rule.condition.__name__ != "" else None + def run(self, *args: Any, **kwargs: Any) -> Any: for rule in self.rules: if rule.condition(*args, **kwargs): - return rule.action(*args, **kwargs, message=rule.message) + return Result(value=rule.action(*args, **kwargs), message=self._get_message(rule)) + + return Result(value=None, message="No conditions matched") def run_all(self, *args: Any, **kwargs: Any) -> list: return [ - rule.action(*args, **kwargs, message=rule.message) + Result(value=rule.action(*args, **kwargs), message=self._get_message(rule)) for rule in self.rules if rule.condition(*args, **kwargs) ] diff --git a/tests/test_article_completed_example.py b/tests/test_article_completed_example.py index 5d20843..2bcddb8 100644 --- a/tests/test_article_completed_example.py +++ b/tests/test_article_completed_example.py @@ -2,7 +2,7 @@ import pytest -from src.rules_engine import Otherwise, Rule, RulesEngine, not_, then +from src.rules_engine import Otherwise, Rule, RulesEngine, not_, then, Result Article = namedtuple("Article", "title price image_url stock") @@ -19,68 +19,85 @@ def article_image_missing(article): return not article.image_url -def return_false_and_message(article, message): - return False, message - - @pytest.mark.parametrize( - "article, result", + "article, expected_result, message", [ ( Article( title="Iphone Case", price=1000, image_url="http://localhost/image", stock=None ), - (False, "article stock missing"), + False, + "article stock missing", ), ( Article(title="Iphone Case", price=None, image_url="http://image", stock=10), False, + "article_price_missing", ), ( Article(title="Iphone Case", price=1000, image_url="", stock=10), False, + "article_image_missing", ), ( Article(title="Iphone Case", price=1000, image_url="http://image", stock=10), True, + None, ), ], ) -def test_article_complete_rules(article, result): - assert result == RulesEngine( - Rule(article_stock_missing, return_false_and_message, message="article stock missing"), +def test_article_complete_rules(article, expected_result, message): + result = RulesEngine( + Rule(article_stock_missing, then(False), message="article stock missing"), Rule(article_price_missing, then(False)), Rule(article_image_missing, then(False)), Otherwise(then(True)), ).run(article) + result.value = expected_result + result.message = message + @pytest.mark.parametrize( - "article, result", + "article, expected_result", [ ( Article( title="Iphone Case", price=1000, image_url="http://localhost/image", stock=None ), - ["B", "C"], + [ + Result(value='B', message='article price missing'), + Result(value='C', message='article image missing'), + ], ), ( Article(title="Iphone Case", price=None, image_url="http://image", stock=10), - ["A", "C"], + [ + Result(value='A', message='article stock missing'), + Result(value='C', message='article image missing'), + ], ), ( Article(title="Iphone Case", price=1000, image_url="", stock=10), - ["A", "B"], + [ + Result(value='A', message='article stock missing'), + Result(value='B', message='article price missing'), + ], ), ( Article(title="Iphone Case", price=1000, image_url="http://image", stock=10), - ["A", "B", "C"], + [ + Result(value='A', message='article stock missing'), + Result(value='B', message='article price missing'), + Result(value='C', message='article image missing'), + ], ), ], ) -def test_article_complete_all_rules(article, result): - assert result == RulesEngine( - Rule(not_(article_stock_missing), then("A")), - Rule(not_(article_price_missing), then("B")), - Rule(not_(article_image_missing), then("C")), +def test_article_complete_all_rules(article, expected_result): + result = RulesEngine( + Rule(not_(article_stock_missing), then("A"), message="article stock missing"), + Rule(not_(article_price_missing), then("B"), message="article price missing"), + Rule(not_(article_image_missing), then("C"), message="article image missing"), ).run_all(article) + assert result == expected_result diff --git a/tests/test_operators.py b/tests/test_operators.py index d81098d..58d3d0c 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -3,7 +3,7 @@ from src.rules_engine import Rule, RulesEngine, all_, any_, not_, then, when -def raise_cannot_be_none_error(obj, message): +def raise_cannot_be_none_error(obj): raise ValueError("cannot be None error") @@ -13,7 +13,9 @@ def test_when_then_operator(): with pytest.raises(ValueError): RulesEngine(Rule(when(obj is None), raise_cannot_be_none_error)).run(obj) - assert RulesEngine(Rule(when(obj is not None), then(True))).run(obj) is None + result = RulesEngine(Rule(when(obj is not None), then(True), "obj is None")).run(obj) + assert result.value is None + assert result.message == "No conditions matched" @pytest.mark.parametrize( @@ -27,7 +29,7 @@ def test_when_then_operator(): def test_not_operator(condition, action, result): obj = None - assert RulesEngine(Rule(not_(when(condition)), then(action))).run(obj) is result + assert RulesEngine(Rule(not_(when(condition)), then(action))).run(obj).value is result @pytest.mark.parametrize( @@ -42,19 +44,21 @@ def test_not_operator(condition, action, result): def test_any_operator(conditions, action, result): obj = None - assert RulesEngine(Rule(any_(*conditions), then(action))).run(obj) is result + assert RulesEngine(Rule(any_(*conditions), then(action))).run(obj).value is result @pytest.mark.parametrize( - "conditions,action,result", + "conditions,action,value,message", [ - ([when(False), when(False), when(False)], "A", None), - ([when(True), when(False), when(False)], "A", None), - ([when(True), when(True), when(False)], "A", None), - ([when(True), when(True), when(True)], "A", "A"), + ([when(False), when(False), when(False)], "A", None, "No conditions matched"), + ([when(True), when(False), when(False)], "A", None, "No conditions matched"), + ([when(True), when(True), when(False)], "A", None, "No conditions matched"), + ([when(True), when(True), when(True)], "A", "A", None), ], ) -def test_all_operator(conditions, action, result): +def test_all_operator(conditions, action, value, message): obj = None - assert RulesEngine(Rule(all_(*conditions), then(action))).run(obj) is result + result = RulesEngine(Rule(all_(*conditions), then(action))).run(obj) + assert result.value == value + assert result.message == message