diff --git a/tests/test_regex.py b/tests/test_regex.py index c9a91d3..d0703c0 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -1,5 +1,4 @@ import re -from typing import List import unittest from trieregex import TrieRegEx as TRE @@ -21,30 +20,21 @@ def setUp(self): 'scallion', 'ginger', 'garlic', 'onion', 'galangal' ] - def findall(self, string: str, boundary: str) -> List[str]: - """Helper function. The TrieRegEx.regex() function is called here and - the result of regex matching is returned - """ - pattern = re.compile(f'{boundary}{self.tre.regex()}{boundary}') - return sorted(pattern.findall(string)) - def test_match_all_incrementals(self): self.tre.add(*self.words) - found = self.findall(' '.join(self.words), '\\b') + found = re.findall(f'\\b{self.tre.regex()}\\b', ' '.join(self.words)) - self.assertEqual(found, sorted(self.words)) + self.assertEqual(sorted(found), sorted(self.words)) def test_does_not_match_larger_string(self): self.tre.add('p') - found = self.findall('pe', '\\b') - + found = re.findall(f'\\b{self.tre.regex()}\\b', 'pe') self.assertEqual(found, []) def test_does_not_match_substring(self): my_words = self.words[1:] # leave out 'p' self.tre.add(*my_words) - found = self.findall(' '.join(self.words), '\\b') - + found = re.findall(f'\\b{self.tre.regex()}\\b', ' '.join(self.words)) self.assertEqual( found, sorted(my_words), @@ -56,13 +46,14 @@ def test_empty_trie_returns_empty_string_regex(self): def test_match_all_words(self): self.tre.add(*self.more_words) - found = self.findall(' '.join(sorted(self.more_words)), '\\b') - self.assertEqual(found, sorted(self.more_words)) + pattern = f'\\b{self.tre.regex()}\\b' + found = re.findall(pattern, ' '.join(self.more_words)) + self.assertEqual(sorted(found), sorted(self.more_words)) def test_match_all_words_surrounded_by_spaces(self): words = sorted(self.more_words) self.tre.add(*words) - found = re.findall(f"(?<= ){'|'.join(words)}(?= )", ' '.join(words)) + found = re.findall(f"(?<= ){self.tre.regex()}(?= )", ' '.join(words)) self.assertEqual( found, words[1:-1], diff --git a/trieregex/trieregex.py b/trieregex/trieregex.py index c098062..c88010e 100644 --- a/trieregex/trieregex.py +++ b/trieregex/trieregex.py @@ -24,7 +24,7 @@ def _adjust_initials_finals(self, word, increase=True): @Memoizer def add(self, *words: str) -> None: - self.regex.clear_cache() # better performance to clear just once + self.regex.clear_cache() for word in words: if word != '' and not self.has(word): self._adjust_initials_finals(word) @@ -36,7 +36,7 @@ def add(self, *words: str) -> None: trie['**'] = {} def remove(self, *words: str) -> None: - self.add.clear_cache() # better performance to clear just once + self.add.clear_cache() self.regex.clear_cache() for word in words: remove_word = False @@ -98,7 +98,7 @@ def regex(self, trie: dict = None, reset: bool = True) -> str: else: sequences = [f'{escape(key)}{self.regex(trie[key], False)}' for key in trie if key != '**'] - sequences.sort(key=lambda x: (-len(x), x)) # for easier inspection + sequences.sort(key=lambda x: (-len(x), x)) if len(sequences) == 1: result = sequences[0]