diff --git a/nucleus/io/sharded_file_utils.py b/nucleus/io/sharded_file_utils.py index e05a5b9..f6f0be5 100644 --- a/nucleus/io/sharded_file_utils.py +++ b/nucleus/io/sharded_file_utils.py @@ -43,7 +43,7 @@ class ShardError(Exception): """An I/O error.""" -def ParseShardedFileSpec(spec): # pylint:disable=invalid-name +def parse_sharded_file_spec(spec): """Parse a sharded file specification. Args: @@ -71,12 +71,12 @@ def ParseShardedFileSpec(spec): # pylint:disable=invalid-name return m.group(2), int(m.group(3)), suffix -def _ShardWidth(num_shards): # pylint:disable=invalid-name +def _shard_width(num_shards): """Return the width of the shard matcher based on the number of shards.""" return max(5, int(math.floor(math.log10(num_shards)) + 1)) -def GenerateShardedFilenames(spec): # pylint:disable=invalid-name +def generate_sharded_filenames(spec): """Generate the list of filenames corresponding to the sharding path. Args: @@ -89,16 +89,16 @@ def GenerateShardedFilenames(spec): # pylint:disable=invalid-name Raises: ShardError: If spec is not a valid sharded file specification. """ - basename, num_shards, suffix = ParseShardedFileSpec(spec) + basename, num_shards, suffix = parse_sharded_file_spec(spec) files = [] - width = _ShardWidth(num_shards) + width = _shard_width(num_shards) format_str = '{{0}}-{{1:0{0}}}-of-{{2:0{0}}}{{3}}'.format(width) for i in range(num_shards): files.append(format_str.format(basename, i, num_shards, suffix)) return files -def GlobListShardedFilePatterns(comma_separated_patterns, sep=','): # pylint:disable=invalid-name +def glob_list_sharded_file_patterns(comma_separated_patterns, sep=','): """Generate list of filenames corresponding to `comma_separated_patterns`. Args: @@ -112,11 +112,11 @@ def GlobListShardedFilePatterns(comma_separated_patterns, sep=','): # pylint:di return sorted(set([ f for pattern in comma_separated_patterns.split(sep) - for f in gfile.Glob(NormalizeToShardedFilePattern(pattern)) + for f in gfile.Glob(normalize_to_sharded_file_pattern(pattern)) ])) -def GenerateShardedFilePattern(basename, num_shards, suffix): # pylint:disable=invalid-name +def generate_sharded_file_pattern(basename, num_shards, suffix): """Generate a sharded file pattern. Args: @@ -126,13 +126,13 @@ def GenerateShardedFilePattern(basename, num_shards, suffix): # pylint:disable= Returns: pattern: """ - width = _ShardWidth(num_shards) + width = _shard_width(num_shards) specifier = '?' * width format_str = '{{0}}-{{1}}-of-{{2:0{0}}}{{3}}'.format(width) return format_str.format(basename, specifier, num_shards, suffix) -def NormalizeToShardedFilePattern(spec_or_pattern): # pylint:disable=invalid-name +def normalize_to_sharded_file_pattern(spec_or_pattern): """Take a sharding spec or sharding file pattern and return a sharded pattern. The input can be a sharding spec(e.g '/some/file@10') or a sharded file @@ -145,13 +145,13 @@ def NormalizeToShardedFilePattern(spec_or_pattern): # pylint:disable=invalid-na A sharded file pattern. """ try: - basename, num_shards, suffix = ParseShardedFileSpec(spec_or_pattern) + basename, num_shards, suffix = parse_sharded_file_spec(spec_or_pattern) except ShardError: return spec_or_pattern - return GenerateShardedFilePattern(basename, num_shards, suffix) + return generate_sharded_file_pattern(basename, num_shards, suffix) -def IsShardedFileSpec(spec): # pylint:disable=invalid-name +def is_sharded_file_spec(spec): """Returns True if spec is a sharded file specification.""" m = SHARD_SPEC_PATTERN.match(spec) return m is not None @@ -160,7 +160,7 @@ def IsShardedFileSpec(spec): # pylint:disable=invalid-name # TODO(mdepristo): retire when GenerateShardedFilename is added to library. def sharded_filename(spec, i): """Gets a path appropriate for writing the ith file of a sharded spec.""" - return GenerateShardedFilenames(spec)[i] + return generate_sharded_filenames(spec)[i] # TODO(b/64046543): Improve the return value (instead of using tuple). It hurts @@ -198,11 +198,11 @@ def resolve_filespecs(shard, *filespecs): raise ValueError('filespecs must have at least one element.') master = filespecs[0] - master_is_sharded = IsShardedFileSpec(master) + master_is_sharded = is_sharded_file_spec(master) master_num_shards = 0 if master_is_sharded: - _, master_num_shards, _ = ParseShardedFileSpec(master) + _, master_num_shards, _ = parse_sharded_file_spec(master) if shard >= master_num_shards or shard < 0: raise ValueError('Invalid shard={} value with master={} sharding'.format( shard, master)) @@ -214,7 +214,7 @@ def resolve_one(filespec): if not filespec: return filespec - is_sharded = IsShardedFileSpec(filespec) + is_sharded = is_sharded_file_spec(filespec) if master_is_sharded != is_sharded: raise ValueError('Master={} and {} have inconsistent sharding'.format( master, filespec)) @@ -222,7 +222,7 @@ def resolve_one(filespec): if not is_sharded: # Not sharded => filespec is the actual filename. return filespec - _, filespec_num_shards, _ = ParseShardedFileSpec(filespec) + _, filespec_num_shards, _ = parse_sharded_file_spec(filespec) if filespec_num_shards != master_num_shards: raise ValueError('Master={} and {} have inconsistent sharding'.format( master, filespec)) @@ -251,7 +251,7 @@ def maybe_generate_sharded_filenames(filespec): """ if not isinstance(filespec, six.string_types): raise TypeError('Invalid filespec: %s' % filespec) - if IsShardedFileSpec(filespec): - return GenerateShardedFilenames(filespec) + if is_sharded_file_spec(filespec): + return generate_sharded_filenames(filespec) else: return [filespec] diff --git a/nucleus/io/sharded_file_utils_test.py b/nucleus/io/sharded_file_utils_test.py index 7d2c491..e8e2928 100644 --- a/nucleus/io/sharded_file_utils_test.py +++ b/nucleus/io/sharded_file_utils_test.py @@ -126,13 +126,14 @@ class ShardsTest(parameterized.TestCase): def testParseShardedFileSpec(self, spec, expected_basename, expected_num_shards, expected_suffix): - basename, num_shards, suffix = io.ParseShardedFileSpec(spec) + basename, num_shards, suffix = io.parse_sharded_file_spec(spec) self.assertEqual(basename, expected_basename) self.assertEqual(num_shards, expected_num_shards) self.assertEqual(suffix, expected_suffix) def testParseShardedFileSpecInvalid(self): - self.assertRaises(io.ShardError, io.ParseShardedFileSpec, '/dir/foo/bar@0') + self.assertRaises(io.ShardError, + io.parse_sharded_file_spec, '/dir/foo/bar@0') @parameterized.named_parameters( ('no_suffix', '/dir/foo/bar@3', [ @@ -145,11 +146,11 @@ def testParseShardedFileSpecInvalid(self): ]), ) def testGenerateShardedFilenames(self, spec, expected): - names = io.GenerateShardedFilenames(spec) + names = io.generate_sharded_filenames(spec) self.assertEqual(names, expected) def testGenerateShardedFilenamesManyShards(self): - names = io.GenerateShardedFilenames('/dir/foo/bar@100000') + names = io.generate_sharded_filenames('/dir/foo/bar@100000') self.assertEqual(len(names), 100000) self.assertEqual(names[99999], '/dir/foo/bar-099999-of-100000') @@ -158,7 +159,7 @@ def testGenerateShardedFilenamesManyShards(self): ('zero_shards', '/dir/foo/bar@0'), ) def testGenerateShardedFilenamesError(self, spec): - self.assertRaises(io.ShardError, io.GenerateShardedFilenames, spec) + self.assertRaises(io.ShardError, io.generate_sharded_filenames, spec) @parameterized.named_parameters( ('basic', '/dir/foo/bar@3', True), @@ -168,7 +169,7 @@ def testGenerateShardedFilenamesError(self, spec): ('not_spec', '/dir/foo/bar', False), ) def testIsShardedFileSpec(self, spec, expected): - actual = io.IsShardedFileSpec(spec) + actual = io.is_sharded_file_spec(spec) self.assertEqual(actual, expected, 'io.IshShardedFileSpec({0}) is {1} expected {2}'.format( spec, actual, expected)) @@ -182,8 +183,8 @@ def testIsShardedFileSpec(self, spec, expected): def testGenerateShardedFilePattern(self, basename, num_shards, suffix, expected): - self.assertEqual( - io.GenerateShardedFilePattern(basename, num_shards, suffix), expected) + self.assertEqual(io.generate_sharded_file_pattern( + basename, num_shards, suffix), expected) @parameterized.named_parameters( ('no_spec', '/dir/foo/bar', '/dir/foo/bar'), @@ -193,7 +194,7 @@ def testGenerateShardedFilePattern(self, basename, num_shards, suffix, ('many', '/dir/foo/bar@12345678', '/dir/foo/bar-????????-of-12345678'), ) def testNormalizeToShardedFilePattern(self, spec, expected): - self.assertEqual(expected, io.NormalizeToShardedFilePattern(spec)) + self.assertEqual(expected, io.normalize_to_sharded_file_pattern(spec)) @parameterized.named_parameters( ('no_spec', 'no_spec', ['no_spec']), @@ -217,7 +218,7 @@ def testGlobListShardedFilePatterns(self, specs, expected_files): full_specs = ','.join( [test_utils.test_tmpfile(spec) for spec in specs.split(',')]) self.assertEqual(sorted(set(expected_full_files)), - io.GlobListShardedFilePatterns(full_specs)) + io.glob_list_sharded_file_patterns(full_specs)) if __name__ == '__main__': absltest.main() diff --git a/nucleus/io/tfrecord.py b/nucleus/io/tfrecord.py index 501b8d8..a1d67bc 100644 --- a/nucleus/io/tfrecord.py +++ b/nucleus/io/tfrecord.py @@ -59,8 +59,8 @@ def read_tfrecords(path, proto=None, max_records=None, options=None): if not options: options = make_tfrecord_options(path) - if sharded_file_utils.IsShardedFileSpec(path): - paths = sharded_file_utils.GenerateShardedFilenames(path) + if sharded_file_utils.is_sharded_file_spec(path): + paths = sharded_file_utils.generate_sharded_filenames(path) else: paths = [path] @@ -105,8 +105,8 @@ class and returns a value on which the comparison for sorted ordering is if options is None: options = make_tfrecord_options(path) - if sharded_file_utils.IsShardedFileSpec(path): - paths = sharded_file_utils.GenerateShardedFilenames(path) + if sharded_file_utils.is_sharded_file_spec(path): + paths = sharded_file_utils.generate_sharded_filenames(path) else: paths = [path] @@ -141,9 +141,9 @@ def write_tfrecords(protos, output_path, options=None): if not options: options = make_tfrecord_options(output_path) - if sharded_file_utils.IsShardedFileSpec(output_path): + if sharded_file_utils.is_sharded_file_spec(output_path): with contextlib2.ExitStack() as stack: - _, n_shards, _ = sharded_file_utils.ParseShardedFileSpec(output_path) + _, n_shards, _ = sharded_file_utils.parse_sharded_file_spec(output_path) writers = [ stack.enter_context( make_tfrecord_writer(sharded_file_utils.sharded_filename(