import h5py
import numpy as np
import matplotlib.pyplot as plot
from matplotlib.patches import Rectangle
from feature_extraction import time_samples, signal_len

runs = 20

folder_MA_regr = 'results/MA_regressions/'
folder_datasets = 'datasets/'
model = 'NN_1000-200-50-10-1_reg1e-3_lossMSE'
hf = h5py.File(folder_MA_regr + 'regression_' + model + '.h5', 'r')
vehicles = ["CitroenC4Picasso","Mazda3","MercedesAMG550","NissanQashqai","OpelInsignia",
            "Peugeot3008","Peugeot307","RenaultCaptur","RenaultScenic","VWPassat"]
n_veh = len(vehicles)

plot.figure(num=13, figsize=(8.6 / 2.54, 6 / 2.54), dpi=300)
plot.rc('axes', titlesize=8)  # fontsize of the axes title
plot.rc('axes', labelsize=7)  # fontsize of the x and y labels
plot.rc('xtick', labelsize=7)  # fontsize of the tick labels
plot.rc('ytick', labelsize=7)  # fontsize of the tick labels


offsets = np.array([])
max_vehs = np.array([])
max_no_vehs = np.array([])
for run in range(runs):
    for veh_ind in range(n_veh):
        MA_test = np.array(hf[f'{vehicles[veh_ind]}_MA_test'], dtype=np.float64)[run,:].reshape(-1, time_samples)
        MA_test_nv = np.array(hf[f'{vehicles[veh_ind]}_MA_test_nv'], dtype=np.float64)[run,:].reshape(-1, time_samples)
        
        hf2 = h5py.File(folder_datasets + f"data_{vehicles[veh_ind]}.h5", 'r')
        y_test_gt = np.array(hf2['labels'], dtype=np.float64)
        speed_test_gt = np.array(hf2['speed'], dtype=np.float64)
        hf2.close()
        
        temp = np.argmax(y_test_gt, axis=1) - np.argmax(MA_test, axis=1)
        offsets = np.append(offsets, temp)
        
        max_vehs = np.append(max_vehs, np.amax(MA_test,axis=1))
        max_no_vehs = np.append(max_no_vehs, np.amax(MA_test_nv,axis=1))
        
offsets *= signal_len / (time_samples - 1)

hf.close()

#%%
plot.subplot(2,1,1)
# offsets[offsets > 2] = 0.33
x_min, x_max = np.min(offsets), np.max(offsets)
plot.hist(offsets, bins=31, range=(x_min, x_max), color='#1f77b4')

plot.autoscale(enable=True, axis='x', tight=True)
plot.title('Histogram of MA maxima detection offsets', pad=2.5)
plot.xlabel('Time [s]', labelpad=0)
plot.gca().tick_params(direction='out', pad=1, length=2.5)
label = 'Mean: ' + f"{np.mean(offsets):.3f}\n"
label += 'Std: ' + f"{np.std(offsets):.3f}"

text_y = 880
plot.text(0.225, text_y, label, fontsize=6)

#%%
plot.subplot(2,1,2)

plot.hist(max_no_vehs, bins=200, label='Audio without vehicles', color='#ff7f0e')
plot.hist(max_vehs, bins=27, label='Audio with vehicles', color='#1f77b4')
plot.legend(loc='upper right')

plot.autoscale(enable=True, axis='x', tight=True)
plot.title('Histogram of MA maxima values', pad=2.5)
plot.xlabel('Magnitude', labelpad=0)
xticks = np.arange(0,16,2)
plot.gca().set_xticks(xticks)
plot.gca().tick_params(direction='out', pad=1, length=2.5)

axes = plot.gca()
x1 = np.max(max_no_vehs)
x2 = np.min(max_vehs)
print(x2-x1)
axes.add_patch(Rectangle((x1, 0), x2-x1, axes.get_ylim()[1], facecolor='#2ca02c', alpha=0.2, label='No detected values'))

plot.legend(fontsize=5, handlelength=2, labelspacing=0.2)


plot.gcf().subplots_adjust(bottom=0.118, top=0.938, left=0.085, right=0.993, hspace=0.5)
# plot.savefig("C:/Moji radovi/Robust acoustic vehicle speed estimation from single microphone measurements/Fig7.png")
plot.show()