Skip to content
Extraits de code Groupes Projets
load_data.py 4,19 ko
Newer Older
  • Learn to ignore specific revisions
  • Adrian Kneip's avatar
    Adrian Kneip a validé
    # Copyright 2017 Bert Moons
    
    # This file is part of QNN.
    
    # QNN is free software: you can redistribute it and/or modify
    # it under the terms of the GNU General Public License as published by
    # the Free Software Foundation, either version 3 of the License, or
    # (at your option) any later version.
    
    # QNN is distributed in the hope that it will be useful,
    # but WITHOUT ANY WARRANTY; without even the implied warranty of
    # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    # GNU General Public License for more details.
    
    # The code for QNN is based on BinaryNet: https://github.com/MatthieuCourbariaux/BinaryNet
    
    # You should have received a copy of the GNU General Public License
    # along with QNN.  If not, see <http://www.gnu.org/licenses/>.
    
    import numpy as np
    from my_datasets import my_cifar10
    from my_datasets import my_mnist
    
    
    def load_dataset(dataset):
        if (dataset == "CIFAR-10"):
    
            print('Loading CIFAR-10 dataset...')
    
            #train_set_size = 45000
            #train_set = cifar10(which_set="train", start=0, stop=train_set_size)
            #valid_set = cifar10(which_set="train", start=train_set_size, stop=50000)
            #test_set = cifar10(which_set="test")
            (train_set_X,train_set_Y),(valid_set_X,valid_set_Y) = my_cifar10.load_data()
    
            train_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., train_set_X), 1.), (-1, 3, 32, 32)),(0,2,3,1))
            valid_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., valid_set_X), 1.), (-1, 3, 32, 32)),(0,2,3,1))
            #test_set.X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., test_set.X), 1.), (-1, 3, 32, 32)),(0,2,3,1))
       
            # flatten targets
            train_set_Y = np.hstack(train_set_Y)
            valid_set_Y = np.hstack(valid_set_Y)
            #test_set.y = np.hstack(test_set.y)
    
            # Onehot the targets
            train_set_Y = np.float32(np.eye(10)[train_set_Y])
            valid_set_Y = np.float32(np.eye(10)[valid_set_Y])
            #test_set.y = np.float32(np.eye(10)[test_set.y])
    
            # for hinge loss
            train_set_Y = 2 * train_set_Y - 1.
            valid_set_Y = 2 * valid_set_Y - 1.
            #test_set.y = 2 * test_set.y - 1.
            
            # enlarge train data set by mirrroring
            x_train_flip = train_set_X[:, :, ::-1, :]
            y_train_flip = train_set_Y
            train_set_X = np.concatenate((train_set_X, x_train_flip), axis=0)
            train_set_Y = np.concatenate((train_set_Y, y_train_flip), axis=0)
    
        elif (dataset == "MNIST"):
    
            print('Loading MNIST dataset...')
    
            #train_set_size = 50000
            #train_set = mnist(which_set="train", start=0, stop=train_set_size)
            #valid_set = mnist(which_set="train", start=train_set_size, stop=60000)
            #test_set = mnist(which_set="test")
            path_to_file = '../my_datasets/mnist.npz'
            (train_set_X,train_set_Y),(valid_set_X,valid_set_Y) = my_mnist.load_data(path_to_file)
    
            train_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., train_set_X), 1.), (-1, 1, 28, 28)),(0,2,3,1))
            valid_set_X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., valid_set_X), 1.), (-1, 1,  28, 28)),(0,2,3,1))
            #test_set.X = np.transpose(np.reshape(np.subtract(np.multiply(2. / 255., test_set.X), 1.), (-1, 1,  28, 28)),(0,2,3,1))
            
            # flatten targets
            train_set_Y = np.hstack(train_set_Y)
            valid_set_Y = np.hstack(valid_set_Y)
            #test_set.y = np.hstack(test_set.y)
    
            # Onehot the targets
            train_set_Y = np.float32(np.eye(10)[train_set_Y])
            valid_set_Y = np.float32(np.eye(10)[valid_set_Y])
            #test_set.y = np.float32(np.eye(10)[test_set.y])
    
            # for hinge loss
            train_set_Y = 2 * train_set_Y - 1.
            valid_set_Y = 2 * valid_set_Y - 1.
            #test_set.y = 2 * test_set.y - 1.
            
            # enlarge train data set by mirrroring
            x_train_flip = train_set_X[:, :, ::-1, :]
            y_train_flip = train_set_Y
            train_set_X = np.concatenate((train_set_X, x_train_flip), axis=0)
            train_set_Y = np.concatenate((train_set_Y, y_train_flip), axis=0)
    
    
    
    
        else:
            print("wrong dataset given")
    
        train_set = (train_set_X,train_set_Y)
        valid_set = (valid_set_X,valid_set_Y)
        return train_set, valid_set