Skip to content

Commit

Permalink
fix: gate condition + better logging on child invokations
Browse files Browse the repository at this point in the history
  • Loading branch information
arthurhauer committed Nov 21, 2024
1 parent 29b3262 commit 8adb5f0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 5 additions & 3 deletions models/node/gate/gate_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ def _initialize_buffer_options(self, buffer_options: dict) -> None:
"""
self.clear_input_buffer_if_condition_not_met = buffer_options['clear_input_buffer_if_condition_not_met']
self.clear_input_buffer_if_condition_met = buffer_options['clear_input_buffer_if_condition_met']
self.clear_output_buffer_if_condition_not_met = buffer_options['clear_output_buffer_if_condition_not_met']
self.clear_output_buffer_if_condition_met = buffer_options['clear_output_buffer_if_condition_met']
self._gate_bypass_condition_met = False

def _run(self, data: FrameworkData, input_name: str) -> None:
self.print(f'Inserting data in input buffer {input_name}')
self._insert_new_input_data(data, input_name)
gate_bypass_condition_met = self._check_gate_condition()
if not gate_bypass_condition_met:
self._gate_bypass_condition_met = self._check_gate_condition()
if not self._gate_bypass_condition_met:
if self.clear_input_buffer_if_condition_not_met:
self.print('Clearing input buffer because condition was not met')
self._clear_input_buffer()
Expand All @@ -86,7 +88,7 @@ def _check_gate_condition(self) -> bool:
raise NotImplementedError()

def _is_next_node_call_enabled(self) -> bool:
return self._output_buffer[self.OUTPUT_MAIN].get_data_count() > 0
return self._gate_bypass_condition_met

def _get_inputs(self) -> List[str]:
return [
Expand Down
4 changes: 3 additions & 1 deletion models/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def add_child(self, output_name: str, node: Node, input_name: str):
name=self.name)
self._children[output_name].append(
{
'input_name': input_name,
'node': node,
'run': lambda data: node.run(data, input_name),
'run_': lambda data: node.run(),
Expand All @@ -254,6 +255,7 @@ def _call_children(self):
continue
output_children = self._children[output_name]
for child in output_children:
self.print(f'Output {output_name} calling child {child["node"].name} input {child["input_name"]} ({output.get_data_count()} samples)')
child['run'](output)

def _thread_runner(self):
Expand Down Expand Up @@ -372,7 +374,7 @@ def dispose(self) -> None:

def print(self, message: str, exception: Exception = None) -> None:
if self._enable_log or not exception is None:
print(f'{time.time()} - {self._MODULE_NAME}.{self.name} - {message}')
print(f'{time.time()} - {self._MODULE_NAME}.{self.name} - {message}\n')
if exception:
print('Stack trace:')
traceback.print_exc()
Expand Down

0 comments on commit 8adb5f0

Please sign in to comment.