Skip to content

Commit

Permalink
Use f-strings, rename/move function
Browse files Browse the repository at this point in the history
  • Loading branch information
ermanh committed Jun 21, 2020
1 parent 44d5f67 commit b219226
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions trieregex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


class Memoizer:
__slots__ = ['func', 'cache']

def __init__(self, func):
self.func = func
Expand All @@ -26,34 +27,34 @@ def _clear_cache(self):


class TrieRegEx():
__slots__ = ['_trie', '_initials', '_finals']

def __init__(self, *words: str) -> None:
self._trie = {} # type: Dict[str: dict]
self._initials = defaultdict(int)
self._finals = defaultdict(int)
self.add(*words)

def _adjust_initials_finals(self, word, increase=True):
if increase:
self._initials[word[0]] += 1
self._finals[word[-1]] += 1
else:
self._initials[word[0]] -= 1
self._finals[word[-1]] -= 1

@Memoizer
def add(self, *words: str) -> None:
self.regex.clear_cache() # better performance to clear just once
for word in words:
if word != '' and not self.has(word):
self.adjust_initials_finals(word)
self._adjust_initials_finals(word)
trie = self._trie
for char in word:
if char not in trie:
trie[char] = {}
trie = trie[char]
trie['**'] = {}
# self.add.reset()

def adjust_initials_finals(self, word, increase=True):
if increase:
self._initials[word[0]] += 1
self._finals[word[-1]] += 1
else:
self._initials[word[0]] -= 1
self._finals[word[-1]] -= 1

def remove(self, *words: str) -> None:
self.add.clear_cache() # better performance to clear just once
Expand All @@ -64,7 +65,7 @@ def remove(self, *words: str) -> None:
is_end = i == len(word)
if is_end and self.has(word[:i]):
remove_word = True
self.adjust_initials_finals(word, increase=False)
self._adjust_initials_finals(word, increase=False)
if remove_word:
node = self._trie
for j in range(i-1):
Expand Down Expand Up @@ -113,22 +114,22 @@ def regex(self, trie: dict = None, reset: bool = True) -> str:
key = list(trie.keys())[0]
if key == '**':
return ''
return escape(key) + self.regex(trie[key], False)
return f'{escape(key)}{self.regex(trie[key], False)}'

else:
sequences = [escape(key) + self.regex(trie[key], False)
sequences = [f'{escape(key)}{self.regex(trie[key], False)}'
for key in trie if key != '**']
sequences.sort(key=lambda x: (-len(x), x)) # for easier inspection

if len(sequences) == 1:
result = sequences[0]
if len(sequences[0]) > 1:
result = '(?:{})'.format(result)
elif len(sequences) == len(''.join(sequences)):
result = '[{}]'.format(''.join(sequences))
result = f'(?:{result})'
elif len(sequences) == len("".join(sequences)):
result = f'[{"".join(sequences)}]'
else:
result = '(?:{})'.format('|'.join(sequences))
result = f'(?:{"|".join(sequences)})'

if '**' in trie:
result = result + '?'
return f'{result}?'
return result

0 comments on commit b219226

Please sign in to comment.