diff --git a/ac_test.go b/ac_test.go index 57f2b8e..2c7dfe2 100644 --- a/ac_test.go +++ b/ac_test.go @@ -204,6 +204,14 @@ func TestACBlices(t *testing.T) { } } +func TestNonASCIIDictionary(t *testing.T) { + dict := []string{"hello world", "こんにちは世界"} + _, err := CompileString(dict) + if err != nil { + t.Errorf("error compiling matcher: %s", err) + } +} + var ( source1 = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/30.0.1599.101 Safari/537.36" source1b = []byte(source1) diff --git a/acascii/ac.go b/acascii/ac.go index b3e1979..006db35 100644 --- a/acascii/ac.go +++ b/acascii/ac.go @@ -11,10 +11,14 @@ package acascii import ( "container/list" + "errors" ) const maxchar = 128 +// ErrNotASCII is returned when the dictionary input is not ASCII +var ErrNotASCII = errors.New("non-ASCII input") + // A node in the trie structure used to implement Aho-Corasick type node struct { root bool // true if this is the root @@ -88,7 +92,7 @@ func (m *Matcher) getFreeNode() *node { // buildTrie builds the fundamental trie structure from a set of // blices. -func (m *Matcher) buildTrie(dictionary [][]byte) { +func (m *Matcher) buildTrie(dictionary [][]byte) error { // Work out the maximum size for the trie (all dictionary entries // are distinct plus the root). This is used to preallocate memory @@ -112,6 +116,9 @@ func (m *Matcher) buildTrie(dictionary [][]byte) { n := m.root for i, b := range blice { idx := int(b) + if idx >= maxchar { + return ErrNotASCII + } c := n.child[idx] if c == nil { @@ -185,10 +192,11 @@ func (m *Matcher) buildTrie(dictionary [][]byte) { } m.trie = m.trie[:m.extent] + return nil } // buildTrieString builds the fundamental trie structure from a []string -func (m *Matcher) buildTrieString(dictionary []string) { +func (m *Matcher) buildTrieString(dictionary []string) error { // Work out the maximum size for the trie (all dictionary entries // are distinct plus the root). This is used to preallocate memory @@ -212,6 +220,10 @@ func (m *Matcher) buildTrieString(dictionary []string) { for _, blice := range dictionary { n := m.root for i := 0; i < len(blice); i++ { + index := int(blice[i]) + if index >= maxchar { + return ErrNotASCII + } b := int(blice[i]) c := n.child[b] if c == nil { @@ -285,13 +297,15 @@ func (m *Matcher) buildTrieString(dictionary []string) { } m.trie = m.trie[:m.extent] + return nil } // Compile creates a new Matcher using a list of []byte func Compile(dictionary [][]byte) (*Matcher, error) { m := new(Matcher) - m.buildTrie(dictionary) - // no error for now + if err := m.buildTrie(dictionary); err != nil { + return nil, err + } return m, nil } @@ -308,7 +322,9 @@ func MustCompile(dictionary [][]byte) *Matcher { // of strings (this is a helper to make initialization easy) func CompileString(dictionary []string) (*Matcher, error) { m := new(Matcher) - m.buildTrieString(dictionary) + if err := m.buildTrieString(dictionary); err != nil { + return nil, err + } return m, nil } @@ -415,7 +431,7 @@ func (m *Matcher) Match(in []byte) bool { for _, b := range in { c := int(b) if c > maxchar { - c = 0 + c = 0 } if !n.root && n.child[c] == nil { n = n.fails[c] @@ -443,7 +459,7 @@ func (m *Matcher) MatchString(in string) bool { for idx := 0; idx < slen; idx++ { c := int(in[idx]) if c >= maxchar { - c = 0 + c = 0 } if !n.root && n.child[c] == nil { n = n.fails[c] diff --git a/acascii/ac_test.go b/acascii/ac_test.go index 8bebf44..904f5c3 100644 --- a/acascii/ac_test.go +++ b/acascii/ac_test.go @@ -204,6 +204,14 @@ func TestACBlices(t *testing.T) { } } +func TestNonASCIIDictionary(t *testing.T) { + dict := []string{"hello world", "こんにちは世界"} + _, err := CompileString(dict) + if err == nil { + t.Errorf("expected error compiling ASCII matcher") + } +} + var ( source1 = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/30.0.1599.101 Safari/537.36" source1b = []byte(source1)