diff --git a/CHANGELOG.md b/CHANGELOG.md index a657fd6..c07d4cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Corrections for EOF (Cancel) Handling: Perform proper checks on whether the file is actually completed using the supplied file size and checksum. +- Some fixes for the handling of multiple NAK PDUs in one NAK sequence # [v0.5.1] 2025-02-10 diff --git a/src/cfdppy/handler/dest.py b/src/cfdppy/handler/dest.py index 34b24fb..0347a2c 100644 --- a/src/cfdppy/handler/dest.py +++ b/src/cfdppy/handler/dest.py @@ -37,6 +37,7 @@ from cfdppy.defs import CfdpState from cfdppy.exceptions import ( + AbstractFileDirectiveBase, InvalidDestinationId, InvalidPduDirection, InvalidPduForDestHandler, @@ -68,6 +69,8 @@ ) if TYPE_CHECKING: + from collections.abc import Iterator + from spacepackets.util import UnsignedByteField _LOGGER = logging.getLogger(__name__) @@ -90,7 +93,6 @@ def empty(cls) -> _DestFileParams: crc32=b"", file_size=None, file_name=Path(), - # file_size_eof=None, metadata_only=False, ) @@ -147,7 +149,7 @@ def packets_ready(self) -> bool: class LostSegmentTracker: def __init__(self): - self.lost_segments = {} + self.lost_segments: list[tuple[int, int]] = [] @property def num_lost_segments(self) -> int: @@ -156,68 +158,80 @@ def num_lost_segments(self) -> int: def reset(self) -> None: self.lost_segments.clear() + def sort(self) -> None: + self.lost_segments = sorted(self.lost_segments, key=lambda t: t[0]) + def add_lost_segment(self, lost_seg: tuple[int, int]) -> None: - self.lost_segments.update({lost_seg[0]: lost_seg[1]}) - self.lost_segments = dict(sorted(self.lost_segments.items())) + self.lost_segments.append(lost_seg) def coalesce_lost_segments(self) -> None: - if len(self.lost_segments) <= 1: + """ + Merge overlapping or adjacent segments in self.lost_segments (List[Tuple[int, int]]). + After the call the list is sorted and fully coalesced. + """ + if not self.lost_segments: # empty list → nothing to do return - merged_segments = [] - # Initialize to the first entry. - current_start, current_end = next(iter(self.lost_segments.items())) - for seg_start, seg_end in self.lost_segments.items(): - if seg_start == current_end: - current_end = seg_end + # 1. Sort by segment start + self.lost_segments.sort(key=lambda s: s[0]) + + merged: list[tuple[int, int]] = [] + cur_start, cur_end = self.lost_segments[0] + + for seg_start, seg_end in self.lost_segments[1:]: + # overlap or adjacency + if seg_start <= cur_end: + # extend current segment + cur_end = max(cur_end, seg_end) + # gap found else: - merged_segments.append((current_start, current_end)) - current_start, current_end = seg_start, seg_end + merged.append((cur_start, cur_end)) + cur_start, cur_end = seg_start, seg_end - merged_segments.append((current_start, current_end)) - self.lost_segments = dict(merged_segments) + merged.append((cur_start, cur_end)) # append the last segment + self.lost_segments = merged def remove_lost_segment(self, segment_to_remove: tuple[int, int]) -> bool: - """Please note that this method can only handle the removal of segments - which do not overlap the boundaries of an existing lost segment. It is however able - to remove lost segments which are only a subset of an existing section. + """Remove *exactly* the requested span from the lost-segment list. - Returns - --------- + Any partial overlap with an existing segment raises ValueError. However, removing segments + which are a subset of an existing segment are allowed. - Returns whether the internal dictionary was manipulated in any way. + Returns True when the list was modified, else False. """ - if segment_to_remove[1] - segment_to_remove[0] == 0: + if segment_to_remove[1] <= segment_to_remove[0]: # empty span return False - did_something = False - end = self.lost_segments.get(segment_to_remove[0]) - if end is not None: - if segment_to_remove[1] > end: - raise ValueError("Specified lost segment end exceeds existing lost segment end") - did_something = True - if segment_to_remove[1] == end: - self.lost_segments.pop(segment_to_remove[0]) - elif segment_to_remove[1] < end: - self.lost_segments.pop(segment_to_remove[0]) - # Re-insert the rest of the missing segment - self.lost_segments.update({segment_to_remove[1]: end}) - else: - for seg_start, seg_end in list(self.lost_segments.items()): - if seg_start < segment_to_remove[0] < seg_end: - if segment_to_remove[1] > seg_end: - raise ValueError( - "Specified lost segment end exceeds existing lost segment end" - ) - if segment_to_remove[1] == seg_end: - self.lost_segments.update({seg_start: segment_to_remove[0]}) - else: - self.lost_segments.update({seg_start: segment_to_remove[0]}) - self.lost_segments.update({segment_to_remove[1]: seg_end}) - did_something = True - break - if did_something: - self.lost_segments = dict(sorted(self.lost_segments.items())) - return did_something + self.sort() + + r_start, r_end = segment_to_remove + new_segments: list[tuple[int, int]] = [] + changed = False + + for s_start, s_end in self.lost_segments: + # Case 1: no overlap - keep segment as-is + if r_end <= s_start or r_start >= s_end: + new_segments.append((s_start, s_end)) + continue + + # ----- partial-overlap detection ----- + # 1. removal sticks out on the left + # 2. removal sticks out on the right + # 3. removal is strictly inside the segment + if r_start < s_start or r_end > s_end: + raise ValueError("Partial overlap with an existing lost segment") + + changed = True + # Left remainder + if r_start > s_start: + new_segments.append((s_start, r_start)) + # Right remainder + if r_end < s_end: + new_segments.append((r_end, s_end)) + + self.sort() + if changed: + self.lost_segments = new_segments + return changed @dataclass @@ -249,16 +263,16 @@ def __init__(self): condition_code=ConditionCode.NO_ERROR, ) self.completion_disposition: CompletionDisposition = CompletionDisposition.COMPLETED - self.pdu_conf = PduConfig.empty() + self.pdu_conf: PduConfig = PduConfig.empty() self.fp: _DestFileParams = _DestFileParams.empty() - self.acked_params = _AckedModeParams() - self.positive_ack_params = _PositiveAckProcedureParams() + self.acked_params: _AckedModeParams = _AckedModeParams() + self.positive_ack_params: _PositiveAckProcedureParams = _PositiveAckProcedureParams() class FsmResult: def __init__(self, states: DestStateWrapper): - self.states = states + self.states: DestStateWrapper = states def acknowledge_inactive_eof_pdu(eof_pdu: EofPdu, status: TransactionStatus) -> AckPdu: @@ -324,12 +338,12 @@ def __init__( remote_cfg_table: RemoteEntityConfigTable, check_timer_provider: CheckTimerProvider, ) -> None: - self.cfg = cfg - self.remote_cfg_table = remote_cfg_table - self.states = DestStateWrapper() - self.user = user - self.check_timer_provider = check_timer_provider - self._params = _DestFieldWrapper() + self.cfg: LocalEntityConfig = cfg + self.remote_cfg_table: RemoteEntityConfigTable = remote_cfg_table + self.states: DestStateWrapper = DestStateWrapper() + self.user: CfdpUserBase = user + self.check_timer_provider: CheckTimerProvider = check_timer_provider + self._params: _DestFieldWrapper = _DestFieldWrapper() self._pdus_to_be_sent: deque[PduHolder] = deque() @property @@ -444,14 +458,18 @@ def _check_inserted_packet(self, packet: GenericPduPacket) -> None: raise NoRemoteEntityConfigFound(entity_id=packet.dest_entity_id) if get_packet_destination(packet) == PacketDestination.SOURCE_HANDLER: raise InvalidPduForDestHandler(packet) + if (self.states.state == CfdpState.IDLE) and ( packet.pdu_type == PduType.FILE_DATA - or packet.directive_type != DirectiveType.METADATA_PDU # type: ignore + or ( + isinstance(packet, AbstractFileDirectiveBase) + and packet.directive_type != DirectiveType.METADATA_PDU + ) ): self._handle_first_packet_not_metadata_pdu(packet) if packet.pdu_type == PduType.FILE_DIRECTIVE and ( - packet.directive_type # type: ignore - in [DirectiveType.ACK_PDU, DirectiveType.PROMPT_PDU] + isinstance(packet, AbstractFileDirectiveBase) + and packet.directive_type in [DirectiveType.ACK_PDU, DirectiveType.PROMPT_PDU] and self.states.state == CfdpState.BUSY and self.transmission_mode == TransmissionMode.UNACKNOWLEDGED ): @@ -542,15 +560,15 @@ def __non_idle_fsm(self, packet: GenericPduPacket | None) -> None: self._handle_fd_or_eof_pdu(pdu_holder) if self.states.step == TransactionStep.WAITING_FOR_METADATA: self._handle_waiting_for_missing_metadata(pdu_holder) - self._deferred_lost_segment_handling() + if self._params.acked_params.deferred_lost_segment_detection_active: + self._deferred_lost_segment_handling() if self.states.step == TransactionStep.RECV_FILE_DATA_WITH_CHECK_LIMIT_HANDLING: self._check_limit_handling() if self.states.step == TransactionStep.WAITING_FOR_MISSING_DATA: if packet is not None and pdu_holder.pdu_type == PduType.FILE_DATA: self._handle_fd_pdu(pdu_holder.to_file_data_pdu()) - if self._params.acked_params.deferred_lost_segment_detection_active: - self._reset_nak_activity_parameters() - self._deferred_lost_segment_handling() + if self._params.acked_params.deferred_lost_segment_detection_active: + self._deferred_lost_segment_handling() if self.states.step == TransactionStep.TRANSFER_COMPLETION: self._handle_transfer_completion() if self.states.step == TransactionStep.SENDING_FINISHED_PDU: @@ -602,7 +620,7 @@ def _handle_eof_without_previous_metadata(self, eof_pdu: EofPdu) -> None: assert self._params.transaction_id is not None self.user.eof_recv_indication(self._params.transaction_id) if eof_pdu.condition_code != ConditionCode.NO_ERROR: - self._handle_cancel_eof(eof_pdu) + self._handle_eof_cancel(eof_pdu) self._prepare_eof_ack_packet() self._eof_ack_pdu_done() @@ -685,6 +703,7 @@ def _common_first_packet_handler(self, pdu: GenericPduPacket) -> bool | None: return None def _handle_metadata_packet(self, metadata_pdu: MetadataPdu) -> None: + assert self._params.transaction_id is not None self._params.checksum_type = metadata_pdu.checksum_type self._params.closure_requested = metadata_pdu.closure_requested self._params.acked_params.metadata_missing = False @@ -704,10 +723,11 @@ def _handle_metadata_packet(self, metadata_pdu: MetadataPdu) -> None: raise NoRemoteEntityConfigFound(metadata_pdu.dest_entity_id) if not self._params.fp.metadata_only: self.states.step = TransactionStep.RECEIVING_FILE_DATA - self._init_vfs_handling(Path(metadata_pdu.source_file_name).name) # type: ignore + assert metadata_pdu.source_file_name is not None + self._init_vfs_handling(Path(metadata_pdu.source_file_name).name) else: self.states.step = TransactionStep.TRANSFER_COMPLETION - msgs_to_user_list = None + msgs_to_user_list: None | list[MessageToUserTlv] = None options = metadata_pdu.options_as_tlv() if options is not None: msgs_to_user_list = [] @@ -718,7 +738,7 @@ def _handle_metadata_packet(self, metadata_pdu: MetadataPdu) -> None: None if metadata_pdu.source_file_name is None else metadata_pdu.file_size ) params = MetadataRecvParams( - transaction_id=self._params.transaction_id, # type: ignore + transaction_id=self._params.transaction_id, file_size=file_size_for_indication, source_id=metadata_pdu.source_entity_id, dest_file_name=metadata_pdu.dest_file_name, @@ -757,12 +777,11 @@ def _handle_waiting_for_missing_metadata(self, packet_holder: PduHolder) -> None self._handle_file_data_without_previous_metadata(packet_holder.to_file_data_pdu()) elif packet_holder.pdu_directive_type == DirectiveType.METADATA_PDU: self._handle_metadata_packet(packet_holder.to_metadata_pdu()) + # Reception of missing segments resets the NAK activity parameters. See CFDP 4.6.4.7. if self._params.acked_params.deferred_lost_segment_detection_active: self._reset_nak_activity_parameters() elif packet_holder.pdu_directive_type == DirectiveType.EOF_PDU: # type: ignore self._handle_eof_without_previous_metadata(packet_holder.to_eof_pdu()) - if self._params.acked_params.deferred_lost_segment_detection_active: - self._reset_nak_activity_parameters() def _reset_nak_activity_parameters(self) -> None: assert self._params.acked_params.procedure_timer is not None @@ -827,8 +846,6 @@ def _handle_fd_pdu(self, file_data_pdu: FileDataPdu) -> None: self.user.file_segment_recv_indication(file_segment_indic_params) try: next_expected_progress = offset + len(data) - if self.transmission_mode == TransmissionMode.ACKNOWLEDGED: - self._lost_segment_handling(offset, len(data)) self.user.vfs.write_data(self._params.fp.file_name, data, offset) self._params.finished_params.file_status = FileStatus.FILE_RETAINED @@ -846,14 +863,16 @@ def _handle_fd_pdu(self, file_data_pdu: FileDataPdu) -> None: return # Ensure that the progress value is always incremented self._params.fp.progress = max(next_expected_progress, self._params.fp.progress) + if self.transmission_mode == TransmissionMode.ACKNOWLEDGED: + self._lost_segment_handling(offset, len(data)) except FileNotFoundError: if self._params.finished_params.file_status != FileStatus.FILE_RETAINED: self._params.finished_params.file_status = FileStatus.DISCARDED_FILESTORE_REJECTION - self._declare_fault(ConditionCode.FILESTORE_REJECTION) + _ = self._declare_fault(ConditionCode.FILESTORE_REJECTION) except PermissionError: if self._params.finished_params.file_status != FileStatus.FILE_RETAINED: self._params.finished_params.file_status = FileStatus.DISCARDED_FILESTORE_REJECTION - self._declare_fault(ConditionCode.FILESTORE_REJECTION) + _ = self._declare_fault(ConditionCode.FILESTORE_REJECTION) def _handle_transfer_completion(self) -> None: self._notice_of_completion() @@ -870,9 +889,7 @@ def _lost_segment_handling(self, offset: int, data_len: int) -> None: by dedicated code which is run when the EOF PDU is handled.""" if offset > self._params.acked_params.last_end_offset: lost_segment = (self._params.acked_params.last_end_offset, offset) - self._params.acked_params.lost_seg_tracker.add_lost_segment( - (self._params.acked_params.last_end_offset, offset) - ) + self._params.acked_params.lost_seg_tracker.add_lost_segment(lost_segment) assert self._params.remote_cfg is not None if self._params.remote_cfg.immediate_nak_mode: self._add_packet_to_be_sent( @@ -888,13 +905,39 @@ def _lost_segment_handling(self, offset: int, data_len: int) -> None: self._params.acked_params.last_end_offset = offset + data_len if offset + data_len <= self._params.acked_params.last_start_offset: # Might be a re-requested FD PDU. - self._params.acked_params.lost_seg_tracker.remove_lost_segment( + removed = self._params.acked_params.lost_seg_tracker.remove_lost_segment( (offset, offset + data_len) ) + # Reception of missing segments resets the NAK activity parameters. + # See CFDP 4.6.4.7. + if removed and self._params.acked_params.deferred_lost_segment_detection_active: + self._reset_nak_activity_parameters() + + @staticmethod + def _iter_segment_requests( + lost_segments: Iterator[tuple[int, int]], + metadata_missing: bool, + max_per_pdu: int, + ) -> Iterator[list[tuple[int, int]]]: + """ + Yield lists of segment requests that fit into a single NAK PDU. + + If metadata is missing we prepend a (0, 0) request exactly once. + """ + batch: list[tuple[int, int]] = [] + if metadata_missing: + batch.append((0, 0)) + + for start, end in lost_segments: + batch.append((start, end)) + if len(batch) == max_per_pdu: + yield batch + batch = [] + # final partial batch + if batch: + yield batch def _deferred_lost_segment_handling(self) -> None: - if not self._params.acked_params.deferred_lost_segment_detection_active: - return assert self._params.remote_cfg is not None assert self._params.fp.file_size is not None if ( @@ -912,15 +955,14 @@ def _deferred_lost_segment_handling(self) -> None: self.states.step = TransactionStep.TRANSFER_COMPLETION self._params.acked_params.deferred_lost_segment_detection_active = False return - first_nak_issuance = False + first_nak_issuance = self._params.acked_params.procedure_timer is None # This is the case if this is the first issuance of NAK PDUs # A timer needs to be instantiated, but we do not increment the activity counter yet. - if self._params.acked_params.procedure_timer is None: + if first_nak_issuance: self._params.acked_params.procedure_timer = Countdown.from_seconds( self._params.remote_cfg.nak_timer_interval_seconds ) - first_nak_issuance = True - elif self._params.acked_params.procedure_timer.busy(): + elif self._params.acked_params.procedure_timer.busy(): # pyright: ignore # There were or there was a previous NAK sequence(s). Wait for timeout before issuing # a new NAK sequence. return @@ -930,60 +972,53 @@ def _deferred_lost_segment_handling(self) -> None: == self._params.remote_cfg.nak_timer_expiration_limit ): self._params.finished_params.delivery_code = DeliveryCode.DATA_INCOMPLETE - self._declare_fault(ConditionCode.NAK_LIMIT_REACHED) + _ = self._declare_fault(ConditionCode.NAK_LIMIT_REACHED) return + # This is not the first NAK issuance and the timer expired. max_segments_in_one_pdu = get_max_seg_reqs_for_max_packet_size_and_pdu_cfg( self._params.remote_cfg.max_packet_len, self._params.pdu_conf ) - next_segment_reqs = [] - if self._params.acked_params.metadata_missing: - next_segment_reqs.append((0, 0)) - for ( - start, - end, - ) in self._params.acked_params.lost_seg_tracker.lost_segments.items(): - next_segment_reqs.append((start, end)) - if len(next_segment_reqs) == max_segments_in_one_pdu: - self._add_packet_to_be_sent( - NakPdu( - self._params.pdu_conf, - 0, - self._params.fp.file_size, - next_segment_reqs, - ) - ) - next_segment_reqs = [] - if len(next_segment_reqs) > 0: - self._add_packet_to_be_sent( - NakPdu( - self._params.pdu_conf, - 0, - self._params.fp.file_size, - next_segment_reqs, - ) + + batches = list( + self._iter_segment_requests( + iter(self._params.acked_params.lost_seg_tracker.lost_segments), + self._params.acked_params.metadata_missing, + max_segments_in_one_pdu, ) + ) + + for idx, batch in enumerate(batches): + # first batch → clamp start to 0 + start = 0 if idx == 0 else batch[0][0] + # last batch → clamp end to file size + end = self._params.fp.file_size if idx == len(batches) - 1 else batch[-1][1] + self._add_packet_to_be_sent(NakPdu(self._params.pdu_conf, start, end, batch)) + if not first_nak_issuance: self._params.acked_params.nak_activity_counter += 1 - self._params.acked_params.procedure_timer.reset() + self._params.acked_params.procedure_timer.reset() # pyright: ignore - def _handle_eof_pdu(self, eof_pdu: EofPdu) -> bool | None: - """Returns whether to exit the FSM prematurely.""" + def _handle_eof_pdu(self, eof_pdu: EofPdu) -> None: self._params.fp.crc32 = eof_pdu.file_checksum self._params.fp.file_size = eof_pdu.file_size if self.cfg.indication_cfg.eof_recv_indication_required: assert self._params.transaction_id is not None self.user.eof_recv_indication(self._params.transaction_id) if eof_pdu.condition_code == ConditionCode.NO_ERROR: - regular_completion = self._handle_no_error_eof() + regular_completion = self._handle_eof_no_error(eof_pdu.file_checksum) if not regular_completion: - return None + return else: - self._handle_cancel_eof(eof_pdu) - self._file_transfer_complete_transition() - return False + self._handle_eof_cancel(eof_pdu) + if self.transmission_mode == TransmissionMode.UNACKNOWLEDGED: + self.states.step = TransactionStep.TRANSFER_COMPLETION + elif self.transmission_mode == TransmissionMode.ACKNOWLEDGED: + self._prepare_eof_ack_packet() + self._eof_ack_pdu_done() - def _handle_cancel_eof(self, eof_pdu: EofPdu) -> None: + def _handle_eof_cancel(self, eof_pdu: EofPdu) -> None: + assert self._params.remote_cfg is not None # This is an EOF (Cancel), perform Cancel Response Procedures according to chapter # 4.6.6 of the standard. Set remote ID as fault location. self._trigger_notice_of_completion_canceled( @@ -999,19 +1034,20 @@ def _handle_cancel_eof(self, eof_pdu: EofPdu) -> None: # Empty file, no file data PDU. self._params.finished_params.delivery_code = DeliveryCode.DATA_COMPLETE return - if self._checksum_verify(self._params.fp.progress, self._params.fp.crc32): + if self._checksum_verify(self._params.fp.progress, eof_pdu.file_checksum): self._params.finished_params.delivery_code = DeliveryCode.DATA_COMPLETE return self._params.finished_params.delivery_code = DeliveryCode.DATA_INCOMPLETE - def _handle_no_error_eof(self) -> bool: + def _handle_eof_no_error(self, crc32: bytes) -> bool: """Returns whether the transfer can be completed regularly.""" + assert self._params.fp.file_size is not None # CFDP 4.6.1.2.9: Declare file size error if progress exceeds file size - if self._params.fp.progress > self._params.fp.file_size: # type: ignore + if self._params.fp.progress > self._params.fp.file_size: if self._declare_fault(ConditionCode.FILE_SIZE_ERROR) != FaultHandlerCode.IGNORE_ERROR: return False elif ( - self._params.fp.progress < self._params.fp.file_size # type: ignore + self._params.fp.progress < self._params.fp.file_size ) and self.transmission_mode == TransmissionMode.ACKNOWLEDGED: # CFDP 4.6.4.3.1: The end offset of the last received file segment and the file # size as stated in the EOF PDU is not the same, so we need to add that segment to @@ -1019,9 +1055,8 @@ def _handle_no_error_eof(self) -> bool: self._params.acked_params.lost_seg_tracker.add_lost_segment( (self._params.fp.progress, self._params.fp.file_size) # type: ignore ) - if ( - self.transmission_mode == TransmissionMode.UNACKNOWLEDGED - and not self._checksum_verify(self._params.fp.progress, self._params.fp.crc32) # type: ignore + if self.transmission_mode == TransmissionMode.UNACKNOWLEDGED and not self._checksum_verify( + self._params.fp.progress, crc32 ): self._start_check_limit_handling() return False @@ -1030,14 +1065,15 @@ def _handle_no_error_eof(self) -> bool: return True def _start_deferred_lost_segment_handling(self) -> None: + assert self._params.fp.file_size is not None if self._params.acked_params.metadata_missing: self.states.step = TransactionStep.WAITING_FOR_METADATA else: self.states.step = TransactionStep.WAITING_FOR_MISSING_DATA self._params.acked_params.deferred_lost_segment_detection_active = True self._params.acked_params.lost_seg_tracker.coalesce_lost_segments() - self._params.acked_params.last_start_offset = self._params.fp.file_size # type: ignore - self._params.acked_params.last_end_offset = self._params.fp.file_size # type: ignore + self._params.acked_params.last_start_offset = self._params.fp.file_size + self._params.acked_params.last_end_offset = self._params.fp.file_size self._deferred_lost_segment_handling() def _prepare_eof_ack_packet(self) -> None: @@ -1065,13 +1101,6 @@ def _checksum_verify(self, verify_len: int, expected_crc32: bytes) -> bool: self._declare_fault(ConditionCode.FILE_CHECKSUM_FAILURE) return False - def _file_transfer_complete_transition(self) -> None: - if self.transmission_mode == TransmissionMode.UNACKNOWLEDGED: - self.states.step = TransactionStep.TRANSFER_COMPLETION - elif self.transmission_mode == TransmissionMode.ACKNOWLEDGED: - self._prepare_eof_ack_packet() - self._eof_ack_pdu_done() - def _trigger_notice_of_completion_canceled( self, condition_code: ConditionCode, fault_location: EntityIdTlv ) -> None: @@ -1090,6 +1119,7 @@ def _start_check_limit_handling(self) -> None: self._params.current_check_count = 0 def _notice_of_completion(self) -> None: + assert self._params.transaction_id is not None if self._params.completion_disposition == CompletionDisposition.COMPLETED: # TODO: Execute any filestore requests pass @@ -1103,7 +1133,7 @@ def _notice_of_completion(self) -> None: self._params.finished_params.file_status = FileStatus.DISCARDED_DELIBERATELY if self.cfg.indication_cfg.transaction_finished_indication_required: finished_indic_params = TransactionFinishedParams( - transaction_id=self._params.transaction_id, # type: ignore + transaction_id=self._params.transaction_id, finished_params=self._params.finished_params, status_report=None, ) @@ -1139,7 +1169,7 @@ def _check_limit_handling(self) -> None: if self._checksum_verify(self._params.fp.progress, self._params.fp.crc32): self._params.finished_params.delivery_code = DeliveryCode.DATA_COMPLETE self._params.finished_params.condition_code = ConditionCode.NO_ERROR - self._file_transfer_complete_transition() + self.states.step = TransactionStep.TRANSFER_COMPLETION return if self._params.current_check_count + 1 >= self._params.remote_cfg.check_limit: self._declare_fault(ConditionCode.CHECK_LIMIT_REACHED) diff --git a/src/cfdppy/handler/source.py b/src/cfdppy/handler/source.py index 328fe73..d5416be 100644 --- a/src/cfdppy/handler/source.py +++ b/src/cfdppy/handler/source.py @@ -532,6 +532,7 @@ def _fsm_non_idle(self, packet: AbstractFileDirectiveBase | None) -> None: def _transaction_start(self) -> None: originating_transaction_id = self._check_for_originating_id() self._prepare_file_params() + assert self._params.fp.file_size is not None self._prepare_pdu_conf(self._params.fp.file_size) self._get_next_transfer_seq_num() self._calculate_max_file_seg_len() @@ -551,20 +552,22 @@ def _check_for_originating_id(self) -> TransactionId | None: contains_proxy_put_response = False contains_originating_id = False originating_id = None + assert self._put_req is not None if self._put_req.msgs_to_user is None: return None for msgs_to_user in self._put_req.msgs_to_user: if msgs_to_user.is_reserved_cfdp_message(): reserved_cfdp_msg = msgs_to_user.to_reserved_msg_tlv() - if reserved_cfdp_msg.is_originating_transaction_id(): - contains_originating_id = True - originating_id = reserved_cfdp_msg.get_originating_transaction_id() - if ( - reserved_cfdp_msg.is_cfdp_proxy_operation() - and reserved_cfdp_msg.get_cfdp_proxy_message_type() - == ProxyMessageType.PUT_RESPONSE - ): - contains_proxy_put_response = True + if reserved_cfdp_msg is not None: + if reserved_cfdp_msg.is_originating_transaction_id(): + contains_originating_id = True + originating_id = reserved_cfdp_msg.get_originating_transaction_id() + if ( + reserved_cfdp_msg.is_cfdp_proxy_operation() + and reserved_cfdp_msg.get_cfdp_proxy_message_type() + == ProxyMessageType.PUT_RESPONSE + ): + contains_proxy_put_response = True if not contains_proxy_put_response and contains_originating_id: return originating_id return None @@ -660,6 +663,7 @@ def _prepare_metadata_pdu(self) -> None: def _prepare_metadata_base_params_with_metadata(self) -> MetadataParams: assert self._params.remote_cfg is not None + assert self._put_req is not None return MetadataParams( dest_file_name=self._put_req.dest_file.as_posix(), source_file_name=self._put_req.source_file.as_posix(), diff --git a/tests/test_lost_seg_tracker.py b/tests/test_lost_seg_tracker.py index 8d7d0dd..4d06a1f 100644 --- a/tests/test_lost_seg_tracker.py +++ b/tests/test_lost_seg_tracker.py @@ -8,9 +8,9 @@ def setUp(self) -> None: self.tracker = LostSegmentTracker() def test_basic(self): - self.assertEqual(self.tracker.lost_segments, {}) + self.assertEqual(self.tracker.lost_segments, []) self.tracker.add_lost_segment((0, 500)) - seg_end = self.tracker.lost_segments[0] + seg_end = self.tracker.lost_segments[0][1] self.assertEqual(seg_end, 500) def test_coalesence_0(self): @@ -18,7 +18,7 @@ def test_coalesence_0(self): self.tracker.add_lost_segment((1000, 1500)) self.tracker.coalesce_lost_segments() self.assertEqual(len(self.tracker.lost_segments), 1) - seg_end = self.tracker.lost_segments[500] + seg_end = self.tracker.lost_segments[0][1] self.assertEqual(seg_end, 1500) def test_coalesence_1(self): @@ -27,7 +27,7 @@ def test_coalesence_1(self): self.tracker.add_lost_segment((1500, 1700)) self.tracker.coalesce_lost_segments() self.assertEqual(len(self.tracker.lost_segments), 1) - seg_end = self.tracker.lost_segments[500] + seg_end = self.tracker.lost_segments[0][1] self.assertEqual(seg_end, 1700) def test_coalesence_2(self): @@ -35,42 +35,42 @@ def test_coalesence_2(self): self.tracker.add_lost_segment((1100, 1200)) self.tracker.coalesce_lost_segments() self.assertEqual(len(self.tracker.lost_segments), 2) - self.assertEqual(self.tracker.lost_segments, {500: 1000, 1100: 1200}) + self.assertEqual(self.tracker.lost_segments, [(500, 1000), (1100, 1200)]) def test_removal_0(self): self.tracker.add_lost_segment((0, 500)) self.assertTrue(self.tracker.remove_lost_segment((0, 500))) - self.assertEqual(self.tracker.lost_segments, {}) + self.assertEqual(self.tracker.lost_segments, []) def test_removal_1(self): self.tracker.add_lost_segment((0, 500)) self.assertTrue(self.tracker.remove_lost_segment((0, 200))) - self.assertEqual(self.tracker.lost_segments, {200: 500}) + self.assertEqual(self.tracker.lost_segments, [(200, 500)]) def test_removal_2(self): self.tracker.add_lost_segment((0, 500)) self.assertTrue(self.tracker.remove_lost_segment((300, 500))) - self.assertEqual(self.tracker.lost_segments, {0: 300}) + self.assertEqual(self.tracker.lost_segments, [(0, 300)]) def test_removal_3(self): self.tracker.add_lost_segment((0, 500)) self.assertTrue(self.tracker.remove_lost_segment((300, 400))) - self.assertEqual(self.tracker.lost_segments, {0: 300, 400: 500}) + self.assertEqual(self.tracker.lost_segments, [(0, 300), (400, 500)]) def test_noop_removal_0(self): self.tracker.add_lost_segment((0, 500)) self.assertFalse(self.tracker.remove_lost_segment((500, 1000))) - self.assertEqual(self.tracker.lost_segments, {0: 500}) + self.assertEqual(self.tracker.lost_segments, [(0, 500)]) def test_noop_removal_1(self): self.tracker.add_lost_segment((0, 500)) self.assertFalse(self.tracker.remove_lost_segment((0, 0))) - self.assertEqual(self.tracker.lost_segments, {0: 500}) + self.assertEqual(self.tracker.lost_segments, [(0, 500)]) def test_noop_removal_2(self): self.tracker.add_lost_segment((0, 500)) self.assertFalse(self.tracker.remove_lost_segment((500, 600))) - self.assertEqual(self.tracker.lost_segments, {0: 500}) + self.assertEqual(self.tracker.lost_segments, [(0, 500)]) def test_invalid_removal_0(self): self.tracker.add_lost_segment((0, 500))