diff --git a/trie/test_trie.py b/trie/test_trie.py index 71b934e..03856fa 100644 --- a/trie/test_trie.py +++ b/trie/test_trie.py @@ -158,9 +158,23 @@ def test_prefix_matching(self): words = ["prefix", "preface", "prepare", "prevent"] for word in words: self.trie.insert(word) - self.assertTrue(all(self.trie.search(word[:i]) for word in words for i in range(3, len(word) + 1))) + + # Test exact word matches + for word in words: + self.assertTrue(self.trie.search(word)) + + # Test prefix matches + for word in words: + for i in range(3, len(word)): + self.assertTrue(self.trie.search(word[:i], is_prefix=True)) + + # Test that "pre" is found as a prefix but not as a complete word + self.assertTrue(self.trie.search("pre", is_prefix=True)) self.assertFalse(self.trie.search("pre")) + # Test non-existent prefix + self.assertFalse(self.trie.search("pra", is_prefix=True)) + def test_edge_cases(self): self.trie.insert("a") self.assertTrue(self.trie.search("a")) diff --git a/trie/trie.py b/trie/trie.py index 552734d..05912f7 100644 --- a/trie/trie.py +++ b/trie/trie.py @@ -19,10 +19,10 @@ def insert(self, word): node = node.children[char] node.is_end_of_word = True - def search(self, word): + def search(self, word, is_prefix=False): """ - Search for a word in the trie. - Returns True if the word is found, False otherwise. + Search for a word or prefix in the trie. + Returns True if the word/prefix is found, False otherwise. Time complexity: O(m), where m is the length of the word. """ node = self.root @@ -30,7 +30,7 @@ def search(self, word): if char not in node.children: return False node = node.children[char] - return node.is_end_of_word + return is_prefix or node.is_end_of_word def delete(self, word): """