import h5py
import numpy as np
import matplotlib.pyplot as plot
from scipy.stats import sem, t

###############################################################################

folder = 'results/speed_estimations/'
model = 'NN_1000-200-50-10-1_reg1e-3_lossMSE'
vehicles = ["CitroenC4Picasso","Mazda3","MercedesAMG550","NissanQashqai","OpelInsignia",
            "Peugeot3008","Peugeot307","RenaultCaptur","RenaultScenic","VWPassat"]
n_veh = len(vehicles)
runs = 20

plot.figure(num=123, figsize=(17.5 / 2.54, 9 / 2.54), dpi=300)
plot.rc('axes', titlesize=8)  # fontsize of the axes title
plot.rc('axes', labelsize=5)  # fontsize of the x and y labels
plot.rc('xtick', labelsize=5)  # fontsize of the tick labels
plot.rc('ytick', labelsize=5)  # fontsize of the tick labels

err = 0
speed_samp = 0
colors = ['#ff7f0e']

for veh_ind in range(n_veh):    
    if veh_ind in [8,9]:
        plot.subplot(3,4,veh_ind+2)
    else:
        plot.subplot(3,4,veh_ind+1)
    
    hf = h5py.File(f'{folder}speed_estimations_{model}.h5', 'r')
    speed_est = np.array(hf[vehicles[veh_ind] + '_speeds_est_all'], dtype=np.float64)
    speed_gt = np.array(hf[vehicles[veh_ind] + '_speeds_gt'], dtype=int)
    hf.close()

    ind_sort = np.argsort(speed_gt)
    speed_gt = speed_gt[ind_sort]
    speed_est = speed_est[:,ind_sort]

    err += np.sum(np.square(speed_est - speed_gt))
    speed_samp += speed_est.size

    speed_est_mean = np.mean(speed_est, axis=0)
    speed_est_std_err = sem(speed_est, axis=0)
    confidence = 0.95
    h = speed_est_std_err * t.ppf((1 + confidence) / 2, speed_est.shape[0]-1)

    plot.plot(speed_gt, color='dodgerblue', linewidth=0.5, label='True speed')
    plot.fill_between(np.arange(speed_gt.size),
                      speed_est_mean - h,
                      speed_est_mean + h,
                      color=colors[0],
                      label='Estimated speed',
                      linewidth=0)

    plot.title(f"{vehicles[veh_ind]}", fontdict={'fontsize' : 7}, pad=3)
    plot.grid(ls=':', lw=0.5)
    plot.autoscale(enable=True, axis='x', tight=True)
    if speed_gt.size < 30:
        plot.xticks(np.arange(speed_gt.size)[0::2], np.arange(speed_gt.size)[0::2]+1)
    else:
        plot.xticks(np.arange(speed_gt.size)[0::3], np.arange(speed_gt.size)[0::3]+1)
    plot.gca().set_ylim([0, 107])
    plot.gca().tick_params(axis='x', direction='in', pad=2, length=1.5)
    plot.gca().tick_params(axis='y', direction='in', pad=1.5, length=1.5)
    plot.xlabel('Speed index', labelpad=0.5)
    plot.ylabel('Speed [km/h]', labelpad=-1.5)
    plot.legend(loc='lower right', fontsize=4, handlelength=2, labelspacing=0.2)
        

plot.gcf().subplots_adjust(bottom=0.048, top=0.965, left=0.03, right=0.99, hspace=0.4, wspace=0.17)
# plot.savefig("C:/Moji radovi/Robust acoustic vehicle speed estimation from single microphone measurements/Fig8.png")
plot.show()
    
hf.close()