Skip to content

Commit

Permalink
fix: various optimization fixes + CSP transform on continuous data + …
Browse files Browse the repository at this point in the history
…roc auc metric node
  • Loading branch information
arthurhauer committed Nov 21, 2024
1 parent c5311dc commit 29b3262
Show file tree
Hide file tree
Showing 20 changed files with 257 additions and 76 deletions.
1 change: 1 addition & 0 deletions application.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _get_node(self, node_name: str) -> Node:
)
node_config['name'] = node_name
node: Node = node_type.from_config_json(node_config)
self._nodes[node_name] = node
self.graphviz_representation += f'\n{node.build_graphviz_representation()}'
for output_name in node_config['outputs']:
if type(node_config['outputs'][output_name]) is not list:
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_config_path(args: Namespace) -> str:
return args.config

def get_config_data(config_path:str):
configuration_file = open(config_path, 'r')
configuration_file = open(config_path, 'r', encoding='utf-8')
config_data = json.load(configuration_file)
configuration_file.close()
return config_data
Expand Down
8 changes: 3 additions & 5 deletions models/framework_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import copy
from typing import Final, List, Dict
from models.exception.invalid_parameter_value import InvalidParameterValue
from models.exception.non_compatible_data import NonCompatibleData
Expand Down Expand Up @@ -149,20 +148,19 @@ def get_channels_as_set(self):
self._channels_set = set(self.channels)
return self._channels_set

def extend(self, input_data: FrameworkData):
def extend(self, data: FrameworkData):
"""This method is used to extend the ``FrameworkData`` object with the data that is input.
The data that is input is checked to ensure that it is compatible with the data that
is already stored in the ``FrameworkData`` object. If the data is compatible, then the
data is extended. If the data is not compatible, then an exception is raised.
:param input_data: The data that is to be extended.
:type input_data: ``FrameworkData``
:param data: The data that is to be extended.
:type data: ``FrameworkData``
:raises NonCompatibleData: Raised when the data that is being input is not compatible with the data that is already stored in the ``FrameworkData`` object.
:return: None
"""
data = copy.deepcopy(input_data)
if len(data.channels) == 0:
return
if not data.has_data():
Expand Down
10 changes: 4 additions & 6 deletions models/node/gate/dynamicgate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import abc
from typing import List, Dict, Final
from typing import Final

from models.exception.invalid_parameter_value import InvalidParameterValue
from models.exception.missing_parameter import MissingParameterError
from models.framework_data import FrameworkData
from models.node.gate.gate_node import Gate
from models.node.node import Node


class DynamicGate(Gate):
Expand All @@ -32,16 +30,16 @@ def _validate_parameters(self, parameters: dict):
@abc.abstractmethod
def _initialize_parameter_fields(self, parameters: dict):
super()._initialize_parameter_fields(parameters)
self._condition_script = parameters['condition']
self._condition = compile(f'condition_result={self._condition_script}', '', 'exec')
condition = parameters['condition']
self._condition_script = f'condition_result={condition}'

def _initialize_buffer_options(self, buffer_options: dict) -> None:
super()._initialize_buffer_options(buffer_options)

@abc.abstractmethod
def _check_gate_condition(self) -> bool:
local_variables = {"condition_data": self._input_buffer[self.INPUT_CONDITION]}
exec(self._condition, globals(), local_variables)
exec(self._condition_script, globals(), local_variables)
condition_result = local_variables['condition_result']
if type(condition_result) is not bool:
raise InvalidParameterValue(
Expand Down
14 changes: 14 additions & 0 deletions models/node/gate/gate_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def _validate_parameters(self, parameters: dict):
raise InvalidParameterValue(module=self._MODULE_NAME, name=self.name,
parameter='buffer_options.clear_input_buffer_if_condition_not_met',
cause='must_be_bool')
if 'clear_input_buffer_if_condition_met' not in parameters['buffer_options']:
raise MissingParameterError(
module=self._MODULE_NAME,
name=self.name,
parameter='buffer_options.clear_input_buffer_if_condition_met'
)
if type(parameters['buffer_options']['clear_input_buffer_if_condition_met']) is not bool:
raise InvalidParameterValue(module=self._MODULE_NAME, name=self.name,
parameter='buffer_options.clear_input_buffer_if_condition_met',
cause='must_be_bool')
if 'clear_output_buffer_if_condition_met' not in parameters['buffer_options']:
raise MissingParameterError(
module=self._MODULE_NAME,
Expand All @@ -51,6 +61,7 @@ def _initialize_buffer_options(self, buffer_options: dict) -> None:
:type buffer_options: dict
"""
self.clear_input_buffer_if_condition_not_met = buffer_options['clear_input_buffer_if_condition_not_met']
self.clear_input_buffer_if_condition_met = buffer_options['clear_input_buffer_if_condition_met']
self.clear_output_buffer_if_condition_met = buffer_options['clear_output_buffer_if_condition_met']

def _run(self, data: FrameworkData, input_name: str) -> None:
Expand All @@ -66,6 +77,9 @@ def _run(self, data: FrameworkData, input_name: str) -> None:
self.print('Clearing output buffer because condition was met')
self._clear_output_buffer()
self._insert_new_output_data(self._input_buffer[self.INPUT_SIGNAL], self.OUTPUT_MAIN)
if self.clear_input_buffer_if_condition_met:
self.print('Clearing input buffer because condition was met')
self._clear_input_buffer()

@abc.abstractmethod
def _check_gate_condition(self) -> bool:
Expand Down
7 changes: 4 additions & 3 deletions models/node/generator/file/csvfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def _initialize_parameter_fields(self, parameters: dict):
self.timestamp_column_name = parameters['timestamp_column_name'] \
if 'timestamp_column_name' in parameters \
else None
self._init_csv_reader()

def _init_csv_reader(self) -> None:
"""This method initializes the CSV reader object. It opens the CSV file and creates a CSV reader object that will be used to read the file.
"""
self.print(f'{self.file_path} opened')
self._csv_file = open(self.file_path)
self._csv_reader = csv.DictReader(self._csv_file)

Expand All @@ -116,11 +116,12 @@ def _is_next_node_call_enabled(self) -> bool:
return self._output_buffer[self.OUTPUT_TIMESTAMP].has_data()

def _is_generate_data_condition_satisfied(self) -> bool:
return not self._csv_file.closed
return True

def _generate_data(self) -> Dict[str, FrameworkData]:
"""This method reads the csv file and store the data in a FrameworkData object.
"""
self._init_csv_reader()
main_data = FrameworkData(self.sampling_frequency, self.channel_column_names)
timestamp_data = FrameworkData(self.sampling_frequency)
for row_index, row in enumerate(self._csv_reader):
Expand All @@ -132,7 +133,7 @@ def _generate_data(self) -> Dict[str, FrameworkData]:
timestamp_data.input_data_on_channel(data=[row_timestamp])
self._csv_file.close()

print(f'{self.file_path} closed')
self.print(f'{self.file_path} closed')
return {
self.OUTPUT_MAIN: main_data,
self.OUTPUT_TIMESTAMP: timestamp_data
Expand Down
30 changes: 18 additions & 12 deletions models/node/generator/file/csvfilearray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from models.exception.invalid_parameter_value import InvalidParameterValue
from models.exception.missing_parameter import MissingParameterError
from models.framework_data import FrameworkData
from models.node.generator.generator_node import GeneratorNode
from models.node.generator.single_run_generator_node import SingleRunGeneratorNode


Expand Down Expand Up @@ -108,18 +107,25 @@ def _generate_data(self) -> Dict[str, FrameworkData]:
main_data = FrameworkData(self.sampling_frequency, self.channel_column_names)
timestamp_data = FrameworkData(self.sampling_frequency)
for file in self.file_paths:
self._csv_file = open(file)
self.print(f'{file} opened')
csv_reader = csv.DictReader(self._csv_file)
for row_index, row in enumerate(csv_reader):
if row_index == 0 and self.channel_column_names is None:
self.channel_column_names = row.keys()
with open(file) as csv_file:
self.print(f'{file} opened')
csv_reader = csv.DictReader(csv_file)
rows = list(csv_reader)

if self.channel_column_names is None:
self.channel_column_names = rows[0].keys()

for channel_name in self.channel_column_names:
main_data.input_data_on_channel([float(row[channel_name])], channel_name)
row_timestamp = row_index if self._should_generate_timestamp() else row[self.timestamp_column_name]
timestamp_data.input_data_on_channel(data=[row_timestamp])
self._csv_file.close()
self.print('closed')
channel_data = [float(row[channel_name]) for row in rows]
main_data.input_data_on_channel(channel_data, channel_name)

if self._should_generate_timestamp():
timestamps = list(range(len(rows)))
else:
timestamps = [row[self.timestamp_column_name] for row in rows]

timestamp_data.input_data_on_channel(timestamps)
self.print(f'{file} closed')

return {
self.OUTPUT_MAIN: main_data,
Expand Down
16 changes: 10 additions & 6 deletions models/node/node.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations
import abc
import copy
import threading
import traceback
import time
from queue import Queue
from threading import Thread, Event
from threading import Thread, Event, Condition
from typing import List, Dict, Final, Any

from models.exception.invalid_parameter_value import InvalidParameterValue
Expand Down Expand Up @@ -38,13 +37,13 @@ def __init__(self, parameters=None) -> None:
self._initialize_children()

self._child_input_relation: Dict[Node, List[str]] = {}

self._is_disposed = False
# Threading attributes
self.local_storage = Queue()
self.running = False
self.thread = None
self.new_data_available = False
self.condition = threading.Condition()
self.condition = Condition()
self._stop_event = Event()
self.is_running_main_process = False
self.thread = Thread(target=self._thread_runner, name=self.name)
Expand Down Expand Up @@ -191,7 +190,7 @@ def _insert_new_input_data(self, data: FrameworkData, input_name: str):
:param input_name: Node input name.
:type input_name: str
"""
self._input_buffer[input_name].extend(copy.deepcopy(data))
self._input_buffer[input_name].extend(data)
if self._should_print_buffer_size:
self._print_buffer_size('input', self._input_buffer)

Expand All @@ -204,7 +203,7 @@ def _insert_new_output_data(self, data: FrameworkData, output_name: str):
:param output_name: Node output name.
:type output_name: str
"""
self._output_buffer[output_name].extend(copy.deepcopy(data))
self._output_buffer[output_name].extend(data)
if self._should_print_buffer_size:
self._print_buffer_size('output', self._output_buffer)

Expand Down Expand Up @@ -251,6 +250,8 @@ def _call_children(self):
"""
for output_name in self._get_outputs():
output = self._output_buffer[output_name]
if output.get_data_count()==0:
continue
output_children = self._children[output_name]
for child in output_children:
child['run'](output)
Expand Down Expand Up @@ -346,6 +347,9 @@ def _get_outputs(self) -> List[str]:
def dispose_all(self) -> None:
"""Disposes itself and all its children nodes
"""
if self._is_disposed:
return
self._is_disposed = True
self._dispose_all_children()
self.dispose()
self._dispose()
Expand Down
1 change: 0 additions & 1 deletion models/node/output/display/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,3 @@ def dispose(self) -> None:
"""
self._clear_output_buffer()
self._clear_input_buffer()
super().dispose()
14 changes: 4 additions & 10 deletions models/node/processing/encoder/onehottosingle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import abc
from typing import Final, Dict

import numpy as np

from models.framework_data import FrameworkData
from models.node.processing.processing_node import ProcessingNode
from typing import List
Expand Down Expand Up @@ -67,16 +69,8 @@ def _process(self, data: Dict[str, FrameworkData]) -> Dict[str, FrameworkData]:
self.print('encoding...')
raw_data = data[self.INPUT_MAIN]
encoded_data: FrameworkData = FrameworkData(sampling_frequency_hz=raw_data.sampling_frequency)
for data_index in range(0, raw_data.get_data_count()):
found_for_index = False
for channel_index, channel in enumerate(raw_data.channels):
if raw_data.get_data_at_index(data_index)[channel] > 0:
encoded_data.input_data_on_channel([channel_index+1])
found_for_index = True
break
if not found_for_index:
encoded_data.input_data_on_channel([0])

encoded = np.argmax(raw_data.get_data_as_2d_array(), axis=0)
encoded_data.input_data_on_channel(encoded)
self.print('encoded!')
return {
self.OUTPUT_MAIN: encoded_data
Expand Down
Empty file.
Loading

0 comments on commit 29b3262

Please sign in to comment.