import h5py
import numpy as np
import project_tools as pt
from importlib import reload
reload(pt)

###############################################################################
speed_classes_min = np.array([ 0,25,35,45,55,65,75,85,95])
speed_classes_max = np.array([25,35,45,55,65,75,85,95,106])
num_classes = speed_classes_min.size
###############################################################################

folder = 'results/speed_estimations/'
model = 'NN_1000-200-50-10-1_reg1e-3_lossMSE'
vehicles = ["CitroenC4Picasso","Mazda3","MercedesAMG550","NissanQashqai","OpelInsignia",
            "Peugeot3008","Peugeot307","RenaultCaptur","RenaultScenic","VWPassat"]
n_veh = len(vehicles)

### Print speed estimation RMSE as in Table 2
print(f"{'Vehicle':16}{'RMSE':>10}")
print("-" * 26)

errs = np.zeros((n_veh, ))
for veh_ind in range(n_veh):
    hf = h5py.File(f'{folder}speed_estimations_{model}.h5', 'r')
    speed_est = np.array(hf[vehicles[veh_ind] + '_speeds_est_all'], dtype=np.float64)
    speed_gt = np.array(hf[vehicles[veh_ind] + '_speeds_gt'], dtype='int')
    speed_est[speed_est < 0] = 0
    speed_est[speed_est > 105] = 105

    speed_diff = speed_est - speed_gt
    err = np.sqrt(np.mean(speed_diff ** 2))
    errs[veh_ind] = err

    hf.close()
    
    print(f"{vehicles[veh_ind]:16}{err:10.2f}")

errs = np.mean(errs, axis=0)
print("-" * 26)
print(f"{'Average':16}{errs:10.2f}\n\n")


### Print the probability of predicting speed class as in Table 3
delta1 = '\u0394=0'
delta2 = '|\u0394|=1'
delta3 = '|\u0394|=2'
delta4 = '|\u0394|>2'
delta5 = '|\u0394|<=1'
print(f"{'Vehicle':16}{delta1:>7}{delta2:>7}{delta3:>7}{delta4:>7}{delta5:>7}")
print("-" * 51)

err_tot = np.zeros((5,))
for veh_ind in range(n_veh):
    hf = h5py.File(f'{folder}speed_estimations_{model}.h5', 'r')
    speed_est = np.array(hf[vehicles[veh_ind] + '_speeds_est_all'], dtype=np.float64)
    speed_gt = np.array(hf[vehicles[veh_ind] + '_speeds_gt'], dtype='int')
    speed_est[speed_est < 0] = 0
    speed_est[speed_est > 105] = 105

    speed_est_class = pt.speed_to_class(speed_est, speed_classes_min, speed_classes_max)
    speed_gt_class = pt.speed_to_class(speed_gt, speed_classes_min, speed_classes_max)

    diff_speed = np.abs(speed_est_class - speed_gt_class)

    err0 = np.sum(diff_speed == 0)
    err1 = np.sum(diff_speed == 1)
    err2 = np.sum(diff_speed == 2)
    err2p = np.sum(diff_speed > 2)
    err1m = err0 + err1
    errTot = err0 + err1 + err2 + err2p
    
    err_tot[0] += 100 * err0 / errTot
    err_tot[1] += 100 * err1 / errTot
    err_tot[2] += 100 * err2 / errTot
    err_tot[3] += 100 * err2p / errTot
    err_tot[4] += 100 * (err0 + err1) / errTot

    print(f"{vehicles[veh_ind]:16}", end="")
    print(f"{100*err0/errTot:6.1f}%", end ="")
    print(f"{100*err1/errTot:6.1f}%", end ="")
    print(f"{100*err2/errTot:6.1f}%", end ="")
    print(f"{100*err2p/errTot:6.1f}%", end ="")
    print(f"{100*err1m/errTot:6.1f}%")
        
    hf.close()

err_tot /= n_veh
print("-" * 51)
print(f"{'Average':16}", end="")
for i in range(5):
    print(f"{err_tot[i]:6.1f}%", end ="" if i < 4 else "\n")