diff --git a/devices/rotary_encoder.py b/devices/rotary_encoder.py index 7386847..79e964e 100644 --- a/devices/rotary_encoder.py +++ b/devices/rotary_encoder.py @@ -14,6 +14,7 @@ def __init__( falling_event=None, bytes_per_sample=2, reverse=False, + triggers=None, ): assert output in ("velocity", "position"), "ouput argument must be 'velocity' or 'position'." assert bytes_per_sample in (2, 4), "bytes_per_sample must be 2 or 4" @@ -28,6 +29,7 @@ def __init__( self.position = 0 self.velocity = 0 self.sampling_rate = sampling_rate + Analog_input.__init__( self, None, @@ -37,6 +39,7 @@ def __init__( rising_event, falling_event, data_type={2: "h", 4: "i"}[bytes_per_sample], + triggers=triggers, ) def read_sample(self): diff --git a/devices/schmitt_trigger.py b/devices/schmitt_trigger.py new file mode 100644 index 0000000..f1e0952 --- /dev/null +++ b/devices/schmitt_trigger.py @@ -0,0 +1,108 @@ +from pyControl.hardware import IO_object, assign_ID, interrupt_queue +import pyControl.framework as fw +import pyControl.state_machine as sm + + +class Crossing: + above = "above" + below = "below" + none = "none" + + +class SchmittTrigger(IO_object): + """ + Generates framework events when an analog signal goes above an upper threshold and/or below a lower threshold. + The rising event is triggered when signal > upper bound, falling event is triggered when signal < lower bound. + + This trigger implements hysteresis, which is a technique to prevent rapid oscillations or "bouncing" of events: + - Hysteresis creates a "dead zone" between the upper and lower thresholds + - Once a rising event is triggered (when signal crosses above the upper bound), + it cannot be triggered again until the signal has fallen below the lower bound + - Similarly, once a falling event is triggered (when signal crosses below the lower bound), + it cannot be triggered again until the signal has risen above the upper bound + + This behavior is particularly useful for noisy signals that might otherwise rapidly cross a single threshold + multiple times, generating unwanted repeated events. + """ + + def __init__(self, bounds, rising_event=None, falling_event=None): + if rising_event is None and falling_event is None: + raise ValueError("Either rising_event or falling_event or both must be specified.") + self.rising_event = rising_event + self.falling_event = falling_event + self.bounds = bounds + self.timestamp = 0 + assign_ID(self) + + def run_start(self): + self.set_bounds(self.bounds) + + def set_bounds(self, threshold): + if isinstance(threshold, tuple): + threshold_requirements_str = "The threshold must be a tuple of two integers (lower_bound, upper_bound) where lower_bound <= upper_bound." + if len(threshold) != 2: + raise ValueError("{} is not a valid threshold. {}".format(threshold, threshold_requirements_str)) + lower, upper = threshold + if not upper >= lower: + raise ValueError( + "{} is not a valid threshold because the lower bound {} is greater than the upper bound {}. {}".format( + threshold, lower, upper, threshold_requirements_str + ) + ) + self.upper_threshold = upper + self.lower_threshold = lower + else: + raise ValueError("{} is not a valid threshold. {}".format(threshold, threshold_requirements_str)) + self.reset_crossing = True + + content = {"bounds": (self.lower_threshold, self.upper_threshold)} + if self.rising_event is not None: + content["rising_event"] = self.rising_event + if self.falling_event is not None: + content["falling_event"] = self.falling_event + fw.data_output_queue.put( + fw.Datatuple( + fw.current_time, + fw.THRSH_TYP, + "s", + str(content), + ) + ) + + def _initialise(self): + # Set event codes for rising and falling events. + self.rising_event_ID = sm.events[self.rising_event] if self.rising_event in sm.events else False + self.falling_event_ID = sm.events[self.falling_event] if self.falling_event in sm.events else False + self.threshold_active = self.rising_event_ID or self.falling_event_ID + + def _process_interrupt(self): + # Put event generated by threshold crossing in event queue. + if self.was_above: + fw.event_queue.put(fw.Datatuple(self.timestamp, fw.EVENT_TYP, "i", self.rising_event_ID)) + else: + fw.event_queue.put(fw.Datatuple(self.timestamp, fw.EVENT_TYP, "i", self.falling_event_ID)) + + @micropython.native + def check(self, sample): + if self.reset_crossing: + # this gets run when the first sample is taken and whenever the threshold is changed + self.reset_crossing = False + self.was_above = sample > self.upper_threshold + self.was_below = sample < self.lower_threshold + self.last_crossing = Crossing.none + return + is_above_threshold = sample > self.upper_threshold + is_below_threshold = sample < self.lower_threshold + + if is_above_threshold and not self.was_above and self.last_crossing != Crossing.above: + self.timestamp = fw.current_time + self.last_crossing = Crossing.above + if self.rising_event_ID: + interrupt_queue.put(self.ID) + elif is_below_threshold and not self.was_below and self.last_crossing != Crossing.below: + self.timestamp = fw.current_time + self.last_crossing = Crossing.below + if self.falling_event_ID: + interrupt_queue.put(self.ID) + + self.was_above, self.was_below = is_above_threshold, is_below_threshold diff --git a/source/communication/data_logger.py b/source/communication/data_logger.py index 2b1a628..e70f4e5 100755 --- a/source/communication/data_logger.py +++ b/source/communication/data_logger.py @@ -73,7 +73,7 @@ def write_info_line(self, subtype, content, time=0): self.data_file.write(self.tsv_row_str("info", time, subtype, content)) def tsv_row_str(self, rtype, time, subtype="", content=""): - time_str = f"{time/1000:.3f}" if isinstance(time, int) else time + time_str = f"{time / 1000:.3f}" if isinstance(time, int) else time return f"{time_str}\t{rtype}\t{subtype}\t{content}\n" def copy_task_file(self, data_dir, tasks_dir, dir_name="task_files"): @@ -140,6 +140,8 @@ def data_to_string(self, new_data, prettify=False, max_len=60): var_str += f'\t\t\t"{var_name}": {var_value}\n' var_str += "\t\t\t}" data_string += self.tsv_row_str("variable", time, nd.subtype, content=var_str) + elif nd.type == MsgType.THRSH: # Threshold + data_string += self.tsv_row_str("threshold", time, nd.subtype, content=nd.content) elif nd.type == MsgType.WARNG: # Warning data_string += self.tsv_row_str("warning", time, content=nd.content) elif nd.type in (MsgType.ERROR, MsgType.STOPF): # Error or stop framework. diff --git a/source/communication/message.py b/source/communication/message.py index 30b19c9..43ab719 100644 --- a/source/communication/message.py +++ b/source/communication/message.py @@ -16,6 +16,7 @@ class MsgType(Enum): ERROR = b"!!" # Error STOPF = b"X" # Stop framework ANLOG = b"A" # Analog + THRSH = b"T" # Threshold @classmethod def from_byte(cls, byte_value): @@ -51,5 +52,9 @@ def get_subtype(self, subtype_char): "t": "task", "a": "api", "u": "user", + "s": "trigger", + }, + MsgType.THRSH: { + "s": "set", }, }[self][subtype_char] diff --git a/source/communication/pycboard.py b/source/communication/pycboard.py index 8cf5835..748c6c9 100644 --- a/source/communication/pycboard.py +++ b/source/communication/pycboard.py @@ -492,7 +492,7 @@ def process_data(self): self.timestamp = msg_timestamp if msg_type in (MsgType.EVENT, MsgType.STATE): content = int(content_bytes.decode()) # Event/state ID. - elif msg_type in (MsgType.PRINT, MsgType.WARNG): + elif msg_type in (MsgType.PRINT, MsgType.WARNG, MsgType.THRSH): content = content_bytes.decode() # Print or error string. elif msg_type == MsgType.VARBL: content = content_bytes.decode() # JSON string diff --git a/source/pyControl/framework.py b/source/pyControl/framework.py index c2f697a..f76cd1a 100644 --- a/source/pyControl/framework.py +++ b/source/pyControl/framework.py @@ -24,6 +24,7 @@ class pyControlError(BaseException): # Exception for pyControl errors. VARBL_TYP = b"V" # Variable change : (time, VARBL_TYP, [g]et/user_[s]et/[a]pi_set/[p]rint/s[t]art/[e]nd, json_str) WARNG_TYP = b"!" # Warning : (time, WARNG_TYP, "", print_string) STOPF_TYP = b"X" # Stop framework : (time, STOPF_TYP, "", "") +THRSH_TYP = b"T" # Threshold : (time, THRSH_TYP, [s]et) # Event_queue ----------------------------------------------------------------- diff --git a/source/pyControl/hardware.py b/source/pyControl/hardware.py index 9e03afd..60521f1 100644 --- a/source/pyControl/hardware.py +++ b/source/pyControl/hardware.py @@ -235,25 +235,35 @@ class Analog_input(IO_object): # streams data to computer. Optionally can generate framework events when voltage # goes above / below specified value theshold. - def __init__(self, pin, name, sampling_rate, threshold=None, rising_event=None, falling_event=None, data_type="H"): - if rising_event or falling_event: - self.threshold = Analog_threshold(threshold, rising_event, falling_event) - else: - self.threshold = False + def __init__( + self, + pin, + name, + sampling_rate, + threshold=None, + rising_event=None, + falling_event=None, + data_type="H", + triggers=None, + ): + self.triggers = triggers if triggers is not None else [] + if threshold is not None: + self.triggers.append(Analog_threshold(threshold, rising_event, falling_event)) + self.timer = pyb.Timer(available_timers.pop()) if pin: # pin argument can be None when Analog_input subclassed. self.ADC = pyb.ADC(pin) self.read_sample = self.ADC.read self.name = name - self.Analog_channel = Analog_channel(name, sampling_rate, data_type) + self.channel = Analog_channel(name, sampling_rate, data_type) assign_ID(self) def _run_start(self): # Start sampling timer, initialise threshold, aquire first sample. - self.timer.init(freq=self.Analog_channel.sampling_rate) + self.timer.init(freq=self.channel.sampling_rate) self.timer.callback(self._timer_ISR) - if self.threshold: - self.threshold.run_start(self.read_sample()) + for trigger in self.triggers: + trigger.run_start() self._timer_ISR(0) def _run_stop(self): @@ -263,9 +273,10 @@ def _run_stop(self): def _timer_ISR(self, t): # Read a sample to the buffer, update write index. sample = self.read_sample() - self.Analog_channel.put(sample) - if self.threshold: - self.threshold.check(sample) + self.channel.put(sample) + if self.triggers: + for trigger in self.triggers: + trigger.check(sample) def record(self): # For backward compatibility. pass @@ -286,15 +297,21 @@ class Analog_channel(IO_object): # data array bytes (variable) def __init__(self, name, sampling_rate, data_type, plot=True): - assert data_type in ("b", "B", "h", "H", "i", "I"), "Invalid data_type." - assert not any( - [name == io.name for io in IO_dict.values() if isinstance(io, Analog_channel)] - ), "Analog signals must have unique names." + if data_type not in ("b", "B", "h", "H", "i", "I"): + raise ValueError("Invalid data_type.") + if any([name == io.name for io in IO_dict.values() if isinstance(io, Analog_channel)]): + raise ValueError( + "Analog signals must have unique names.{} {}".format( + name, [io.name for io in IO_dict.values() if isinstance(io, Analog_channel)] + ) + ) + self.name = name assign_ID(self) self.sampling_rate = sampling_rate self.data_type = data_type self.plot = plot + self.bytes_per_sample = {"b": 1, "B": 1, "h": 2, "H": 2, "i": 4, "I": 4}[data_type] self.buffer_size = max(4, min(256 // self.bytes_per_sample, sampling_rate // 10)) self.buffers = (array(data_type, [0] * self.buffer_size), array(data_type, [0] * self.buffer_size)) @@ -345,15 +362,14 @@ def send_buffer(self, run_stop=False): class Analog_threshold(IO_object): - # Generates framework events when an analog signal goes above or below specified threshold. + # Generates framework events when an analog signal goes above or below specified threshold value. - def __init__(self, threshold=None, rising_event=None, falling_event=None): - assert isinstance( - threshold, int - ), "Integer threshold must be specified if rising or falling events are defined." - self.threshold = threshold + def __init__(self, threshold, rising_event=None, falling_event=None): + if rising_event is None and falling_event is None: + raise ValueError("Either rising_event or falling_event or both must be specified.") self.rising_event = rising_event self.falling_event = falling_event + self.threshold = threshold self.timestamp = 0 self.crossing_direction = False assign_ID(self) @@ -364,8 +380,8 @@ def _initialise(self): self.falling_event_ID = sm.events[self.falling_event] if self.falling_event in sm.events else False self.threshold_active = self.rising_event_ID or self.falling_event_ID - def run_start(self, sample): - self.above_threshold = sample > self.threshold + def run_start(self): + self.set_threshold(self.threshold) def _process_interrupt(self): # Put event generated by threshold crossing in event queue. @@ -376,14 +392,40 @@ def _process_interrupt(self): @micropython.native def check(self, sample): + if self.reset_above_threshold: + # this gets run when the first sample is taken and whenever the threshold is changed + self.reset_above_threshold = False + self.above_threshold = sample > self.threshold + return new_above_threshold = sample > self.threshold if new_above_threshold != self.above_threshold: # Threshold crossing. self.above_threshold = new_above_threshold if (self.above_threshold and self.rising_event_ID) or (not self.above_threshold and self.falling_event_ID): self.timestamp = fw.current_time self.crossing_direction = self.above_threshold + interrupt_queue.put(self.ID) + def set_threshold(self, threshold): + if not isinstance(threshold, int): + raise ValueError(f"Threshold must be an integer, got {type(threshold).__name__}.") + self.threshold = threshold + self.reset_above_threshold = True + + content = {"value": self.threshold} + if self.rising_event is not None: + content["rising_event"] = self.rising_event + if self.falling_event is not None: + content["falling_event"] = self.falling_event + fw.data_output_queue.put( + fw.Datatuple( + fw.current_time, + fw.THRSH_TYP, + "s", + str(content), + ) + ) + # Digital Output -------------------------------------------------------------- diff --git a/tasks/example/running_wheel.py b/tasks/example/running_wheel.py index 09452ad..21e604c 100644 --- a/tasks/example/running_wheel.py +++ b/tasks/example/running_wheel.py @@ -1,9 +1,12 @@ # Example of using a rotary encoder to measure running speed and trigger events when # running starts and stops. The subject must run for 10 seconds to trigger reward delivery, # then stop running for 5 seconds to initiate the next trial. +# If while running the subject exceeds a bonus velocity threshold, they earn a bonus +# and the reward duration is extended by a bonus duration. from pyControl.utility import * from devices import * +from pyControl.hardware import Analog_threshold # Variables. @@ -11,7 +14,20 @@ v.stop_time = 5 * second # Time subject must stop running to intiate the next trial. v.reward_duration = 100 * ms # Time reward solenoid is open for. v.velocity_threshold = 100 # Minimum encoder velocity treated as running (encoder counts/second). +v.bonus_velocity_threshold = 5000 # Encoder velocity that triggers bonus reward (encoder counts/second). +v.give_bonus = False # Whether to give bonus reward. +v.bonus_reward_duration = 50 * ms # Time to add to reward duration if bonus is earned. +running_trigger = Analog_threshold( + threshold=v.velocity_threshold, + rising_event="started_running", + falling_event="stopped_running", +) + +bonus_trigger = Analog_threshold( + threshold=v.bonus_velocity_threshold, + rising_event="bonus_earned", +) # Instantiate hardware - would normally be in a seperate hardware definition file. board = Breakout_1_2() # Breakout board. @@ -21,9 +37,7 @@ name="running_wheel", sampling_rate=100, output="velocity", - threshold=v.velocity_threshold, - rising_event="started_running", - falling_event="stopped_running", + triggers=[running_trigger, bonus_trigger], ) # Running wheel must be plugged into port 1 of breakout board. solenoid = Digital_output(board.port_2.POW_A) # Reward delivery solenoid. @@ -40,6 +54,7 @@ events = [ "started_running", "stopped_running", + "bonus_earned", "run_timer", "stopped_timer", "reward_timer", @@ -70,10 +85,13 @@ def running_for_reward(event): # If subject runs for long enough go to reward state. # If subject stops go back to trial start. if event == "entry": + v.give_bonus = False set_timer("run_timer", v.run_time) elif event == "stopped_running": disarm_timer("run_timer") goto_state("trial_start") + elif event == "bonus_earned": + v.give_bonus = True elif event == "run_timer": goto_state("reward") @@ -81,7 +99,7 @@ def running_for_reward(event): def reward(event): # Deliver reward then go to inter trial interval. if event == "entry": - timed_goto_state("inter_trial_interval", v.reward_duration) + timed_goto_state("inter_trial_interval", v.reward_duration + v.bonus_reward_duration * v.give_bonus) solenoid.on() elif event == "exit": solenoid.off()