Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright 2017 Bert Moons
# This file is part of QNN.
# QNN is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# QNN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# The code for QNN is based on BinaryNet: https://github.com/MatthieuCourbariaux/BinaryNet
# You should have received a copy of the GNU General Public License
# along with QNN. If not, see <http://www.gnu.org/licenses/>.
import numpy as np
from my_datasets import my_cifar10
from my_datasets import my_mnist
def load_dataset(dataset):
if (dataset == "CIFAR-10"):
print('Loading CIFAR-10 dataset...')
#train_set_size = 45000
#train_set = cifar10(which_set="train", start=0, stop=train_set_size)
#valid_set = cifar10(which_set="train", start=train_set_size, stop=50000)
#test_set = cifar10(which_set="test")
(train_set_X,train_set_Y),(valid_set_X,valid_set_Y) = my_cifar10.load_data()
train_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., train_set_X), 1.), (-1, 3, 32, 32)),(0,2,3,1))
valid_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., valid_set_X), 1.), (-1, 3, 32, 32)),(0,2,3,1))
#test_set.X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., test_set.X), 1.), (-1, 3, 32, 32)),(0,2,3,1))
# flatten targets
train_set_Y = np.hstack(train_set_Y)
valid_set_Y = np.hstack(valid_set_Y)
#test_set.y = np.hstack(test_set.y)
# Onehot the targets
train_set_Y = np.float32(np.eye(10)[train_set_Y])
valid_set_Y = np.float32(np.eye(10)[valid_set_Y])
#test_set.y = np.float32(np.eye(10)[test_set.y])
# for hinge loss
train_set_Y = 2 * train_set_Y - 1.
valid_set_Y = 2 * valid_set_Y - 1.
#test_set.y = 2 * test_set.y - 1.
# enlarge train data set by mirrroring
x_train_flip = train_set_X[:, :, ::-1, :]
y_train_flip = train_set_Y
train_set_X = np.concatenate((train_set_X, x_train_flip), axis=0)
train_set_Y = np.concatenate((train_set_Y, y_train_flip), axis=0)
elif (dataset == "MNIST"):
print('Loading MNIST dataset...')
#train_set_size = 50000
#train_set = mnist(which_set="train", start=0, stop=train_set_size)
#valid_set = mnist(which_set="train", start=train_set_size, stop=60000)
#test_set = mnist(which_set="test")
path_to_file = '../my_datasets/mnist.npz'
(train_set_X,train_set_Y),(valid_set_X,valid_set_Y) = my_mnist.load_data(path_to_file)
train_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., train_set_X), 1.), (-1, 1, 28, 28)),(0,2,3,1))
valid_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., valid_set_X), 1.), (-1, 1, 28, 28)),(0,2,3,1))
#test_set.X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., test_set.X), 1.), (-1, 1, 28, 28)),(0,2,3,1))
# flatten targets
train_set_Y = np.hstack(train_set_Y)
valid_set_Y = np.hstack(valid_set_Y)
#test_set.y = np.hstack(test_set.y)
# Onehot the targets
train_set_Y = np.float32(np.eye(10)[train_set_Y])
valid_set_Y = np.float32(np.eye(10)[valid_set_Y])
#test_set.y = np.float32(np.eye(10)[test_set.y])
# for hinge loss
train_set_Y = 2 * train_set_Y - 1.
valid_set_Y = 2 * valid_set_Y - 1.
#test_set.y = 2 * test_set.y - 1.
# enlarge train data set by mirrroring
x_train_flip = train_set_X[:, :, ::-1, :]
y_train_flip = train_set_Y
train_set_X = np.concatenate((train_set_X, x_train_flip), axis=0)
train_set_Y = np.concatenate((train_set_Y, y_train_flip), axis=0)
else:
print("wrong dataset given")
train_set = (train_set_X,train_set_Y)
valid_set = (valid_set_X,valid_set_Y)
return train_set, valid_set