diff --git a/ccat_reranker.py b/ccat_reranker.py index 82d20a4..5cfe0c1 100644 --- a/ccat_reranker.py +++ b/ccat_reranker.py @@ -23,7 +23,8 @@ def after_cat_recalls_memories(cat) -> None: cat.working_memory['episodic_memories'] = recent_docs else: print("#HicSuntGattones") - + + #TODO refactor if settings["SBERT"]: model = CrossEncoder(settings["ranker"]) if cat.working_memory['declarative_memories']: @@ -35,6 +36,15 @@ def after_cat_recalls_memories(cat) -> None: cat.working_memory['declarative_memories'] = sbert_docs else: print("#HicSuntGattones") + else: + if cat.working_memory['declarative_memories']: + if settings["LITM"]: + litm_docs = litm(cat.working_memory['declarative_memories']) + cat.working_memory['declarative_memories'] = litm_docs + else: + print("#HicSuntGattones") + else: + print("#HicSuntGattones") if settings["FILTER"]: if cat.working_memory['procedural_memories']: