Skip to content

Commit

Permalink
Merge pull request tensorflow#8256 from stagedml:tokenizer-update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 301949311
  • Loading branch information
tensorflower-gardener committed Mar 20, 2020
2 parents 27207a2 + 30579e0 commit 2416dd9
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 54 deletions.
122 changes: 82 additions & 40 deletions official/nlp/transformer/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf


# pylint: disable=g-complex-comprehension
PAD = "<pad>"
PAD_ID = 0
EOS = "<EOS>"
Expand All @@ -45,27 +47,36 @@

_UNDEFINED_UNICODE = u"\u3013"


def alphanumeric_char_set():
return set(
six.unichr(i)
for i in xrange(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N")))


# Set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = set(
six.unichr(i) for i in xrange(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N")))
_ALPHANUMERIC_CHAR_SET = alphanumeric_char_set()

# min_count is the minimum number of times a subtoken must appear in the data
# before before it is added to the vocabulary. The value is found using binary
# search to obtain the target vocabulary size.
_MIN_MIN_COUNT = 1 # min value to use when binary searching for min_count
_MIN_MIN_COUNT = 1 # min value to use when binary searching for min_count
_MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count


class Subtokenizer(object):
"""Encodes and decodes strings to/from integer IDs."""

def __init__(self, vocab_file, reserved_tokens=None):
def __init__(self, vocab_file, reserved_tokens=None, master_char_set=None):
"""Initializes class, creating a vocab file if data_files is provided."""
tf.compat.v1.logging.info("Initializing Subtokenizer from file %s." %
vocab_file)

if master_char_set is None:
master_char_set = _ALPHANUMERIC_CHAR_SET

if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS

Expand All @@ -78,13 +89,20 @@ def __init__(self, vocab_file, reserved_tokens=None):
self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken))

# Create cache to speed up subtokenization
self._cache_size = 2 ** 20
self._cache_size = 2**20
self._cache = [(None, None)] * self._cache_size
self._master_char_set = master_char_set

@staticmethod
def init_from_files(
vocab_file, files, target_vocab_size, threshold, min_count=None,
file_byte_limit=1e6, reserved_tokens=None, correct_strip=True):
def init_from_files(vocab_file,
files,
target_vocab_size,
threshold,
min_count=None,
file_byte_limit=1e6,
reserved_tokens=None,
correct_strip=True,
master_char_set=None):
"""Create subtoken vocabulary based on files, and save vocab to file.
Args:
Expand All @@ -101,34 +119,41 @@ def init_from_files(
reserved_tokens: List of string tokens that are guaranteed to be at the
beginning of the subtoken vocabulary list.
correct_strip: Whether to convert text to unicode before strip.
master_char_set: the char set.
Returns:
Subtokenizer object
"""
if master_char_set is None:
master_char_set = _ALPHANUMERIC_CHAR_SET
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS

if tf.io.gfile.exists(vocab_file):
tf.compat.v1.logging.info("Vocab file already exists (%s)" % vocab_file)
else:
tf.compat.v1.logging.info("Begin steps to create subtoken vocabulary...")
token_counts = _count_tokens(files, file_byte_limit, correct_strip)
token_counts = _count_tokens(files, file_byte_limit, correct_strip,
master_char_set)
alphabet = _generate_alphabet_dict(token_counts)
subtoken_list = _generate_subtokens_with_target_vocab_size(
token_counts, alphabet, target_vocab_size, threshold, min_count,
reserved_tokens)
tf.compat.v1.logging.info("Generated vocabulary with %d subtokens." %
len(subtoken_list))
_save_vocab_file(vocab_file, subtoken_list)
return Subtokenizer(vocab_file)
return Subtokenizer(vocab_file, master_char_set=master_char_set)

def encode(self, raw_string, add_eos=False):
"""Encodes a string into a list of int subtoken ids."""
ret = []
tokens = _split_string_to_tokens(native_to_unicode(raw_string))
tokens = _split_string_to_tokens(
native_to_unicode(raw_string), self._master_char_set)
for token in tokens:
ret.extend(self._token_to_subtoken_ids(token))
if add_eos:
assert EOS in self.subtoken_list, \
"Can't append 'EOS' because it is not in list of known subtokens."
ret.append(EOS_ID)
return ret

Expand Down Expand Up @@ -161,13 +186,14 @@ def decode(self, subtokens):
"Subtokens argument passed into decode() must be a list of integers.")

return _unicode_to_native(
_join_tokens_to_string(self._subtoken_ids_to_tokens(subtokens)))
_join_tokens_to_string(
self._subtoken_ids_to_tokens(subtokens), self._master_char_set))

def _subtoken_ids_to_tokens(self, subtokens):
"""Convert list of int subtoken ids to a list of string tokens."""
escaped_tokens = "".join([
self.subtoken_list[s] for s in subtokens
if s < len(self.subtoken_list)])
self.subtoken_list[s] for s in subtokens if s < len(self.subtoken_list)
])
escaped_tokens = escaped_tokens.split("_")

# All tokens in the vocabulary list have been escaped (see _escape_token())
Expand Down Expand Up @@ -204,30 +230,30 @@ def _load_vocab_file(vocab_file, reserved_tokens=None):

def native_to_unicode(s):
"""Convert string to unicode (required in Python 2)."""
try: # Python 2
try: # Python 2
return s if isinstance(s, unicode) else s.decode("utf-8")
except NameError: # Python 3
return s


def _unicode_to_native(s):
"""Convert string from unicode to native format (required in Python 2)."""
try: # Python 2
try: # Python 2
return s.encode("utf-8") if isinstance(s, unicode) else s
except NameError: # Python 3
return s


def _split_string_to_tokens(text):
def _split_string_to_tokens(text, master_char_set):
"""Splits text to a list of string tokens."""
if not text:
return []
ret = []
token_start = 0
# Classify each character in the input string
is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text]
is_master = [c in master_char_set for c in text]
for pos in xrange(1, len(text)):
if is_alnum[pos] != is_alnum[pos - 1]:
if is_master[pos] != is_master[pos - 1]:
token = text[token_start:pos]
if token != u" " or token_start == 0:
ret.append(token)
Expand All @@ -237,12 +263,12 @@ def _split_string_to_tokens(text):
return ret


def _join_tokens_to_string(tokens):
def _join_tokens_to_string(tokens, master_char_set):
"""Join a list of string tokens into a single string."""
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens]
token_is_master = [t[0] in master_char_set for t in tokens]
ret = []
for i, token in enumerate(tokens):
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]:
if i > 0 and token_is_master[i - 1] and token_is_master[i]:
ret.append(u" ")
ret.append(token)
return "".join(ret)
Expand Down Expand Up @@ -324,7 +350,10 @@ def match(m):
return _UNESCAPE_REGEX.sub(match, token)


def _count_tokens(files, file_byte_limit=1e6, correct_strip=True):
def _count_tokens(files,
file_byte_limit=1e6,
correct_strip=True,
master_char_set=None):
"""Return token counts of words in the files.
Samples file_byte_limit bytes from each file, and counts the words that appear
Expand All @@ -337,11 +366,15 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True):
vocabulary generation for PY2. Sets correct_strip to False in PY2 to
reproduce previous common public result. Sets correct_strip to True will
let PY2 and PY3 get a consistent vocabulary.
master_char_set: the char set.
Returns:
Dictionary mapping tokens to the number of times they appear in the sampled
lines from the files.
"""
if master_char_set is None:
master_char_set = _ALPHANUMERIC_CHAR_SET

token_counts = collections.defaultdict(int)

for filepath in files:
Expand All @@ -362,7 +395,8 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True):
counter = 0

# Add words to token counts
for token in _split_string_to_tokens(native_to_unicode(line)):
for token in _split_string_to_tokens(
native_to_unicode(line), master_char_set):
token_counts[token] += 1
return token_counts

Expand Down Expand Up @@ -394,9 +428,12 @@ def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length):
return ret


def _generate_subtokens_with_target_vocab_size(
token_counts, alphabet, target_size, threshold, min_count=None,
reserved_tokens=None):
def _generate_subtokens_with_target_vocab_size(token_counts,
alphabet,
target_size,
threshold,
min_count=None,
reserved_tokens=None):
"""Generate subtoken vocabulary close to the target size."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
Expand Down Expand Up @@ -449,8 +486,8 @@ def _generate_alphabet_dict(iterable, reserved_tokens=None):
return alphabet


def _count_and_gen_subtokens(
token_counts, alphabet, subtoken_dict, max_subtoken_length):
def _count_and_gen_subtokens(token_counts, alphabet, subtoken_dict,
max_subtoken_length):
"""Count number of times subtokens appear, and generate new subtokens.
Args:
Expand All @@ -468,8 +505,8 @@ def _count_and_gen_subtokens(
subtoken_counts = collections.defaultdict(int)
for token, count in six.iteritems(token_counts):
token = _escape_token(token, alphabet)
subtokens = _split_token_to_subtokens(
token, subtoken_dict, max_subtoken_length)
subtokens = _split_token_to_subtokens(token, subtoken_dict,
max_subtoken_length)

# Generate new subtokens by taking substrings from token.
start = 0
Expand Down Expand Up @@ -503,8 +540,10 @@ def _filter_and_bucket_subtokens(subtoken_counts, min_count):
return subtoken_buckets


def _gen_new_subtoken_list(
subtoken_counts, min_count, alphabet, reserved_tokens=None):
def _gen_new_subtoken_list(subtoken_counts,
min_count,
alphabet,
reserved_tokens=None):
"""Generate candidate subtokens ordered by count, and new max subtoken length.
Add subtokens to the candiate list in order of length (longest subtokens
Expand Down Expand Up @@ -575,9 +614,11 @@ def _gen_new_subtoken_list(
return subtoken_list, max_subtoken_length


def _generate_subtokens(
token_counts, alphabet, min_count, num_iterations=4,
reserved_tokens=None):
def _generate_subtokens(token_counts,
alphabet,
min_count,
num_iterations=4,
reserved_tokens=None):
"""Create a list of subtokens in decreasing order of frequency.
Args:
Expand Down Expand Up @@ -609,8 +650,9 @@ def _generate_subtokens(

# Create dict mapping subtoken->count, with additional subtokens created
# from substrings taken from the tokens.
subtoken_counts = _count_and_gen_subtokens(
token_counts, alphabet, subtoken_dict, max_subtoken_length)
subtoken_counts = _count_and_gen_subtokens(token_counts, alphabet,
subtoken_dict,
max_subtoken_length)

# Generate new list of subtokens sorted by subtoken count.
subtoken_list, max_subtoken_length = _gen_new_subtoken_list(
Expand Down
Loading

0 comments on commit 2416dd9

Please sign in to comment.