diff --git a/trieregex.py b/trieregex.py index 280515f..90b606f 100644 --- a/trieregex.py +++ b/trieregex.py @@ -5,6 +5,7 @@ class Memoizer: + __slots__ = ['func', 'cache'] def __init__(self, func): self.func = func @@ -26,6 +27,7 @@ def _clear_cache(self): class TrieRegEx(): + __slots__ = ['_trie', '_initials', '_finals'] def __init__(self, *words: str) -> None: self._trie = {} # type: Dict[str: dict] @@ -33,27 +35,26 @@ def __init__(self, *words: str) -> None: self._finals = defaultdict(int) self.add(*words) + def _adjust_initials_finals(self, word, increase=True): + if increase: + self._initials[word[0]] += 1 + self._finals[word[-1]] += 1 + else: + self._initials[word[0]] -= 1 + self._finals[word[-1]] -= 1 + @Memoizer def add(self, *words: str) -> None: self.regex.clear_cache() # better performance to clear just once for word in words: if word != '' and not self.has(word): - self.adjust_initials_finals(word) + self._adjust_initials_finals(word) trie = self._trie for char in word: if char not in trie: trie[char] = {} trie = trie[char] trie['**'] = {} - # self.add.reset() - - def adjust_initials_finals(self, word, increase=True): - if increase: - self._initials[word[0]] += 1 - self._finals[word[-1]] += 1 - else: - self._initials[word[0]] -= 1 - self._finals[word[-1]] -= 1 def remove(self, *words: str) -> None: self.add.clear_cache() # better performance to clear just once @@ -64,7 +65,7 @@ def remove(self, *words: str) -> None: is_end = i == len(word) if is_end and self.has(word[:i]): remove_word = True - self.adjust_initials_finals(word, increase=False) + self._adjust_initials_finals(word, increase=False) if remove_word: node = self._trie for j in range(i-1): @@ -113,22 +114,22 @@ def regex(self, trie: dict = None, reset: bool = True) -> str: key = list(trie.keys())[0] if key == '**': return '' - return escape(key) + self.regex(trie[key], False) + return f'{escape(key)}{self.regex(trie[key], False)}' else: - sequences = [escape(key) + self.regex(trie[key], False) + sequences = [f'{escape(key)}{self.regex(trie[key], False)}' for key in trie if key != '**'] sequences.sort(key=lambda x: (-len(x), x)) # for easier inspection if len(sequences) == 1: result = sequences[0] if len(sequences[0]) > 1: - result = '(?:{})'.format(result) - elif len(sequences) == len(''.join(sequences)): - result = '[{}]'.format(''.join(sequences)) + result = f'(?:{result})' + elif len(sequences) == len("".join(sequences)): + result = f'[{"".join(sequences)}]' else: - result = '(?:{})'.format('|'.join(sequences)) + result = f'(?:{"|".join(sequences)})' if '**' in trie: - result = result + '?' + return f'{result}?' return result