From 763e8693ae112fe82ca9b36fc1e587be5d3493e6 Mon Sep 17 00:00:00 2001 From: Raphael Ahrens Date: Sat, 29 Mar 2025 11:38:06 +0100 Subject: [PATCH] Refactor of the `serialize` and `to_serializable` function This commit tries to tackle #268 and rewrite the `serialize` function to handle less class specific cases and move it into the single dispatch function `to_serializable`. --- pytm/pytm.py | 62 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/pytm/pytm.py b/pytm/pytm.py index a711687..060c272 100644 --- a/pytm/pytm.py +++ b/pytm/pytm.py @@ -1278,7 +1278,6 @@ def get_table(db, klass): db.close() - class Controls: """Controls implemented by/on and Element""" @@ -2004,52 +2003,75 @@ def to_serializable(val): @to_serializable.register(TM) def ts_tm(obj): - return serialize(obj, nested=True) + result = serialize(obj, nested=True, ignore=( + "_sf", "_duplicate_ignored_attrs", "_threats", "_elements", "assumptions")) + result["elements"] = [e for e in obj._elements if isinstance(e, (Actor, Asset))] + result["assumptions"] = list(obj.assumptions) + return result @to_serializable.register(Controls) @to_serializable.register(Data) +@to_serializable.register(Finding) +def _(obj): + return serialize(obj, nested=False) + + @to_serializable.register(Threat) -@to_serializable.register(Element) +def _(obj): + result = serialize(obj, nested=False, ignore=["target"]) + result["target"] = [v.__name__ for v in obj.target] + return result + + @to_serializable.register(Finding) +def _(obj): + return serialize(obj, nested=False, ignore=["element"]) + + +@to_serializable.register(Element) def ts_element(obj): - return serialize(obj, nested=False) + result = serialize(obj, nested=False, ignore=("_is_drawn", "uuid", "levels", "sourceFiles", "assumptions", "findings")) + result["levels"] = list(obj.levels) + result["sourceFiles"] = list(obj.sourceFiles) + result["assumptions"] = list(obj.assumptions) + result["findings"] = [v.id for v in obj.findings] + return result -def serialize(obj, nested=False): +@to_serializable.register(Actor) +@to_serializable.register(Asset) +def _(obj): + # Note that we use the ts_element function defined for the Element class + result = ts_element(obj) + result["__class__"] = obj.__class__.__name__ + return result + + +def serialize(obj, nested=False, ignore=None): """Used if *obj* is an instance of TM, Element, Threat or Finding.""" klass = obj.__class__ result = {} - if isinstance(obj, (Actor, Asset)): - result["__class__"] = klass.__name__ + if ignore is None: + ignore = [] + for i in dir(obj): if ( i.startswith("__") or callable(getattr(klass, i, {})) - or ( - isinstance(obj, TM) - and i in ("_sf", "_duplicate_ignored_attrs", "_threats") - ) - or (isinstance(obj, Element) and i in ("_is_drawn", "uuid")) - or (isinstance(obj, Finding) and i == "element") + or i in ignore ): continue value = getattr(obj, i) - if isinstance(obj, TM) and i == "_elements": - value = [e for e in value if isinstance(e, (Actor, Asset))] if value is not None: if isinstance(value, (Element, Data)): value = value.name - elif isinstance(obj, Threat) and i == "target": - value = [v.__name__ for v in value] - elif i in ("levels", "sourceFiles", "assumptions"): - value = list(value) elif ( not nested and not isinstance(value, str) and isinstance(value, Iterable) ): - value = [v.id if isinstance(v, Finding) else v.name for v in value] + value = [v.name for v in value] result[i.lstrip("_")] = value return result