import h5py
import numpy as np
import os
from tqdm import tqdm
import project_tools as pt
from importlib import reload
reload(pt)


########################################   Project parameters  #######################################
sample_rate = 44100
window_len = 4096       # Window length in samples
hop_perc = 27           # Hop length relative to window size (chosen to obtain time_samples = 400)
hop_len = int((hop_perc / 100.0) * window_len)  # Hop length in the STFT calculation
time_samples = 400      # Number of time samples where features are calculated
mel_bands = 40          # Number of Mel-bands to calculate Mel-band energy feature
signal_len = 10.0       # Signal (clip) length in seconds
############################
adj_feat = 12           # Number of adjacent LMS features to concatenate
lms_adj_step = 3        # Time step between adjacent LMS features to concatenate
feat_num = (2*adj_feat + 1) * mel_bands         # Number of features
f_min = 0               # minimal frequency in mel bands
f_max = 16000           # maximal frequency in mel bands
lms_pow = 1             # power of FFT used in LMS calculation
############################
dCPA = 1.5              # distance at the closest point of approach                            
######################################################################################################


if __name__ == "__main__":
    
    folds = ["CitroenC4Picasso","Mazda3", "MercedesAMG550","NissanQashqai","OpelInsignia","Peugeot307",
             "Peugeot3008","RenaultCaptur","RenaultScenic","VWPassat","NoCar","NoCarTest"]
    
    for fold in folds:
        audio_folder = 'audio+annotations/' + fold + '/'
        dataset_name = f'data_{fold}.h5'
      
        all_files = os.listdir(audio_folder)
        audio_files = [file for file in all_files if file.endswith("wav")]
        audio_files.sort(reverse=False)
        files_num = len(audio_files)
        features = np.zeros((files_num, time_samples, feat_num))
        labels = np.zeros((files_num, time_samples))
        speed = np.zeros((files_num, ))
        train_valid = [''] * files_num
        
        tv_txt = {}
        with open(audio_folder + 'Train_valid_split.txt') as f:
            for line in f:
                temp = (line.strip('\n')).split(' ')
                tv_txt[temp[0]] = temp[1]
        
        for file_index in tqdm(range(files_num)):
            features[file_index, :, :], time = \
                pt.extract_features(audio_folder + audio_files[file_index],
                                    signal_len,
                                    window_len,
                                    hop_len,
                                    mel_bands,
                                    lms_pow,
                                    adj_feat,
                                    lms_adj_step,
                                    f_min,
                                    f_max)
            
            labels[file_index, :], speed[file_index] = \
                pt.MA_feature(audio_folder + audio_files[file_index], time, dCPA, True)
                
            train_valid[file_index] = tv_txt[os.path.splitext(audio_files[file_index])[0]]
    
        output_folder = 'datasets/'
        hf = h5py.File(output_folder + dataset_name, 'w')
        hf.create_dataset('features', data=features, compression="gzip")
        hf.create_dataset('labels', data=labels, compression="gzip")
        hf.create_dataset('speed', data=speed, compression="gzip")
        hf.create_dataset('train_valid', data=np.array(train_valid, dtype='S'), compression="gzip")
        hf.close()