diff --git a/src/main/java/automata/DfaPattern.java b/src/main/java/automata/DfaPattern.java index fd509db..aadc414 100644 --- a/src/main/java/automata/DfaPattern.java +++ b/src/main/java/automata/DfaPattern.java @@ -1,6 +1,7 @@ package automata; import automata.codegen.CompiledDfa; +import automata.graph.StandardCodeUnits; import automata.graph.MatchMode; import automata.graph.Tdfa; import automata.graph.Tnfa; @@ -121,8 +122,8 @@ public CompiledDfaPattern( super(pattern); // NFAs - final Tnfa nfaWithoutWildcard = Tnfa.parse(pattern, flags, true, false); - final Tnfa nfaWithWildcard = Tnfa.parse(pattern, flags, true, true); + final Tnfa nfaWithoutWildcard = Tnfa.parse(pattern, StandardCodeUnits.UTF_16, flags, true, false); + final Tnfa nfaWithWildcard = Tnfa.parse(pattern, StandardCodeUnits.UTF_16, flags, true, true); // DFAs final Tdfa matchesDfa = Tdfa.fromTnfa(nfaWithoutWildcard, MatchMode.FULL, optimized); @@ -235,8 +236,8 @@ public InterpretableDfaPattern( super(pattern); this.printDebugInfo = printDebugInfo; - final Tnfa nfaWithoutWildcard = Tnfa.parse(pattern, flags, true, false); - final Tnfa nfaWithWildcard = Tnfa.parse(pattern, flags, true, true); + final Tnfa nfaWithoutWildcard = Tnfa.parse(pattern, StandardCodeUnits.UTF_16, flags, true, false); + final Tnfa nfaWithWildcard = Tnfa.parse(pattern, StandardCodeUnits.UTF_16, flags, true, true); this.matchesDfa = Tdfa.fromTnfa(nfaWithoutWildcard, MatchMode.FULL, optimized); this.lookingAtDfa = Tdfa.fromTnfa(nfaWithoutWildcard, MatchMode.PREFIX, optimized); diff --git a/src/main/java/automata/graph/RegexNfaBuilder.java b/src/main/java/automata/graph/RegexNfaBuilder.java index e390045..5d4331b 100644 --- a/src/main/java/automata/graph/RegexNfaBuilder.java +++ b/src/main/java/automata/graph/RegexNfaBuilder.java @@ -6,13 +6,10 @@ import automata.parser.CodePointSetVisitor; import automata.util.IntRange; import automata.util.IntRangeSet; +import java.util.Stack; import java.util.Optional; import java.util.OptionalInt; -import java.util.Map; -import java.util.LinkedList; -import java.util.HashMap; import java.util.function.UnaryOperator; -import java.util.stream.Collectors; /** * Regex AST visitor which can be used to build up the corresponding NFA. @@ -30,6 +27,15 @@ public abstract class RegexNfaBuilder extends CodePointSetVisitor implements RegexVisitor>, IntRangeSet> { + /** + * Code units to be used in the NFA. + */ + public final StandardCodeUnits codeUnits; + + public RegexNfaBuilder(StandardCodeUnits codeUnits) { + this.codeUnits = codeUnits; + } + /** * Summon a fresh state identifier. * @@ -87,107 +93,56 @@ public UnaryOperator> visitEpsilon() { return UnaryOperator.identity(); } - /** - * Break down the input code point set into a mapping of low ranges to high - * ranges. - * - * The values in the output should be disjoint and union out to a subset of - * the high surrogate range. The values in the keys won't necessarily be - * disjoint, but they should all be in the low surrogate range. Since the - * high surrogate range is just {@code 0xD800–0xDBFF} (1024 values), the - * total size of the output map is at most 1024 entries. - * - * @param codePointSet input code point set - * @return mapping from low surrogate ranges to high surrogate ranges - */ - public static Map supplementaryCodeUnitRanges( - IntRangeSet codePointSet - ) { - - // TODO: this doesn't need to be a map since we scan high codepoints in order - // and only ever update the last one - final var supplementaryCodeUnits = new HashMap>(); - - for (IntRange range : codePointSet.difference(IntRangeSet.of(CodePoints.BMP_RANGE)).ranges()) { - - int rangeStartHi = Character.highSurrogate(range.lowerBound()); - int rangeStartLo = Character.lowSurrogate(range.lowerBound()); - - int rangeEndHi = Character.highSurrogate(range.upperBound()); - int rangeEndLo = Character.lowSurrogate(range.upperBound()); - - if (rangeStartHi == rangeEndHi) { - // Add the _only_ range - supplementaryCodeUnits - .computeIfAbsent(rangeStartHi, k -> new LinkedList<>()) - .addLast(IntRange.between(rangeStartLo, rangeEndLo)); - } else { - // Add the first range - supplementaryCodeUnits - .computeIfAbsent(rangeStartHi, k -> new LinkedList<>()) - .addLast(IntRange.between(rangeStartLo, Character.MAX_LOW_SURROGATE)); - - // Add the last range - supplementaryCodeUnits - .computeIfAbsent(rangeEndHi, k -> new LinkedList<>()) - .addLast(IntRange.between(Character.MIN_LOW_SURROGATE, rangeEndLo)); - - // Add everything in between - for (int hi = rangeStartHi + 1; hi <= rangeEndHi - 1; hi++) { - supplementaryCodeUnits - .computeIfAbsent(hi, k -> new LinkedList<>()) - .addLast(CodePoints.LOW_SURROGATE_RANGE); - } - } - } - - return supplementaryCodeUnits - .entrySet() - .stream() - .collect( - Collectors.groupingBy( - e -> new IntRangeSet(e.getValue()), - Collectors.mapping( - e -> IntRangeSet.of(IntRange.single(e.getKey())), - Collectors.collectingAndThen(Collectors.toList(), IntRangeSet::union) - ) - ) - ); - } - public UnaryOperator> visitCharacterClass(IntRangeSet codePointSet) { if (!codePointSet.difference(IntRangeSet.of(CodePoints.UNICODE_RANGE)).isEmpty()) { throw new IllegalArgumentException("Codepoints outside the unicode range aren't allowed"); } - /* Code unit transitions corresponding to the basic multilingual plane. - * By definition of the BMP, this means these are exactly one code unit. - */ - final var basicCodeUnits = codePointSet.intersection(IntRangeSet.of(CodePoints.BMP_RANGE)); - - /* Mapping from the first (high) 16-bit code unit to the range of second - * (low) 16-bit code units. There are `0xDBFF - 0xD800 + 1 = 1024` high - * code points, so this map will have between 0 and 1024 entries. - */ - final var supplementaryCodeUnits = supplementaryCodeUnitRanges(codePointSet); + final var suffixTrie = codeUnits.codeUnitRangeSuffixTrie(codePointSet); // How many code units wide is the class? - final Optional classSize = - supplementaryCodeUnits.isEmpty() ? Optional.of(1) : - basicCodeUnits.isEmpty() ? Optional.of(2) : - Optional.empty(); + final Optional classSize = suffixTrie.inSetDepth(); + + // Depth first traversal of the tree + record TraversalEntry( + TrieSet subTrie, + int distanceToRoot, + IntRangeSet codeUnitToParent, + Q parentNode + ) { } return (NfaState toState) -> { - final Q to = toState.state(); - final Q start = freshState(); - if (!basicCodeUnits.isEmpty()) { - addCodeUnitsState(start, basicCodeUnits, to); + final var start = freshState(); + + // Depth first traversal of the suffix trie + final var toVisit = new Stack>(); + for (final var childEntry : suffixTrie.children.entrySet()) { + toVisit.push(new TraversalEntry<>( + childEntry.getValue(), + 1, + childEntry.getKey(), + toState.state() + )); } + while (!toVisit.isEmpty()) { + final var entry = toVisit.pop(); + final Q thisNode = entry.subTrie.inSet ? start : freshState(); + + // Visit this node + addCodeUnitsState(thisNode, entry.codeUnitToParent, entry.parentNode); + if (entry.subTrie.inSet && !entry.subTrie.children.isEmpty()) { + throw new IllegalStateException("One code unit sequence cannot be a suffix of another"); + } - for (var loAndHigh : supplementaryCodeUnits.entrySet()) { - final Q hiEnd = freshState(); - addCodeUnitsState(start, loAndHigh.getValue(), hiEnd); - addCodeUnitsState(hiEnd, loAndHigh.getKey(), to); + // Plan to visit children + for (final var childEntry : entry.subTrie.children.entrySet()) { + toVisit.push(new TraversalEntry<>( + childEntry.getValue(), + entry.distanceToRoot + 1, + childEntry.getKey(), + thisNode + )); + } } final var fixedGroup = toState @@ -195,6 +150,7 @@ public UnaryOperator> visitCharacterClass(IntRangeSet codePointSet) .flatMap(loc -> classSize.map(loc::addDistance)); return new NfaState(start, toState.insideRepetition(), toState.unavoidable(), fixedGroup); + }; } diff --git a/src/main/java/automata/graph/StandardCodeUnits.java b/src/main/java/automata/graph/StandardCodeUnits.java new file mode 100644 index 0000000..f6c265a --- /dev/null +++ b/src/main/java/automata/graph/StandardCodeUnits.java @@ -0,0 +1,180 @@ +package automata.graph; + +import automata.util.IntRange; +import automata.util.IntRangeSet; +import automata.parser.CodePoints; +import automata.graph.TrieSet; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public enum StandardCodeUnits { + + /** + * Codepoints are turned into one or two 16-bit code units. + * + * Java strings internally use this. + */ + UTF_16(IntRange.between(Character.MIN_VALUE, Character.MAX_VALUE)) { + + /** + * Break down the input code point set into a mapping of low ranges to high + * ranges. + * + * The values in the output should be disjoint and union out to a subset of + * the high surrogate range. The values in the keys won't necessarily be + * disjoint, but they should all be in the low surrogate range. Since the + * high surrogate range is just {@code 0xD800–0xDBFF} (1024 values), the + * total size of the output map is at most 1024 entries. + * + * @param codePointSet input code point set + * @return mapping from low surrogate ranges to high surrogate ranges + */ + private static Map supplementaryCodeUnitRanges( + IntRangeSet codePointSet + ) { + + // TODO: this doesn't need to be a map since we scan high codepoints in order + // and only ever update the last one + final var supplementaryCodeUnits = new HashMap>(); + + for (IntRange range : codePointSet.difference(IntRangeSet.of(CodePoints.BMP_RANGE)).ranges()) { + + int rangeStartHi = Character.highSurrogate(range.lowerBound()); + int rangeStartLo = Character.lowSurrogate(range.lowerBound()); + + int rangeEndHi = Character.highSurrogate(range.upperBound()); + int rangeEndLo = Character.lowSurrogate(range.upperBound()); + + if (rangeStartHi == rangeEndHi) { + // Add the _only_ range + supplementaryCodeUnits + .computeIfAbsent(rangeStartHi, k -> new LinkedList<>()) + .addLast(IntRange.between(rangeStartLo, rangeEndLo)); + } else { + // Add the first range + supplementaryCodeUnits + .computeIfAbsent(rangeStartHi, k -> new LinkedList<>()) + .addLast(IntRange.between(rangeStartLo, Character.MAX_LOW_SURROGATE)); + + // Add the last range + supplementaryCodeUnits + .computeIfAbsent(rangeEndHi, k -> new LinkedList<>()) + .addLast(IntRange.between(Character.MIN_LOW_SURROGATE, rangeEndLo)); + + // Add everything in between + for (int hi = rangeStartHi + 1; hi <= rangeEndHi - 1; hi++) { + supplementaryCodeUnits + .computeIfAbsent(hi, k -> new LinkedList<>()) + .addLast(CodePoints.LOW_SURROGATE_RANGE); + } + } + } + + return supplementaryCodeUnits + .entrySet() + .stream() + .collect( + Collectors.groupingBy( + e -> new IntRangeSet(e.getValue()), + Collectors.mapping( + e -> IntRangeSet.of(IntRange.single(e.getKey())), + Collectors.collectingAndThen(Collectors.toList(), IntRangeSet::union) + ) + ) + ); + } + + @Override + public TrieSet codeUnitRangeSuffixTrie(IntRangeSet codePointSet) { + + /* Code unit transitions corresponding to the basic multilingual plane. + * By definition of the BMP, this means these are exactly one code unit. + */ + final var basicCodeUnits = codePointSet.intersection(IntRangeSet.of(CodePoints.BMP_RANGE)); + + /* Mapping from the first (high) 16-bit code unit to the range of second + * (low) 16-bit code units. There are `0xDBFF - 0xD800 + 1 = 1024` high + * code points, so this map will have between 0 and 1024 entries. + */ + final var supplementaryCodeUnits = supplementaryCodeUnitRanges(codePointSet); + + final var output = new TrieSet(); + output.add(List.of(basicCodeUnits)); + for (final var entry : supplementaryCodeUnits.entrySet()) { + output.add(List.of(entry.getKey(), entry.getValue())); + } + return output; + } + }, + + /** + * Codepoints are turned into one to four 8-bit code units. + */ + UTF_8(IntRange.between(0, 0xFF)) { + + // Ranges of code points taking 1, 2, 3, and 4 code unit + private static final IntRange ONE_BYTE_RANGE = IntRange.between(0, 0x7F); + private static final IntRange TWO_BYTE_RANGE = IntRange.between(0x80, 0x7FF); + private static final IntRange THREE_BYTE_RANGE = IntRange.between(0x800, 0xFFFF); + private static final IntRange FOUR_BYTE_RANGE = IntRange.between(0x10000, 0x10FFFF); + + @Override + public TrieSet codeUnitRangeSuffixTrie(IntRangeSet codePointSet) { + throw new AssertionError("unimplemented"); + } + }, + + /** + * Codepoints each map to exactly one 32-bit code unit. + */ + UTF_32(IntRange.between(Character.MIN_CODE_POINT, Character.MAX_CODE_POINT)) { + @Override + public TrieSet codeUnitRangeSuffixTrie(IntRangeSet codePointSet) { + final var output = new TrieSet(); + output.add(List.of(codePointSet)); + return output; + } + }; + + /** + * Full range of possible code units. + */ + public final IntRange codeUnitRange; + + private StandardCodeUnits(IntRange codeUnitRange) { + this.codeUnitRange = codeUnitRange; + } + + /** + * Converts a set of code points into a trie containing reversed code units. + * + *

The output trie must obey a handful of invariants: + * + *

    + *
  • + * The children of any node in the trie should have disjoint subsets of + * the valid code unit range. + * + * As a consequence, any code point {@code c} with code unit + * representation {@code u0, u1, ... un} can traced to a node in the trie + * by walking the suffix tree from the root using the reversed code units + * {@code un, ... u1, u0}. A code point is in the trie set if this process + * ends at a node which has {@code inSet = true}. + * + *
  • + * A code point is in the input set if and only if it is in the output set + * (using the definition of "in" from the previous invariant). + *
+ * + *

This is the information needed to efficiently create a small DFA for the + * range of code points. + * + * @param codePointRange range of code points to include in the trie + * @return trie ranges of reversed code units + */ + public abstract TrieSet codeUnitRangeSuffixTrie(IntRangeSet codePointSet); + +} diff --git a/src/main/java/automata/graph/Tnfa.java b/src/main/java/automata/graph/Tnfa.java index 8f4a3f7..c625918 100644 --- a/src/main/java/automata/graph/Tnfa.java +++ b/src/main/java/automata/graph/Tnfa.java @@ -101,11 +101,12 @@ private Tnfa( */ public static Tnfa parse( String pattern, + StandardCodeUnits codeUnits, int flags, boolean wrappingGroup, boolean wildcardPrefix ) throws PatternSyntaxException { - final var builder = new Builder(); + final var builder = new Builder(codeUnits); final var visited = RegexParser.parse( builder, pattern, @@ -117,6 +118,10 @@ public static Tnfa parse( } public static class Builder extends RegexNfaBuilder { + public Builder(StandardCodeUnits codeUnits) { + super(codeUnits); + } + // Used to buffer up transitions record BufferedTransition(int from, TnfaTransition transition, int to) { } diff --git a/src/main/java/automata/util/TrieSet.java b/src/main/java/automata/util/TrieSet.java new file mode 100644 index 0000000..0b275fc --- /dev/null +++ b/src/main/java/automata/util/TrieSet.java @@ -0,0 +1,82 @@ +package automata.graph; + +import java.util.HashMap; +import java.util.Map; +import java.util.Stack; +import java.util.Optional; + +class TrieSet { + + public boolean inSet = false; + + public final Map> children = new HashMap>(); + + boolean hasChildren() { + return !children.isEmpty(); + } + + /** + * Register a path in the trie. + * + * @param path path which should be created in the trie starting from the root + */ + public void add(Iterable path) { + final var iterator = path.iterator(); + var nextSet = this; + while (iterator.hasNext()) { + nextSet = nextSet.children.computeIfAbsent(iterator.next(), k -> new TrieSet()); + } + nextSet.inSet = true; + } + + /** + * Check if a path is in the trie. + * + * @param path path to check + * @return whether the path is in the trie + */ + public boolean contains(Iterable path) { + final var iterator = path.iterator(); + var nextSet = this; + while (iterator.hasNext()) { + nextSet = nextSet.children.get(iterator.next()); + if (nextSet == null) { + return false; + } + } + return nextSet.inSet; + } + + /** + * If every entry in the set is at the same depth, return that depth. + */ + public Optional inSetDepth() { + int depth = 0; + + var thisLevel = new Stack>(); + var nextLevel = new Stack>(); + thisLevel.push(this); + boolean encounteredInSet = false; + + while (!encounteredInSet && !thisLevel.isEmpty()) { + + // Visit the next level + while (!thisLevel.isEmpty()) { + var entry = thisLevel.pop(); + encounteredInSet = encounteredInSet || entry.inSet; + nextLevel.addAll(entry.children.values()); + } + + // Swap the stacks + var temp = thisLevel; + thisLevel = nextLevel; + nextLevel = temp; + + depth++; + } + + return (encounteredInSet && thisLevel.isEmpty()) + ? Optional.of(depth - 1) + : Optional.empty(); + } +} diff --git a/src/test/scala/automata/graph/Dot.scala b/src/test/scala/automata/graph/Dot.scala index 5b0a887..6d856b0 100644 --- a/src/test/scala/automata/graph/Dot.scala +++ b/src/test/scala/automata/graph/Dot.scala @@ -8,7 +8,7 @@ class Dot extends AnyFunSpec { val re = "((a)*|b)(ab|b)" describe(s"dot-graph for $re") { - val tnfa = Tnfa.parse(re, 0, true, false) + val tnfa = Tnfa.parse(re, StandardCodeUnits.UTF_16, 0, true, false) val tdfa = Tdfa.fromTnfa(tnfa, MatchMode.FULL, true) it("tnfa") {