Skip to content
Extraits de code Groupes Projets
Valider a0fa316d rédigé par Adrian Kneip's avatar Adrian Kneip
Parcourir les fichiers

Upload New File

parent 2ef3acd5
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
# Copyright 2017 Bert Moons
# This file is part of QNN.
# QNN is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# QNN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# The code for QNN is based on BinaryNet: https://github.com/MatthieuCourbariaux/BinaryNet
# You should have received a copy of the GNU General Public License
# along with QNN. If not, see <http://www.gnu.org/licenses/>.
import numpy as np
from my_datasets import my_cifar10
from my_datasets import my_mnist
def load_dataset(dataset):
if (dataset == "CIFAR-10"):
print('Loading CIFAR-10 dataset...')
#train_set_size = 45000
#train_set = cifar10(which_set="train", start=0, stop=train_set_size)
#valid_set = cifar10(which_set="train", start=train_set_size, stop=50000)
#test_set = cifar10(which_set="test")
(train_set_X,train_set_Y),(valid_set_X,valid_set_Y) = my_cifar10.load_data()
train_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., train_set_X), 1.), (-1, 3, 32, 32)),(0,2,3,1))
valid_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., valid_set_X), 1.), (-1, 3, 32, 32)),(0,2,3,1))
#test_set.X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., test_set.X), 1.), (-1, 3, 32, 32)),(0,2,3,1))
# flatten targets
train_set_Y = np.hstack(train_set_Y)
valid_set_Y = np.hstack(valid_set_Y)
#test_set.y = np.hstack(test_set.y)
# Onehot the targets
train_set_Y = np.float32(np.eye(10)[train_set_Y])
valid_set_Y = np.float32(np.eye(10)[valid_set_Y])
#test_set.y = np.float32(np.eye(10)[test_set.y])
# for hinge loss
train_set_Y = 2 * train_set_Y - 1.
valid_set_Y = 2 * valid_set_Y - 1.
#test_set.y = 2 * test_set.y - 1.
# enlarge train data set by mirrroring
x_train_flip = train_set_X[:, :, ::-1, :]
y_train_flip = train_set_Y
train_set_X = np.concatenate((train_set_X, x_train_flip), axis=0)
train_set_Y = np.concatenate((train_set_Y, y_train_flip), axis=0)
elif (dataset == "MNIST"):
print('Loading MNIST dataset...')
#train_set_size = 50000
#train_set = mnist(which_set="train", start=0, stop=train_set_size)
#valid_set = mnist(which_set="train", start=train_set_size, stop=60000)
#test_set = mnist(which_set="test")
path_to_file = '../my_datasets/mnist.npz'
(train_set_X,train_set_Y),(valid_set_X,valid_set_Y) = my_mnist.load_data(path_to_file)
train_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., train_set_X), 1.), (-1, 1, 28, 28)),(0,2,3,1))
valid_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., valid_set_X), 1.), (-1, 1, 28, 28)),(0,2,3,1))
#test_set.X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., test_set.X), 1.), (-1, 1, 28, 28)),(0,2,3,1))
# flatten targets
train_set_Y = np.hstack(train_set_Y)
valid_set_Y = np.hstack(valid_set_Y)
#test_set.y = np.hstack(test_set.y)
# Onehot the targets
train_set_Y = np.float32(np.eye(10)[train_set_Y])
valid_set_Y = np.float32(np.eye(10)[valid_set_Y])
#test_set.y = np.float32(np.eye(10)[test_set.y])
# for hinge loss
train_set_Y = 2 * train_set_Y - 1.
valid_set_Y = 2 * valid_set_Y - 1.
#test_set.y = 2 * test_set.y - 1.
# enlarge train data set by mirrroring
x_train_flip = train_set_X[:, :, ::-1, :]
y_train_flip = train_set_Y
train_set_X = np.concatenate((train_set_X, x_train_flip), axis=0)
train_set_Y = np.concatenate((train_set_Y, y_train_flip), axis=0)
else:
print("wrong dataset given")
train_set = (train_set_X,train_set_Y)
valid_set = (valid_set_X,valid_set_Y)
return train_set, valid_set
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter