Skip to content

Commit be6cc64

Browse files
committed
CST-based indentation discovery
indentation_kit.py: prefer CST-based indentation discovery. Fall back to text-based discovery if unavailable. Passes indentation tests for codeeditor.py
1 parent 2e675f8 commit be6cc64

File tree

8 files changed

+83
-50
lines changed

8 files changed

+83
-50
lines changed

src/cedarscript_editor/cedarscript_editor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,26 @@ def _update_command(self, cmd: UpdateCommand):
156156
region, action, lines, RangeSpec.EMPTY, identifier_finder
157157
)
158158
content = IndentationInfo.shift_indentation(
159-
content_range.read(lines), lines, search_range.indent, relindent_level
159+
content_range.read(lines), lines, search_range.indent, relindent_level,
160+
identifier_finder
160161
)
161162
content = (region, content)
162163
case _:
163164
match action:
164165
case MoveClause(insert_position=region, relative_indentation=relindent_level):
165166
content = IndentationInfo.shift_indentation(
166-
move_src_range.read(lines), lines, search_range.indent, relindent_level
167+
move_src_range.read(lines), lines, search_range.indent, relindent_level,
168+
identifier_finder
167169
)
168170
case DeleteClause():
169171
pass
170172
case _:
171173
raise ValueError(f'Invalid content: {content}')
172174

173-
self._apply_action(action, lines, search_range, content, range_spec_to_delete=move_src_range)
175+
self._apply_action(
176+
action, lines, search_range, content,
177+
range_spec_to_delete=move_src_range, identifier_finder=identifier_finder
178+
)
174179

175180
write_file(file_path, lines)
176181

@@ -179,7 +184,8 @@ def _update_command(self, cmd: UpdateCommand):
179184
@staticmethod
180185
def _apply_action(
181186
action: EditingAction, lines: Sequence[str], range_spec: RangeSpec, content: str | None = None,
182-
range_spec_to_delete: RangeSpec | None = None
187+
range_spec_to_delete: RangeSpec | None = None,
188+
identifier_finder: IdentifierFinder | None = None
183189
):
184190
match action:
185191

@@ -199,7 +205,7 @@ def _apply_action(
199205
case ReplaceClause() | InsertClause():
200206
match content:
201207
case str():
202-
content = IndentationInfo.from_content(lines).apply_relative_indents(
208+
content = IndentationInfo.from_content(lines, identifier_finder).apply_relative_indents(
203209
content, range_spec.indent
204210
)
205211
case Sequence():

src/cedarscript_editor/tree_sitter_identifier_finder.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class TreeSitterIdentifierFinder(IdentifierFinder):
3434
"""
3535

3636
def __init__(self, fname: str, source: str | Sequence[str], parent_restriction: ParentRestriction = None):
37+
super().__init__()
3738
self.parent_restriction = parent_restriction
3839
match source:
3940
case str() as s:
@@ -65,6 +66,7 @@ def __call__(
6566
# Returns IdentifierBoundaries
6667
return self._find_identifier(marker, parent_restriction)
6768

69+
6870
def _find_identifier(self,
6971
marker: Marker,
7072
parent_restriction: ParentRestriction
@@ -84,39 +86,16 @@ def _find_identifier(self,
8486
"""
8587
query_info_key = marker.type
8688
identifier_name = marker.value
87-
match marker.type:
88-
case 'method':
89-
query_info_key = 'function'
9089
try:
9190
all_restrictions: list[ParentRestriction] = [parent_restriction]
9291
# Extract parent name if using dot notation
9392
if '.' in identifier_name:
9493
*parent_parts, identifier_name = identifier_name.split('.')
9594
all_restrictions.append("." + '.'.join(reversed(parent_parts)))
9695

96+
identifier_type = marker.type
9797
# Get all node candidates first
98-
candidate_nodes = (
99-
self.language.query(self.query_info[query_info_key].format(name=identifier_name))
100-
.captures(self.tree.root_node)
101-
)
102-
if not candidate_nodes:
103-
return None
104-
105-
# Convert captures to boundaries and filter by parent
106-
candidates: list[IdentifierBoundaries] = []
107-
for ib in capture2identifier_boundaries(candidate_nodes, self.lines):
108-
# For methods, verify the immediate parent is a class
109-
if marker.type == 'method':
110-
if not ib.parents or not ib.parents[0].parent_type.startswith('class'):
111-
continue
112-
# Check parent restriction (e.g., specific class name)
113-
candidate_matched_all_restrictions = True
114-
for pr in all_restrictions:
115-
if not ib.match_parent(pr):
116-
candidate_matched_all_restrictions = False
117-
break
118-
if candidate_matched_all_restrictions:
119-
candidates.append(ib)
98+
candidates = self.find_identifiers(query_info_key, identifier_name, all_restrictions)
12099
except Exception as e:
121100
raise ValueError(f"Unable to capture nodes for {marker}: {e}") from e
122101

@@ -141,6 +120,34 @@ def _find_identifier(self,
141120
return result.location_to_search_range(relative_position_type)
142121
return result
143122

123+
def find_identifiers(
124+
self, identifier_type: str, name: str, all_restrictions: list[ParentRestriction] = []
125+
) -> list[IdentifierBoundaries]:
126+
if not self.language:
127+
return []
128+
match identifier_type:
129+
case 'method':
130+
identifier_type = 'function'
131+
candidate_nodes = self.language.query(self.query_info[identifier_type].format(name=name)).captures(self.tree.root_node)
132+
if not candidate_nodes:
133+
return []
134+
# Convert captures to boundaries and filter by parent
135+
candidates: list[IdentifierBoundaries] = []
136+
for ib in capture2identifier_boundaries(candidate_nodes, self.lines):
137+
# For methods, verify the immediate parent is a class
138+
if identifier_type == 'method':
139+
if not ib.parents or not ib.parents[0].parent_type.startswith('class'):
140+
continue
141+
# Check parent restriction (e.g., specific class name)
142+
candidate_matched_all_restrictions = True
143+
for pr in all_restrictions:
144+
if not ib.match_parent(pr):
145+
candidate_matched_all_restrictions = False
146+
break
147+
if candidate_matched_all_restrictions:
148+
candidates.append(ib)
149+
return candidates
150+
144151

145152
def _get_by_offset(obj: Sequence, offset: int):
146153
if 0 <= offset < len(obj):

src/text_manipulation/cst_kit.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import runtime_checkable, Protocol, Sequence
2-
2+
from functools import cached_property
33
from cedarscript_ast_parser import Marker, Segment, RelativeMarker, RelativePositionType, MarkerType, BodyOrWhole
44

55
from .range_spec import IdentifierBoundaries, RangeSpec, ParentRestriction
@@ -15,3 +15,12 @@ def __call__(
1515
) -> IdentifierBoundaries | RangeSpec | None:
1616
"""Find identifier boundaries for a given marker or segment."""
1717
pass
18+
19+
def find_identifiers(
20+
self, identifier_type: str, name: str, all_restrictions: list[ParentRestriction] | None = None
21+
) -> list[IdentifierBoundaries]:
22+
pass
23+
24+
@cached_property
25+
def find_all_callables(self) -> list[IdentifierBoundaries]:
26+
return self.find_identifiers('function', r'.*')

src/text_manipulation/indentation_kit.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def shift_indentation(cls,
104104
[' def example():', ' print('Hello')']
105105
:param target_lines:
106106
"""
107-
context_indent_char_count = cls.from_content(target_lines).char_count
107+
context_indent_char_count = cls.from_content(target_lines, identifier_finder).char_count
108108
return (cls.
109-
from_content(content).
109+
from_content(content, identifier_finder).
110110
_replace(char_count=context_indent_char_count).
111111
_shift_indentation(
112112
content, target_reference_indentation_count, relindent_level
@@ -146,19 +146,31 @@ def from_content(
146146
character count by analyzing patterns and using GCD.
147147
"""
148148
# TODO Always send str?
149-
lines = [x for x in content.splitlines() if x.strip()] if isinstance(content, str) else content
150-
151-
indentations = [extract_indentation(line) for line in lines if line.strip()]
152-
has_zero_indent = any((i == '' for i in indentations))
153-
indentations = [indent for indent in indentations if indent]
154-
155-
if not indentations:
156-
return cls(4, ' ', 0, True, "No indentation found. Assuming 4 spaces (PEP 8).")
157-
158-
indent_chars = Counter(indent[0] for indent in indentations)
159-
dominant_char = ' ' if indent_chars.get(' ', 0) >= indent_chars.get('\t', 0) else '\t'
160-
161-
indent_lengths = [len(indent) for indent in indentations]
149+
indent_lengths = []
150+
if identifier_finder:
151+
indent_lengths = []
152+
for ib in identifier_finder.find_all_callables:
153+
if ib.whole and ib.whole.indent:
154+
indent_lengths.append(ib.whole.indent)
155+
if ib.body and ib.body.indent:
156+
indent_lengths.append(ib.body.indent)
157+
has_zero_indent = any((i == 0 for i in indent_lengths))
158+
159+
if not (indent_lengths):
160+
lines = [x for x in content.splitlines() if x.strip()] if isinstance(content, str) else content
161+
indentations = [extract_indentation(line) for line in lines if line.strip()]
162+
has_zero_indent = any((i == '' for i in indentations))
163+
indentations = [indent for indent in indentations if indent]
164+
165+
if not indentations:
166+
return cls(4, ' ', 0, True, "No indentation found. Assuming 4 spaces (PEP 8).")
167+
168+
indent_chars = Counter(indent[0] for indent in indentations)
169+
dominant_char = ' ' if indent_chars.get(' ', 0) >= indent_chars.get('\t', 0) else '\t'
170+
171+
indent_lengths = [len(indent) for indent in indentations]
172+
else:
173+
dominant_char = ' '
162174

163175
char_count = 1
164176
if dominant_char != '\t':
@@ -167,7 +179,7 @@ def from_content(
167179
min_indent_chars = 0 if has_zero_indent else min(indent_lengths) if indent_lengths else 0
168180
min_indent_level = min_indent_chars // char_count
169181

170-
consistency = all(len(indent) % char_count == 0 for indent in indentations if indent)
182+
consistency = all(indent_len % char_count == 0 for indent_len in indent_lengths if indent_len)
171183
match dominant_char:
172184
case ' ':
173185
domcharstr = 'space'

src/text_manipulation/line_kit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ def extract_indentation(line: str) -> str:
4444
>>> extract_indentation("No indentation")
4545
''
4646
"""
47-
return line[:len(line) - len(line.lstrip())]
47+
return line[:get_line_indent_count(line)]

tests/corpus/refactor-benchmark.indentation-size-discovery/chat.xml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ FROM FILE "codeeditor.py"
88
MOVE METHOD "__get_brackets"
99
INSERT BEFORE CLASS "CodeEditor"
1010
RELATIVE INDENTATION 0;
11-
```
1211

1312
-- 1. Move the method to become a top-level function.
1413
UPDATE CLASS "AutosaveForPlugin"
@@ -23,5 +22,5 @@ FROM FILE "base.py"
2322
MOVE METHOD "adapt_method_mode"
2423
INSERT BEFORE CLASS "BaseHandler"
2524
RELATIVE INDENTATION 0;
26-
25+
```
2726
</no-train>

0 commit comments

Comments
 (0)