-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar_10.py
40 lines (31 loc) · 1.11 KB
/
cifar_10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# -*- coding: utf-8 -*-
"""CIFAR-10.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1gz2WNaAci8BVhAa4ON5xSsjZwzi7zIp_
"""
import tensorflow as tf
import numpy as np
def load_dataset(path,batchid):
with open(path + '/data_batch_' + str(batchid), mode='rb') as file:
batch=pickle.load(file,encoding='latin')
features=batch['data'].reshape((len(batch['data']),3,32,32)).transpose(0,2,3,1)
labels=batch['label']
return features,labels
def load_names():
return ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def normalise(x):
min_val=np.min(x)
max_val=np.max(x)
x=(x-min_val)/(max_val-min_val)
def one_hot_encoding(x):
encoded=np.zeroes((len(x),10))a
def preprocess_save(normalise,one_hot_encoding,features,labels,file):
features=normalise(features)
labels=one_hot_encoding(labels)
pickle.dump((features,labels),open(filename,'wb'))
def preprocess_save_data(path,normalise,one_hot_encoding):
n_batches=5
valid_features=[]
valid_labels=[]
for batch_i in range(1,n_batches+1):