|
1 |
| -from enum import StrEnum, auto |
2 |
| -from typing import TypeAlias, NamedTuple, Union |
3 | 1 | from collections.abc import Sequence
|
| 2 | +from dataclasses import dataclass |
| 3 | +from enum import StrEnum, auto |
| 4 | +from typing import TypeAlias, NamedTuple, Union, Optional |
4 | 5 |
|
5 |
| -from tree_sitter import Parser |
6 | 6 | import cedarscript_grammar
|
7 |
| -from dataclasses import dataclass |
| 7 | +from tree_sitter import Parser |
8 | 8 |
|
9 | 9 |
|
10 | 10 | class ParseError(NamedTuple):
|
@@ -72,7 +72,7 @@ def __str__(self):
|
72 | 72 | case 'string' | None:
|
73 | 73 | pass
|
74 | 74 | case _:
|
75 |
| - result += self.marker_subtype.value |
| 75 | + result += self.marker_subtype |
76 | 76 |
|
77 | 77 | result += f" '{self.value.strip()}'"
|
78 | 78 | if self.offset is not None:
|
@@ -138,7 +138,8 @@ def as_marker(self) -> Marker:
|
138 | 138 |
|
139 | 139 | def __str__(self):
|
140 | 140 | wc = self.where_clause
|
141 |
| - if wc: wc = f' ({wc})' |
| 141 | + if wc: |
| 142 | + wc = f' ({wc})' |
142 | 143 | result = f"{str(self.identifier_type).lower()} {self.name}{wc}"
|
143 | 144 | if self.offset is not None:
|
144 | 145 | result += f" at offset {self.offset}"
|
@@ -232,11 +233,51 @@ def files_to_change(self) -> tuple[str, ...]:
|
232 | 233 | # </file-command>
|
233 | 234 |
|
234 | 235 |
|
| 236 | +@dataclass |
| 237 | +class LoopControl(StrEnum): |
| 238 | + BREAK = 'BREAK' |
| 239 | + CONTINUE = 'CONTINUE' |
| 240 | + |
| 241 | + |
| 242 | +@dataclass |
| 243 | +class CaseWhen: |
| 244 | + """Represents a WHEN condition in a CASE statement""" |
| 245 | + empty: bool = False |
| 246 | + regex: Optional[str] = None |
| 247 | + prefix: Optional[str] = None |
| 248 | + suffix: Optional[str] = None |
| 249 | + indent_level: Optional[int] = None |
| 250 | + line_number: Optional[int] = None |
| 251 | + |
| 252 | + |
| 253 | +@dataclass |
| 254 | +class CaseAction: |
| 255 | + """Represents a THEN action in a CASE statement""" |
| 256 | + loop_control: Optional[LoopControl] = None |
| 257 | + remove: bool = False |
| 258 | + replace: Optional[str] = None |
| 259 | + indent: Optional[int] = None |
| 260 | + content: Optional[str | tuple[Region, int | None]] = None |
| 261 | + |
| 262 | + |
| 263 | +@dataclass |
| 264 | +class CaseStatement: |
| 265 | + """Represents a CASE statement with when-then pairs and optional else""" |
| 266 | + cases: list[tuple[CaseWhen, CaseAction]] |
| 267 | + else_action: Optional[CaseAction] = None |
| 268 | + |
| 269 | + |
| 270 | +@dataclass |
| 271 | +class EdScript: |
| 272 | + """Represents an ED script content""" |
| 273 | + script: str |
| 274 | + |
| 275 | + |
235 | 276 | @dataclass
|
236 | 277 | class UpdateCommand(Command):
|
237 | 278 | target: FileOrIdentifierWithin
|
238 | 279 | action: EditingAction
|
239 |
| - content: str | tuple[Region, int | None] | None = None |
| 280 | + content: str | tuple[Region, int | None] | EdScript | CaseStatement | None = None |
240 | 281 |
|
241 | 282 | @property
|
242 | 283 | def files_to_change(self) -> tuple[str, ...]:
|
@@ -601,17 +642,99 @@ def parse_relative_indentation(self, node) -> int | None:
|
601 | 642 | return int(self.find_first_by_type(node.named_children, 'number').text)
|
602 | 643 |
|
603 | 644 | def parse_content(self, node) -> str | tuple[Region, int | None] | None:
|
604 |
| - content = self.find_first_by_type(node.named_children, ['content_literal', 'content_from_segment']) |
| 645 | + content = self.find_first_by_type(node.named_children, [ |
| 646 | + 'content_literal', 'content_from_segment', 'ed_stmt', 'case_stmt' |
| 647 | + ]) |
605 | 648 | if not content:
|
606 | 649 | return None
|
607 | 650 | match content.type:
|
608 | 651 | case 'content_literal':
|
609 | 652 | return self.parse_content_literal(content) # str
|
610 | 653 | case 'content_from_segment':
|
611 | 654 | return self.parse_content_from_segment_clause(content) # tuple[Region, int]
|
| 655 | + case 'ed_stmt': |
| 656 | + return self.parse_ed_stmt(content) # EdScript |
| 657 | + case 'case_stmt': |
| 658 | + return self.parse_case_stmt(content) # CaseStatement |
612 | 659 | case _:
|
613 | 660 | raise ValueError(f"Invalid content type: {content.type}")
|
614 | 661 |
|
| 662 | + def parse_case_stmt(self, node) -> CaseStatement: |
| 663 | + """Parse a CASE statement""" |
| 664 | + cases = [] |
| 665 | + |
| 666 | + # Parse all WHEN-THEN pairs |
| 667 | + current_when = None |
| 668 | + for child in node.children: |
| 669 | + match child.type: |
| 670 | + case 'case_when': |
| 671 | + current_when = self.parse_case_when(child) |
| 672 | + case 'case_action' if current_when is not None: |
| 673 | + action = self.parse_case_action(child) |
| 674 | + cases.append((current_when, action)) |
| 675 | + current_when = None |
| 676 | + |
| 677 | + # Parse optional ELSE clause |
| 678 | + else_action = None |
| 679 | + else_node = self.find_first_by_field_name(node, 'else') |
| 680 | + if else_node: |
| 681 | + else_action = self.parse_case_action(else_node) |
| 682 | + |
| 683 | + return CaseStatement(cases=cases, else_action=else_action) |
| 684 | + |
| 685 | + def parse_case_when(self, node) -> CaseWhen: |
| 686 | + """Parse a WHEN clause in a CASE statement""" |
| 687 | + when = CaseWhen() |
| 688 | + |
| 689 | + if self.find_first_by_field_name(node, 'empty'): |
| 690 | + when.empty = True |
| 691 | + elif regex := self.find_first_by_field_name(node, 'regex'): |
| 692 | + when.regex = self.parse_string(regex) |
| 693 | + elif prefix := self.find_first_by_field_name(node, 'prefix'): |
| 694 | + when.prefix = self.parse_string(prefix) |
| 695 | + elif suffix := self.find_first_by_field_name(node, 'suffix'): |
| 696 | + when.suffix = self.parse_string(suffix) |
| 697 | + elif indent := self.find_first_by_field_name(node, 'indent_level'): |
| 698 | + when.indent_level = int(indent.text) |
| 699 | + elif line_num := self.find_first_by_field_name(node, 'line_number'): |
| 700 | + when.line_number = int(line_num.text) |
| 701 | + |
| 702 | + return when |
| 703 | + |
| 704 | + def parse_case_action(self, node) -> CaseAction: |
| 705 | + """Parse a THEN clause in a CASE statement""" |
| 706 | + action = CaseAction() |
| 707 | + |
| 708 | + # Parse loop control if present |
| 709 | + loop_control = self.find_first_by_type(node.children, 'loop_control') |
| 710 | + if loop_control: |
| 711 | + if self.find_first_by_type(loop_control.children, 'loop_break'): |
| 712 | + action.loop_control = LoopControl.BREAK |
| 713 | + elif self.find_first_by_type(loop_control.children, 'loop_continue'): |
| 714 | + action.loop_control = LoopControl.CONTINUE |
| 715 | + |
| 716 | + # Parse other action types |
| 717 | + if self.find_first_by_field_name(node, 'remove'): |
| 718 | + action.remove = True |
| 719 | + elif replace := self.find_first_by_field_name(node, 'replace'): |
| 720 | + action.replace = self.parse_string(replace) |
| 721 | + elif indent := self.find_first_by_field_name(node, 'indent'): |
| 722 | + action.indent = int(indent.text) |
| 723 | + else: |
| 724 | + # Check for content replacement |
| 725 | + content = self.find_first_by_type(node.children, ['content_literal', 'content_from_segment']) |
| 726 | + if content: |
| 727 | + action.content = self.parse_content(content) |
| 728 | + |
| 729 | + return action |
| 730 | + |
| 731 | + def parse_ed_stmt(self, node) -> EdScript: |
| 732 | + """Parse an ED script statement""" |
| 733 | + ed_script = self.find_first_by_type(node.children, 'string') |
| 734 | + if ed_script is None: |
| 735 | + raise ValueError("No ED script found in ed_stmt") |
| 736 | + return EdScript(script=self.parse_string(ed_script)) |
| 737 | + |
615 | 738 | def parse_singlefile_clause(self, node):
|
616 | 739 | if node is None or node.type != 'singlefile_clause':
|
617 | 740 | raise ValueError("Expected singlefile_clause node")
|
|
0 commit comments