#!/usr/bin/python3
import os,sys,math
import numpy as np
import cv2
import gzip #need to use gzip.open instead of open
import struct

import torch

def read_MNIST_label_file(fname):
    #fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
    fp = gzip.open(fname,'rb');
    magic = fp.read(4);
    #nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
    bts = fp.read(4);
    #bts = bytereverse(bts);
    #nitems = np.frombuffer(bts,dtype=np.int32);
    nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
    #> < @ - endianness

    bts = fp.read(nitems);
    N = len(bts);
    labels = np.zeros((N),dtype=np.uint8);
    labels = np.frombuffer(bts,dtype=np.uint8,count=N);
    #for i in range(0,10):
    #    bt = fp.read(1);
    #    labels[i] = np.frombuffer(bt,dtype=np.uint8);
    fp.close();
    return labels;

def read_MNIST_image_file(fname):
    fp = gzip.open(fname,'rb');
    magic = fp.read(4);
    bts = fp.read(4);
    nitems = np.int32(struct.unpack('>I',bts)[0]);
    bts = fp.read(4);
    nrows = np.int32(struct.unpack('>I',bts)[0]);
    bts = fp.read(4);
    ncols = np.int32(struct.unpack('>I',bts)[0]);

    images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
    for I in range(0,nitems):
        bts = fp.read(nrows*ncols);
        img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
        img1 = img1.reshape((nrows,ncols));
        images[I,:,:] = img1;

    fp.close();

    return images;

#The mnist dataset is small enough to fit entirely in memory
def mnist_load():
    baseloc = "../training_data"

    traindatafile = "train-images-idx3-ubyte.gz"
    trainlabelfile = "train-labels-idx1-ubyte.gz"
    testdatafile = "t10k-images-idx3-ubyte.gz"
    testlabelfile = "t10k-labels-idx1-ubyte.gz"

    traindatafile = os.path.join(baseloc,traindatafile)
    trainlabelfile = os.path.join(baseloc,trainlabelfile)
    testdatafile = os.path.join(baseloc,testdatafile)
    testlabelfile = os.path.join(baseloc,testlabelfile)
    
    labels_train = read_MNIST_label_file(trainlabelfile)
    labels_test = read_MNIST_label_file(testlabelfile)
    images_train = read_MNIST_image_file(traindatafile)
    images_test = read_MNIST_image_file(testdatafile)

    labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
    labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
    images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
    images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
    
    # #debug
    # print(labels_train.shape)
    # print(labels_test.shape)
    # print(images_train.shape)
    # print(images_test.shape)
    

    return [labels_train, labels_test, images_train, images_test]

if(__name__ == "__main__"):
    [labels_train, labels_test, images_train, images_test] = mnist_load()
    print("Loaded MNIST Data")
    print(labels_train.shape)
    print(labels_test.shape)
    print(images_train.shape)
    print(images_test.shape)