Skip to content

Commit

Permalink
Merge pull request tensorflow#8311 from stagedml:bleu-numoflines
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 301924393
  • Loading branch information
tensorflower-gardener committed Mar 19, 2020
2 parents d697041 + cdfc1dd commit 27207a2
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions official/nlp/transformer/compute_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def __init__(self):
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")

def property_chars(self, prefix):
return "".join(six.unichr(x) for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix))
return "".join(
six.unichr(x)
for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix))


uregex = UnicodeRegex()
Expand Down Expand Up @@ -92,9 +94,10 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
tf.io.gfile.GFile(hyp_filename).read()).strip().splitlines()

if len(ref_lines) != len(hyp_lines):
raise ValueError("Reference and translation files have different number of "
"lines. If training only a few steps (100-200), the "
"translation may be empty.")
raise ValueError(
"Reference and translation files have different number of "
"lines (%d VS %d). If training only a few steps (100-200), the "
"translation may be empty." % (len(ref_lines), len(hyp_lines)))
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
Expand All @@ -116,18 +119,23 @@ def main(unused_argv):
def define_compute_bleu_flags():
"""Add flags for computing BLEU score."""
flags.DEFINE_string(
name="translation", default=None,
name="translation",
default=None,
help=flags_core.help_wrap("File containing translated text."))
flags.mark_flag_as_required("translation")

flags.DEFINE_string(
name="reference", default=None,
name="reference",
default=None,
help=flags_core.help_wrap("File containing reference translation."))
flags.mark_flag_as_required("reference")

flags.DEFINE_enum(
name="bleu_variant", short_name="bv", default="both",
enum_values=["both", "uncased", "cased"], case_sensitive=False,
name="bleu_variant",
short_name="bv",
default="both",
enum_values=["both", "uncased", "cased"],
case_sensitive=False,
help=flags_core.help_wrap(
"Specify one or more BLEU variants to calculate. Variants: \"cased\""
", \"uncased\", or \"both\"."))
Expand Down

0 comments on commit 27207a2

Please sign in to comment.