Skip to content
This repository has been archived by the owner on May 25, 2024. It is now read-only.

Commit

Permalink
Convert sharded_file_utils function names frmo CamelCase to snake_case
Browse files Browse the repository at this point in the history
for consistency with the rest of Nucleus.

PiperOrigin-RevId: 228227520
  • Loading branch information
Genomics team in Google Brain authored and Copybara-Service committed Jan 7, 2019
1 parent 4bfef2a commit 39070af
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 36 deletions.
40 changes: 20 additions & 20 deletions nucleus/io/sharded_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ShardError(Exception):
"""An I/O error."""


def ParseShardedFileSpec(spec): # pylint:disable=invalid-name
def parse_sharded_file_spec(spec):
"""Parse a sharded file specification.
Args:
Expand Down Expand Up @@ -71,12 +71,12 @@ def ParseShardedFileSpec(spec): # pylint:disable=invalid-name
return m.group(2), int(m.group(3)), suffix


def _ShardWidth(num_shards): # pylint:disable=invalid-name
def _shard_width(num_shards):
"""Return the width of the shard matcher based on the number of shards."""
return max(5, int(math.floor(math.log10(num_shards)) + 1))


def GenerateShardedFilenames(spec): # pylint:disable=invalid-name
def generate_sharded_filenames(spec):
"""Generate the list of filenames corresponding to the sharding path.
Args:
Expand All @@ -89,16 +89,16 @@ def GenerateShardedFilenames(spec): # pylint:disable=invalid-name
Raises:
ShardError: If spec is not a valid sharded file specification.
"""
basename, num_shards, suffix = ParseShardedFileSpec(spec)
basename, num_shards, suffix = parse_sharded_file_spec(spec)
files = []
width = _ShardWidth(num_shards)
width = _shard_width(num_shards)
format_str = '{{0}}-{{1:0{0}}}-of-{{2:0{0}}}{{3}}'.format(width)
for i in range(num_shards):
files.append(format_str.format(basename, i, num_shards, suffix))

return files

def GlobListShardedFilePatterns(comma_separated_patterns, sep=','): # pylint:disable=invalid-name
def glob_list_sharded_file_patterns(comma_separated_patterns, sep=','):
"""Generate list of filenames corresponding to `comma_separated_patterns`.
Args:
Expand All @@ -112,11 +112,11 @@ def GlobListShardedFilePatterns(comma_separated_patterns, sep=','): # pylint:di
return sorted(set([
f
for pattern in comma_separated_patterns.split(sep)
for f in gfile.Glob(NormalizeToShardedFilePattern(pattern))
for f in gfile.Glob(normalize_to_sharded_file_pattern(pattern))
]))


def GenerateShardedFilePattern(basename, num_shards, suffix): # pylint:disable=invalid-name
def generate_sharded_file_pattern(basename, num_shards, suffix):
"""Generate a sharded file pattern.
Args:
Expand All @@ -126,13 +126,13 @@ def GenerateShardedFilePattern(basename, num_shards, suffix): # pylint:disable=
Returns:
pattern:
"""
width = _ShardWidth(num_shards)
width = _shard_width(num_shards)
specifier = '?' * width
format_str = '{{0}}-{{1}}-of-{{2:0{0}}}{{3}}'.format(width)
return format_str.format(basename, specifier, num_shards, suffix)


def NormalizeToShardedFilePattern(spec_or_pattern): # pylint:disable=invalid-name
def normalize_to_sharded_file_pattern(spec_or_pattern):
"""Take a sharding spec or sharding file pattern and return a sharded pattern.
The input can be a sharding spec(e.g '/some/file@10') or a sharded file
Expand All @@ -145,13 +145,13 @@ def NormalizeToShardedFilePattern(spec_or_pattern): # pylint:disable=invalid-na
A sharded file pattern.
"""
try:
basename, num_shards, suffix = ParseShardedFileSpec(spec_or_pattern)
basename, num_shards, suffix = parse_sharded_file_spec(spec_or_pattern)
except ShardError:
return spec_or_pattern
return GenerateShardedFilePattern(basename, num_shards, suffix)
return generate_sharded_file_pattern(basename, num_shards, suffix)


def IsShardedFileSpec(spec): # pylint:disable=invalid-name
def is_sharded_file_spec(spec):
"""Returns True if spec is a sharded file specification."""
m = SHARD_SPEC_PATTERN.match(spec)
return m is not None
Expand All @@ -160,7 +160,7 @@ def IsShardedFileSpec(spec): # pylint:disable=invalid-name
# TODO(mdepristo): retire when GenerateShardedFilename is added to library.
def sharded_filename(spec, i):
"""Gets a path appropriate for writing the ith file of a sharded spec."""
return GenerateShardedFilenames(spec)[i]
return generate_sharded_filenames(spec)[i]


# TODO(b/64046543): Improve the return value (instead of using tuple). It hurts
Expand Down Expand Up @@ -198,11 +198,11 @@ def resolve_filespecs(shard, *filespecs):
raise ValueError('filespecs must have at least one element.')

master = filespecs[0]
master_is_sharded = IsShardedFileSpec(master)
master_is_sharded = is_sharded_file_spec(master)

master_num_shards = 0
if master_is_sharded:
_, master_num_shards, _ = ParseShardedFileSpec(master)
_, master_num_shards, _ = parse_sharded_file_spec(master)
if shard >= master_num_shards or shard < 0:
raise ValueError('Invalid shard={} value with master={} sharding'.format(
shard, master))
Expand All @@ -214,15 +214,15 @@ def resolve_one(filespec):
if not filespec:
return filespec

is_sharded = IsShardedFileSpec(filespec)
is_sharded = is_sharded_file_spec(filespec)
if master_is_sharded != is_sharded:
raise ValueError('Master={} and {} have inconsistent sharding'.format(
master, filespec))

if not is_sharded: # Not sharded => filespec is the actual filename.
return filespec

_, filespec_num_shards, _ = ParseShardedFileSpec(filespec)
_, filespec_num_shards, _ = parse_sharded_file_spec(filespec)
if filespec_num_shards != master_num_shards:
raise ValueError('Master={} and {} have inconsistent sharding'.format(
master, filespec))
Expand Down Expand Up @@ -251,7 +251,7 @@ def maybe_generate_sharded_filenames(filespec):
"""
if not isinstance(filespec, six.string_types):
raise TypeError('Invalid filespec: %s' % filespec)
if IsShardedFileSpec(filespec):
return GenerateShardedFilenames(filespec)
if is_sharded_file_spec(filespec):
return generate_sharded_filenames(filespec)
else:
return [filespec]
21 changes: 11 additions & 10 deletions nucleus/io/sharded_file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,14 @@ class ShardsTest(parameterized.TestCase):
def testParseShardedFileSpec(self, spec, expected_basename,
expected_num_shards, expected_suffix):

basename, num_shards, suffix = io.ParseShardedFileSpec(spec)
basename, num_shards, suffix = io.parse_sharded_file_spec(spec)
self.assertEqual(basename, expected_basename)
self.assertEqual(num_shards, expected_num_shards)
self.assertEqual(suffix, expected_suffix)

def testParseShardedFileSpecInvalid(self):
self.assertRaises(io.ShardError, io.ParseShardedFileSpec, '/dir/foo/bar@0')
self.assertRaises(io.ShardError,
io.parse_sharded_file_spec, '/dir/foo/bar@0')

@parameterized.named_parameters(
('no_suffix', '/dir/foo/bar@3', [
Expand All @@ -145,11 +146,11 @@ def testParseShardedFileSpecInvalid(self):
]),
)
def testGenerateShardedFilenames(self, spec, expected):
names = io.GenerateShardedFilenames(spec)
names = io.generate_sharded_filenames(spec)
self.assertEqual(names, expected)

def testGenerateShardedFilenamesManyShards(self):
names = io.GenerateShardedFilenames('/dir/foo/bar@100000')
names = io.generate_sharded_filenames('/dir/foo/bar@100000')
self.assertEqual(len(names), 100000)
self.assertEqual(names[99999], '/dir/foo/bar-099999-of-100000')

Expand All @@ -158,7 +159,7 @@ def testGenerateShardedFilenamesManyShards(self):
('zero_shards', '/dir/foo/bar@0'),
)
def testGenerateShardedFilenamesError(self, spec):
self.assertRaises(io.ShardError, io.GenerateShardedFilenames, spec)
self.assertRaises(io.ShardError, io.generate_sharded_filenames, spec)

@parameterized.named_parameters(
('basic', '/dir/foo/bar@3', True),
Expand All @@ -168,7 +169,7 @@ def testGenerateShardedFilenamesError(self, spec):
('not_spec', '/dir/foo/bar', False),
)
def testIsShardedFileSpec(self, spec, expected):
actual = io.IsShardedFileSpec(spec)
actual = io.is_sharded_file_spec(spec)
self.assertEqual(actual, expected,
'io.IshShardedFileSpec({0}) is {1} expected {2}'.format(
spec, actual, expected))
Expand All @@ -182,8 +183,8 @@ def testIsShardedFileSpec(self, spec, expected):
def testGenerateShardedFilePattern(self, basename, num_shards, suffix,
expected):

self.assertEqual(
io.GenerateShardedFilePattern(basename, num_shards, suffix), expected)
self.assertEqual(io.generate_sharded_file_pattern(
basename, num_shards, suffix), expected)

@parameterized.named_parameters(
('no_spec', '/dir/foo/bar', '/dir/foo/bar'),
Expand All @@ -193,7 +194,7 @@ def testGenerateShardedFilePattern(self, basename, num_shards, suffix,
('many', '/dir/foo/bar@12345678', '/dir/foo/bar-????????-of-12345678'),
)
def testNormalizeToShardedFilePattern(self, spec, expected):
self.assertEqual(expected, io.NormalizeToShardedFilePattern(spec))
self.assertEqual(expected, io.normalize_to_sharded_file_pattern(spec))

@parameterized.named_parameters(
('no_spec', 'no_spec', ['no_spec']),
Expand All @@ -217,7 +218,7 @@ def testGlobListShardedFilePatterns(self, specs, expected_files):
full_specs = ','.join(
[test_utils.test_tmpfile(spec) for spec in specs.split(',')])
self.assertEqual(sorted(set(expected_full_files)),
io.GlobListShardedFilePatterns(full_specs))
io.glob_list_sharded_file_patterns(full_specs))

if __name__ == '__main__':
absltest.main()
12 changes: 6 additions & 6 deletions nucleus/io/tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def read_tfrecords(path, proto=None, max_records=None, options=None):
if not options:
options = make_tfrecord_options(path)

if sharded_file_utils.IsShardedFileSpec(path):
paths = sharded_file_utils.GenerateShardedFilenames(path)
if sharded_file_utils.is_sharded_file_spec(path):
paths = sharded_file_utils.generate_sharded_filenames(path)
else:
paths = [path]

Expand Down Expand Up @@ -105,8 +105,8 @@ class and returns a value on which the comparison for sorted ordering is
if options is None:
options = make_tfrecord_options(path)

if sharded_file_utils.IsShardedFileSpec(path):
paths = sharded_file_utils.GenerateShardedFilenames(path)
if sharded_file_utils.is_sharded_file_spec(path):
paths = sharded_file_utils.generate_sharded_filenames(path)
else:
paths = [path]

Expand Down Expand Up @@ -141,9 +141,9 @@ def write_tfrecords(protos, output_path, options=None):
if not options:
options = make_tfrecord_options(output_path)

if sharded_file_utils.IsShardedFileSpec(output_path):
if sharded_file_utils.is_sharded_file_spec(output_path):
with contextlib2.ExitStack() as stack:
_, n_shards, _ = sharded_file_utils.ParseShardedFileSpec(output_path)
_, n_shards, _ = sharded_file_utils.parse_sharded_file_spec(output_path)
writers = [
stack.enter_context(
make_tfrecord_writer(sharded_file_utils.sharded_filename(
Expand Down

0 comments on commit 39070af

Please sign in to comment.