import os

# In case you use Google colab for training, uncomment the following three lines
# from google.colab import drive
# drive.mount('/content/drive/')
# os.chdir('/content/drive/My Drive/SpeedEstimation')

import h5py
import numpy as np
from sklearn.preprocessing import StandardScaler
from keras import regularizers, activations
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.callbacks import ModelCheckpoint
import keras.backend as K
import time
import project_tools as pt
from feature_extraction import time_samples, feat_num

###############################################################################
test_veh_ind = 0  ### Test vehicle index

runs = 20
layer_units_NN = [200, 50, 10, 1]
batch_size = 256
epochs = 200
loss = 'MSE'
regul = 1e-3
regul_str = f"{regul:.0e}".replace('0','')
NN_model = f"NN_{feat_num}-" + str(layer_units_NN)[1:-1].replace(', ','-') + f"_reg{regul_str}_loss{loss}"

vehicles = ["CitroenC4Picasso","Mazda3","MercedesAMG550","NissanQashqai","OpelInsignia",
             "Peugeot3008","Peugeot307","RenaultCaptur","RenaultScenic","VWPassat"]
test_veh_name = vehicles[test_veh_ind]
model_vehicle = NN_model + f"__{test_veh_name}"

###############################################################################
def loss_MAE(y_true, y_pred):
    return K.mean(K.abs(y_pred - y_true))


def get_nn_model(input_dim, layer_units, loss, reg_coeff):
    model = Sequential()
    model.add(Dense(layer_units[0], input_dim=input_dim, kernel_regularizer=regularizers.l2(reg_coeff)))
    model.add(Activation(activations.relu))
    for lu in layer_units[1:-1]:
        model.add(Dense(lu, kernel_regularizer=regularizers.l2(reg_coeff)))
        model.add(Activation(activations.relu))
    model.add(Dense(layer_units[-1], activation='linear'))
    
    if loss == 'MAE':
        model.compile(optimizer='Adam', loss=loss_MAE)
    elif loss == 'MSE':
        model.compile(optimizer='Adam', loss='mse')

    return model


def scale_data(x_train, x_valid, x_test):
    scaler = StandardScaler()
    scaler.fit(x_train)
    x_train = scaler.transform(x_train)
    x_valid = scaler.transform(x_valid)
    x_test = scaler.transform(x_test)
    
    return x_train, x_valid, x_test, scaler

###############################################################################
datasets = []
for vehicle in vehicles:
    datasets.append(f"datasets/data_{vehicle}.h5")
datasets.append("datasets/data_NoCar.h5")

#####  Extra no vehicle files, used only for testing  #####
hf = h5py.File("datasets/data_NoCarTest.h5", 'r')
x_test_nv = np.array(hf['features'], dtype=np.float64)
y_test_nv = np.array(hf['labels'], dtype=np.float64)
speed_test_nv = np.array(hf['speed'], dtype=np.float64)
hf.close()
x_test_nv = x_test_nv.reshape(x_test_nv.shape[0] * x_test_nv.shape[1], x_test_nv.shape[2])
y_test_nv = y_test_nv.reshape(-1, )
###########################################################

if not os.path.exists('models'):
    os.makedirs('models')       # create folder to store store NN model coefficients during training
    
for run in range(runs):
    dataset_train = datasets.copy()
    del dataset_train[test_veh_ind]
    dataset_test = datasets[test_veh_ind]

    ind1 = dataset_train[0].find("_")
    ind2 = dataset_train[0].find(".", ind1+1)
    veh_train = dataset_train[0][ind1+1:ind2]
      
    hf = h5py.File(dataset_train[0], 'r')
    speed_train = np.array(hf['speed'], dtype=np.float64)
    x_train = np.array(hf['features'], dtype=np.float64)
    y_train = np.array(hf['labels'], dtype=np.float64)
    tv = np.array(hf['train_valid'])
    tv = np.array([str(item,'utf-8') for item in tv])
    hf.close()
      
    train_ind = np.argwhere(tv == 'train').flatten()
    valid_ind = np.argwhere(tv == 'valid').flatten()
    
    x_valid = x_train[valid_ind,:,:].copy()
    y_valid = y_train[valid_ind,:].copy()
    speed_valid = speed_train[valid_ind].copy()
    x_train = x_train[train_ind,:,:]
    y_train = y_train[train_ind,:]
    speed_train = speed_train[train_ind]
      
    for k in range(1, len(dataset_train)):
        ind1 = dataset_train[k].find("_")
        ind2 = dataset_train[k].find(".", ind1+1)
        veh_train = dataset_train[k][ind1+1:ind2]

        hf = h5py.File(dataset_train[k], 'r')
        speed_train2 = np.array(hf['speed'], dtype=np.float64)
        x_train2 = np.array(hf['features'], dtype=np.float64)
        y_train2 = np.array(hf['labels'], dtype=np.float64)
        tv2 = np.array(hf['train_valid'])
        tv2 = np.array([str(item,'utf-8') for item in tv2])

        train_ind = np.argwhere(tv2 == 'train').flatten()
        valid_ind = np.argwhere(tv2 == 'valid').flatten()

        x_valid2 = x_train2[valid_ind,:,:].copy()
        y_valid2 = y_train2[valid_ind,:].copy()
        speed_valid2 = speed_train2[valid_ind].copy()
        x_train2 = x_train2[train_ind,:,:]
        y_train2 = y_train2[train_ind,:]
        speed_train2 = speed_train2[train_ind]
        
        x_valid = np.append(x_valid, x_valid2, axis=0)
        y_valid = np.append(y_valid, y_valid2, axis=0)
        speed_valid = np.append(speed_valid, speed_valid2, axis=0)
        x_train = np.append(x_train, x_train2, axis=0)
        y_train = np.append(y_train, y_train2, axis=0)
        speed_train = np.append(speed_train, speed_train2, axis=0)

        hf.close()
    
    hf = h5py.File(dataset_test, 'r')
    x_test = np.array(hf['features'], dtype=np.float64)
    y_test = np.array(hf['labels'], dtype=np.float64)
    speed_test = np.array(hf['speed'], dtype=np.float64)
    hf.close()

    x_train = x_train.reshape(x_train.shape[0] * x_train.shape[1], x_train.shape[2])
    y_train = y_train.reshape(-1, )
    x_valid = x_valid.reshape(x_valid.shape[0] * x_valid.shape[1], x_valid.shape[2])
    y_valid = y_valid.reshape(-1, )
    x_test = x_test.reshape(x_test.shape[0] * x_test.shape[1], x_test.shape[2])
    y_test = y_test.reshape(-1, )
    
    x_train, x_valid, x_test, scaler1 = scale_data(x_train, x_valid, x_test)
    x_test_nv_1 = scaler1.transform(x_test_nv.copy())

    if run == 0:
        y_train_pred = np.zeros((runs, y_train.size))
        y_valid_pred = np.zeros((runs, y_valid.size))
        y_test_pred = np.zeros((runs, y_test.size))
        y_test_pred_nv = np.zeros((runs, y_test_nv.size))

    # Training NN
    K.clear_session()  # for multiple runs, clear the session
    model_NN = get_nn_model(x_train.shape[-1], layer_units_NN, loss, regul)
    callbacks = [ModelCheckpoint(filepath='models/NNcoefs.hdf5', monitor='val_loss', verbose=0, mode='min', save_best_only=True)]
    x_train_shuff, y_train_shuff = pt.shuffle_data(x_train, y_train)
    history = model_NN.fit(x_train_shuff, 
                           y_train_shuff, 
                           validation_data = (x_valid, y_valid), 
                           shuffle = True,
                           epochs = epochs, 
                           verbose = 0, 
                           batch_size = batch_size, 
                           callbacks = callbacks)
    val_loss_NN = history.history['val_loss']
    mean_val_loss = np.mean(val_loss_NN[-int(epochs/2):])

    while True:
        try:
            time.sleep(2)
            model_NN.load_weights('models/NNcoefs.hdf5')
            break
        except KeyError:
            print("Unable to open object. Trying again...")
    
    y_train_pred_NN = model_NN.predict(x_train, verbose=0).reshape(-1, time_samples)
    y_train_pred[run,:] = y_train_pred_NN.flatten()
    y_valid_pred_NN = model_NN.predict(x_valid, verbose=0).reshape(-1, time_samples)
    y_valid_pred[run,:] = y_valid_pred_NN.flatten()
    y_test_pred_NN = model_NN.predict(x_test, verbose=0).reshape(-1, time_samples)
    y_test_pred[run,:] = y_test_pred_NN.flatten()
    y_test_pred_nv_NN = model_NN.predict(x_test_nv_1, verbose=0).reshape(-1, time_samples)
    y_test_pred_nv[run,:] = y_test_pred_nv_NN.flatten()

    mean_test_loss = np.mean((y_test_pred_NN - y_test.reshape(-1, time_samples)) ** 2)
    print(f"Run: {run}, mean_val_loss: {mean_val_loss:.5f}, mean_test_loss: {mean_test_loss:.5f} ")

hf2 = h5py.File(f'results/MA_regressions/regression_{model_vehicle}.h5', 'w')
hf2.create_dataset(test_veh_name + '_MA_train', data=y_train_pred, compression="gzip")
hf2.create_dataset(test_veh_name + '_MA_valid', data=y_valid_pred, compression="gzip")
hf2.create_dataset(test_veh_name + '_MA_test', data=y_test_pred, compression="gzip")
hf2.create_dataset(test_veh_name + '_MA_test_nv', data=y_test_pred_nv, compression="gzip")
hf2.create_dataset(test_veh_name + '_speeds_train', data=speed_train, compression="gzip")
hf2.create_dataset(test_veh_name + '_speeds_valid', data=speed_valid, compression="gzip")
hf2.create_dataset(test_veh_name + '_speeds_test', data=speed_test, compression="gzip")
hf2.close()