import os
import numpy as np
import scipy.io.wavfile as wav
from scipy.fftpack import fft
from sklearn.linear_model import LinearRegression
import librosa
import matplotlib.pyplot as plot
import warnings


def load_audio(audio_filename, signal_length=None, sample_rate=44100):
    warnings.simplefilter("ignore")
    fs, y = wav.read(audio_filename)
    warnings.simplefilter("default")
    if y.dtype == np.int16:
        y = y / 32768.0  # short int to float
    if len(y.shape) == 2:
        y = y[:, 0]
    y = np.asarray(y)

    if signal_length is not None:
        if y.size / fs > signal_length:
            # Cut the signal if it's too long
            y = y[: int(signal_length * fs)]
        else:
            # Pad the signal with zeros if it's too short
            y = np.pad(y, int((signal_length * fs - y.size) / 2), mode='constant')

    if fs != sample_rate:
        y = librosa.core.resample(y, fs, sample_rate)
        fs = sample_rate

    return y, fs


def extract_features(file_name, sig_len, win_len, hop_len, mel_band_num, lms_pow, adj_feats, lms_adj_step, f_min=0, f_max=22050):
    window = np.hamming(win_len)
    y, fs = load_audio(file_name, sig_len)
        
    # Pad the signal to take into account window at the edges
    y = np.pad(y, int(win_len / 2), mode='constant')
    nb_frames = int((y.size - win_len) / hop_len) + 1

    fft_len = win_len
    fft_mel_bands = librosa.filters.mel(fs, fft_len, mel_band_num, fmin=f_min, fmax=f_max).T
    fft_en = np.zeros((nb_frames, int(1 + fft_len / 2)))
    lms = np.zeros((nb_frames, fft_mel_bands.shape[1]))
    time_vec = np.zeros((nb_frames,))
    for i in range(nb_frames):
        y_win = y[i * hop_len: i * hop_len + win_len] * window
        fft_aux = np.abs(fft(y_win)[: 1 + int(fft_len / 2)])
        fft_aux[fft_aux == 0] = np.finfo(np.float32).eps
        fft_en[i, :] = fft_aux
        lms[i, :] = np.dot(fft_en[i, :] ** lms_pow, fft_mel_bands)
        time_vec[i] = (i * hop_len) / fs

    lms = lms_feature(lms)
    if adj_feats > 0:
        lms = expand_lms_feat(lms, adj_feats, lms_adj_step)
    features = lms.T

    return features, time_vec


def lms_feature(lms):
    lms[lms == 0] = np.finfo(np.float32).eps
    lms = np.log(lms.T)
    # Manage low values at the beginning of interval
    inter_points = 7
    mean_diff = np.mean(np.mean(lms[:,2:5], axis=1)-np.mean(lms[:,0:2], axis=1))
    if mean_diff > 3:
        for k in range(lms.shape[0]):
            inter_val, coef, interc = linear_interp(lms[k,2:2+inter_points], 2, direction='L')
            lms[k,0:2] = inter_val

    return lms


def running_mean(x, n):
    if n % 2 == 0:
        raise ValueError("Filter length is not odd")

    aver = np.convolve(x, np.ones((n,)) / n, mode='same')
    n_half = int((n - 1) / 2)
    corr_length = np.array(range(n_half + 1, n)) / n
    aver[:n_half] = np.divide(aver[:n_half], corr_length)
    aver[-n_half:] = np.divide(aver[-n_half:], corr_length[::-1])

    return aver
    

def expand_lms_feat(lms, points, step):
    refer_points = 10
    (lms_d0, _) = lms.shape[0], lms.shape[1]
    lms_copy = lms.copy()
    
    lms_curr = None
    for k in range(points):
       lms_temp = lms_copy if k == 0 else lms_curr
       lms_curr = np.zeros_like(lms_copy)
       for j in range(lms_d0):
           arr_ext = running_mean(lms_temp[j,:refer_points], 3)
           add_left, coef, interc = linear_interp(arr_ext, step, 'L')
           add_left = (add_left + lms_temp[j, :step]) / 2
           lms_curr[j, :step] = add_left
           lms_curr[j, step:] = lms_temp[j, :-step].copy()
       lms = np.concatenate((lms_curr, lms), axis=0)
    
    lms_curr = None
    for k in range(points):
       lms_temp = lms_copy if k == 0 else lms_curr
       lms_curr = np.zeros_like(lms_copy)
       for j in range(lms_d0):
           arr_ext = running_mean(lms_temp[j,-refer_points:], 3)
           add_right, coef, interc = linear_interp(arr_ext, step, 'R')
           add_right = (add_right + lms_temp[j,-step:]) / 2
           lms_curr[j, :-step] = lms_temp[j, step:].copy()
           lms_curr[j, -step:] = add_right
       lms = np.concatenate((lms, lms_curr), axis=0)
     
    lms_copy = lms_curr = lms_temp = None
    return lms


def linear_interp(array, point_no, direction='L'):
    assert direction == 'L' or direction == 'R', "Extrapolation direction not valid"
    
    arr_len = np.size(array)
    x = np.linspace(0,arr_len-1,arr_len).reshape(-1,1)
    lin_reg = LinearRegression().fit(x, array.reshape(-1,1))
    coef = float(lin_reg.coef_)
    interc = float(lin_reg.intercept_)
    if direction == 'L':
        x_extra = np.linspace(-point_no, -1, point_no)
    else:
        x_extra = np.linspace(arr_len, arr_len + point_no - 1, point_no)
    
    y_extra = coef * x_extra + interc
    return y_extra, coef, interc
    

def MA_feature(audio_filename, time, dCPA, create_image=False):
    file_name, file_extension = os.path.splitext(audio_filename)
    with open(file_name + '.txt') as f:
        lines = f.readlines()
    line_entries = lines[0].split(" ")
    speed = float(line_entries[0])
    tCPA = float(line_entries[1])
    
    beta = 0.05
    v = speed * 1000 / 3600
    MA_feat = v / (beta * v**2 * (time - tCPA)**2 + dCPA**2)
            
    if create_image:
        
        # create a folder to store images of ground truth MA plots
        if not os.path.exists('temp'):
            os.makedirs('temp')
        
        plot.figure(num=124, figsize=(14, 7))
        plot.plot(time, MA_feat, 'b', linewidth=1.0)
        # for k in range(minima_pos.size):
        plot.axvline(x=tCPA, color='k', linestyle=':', linewidth=1)
        name_split = file_name.split("/")
        plot.title(str(speed))
        plot.savefig(f"temp/MA_{name_split[-1]}.png")
        plot.clf()
            
    return MA_feat, speed


def shuffle_data(x, y):
    rand_perm = np.random.RandomState().permutation(x.shape[0])
    x_shuff = x[rand_perm, :].copy()
    y_shuff = y[rand_perm].copy()

    return x_shuff, y_shuff


def speed_est_feats(MA_feat, center_inds, points):
    (d0, d1) = MA_feat.shape[0], MA_feat.shape[1]
    speed_feat = np.zeros((d0, 2 * points + 1))
    
    for i in range(d0):
        dist_frame = np.zeros(2 * points + 1)
        center_ind = center_inds[i]
        if center_ind >= points and center_ind + points < d1:
            dist_frame = MA_feat[i,center_ind-points:center_ind+points+1]
        elif center_ind < points:
            dist_frame[points-center_ind:] = MA_feat[i,0:center_ind+points+1]
        elif center_ind + points >= d1:
            dist_frame[:-(center_ind+points-d1+1)] = MA_feat[i,center_ind-points:]
        speed_feat[i,:] = dist_frame

    return speed_feat


def detect_MA_maxima(MA_regressions):
    max_positions = np.zeros(MA_regressions.shape[0],)
    for i in range(MA_regressions.shape[0]):
        max_positions[i] = np.argmax(MA_regressions[i,:])
    
    return max_positions.astype(int)


def speed_to_class(speeds, min_speed, max_speed):
    speed_class = np.zeros_like(speeds)
    if len(speeds.shape) == 1:
        for k in range(speeds.size):
            speed_class[k] = np.argwhere(np.logical_and(speeds[k] >= min_speed, speeds[k] < max_speed)).flatten()[0]
    else:
        for k in range(speeds.shape[0]):
            for j in range(speeds.shape[1]):
                speed_class[k,j] = np.argwhere(np.logical_and(speeds[k,j] >= min_speed, speeds[k,j] < max_speed)).flatten()[0]
    
    return speed_class