diff --git a/utils/load_data.py b/utils/load_data.py new file mode 100644 index 0000000000000000000000000000000000000000..ca76fafe5a493e2a5ce716e2ed1389542faa3624 --- /dev/null +++ b/utils/load_data.py @@ -0,0 +1,105 @@ +# Copyright 2017 Bert Moons + +# This file is part of QNN. + +# QNN is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# QNN is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +# The code for QNN is based on BinaryNet: https://github.com/MatthieuCourbariaux/BinaryNet + +# You should have received a copy of the GNU General Public License +# along with QNN. If not, see <http://www.gnu.org/licenses/>. + +import numpy as np +from my_datasets import my_cifar10 +from my_datasets import my_mnist + + +def load_dataset(dataset): + if (dataset == "CIFAR-10"): + + print('Loading CIFAR-10 dataset...') + + #train_set_size = 45000 + #train_set = cifar10(which_set="train", start=0, stop=train_set_size) + #valid_set = cifar10(which_set="train", start=train_set_size, stop=50000) + #test_set = cifar10(which_set="test") + (train_set_X,train_set_Y),(valid_set_X,valid_set_Y) = my_cifar10.load_data() + + train_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., train_set_X), 1.), (-1, 3, 32, 32)),(0,2,3,1)) + valid_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., valid_set_X), 1.), (-1, 3, 32, 32)),(0,2,3,1)) + #test_set.X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., test_set.X), 1.), (-1, 3, 32, 32)),(0,2,3,1)) + + # flatten targets + train_set_Y = np.hstack(train_set_Y) + valid_set_Y = np.hstack(valid_set_Y) + #test_set.y = np.hstack(test_set.y) + + # Onehot the targets + train_set_Y = np.float32(np.eye(10)[train_set_Y]) + valid_set_Y = np.float32(np.eye(10)[valid_set_Y]) + #test_set.y = np.float32(np.eye(10)[test_set.y]) + + # for hinge loss + train_set_Y = 2 * train_set_Y - 1. + valid_set_Y = 2 * valid_set_Y - 1. + #test_set.y = 2 * test_set.y - 1. + + # enlarge train data set by mirrroring + x_train_flip = train_set_X[:, :, ::-1, :] + y_train_flip = train_set_Y + train_set_X = np.concatenate((train_set_X, x_train_flip), axis=0) + train_set_Y = np.concatenate((train_set_Y, y_train_flip), axis=0) + + elif (dataset == "MNIST"): + + print('Loading MNIST dataset...') + + #train_set_size = 50000 + #train_set = mnist(which_set="train", start=0, stop=train_set_size) + #valid_set = mnist(which_set="train", start=train_set_size, stop=60000) + #test_set = mnist(which_set="test") + path_to_file = '../my_datasets/mnist.npz' + (train_set_X,train_set_Y),(valid_set_X,valid_set_Y) = my_mnist.load_data(path_to_file) + + train_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., train_set_X), 1.), (-1, 1, 28, 28)),(0,2,3,1)) + valid_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., valid_set_X), 1.), (-1, 1, 28, 28)),(0,2,3,1)) + #test_set.X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., test_set.X), 1.), (-1, 1, 28, 28)),(0,2,3,1)) + + # flatten targets + train_set_Y = np.hstack(train_set_Y) + valid_set_Y = np.hstack(valid_set_Y) + #test_set.y = np.hstack(test_set.y) + + # Onehot the targets + train_set_Y = np.float32(np.eye(10)[train_set_Y]) + valid_set_Y = np.float32(np.eye(10)[valid_set_Y]) + #test_set.y = np.float32(np.eye(10)[test_set.y]) + + # for hinge loss + train_set_Y = 2 * train_set_Y - 1. + valid_set_Y = 2 * valid_set_Y - 1. + #test_set.y = 2 * test_set.y - 1. + + # enlarge train data set by mirrroring + x_train_flip = train_set_X[:, :, ::-1, :] + y_train_flip = train_set_Y + train_set_X = np.concatenate((train_set_X, x_train_flip), axis=0) + train_set_Y = np.concatenate((train_set_Y, y_train_flip), axis=0) + + + + + else: + print("wrong dataset given") + + train_set = (train_set_X,train_set_Y) + valid_set = (valid_set_X,valid_set_Y) + return train_set, valid_set