Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated MNIST dataset download link #93

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 65 additions & 64 deletions nengo_extras/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .compat import is_integer, is_iterable, pickle_load_bytes, urlretrieve


data_dir = nengo.rc.get('nengo_extras', 'data_dir')
data_dir = nengo.rc.get("nengo_extras", "data_dir")


def get_file(filename, url):
Expand All @@ -23,32 +23,33 @@ def get_file(filename, url):


def get_cifar10_tar_gz():
filename = os.path.join(data_dir, 'cifar-10-python.tar.gz')
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
filename = os.path.join(data_dir, "cifar-10-python.tar.gz")
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
return get_file(filename, url)


def get_cifar100_tar_gz():
filename = os.path.join(data_dir, 'cifar-100-python.tar.gz')
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
filename = os.path.join(data_dir, "cifar-100-python.tar.gz")
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
return get_file(filename, url)


def get_ilsvrc2012_tar_gz():
filename = os.path.join(data_dir, 'ilsvrc-2012-batches-test3.tar.gz')
url = 'http://files.figshare.com/5370887/ilsvrc-2012-batches-test3.tar.gz'
filename = os.path.join(data_dir, "ilsvrc-2012-batches-test3.tar.gz")
url = "http://files.figshare.com/5370887/ilsvrc-2012-batches-test3.tar.gz"
return get_file(filename, url)


def get_mnist_pkl_gz():
filename = os.path.join(data_dir, 'mnist.pkl.gz')
url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz'
filename = os.path.join(data_dir, "mnist.pkl.gz")
# url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz'
url = "https://ndownloader.figshare.com/files/25635053"
return get_file(filename, url)


def get_svhn_tar_gz():
filename = os.path.join(data_dir, 'svhn-py-colmajor.tar.gz')
url = 'https://files.figshare.com/7868377/svhn-py-colmajor.tar.gz'
filename = os.path.join(data_dir, "svhn-py-colmajor.tar.gz")
url = "https://files.figshare.com/7868377/svhn-py-colmajor.tar.gz"
return get_file(filename, url)


Expand Down Expand Up @@ -87,17 +88,18 @@ def load_cifar10(filepath=None, n_train=5, n_test=1, label_names=False):
# helper for reading each batch file
def read_tar_batch(tar, name):
data = unpickle_tarfile(tar, name)
return data[b'data'], np.array(data[b'labels'])
return data[b"data"], np.array(data[b"labels"])

filepath = os.path.expanduser(filepath)
with tarfile.open(filepath, 'r:gz') as tar:
with tarfile.open(filepath, "r:gz") as tar:
if n_train < 1:
train = (np.array([]), np.array([]))
else:
train = ([], [])
for i in range(n_train):
data, labels = read_tar_batch(
tar, 'cifar-10-batches-py/data_batch_%d' % (i+1))
tar, "cifar-10-batches-py/data_batch_%d" % (i + 1)
)
train[0].append(data)
train[1].append(labels)

Expand All @@ -106,11 +108,11 @@ def read_tar_batch(tar, name):
if n_test < 1:
test = (np.array([]), np.array([]))
else:
test = read_tar_batch(tar, 'cifar-10-batches-py/test_batch')
test = read_tar_batch(tar, "cifar-10-batches-py/test_batch")

if label_names:
meta = unpickle_tarfile(tar, 'cifar-10-batches-py/batches.meta')
names = meta[b'label_names']
meta = unpickle_tarfile(tar, "cifar-10-batches-py/batches.meta")
names = meta[b"label_names"]

return (train, test) + ((names,) if label_names else ())

Expand Down Expand Up @@ -143,17 +145,17 @@ def load_cifar100(filepath=None, fine_labels=True, label_names=False):
# helper for reading each batch file
def read_tar_batch(tar, name):
data = unpickle_tarfile(tar, name)
return data[b'data'], np.array(
data[b'fine_labels' if fine_labels else b'coarse_labels'])
return data[b"data"], np.array(
data[b"fine_labels" if fine_labels else b"coarse_labels"]
)

filepath = os.path.expanduser(filepath)
with tarfile.open(filepath, 'r:gz') as tar:
train = read_tar_batch(tar, 'cifar-100-python/train')
test = read_tar_batch(tar, 'cifar-100-python/test')
with tarfile.open(filepath, "r:gz") as tar:
train = read_tar_batch(tar, "cifar-100-python/train")
test = read_tar_batch(tar, "cifar-100-python/test")
if label_names:
meta = unpickle_tarfile(tar, 'cifar-100-python/meta')
names = meta[
b'fine_label_names' if fine_labels else b'coarse_label_names']
meta = unpickle_tarfile(tar, "cifar-100-python/meta")
names = meta[b"fine_label_names" if fine_labels else b"coarse_label_names"]

return (train, test) + ((names,) if label_names else ())

Expand Down Expand Up @@ -192,19 +194,18 @@ def load_ilsvrc2012(filepath=None, n_files=None):
# helper for reading each batch file
def read_tar_batch(tar, name):
data = unpickle_tarfile(tar, name)
return data[b'data'], data[b'labels'] # JPEG strings, labels
return data[b"data"], data[b"labels"] # JPEG strings, labels

def bytes_to_array(b):
image = PIL.Image.open(io.BytesIO(b))
array = np.array(image, dtype=np.uint8).reshape(
image.size[0], image.size[1], 3)
array = np.array(image, dtype=np.uint8).reshape(image.size[0], image.size[1], 3)
array = np.transpose(array, (2, 0, 1))
return array

filepath = os.path.expanduser(filepath)
with tarfile.open(filepath, 'r:gz') as tar:
with tarfile.open(filepath, "r:gz") as tar:
names = tar.getnames()
regex = re.compile(r'.*/data_batch_([0-9]+\.[0-9]+)')
regex = re.compile(r".*/data_batch_([0-9]+\.[0-9]+)")
matches = [regex.match(name) for name in names]
matches = [match for match in matches if match]

Expand All @@ -229,9 +230,9 @@ def bytes_to_array(b):
labels = np.array(raw_labels)
labels.shape = (n_images,)

meta = unpickle_tarfile(tar, 'batches.meta')
data_mean = meta[b'data_mean'].reshape(image_shape)
label_names = meta[b'label_names']
meta = unpickle_tarfile(tar, "batches.meta")
data_mean = meta[b"data_mean"].reshape(image_shape)
label_names = meta[b"label_names"]

return images, labels, data_mean, label_names

Expand All @@ -241,10 +242,10 @@ def load_ilsvrc2012_metadata(filepath=None):
filepath = get_ilsvrc2012_tar_gz()

filepath = os.path.expanduser(filepath)
with tarfile.open(filepath, 'r:gz') as tar:
meta = unpickle_tarfile(tar, 'batches.meta')
data_mean = meta[b'data_mean'].reshape((3, 256, 256))
label_names = meta[b'label_names']
with tarfile.open(filepath, "r:gz") as tar:
meta = unpickle_tarfile(tar, "batches.meta")
data_mean = meta[b"data_mean"].reshape((3, 256, 256))
label_names = meta[b"label_names"]

return data_mean, label_names

Expand Down Expand Up @@ -275,19 +276,20 @@ def load_mnist(filepath=None, validation=False):
filepath = get_mnist_pkl_gz()

filepath = os.path.expanduser(filepath)
with gzip.open(filepath, 'rb') as f:
with gzip.open(filepath, "rb") as f:
train_set, valid_set, test_set = pickle_load_bytes(f)

if validation:
return train_set, valid_set, test_set
else: # combine valid into train
train_set = (np.vstack((train_set[0], valid_set[0])),
np.hstack((train_set[1], valid_set[1])))
train_set = (
np.vstack((train_set[0], valid_set[0])),
np.hstack((train_set[1], valid_set[1])),
)
return train_set, test_set


def load_svhn(filepath=None, n_train=9, n_test=3, data_mean=False,
label_names=False):
def load_svhn(filepath=None, n_train=9, n_test=3, data_mean=False, label_names=False):
"""Load the SVHN dataset.

Parameters
Expand Down Expand Up @@ -318,31 +320,29 @@ def load_svhn(filepath=None, n_train=9, n_test=3, data_mean=False,

def read_tar_batch(tar, name):
data = unpickle_tarfile(tar, name)
return data[b'data'], np.array(data[b'labels'])
return data[b"data"], np.array(data[b"labels"])

def load_batches(tar, inds):
if len(inds) < 1:
return (np.array([]), np.array([]))

batches = ([], [])
for i in inds:
data, labels = read_tar_batch(
tar, 'svhn-py-colmajor/data_batch_%d' % i)
data, labels = read_tar_batch(tar, "svhn-py-colmajor/data_batch_%d" % i)
batches[0].append(data.T)
batches[1].append(labels)

return (np.vstack(batches[0]).reshape((-1,) + shape),
np.hstack(batches[1]))
return (np.vstack(batches[0]).reshape((-1,) + shape), np.hstack(batches[1]))

filepath = os.path.expanduser(filepath)
with tarfile.open(filepath, 'r:gz') as tar:
train = load_batches(tar, list(range(1, n_train+1)))
test = load_batches(tar, list(range(10, n_test+10)))
with tarfile.open(filepath, "r:gz") as tar:
train = load_batches(tar, list(range(1, n_train + 1)))
test = load_batches(tar, list(range(10, n_test + 10)))

if label_names or data_mean:
meta = unpickle_tarfile(tar, 'svhn-py-colmajor/batches.meta')
data_mean = (meta[b'data_mean'].reshape(shape),) if data_mean else ()
label_names = (meta[b'label_names'],) if label_names else ()
meta = unpickle_tarfile(tar, "svhn-py-colmajor/batches.meta")
data_mean = (meta[b"data_mean"].reshape(shape),) if data_mean else ()
label_names = (meta[b"label_names"],) if label_names else ()

return (train, test) + data_mean + label_names

Expand All @@ -362,11 +362,11 @@ def spasafe_name(name, pre_comma_only=True):
if len(name) == 0:
raise ValueError("Empty name.")

if pre_comma_only and ',' in name:
name = name.split(',')[0] # part before first comma
if pre_comma_only and "," in name:
name = name.split(",")[0] # part before first comma
name = name.strip()
name = re.sub(r'(\s|-|,)+', '_', name) # repl space/hyphen/comma w undersc
name = re.sub('(^[^a-zA-Z]+)|[^a-zA-Z0-9_]+', '', name) # del other chars
name = re.sub(r"(\s|-|,)+", "_", name) # repl space/hyphen/comma w undersc
name = re.sub("(^[^a-zA-Z]+)|[^a-zA-Z0-9_]+", "", name) # del other chars
name = name[0].upper() + name[1:] # capitalize first letter
return name

Expand All @@ -383,8 +383,9 @@ def spasafe_names(label_names, pre_comma_only=True):
pre_comma_only : boolean
Only use the part of a name before a/the first comma.
"""
vocab_names = [spasafe_name(name, pre_comma_only=pre_comma_only)
for name in label_names]
vocab_names = [
spasafe_name(name, pre_comma_only=pre_comma_only) for name in label_names
]

# number duplicates
unique = set()
Expand All @@ -398,7 +399,7 @@ def spasafe_names(label_names, pre_comma_only=True):
duplicates = {name: 0 for name in duplicates}
for i, name in enumerate(vocab_names):
if name in duplicates:
vocab_names[i] = '%s%d' % (name, duplicates[name])
vocab_names[i] = "%s%d" % (name, duplicates[name])
duplicates[name] += 1

return vocab_names
Expand All @@ -423,7 +424,8 @@ def one_hot_from_labels(labels, classes=None, dtype=float):
n = labels.shape[0]

if np.issubdtype(labels.dtype, np.integer) and (
classes is None or is_integer(classes)):
classes is None or is_integer(classes)
):
index = labels
index_min, index_max = index.min(), index.max()
n_classes = (index_max + 1) if classes is None else classes
Expand Down Expand Up @@ -465,11 +467,10 @@ def __init__(self, beta=1e-2, gamma=1e-5):
self.V = None
self.Sinv = None

def contrast_normalize(self, X, remove_mean=True, beta=None,
hard_beta=True):
def contrast_normalize(self, X, remove_mean=True, beta=None, hard_beta=True):
X = np.asarray(X, dtype=np.float64)
if X.ndim != 2:
raise ValueError('contrast_normalize requires flat patches')
raise ValueError("contrast_normalize requires flat patches")

Xc = X - X.mean(axis=1)[:, None] if remove_mean else X
l2 = (Xc * Xc).sum(axis=1)
Expand Down