import numpy as np
import h5py
import matplotlib.pyplot as plot
from feature_extraction import time_samples, signal_len


folder_MA_regr = 'results/MA_regressions/'
folder_speeds = 'results/speed_estimations/'
model = 'NN_1000-200-50-10-1_reg1e-3_lossMSE'
hf = h5py.File(folder_MA_regr + 'regression_' + model + '.h5', 'r')
vehicles = ["CitroenC4Picasso","Mazda3","MercedesAMG550","NissanQashqai","OpelInsignia",
            "Peugeot3008","Peugeot307","RenaultCaptur","RenaultScenic","VWPassat"]

vehicles_full = ["Citroen C4\nPicasso","Mazda 3\nSkyactive","Mercedes\nAMG 550","Nissan\nQashqai","Opel\nInsignia",
        "Peugeot\n3008","Peugeot\n307","Renault\nCaptur","Renault\nScenic","VW Passat\nB7"]
speeds = np.array([[35, 54, 74, 101],
                   [33, 52, 81, 103],
                   [38, 58, 78, 100],
                   [40, 61, 85, 102],
                   [38, 61, 80, 100],
                   [40, 60, 83, 100],
                   [33, 56, 79, 101],
                   [30, 52, 76, 98],
                   [35, 52, 86, 101],
                   [39, 65, 81, 100]])
n_veh = len(vehicles)
n_speeds = speeds.shape[1]

plot.figure(num=113, figsize=(17.8 / 2.54, 7.2 / 2.54), dpi=300)
plot.rc('axes', titlesize=8)  # fontsize of the axes title
plot.rc('axes', labelsize=5)  # fontsize of the x and y labels
plot.rc('xtick', labelsize=4)  # fontsize of the tick labels
plot.rc('ytick', labelsize=4)  # fontsize of the tick labels

time = np.linspace(0,signal_len,time_samples)
run = 4
time_win = 3
samps = int(time_win / signal_len * time_samples)  # show time_win seconds before and after GT maximum
for i in range(n_veh):
    MA_test = np.array(hf[vehicles[i] + '_MA_test'], dtype=np.float64)[run,:]
    MA_test = MA_test.reshape(-1, time_samples)

    hf2 = h5py.File(f"datasets/data_{vehicles[i]}.h5", 'r')
    y_test_GT = np.array(hf2['labels'], dtype=np.float64)
    speed_test_GT = np.array(hf2['speed'], dtype=np.float64)
    hf2.close()
    for j in range(n_speeds):
        speed_ind = np.argwhere(speed_test_GT == speeds[i,j]).flatten()[0]
        
        plot.subplot(n_speeds, n_veh, j*n_veh + i + 1)
        max_ind = np.argmax(y_test_GT[speed_ind,:])
        time_cut = time[max_ind - samps : max_ind + samps + 1] - time[max_ind]
        plot.plot(time_cut, y_test_GT[speed_ind, max_ind-samps : max_ind+samps+1], color='k', linestyle='--', linewidth=0.25)
        plot.plot(time_cut, MA_test[speed_ind, max_ind-samps : max_ind+samps+1], color='r', linewidth=0.3)
        
        maxv = max(np.max(y_test_GT[speed_ind,:]), np.max(MA_test[speed_ind,:]))
        
        if j == 0:
            plot.title(f"{vehicles_full[i]}", fontdict={'fontsize' : 6}, pad=3)
        plot.xticks([], [])
        yticks = np.arange(3, maxv, 3)
        plot.yticks(yticks)

        plot.gca().text(-2.7, 0.92*maxv, f"{speeds[i,j]}km/h", fontsize=4)
        if j == n_speeds-1:
            plot.xticks([-time_win, 0, time_win], [-time_win, 0, time_win])
            plot.xlabel('Time [s]', labelpad=0)
            plot.gca().tick_params(axis='x', pad=1, length=1)
        # plot.legend(fontsize=6)
        plot.autoscale(enable=True, axis='x', tight=True)
        plot.gca().tick_params(direction='in', axis='y', pad=0.8, length=1)
        
        plot.grid(color='#555555', linestyle=':', linewidth=0.3, axis='y')

plot.gcf().subplots_adjust(bottom=0.055, top=0.925, left=0.011, right=0.992, hspace=0.1, wspace=0.15)
# plot.savefig("C:/Moji radovi/Robust acoustic vehicle speed estimation from single microphone measurements/Fig6.png")
plot.show()