import numpy as np
import os
import random
import tensorflow as tf
from tensorflow.keras.utils import *
from tensorflow.keras.models import Sequential
from tensorflow.python.keras.layers import Input, Convolution2D, Flatten, Dense, Activation, MaxPooling2D, add, Dropout, BatchNormalization, GlobalAveragePooling2D, GlobalMaxPool2D, GlobalAvgPool2D
from tensorflow.python.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt

# Allow GPU memory growth
if hasattr(tf, 'GPUOptions'):
    import keras.backend as K
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    K.tensorflow_backend.set_session(sess)
else:
    # For other GPUs
    for gpu in tf.config.experimental.list_physical_devices('GPU'):
        tf.config.experimental.set_memory_growth(gpu, True)


def get_map(pdb_path, all_dist_path, true_npy = False):
    seqy = None
    mypath = all_dist_path + pdb_path + '.npy'

    if os.path.exists(mypath):
        if not true_npy:
            cb_map = np.load(mypath, allow_pickle = True)
            ly = len(cb_map)
        else:
            (ly, seqy, cb_map) = np.load(mypath, allow_pickle = True)
    else:
    # if seqy == None:
        print('Expected distance map file for', pdb_path, 'not found at', all_dist_path)
        exit(1)
    Y = cb_map

    return Y, ly

def plot_learning_curves(history, name):
    print('')
    print('Curves..')
    
    plt.clf()
    if 'mean_absolute_error' in history.history:
        plt.plot(history.history['mean_absolute_error'], 'g', label = 'Training MAE')
        plt.plot(history.history['val_mean_absolute_error'], 'b', label = 'Validation MAE')
        plt.xlabel('Epochs')
        plt.ylabel('MAE')
    elif 'accuracy' in history.history:
        plt.plot(history.history['accuracy'], 'g', label = 'Training Accuracy')
        plt.plot(history.history['val_accuracy'], 'b', label = 'Validation Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
    else:
        plt.plot(history.history['mae'], 'g', label = 'Training MAE')
        plt.plot(history.history['val_mae'], 'b', label = 'Validation MAE')
        plt.xlabel('Epochs')
        plt.ylabel('MAE')
    plt.legend()
    plt.savefig(name)
    plt.close()

'''
***** Calculate LDDT here
'''
# Helpers for metrics calculated using numpy scheme
def get_flattened(dmap):
  if dmap.ndim == 1:
    return dmap
  elif dmap.ndim == 2:
    return dmap[np.triu_indices_from(dmap, k=1)]
  else:
    assert False, "ERROR: the passes array has dimension not equal to 2 or 1!"

def get_separations(dmap):
  t_indices = np.triu_indices_from(dmap, k=1)
  separations = np.abs(t_indices[0] - t_indices[1])
  return separations
  
# return a 1D boolean array indicating where the sequence separation in the
# upper triangle meets the threshold comparison
def get_sep_thresh_b_indices(dmap, thresh, comparator):
  assert comparator in {'gt', 'lt', 'ge', 'le'}, "ERROR: Unknown comparator for thresholding!"
  dmap_flat = get_flattened(dmap)
  separations = get_separations(dmap)
  if comparator == 'gt':
    threshed = separations > thresh
  elif comparator == 'lt':
    threshed = separations < thresh
  elif comparator == 'ge':
    threshed = separations >= thresh
  elif comparator == 'le':
    threshed = separations <= thresh

  return threshed

# return a 1D boolean array indicating where the distance in the
# upper triangle meets the threshold comparison
def get_dist_thresh_b_indices(dmap, thresh, comparator):
  assert comparator in {'gt', 'lt', 'ge', 'le'}, "ERROR: Unknown comparator for thresholding!"
  dmap_flat = get_flattened(dmap)
  if comparator == 'gt':
    threshed = dmap_flat > thresh
  elif comparator == 'lt':
    threshed = dmap_flat < thresh
  elif comparator == 'ge':
    threshed = dmap_flat >= thresh
  elif comparator == 'le':
    threshed = dmap_flat <= thresh
  return threshed


# Calculate lDDT using numpy scheme
def get_LDDT(true_map, pred_map, R=15, sep_thresh=6, T_set=[0.5, 1, 2, 4], precision=4):
    '''
    Mariani V, Biasini M, Barbato A, Schwede T.
    lDDT: a local superposition-free score for comparing protein structures and models using distance difference tests.
    Bioinformatics. 2013 Nov 1;29(21):2722-8.
    doi: 10.1093/bioinformatics/btt473.
    Epub 2013 Aug 27.
    PMID: 23986568; PMCID: PMC3799472.
    '''
    
    # Helper for number preserved in a threshold
    def get_n_preserved(ref_flat, mod_flat, thresh):
        err = np.abs(ref_flat - mod_flat)
        n_preserved = (err < thresh).sum()
        return n_preserved
    
    # flatten upper triangles
    true_flat_map = get_flattened(true_map)
    pred_flat_map = get_flattened(pred_map)
    
    # Find set L
    S_thresh_indices = get_sep_thresh_b_indices(true_map, sep_thresh, 'gt')
    R_thresh_indices = get_dist_thresh_b_indices(true_flat_map, R, 'lt')
    
    L_indices = S_thresh_indices & R_thresh_indices
    
    true_flat_in_L = true_flat_map[L_indices]
    pred_flat_in_L = pred_flat_map[L_indices]
    
    # Number of pairs in L
    L_n = L_indices.sum()
    
    # Calculated lDDT
    preserved_fractions = []
    for _thresh in T_set:
        _n_preserved = get_n_preserved(true_flat_in_L, pred_flat_in_L, _thresh)
        _f_preserved = _n_preserved / L_n
        preserved_fractions.append(_f_preserved)
        
    lDDT = np.mean(preserved_fractions)
    if precision > 0:
        lDDT = round(lDDT, precision)
    return lDDT

def trrosetta_probindex2dist(index):
    d = 1.75
    for k in range(1, 37):
        d += 0.5
        if index == k:
            return d
    return d

def trrosetta2maps(a):
    if len(a[0, 0, :]) != 37:
        print('ERROR! This does not look like a trRosetta prediction')
        return
    D = np.full((len(a), len(a)), 21.0)
    for i in range(len(a)):
        for j in range(len(a)):
            maxprob_value = 0.0
            for k in range(37):
                if maxprob_value < a[i, j, k]:
                    maxprob_value = a[i, j, k]
                    D[i, j] = trrosetta_probindex2dist(k)
    return D

def get_feature_channels(npz_path, npy_path):

    if os.path.exists(npz_path) and os.path.exists(npy_path):
        x1 = np.load(npz_path)
        a = x1['dist']

        y_pred = trrosetta2maps(a)

        x2 = np.load(npy_path)
        b = x2[0]
        
        c = np.dstack((a, b, y_pred))

        return c, y_pred
    else:
        print('File does not exist')
        exit()


def get_input_output(features_dir, true_dist_path, xy_dimension, batch_list, expected_n_channels):
    X_input = np.zeros((len(batch_list), xy_dimension, xy_dimension, expected_n_channels))
    Pred_Y = np.full((len(batch_list), xy_dimension, xy_dimension, 1), 100.0)
    True_Y = np.full((len(batch_list), xy_dimension, xy_dimension, 1), 100.0)
    Y = np.zeros((len(batch_list)))

    for index, item in enumerate(batch_list):
        id_name = item.split('-')[0]

        '''
        Use this index (indx) to reference the npz and npy files from their 
        respective indexed folders
        '''
        indx = int(item.split('-')[1]) + 1
        
        y_true, l = get_map(id_name + '-cb', true_dist_path, True)
        
        '''
        Give the path to npz and npy files according to the location
        '''
        npz_path = features_dir + 'path_to_npz_file'
        npy_path = features_dir + 'path_to_npy_file'

        x, y_pred = get_feature_channels(npz_path, npy_path)

        pred_y = np.reshape(y_pred, (1, l, l, 1))
        true_y = np.reshape(y_true, (1, l, l, 1))


        if l <= xy_dimension:
            X_input[index, 0: l, 0: l, :] = x
            
            Pred_Y[index, 0: l, 0: l, 0] = pred_y[0, 0: l, 0: l, 0]
            True_Y[index, 0: l, 0: l, 0] = true_y[0, 0: l, 0: l, 0]

        else:
            # randomly cropping along the diagonal of the distance map for data augmentation
            rx = random.randint(0, l - xy_dimension)
            ry = rx
            assert rx + xy_dimension <= l
            assert ry + xy_dimension <= l
            
            X_input[index, :, :, :] = x[rx:rx+xy_dimension, ry:ry+xy_dimension, :]
            Pred_Y[index, :, :, 0] = pred_y[0, rx:rx+xy_dimension, ry:ry+xy_dimension, 0]
            True_Y[index, :, :, 0] = true_y[0, rx:rx+xy_dimension, ry:ry+xy_dimension, 0]


    for index in range(len(batch_list)):
        Y[index] = get_LDDT(True_Y[index, :, :, 0], Pred_Y[index, :, :, 0])

    return X_input, (Y)

class DataGenerator(Sequence):
    def __init__(self, id_list, features_dir, true_distmap_path, dimen, expected_n_channels, batch_size = 2):
        self.id_list = id_list
        self.features_dir = features_dir
        self.true_distmap_path = true_distmap_path
        self.dimen = dimen
        self.expected_n_channels = expected_n_channels
        self.batch_size = batch_size

    def on_epoch_begin(self):
        self.indexes = np.arange(len(self.id_list))
        np.random.shuffle(self.indexes)

    def __len__(self):
        return int(len(self.id_list) / self.batch_size)

    def __getitem__(self, index):
        batch_list = self.id_list[index * self.batch_size: (index + 1) * self.batch_size]
        X, Y = get_input_output(self.features_dir, self.true_distmap_path, self.dimen, batch_list, self.expected_n_channels)
        
        return X, Y


'''
In this list, each id has 5 different indices (according to the folder names of the combinations) 
which are concatenated at the end (eg: 2v9lA-0, 2v9lA-1, 2v9lA-2, etc)
'''
validation_id_list_path = 'validation_pdb_ext.lst'
training_id_list_path = 'training_pdb_ext.lst'


features_dir = 'location where the npz and npy files are stored'
true_distmap_path = 'pdb_true_dist_cath/'


def getlist(path):
    id_list = []
    f = open(path, 'r')
    flines = f.read()
    f.close()

    lines = flines.splitlines()
    for l in lines:
        id_list.append(l.split()[0])
    return id_list


print('')
print('Get the training and validation set')

valid_ids = getlist(validation_id_list_path)
train_ids = getlist(training_id_list_path)

window_dimen = 512
batch_size = 2
expected_n_channels = 564

print ('length of training set:', len(train_ids))

train = DataGenerator(train_ids, features_dir, true_distmap_path, window_dimen, expected_n_channels, batch_size)
validate = DataGenerator(valid_ids, features_dir, true_distmap_path, window_dimen, expected_n_channels, batch_size)

print ('length of train generator', len(train))
print ('length of valid generator', len(validate))


print ('Build Model')

def eff_model(window_dimen, expected_n_channels):
    base_model = tf.keras.applications.EfficientNetB0(
        include_top=False,
        weights=None,
        input_shape=(window_dimen, window_dimen, expected_n_channels), 
    )

    model = Sequential()
    model.add(BatchNormalization(input_shape=(window_dimen, window_dimen, expected_n_channels)))
    model.add(base_model)
    model.add(BatchNormalization())
    model.add(Convolution2D(1, 3, padding = 'same', activation="relu"))
    model.add(GlobalAveragePooling2D())

    return model


def eff_new_model(window_dimen, expected_n_channels):
    base_model = tf.keras.applications.EfficientNetB0(
        include_top=False,
        weights=None,
        input_shape=(window_dimen, window_dimen, expected_n_channels),
    )

    model = Sequential()
    model.add(BatchNormalization(input_shape=(window_dimen, window_dimen, expected_n_channels)))
    model.add(base_model)
    model.add(GlobalAveragePooling2D())

    model.add(Dense( 1, activation = 'relu' ))

    return model


file_weights = 'model_weights' + str(window_dimen) + '-effnetb0.hdf5'

old_model = eff_model(window_dimen, expected_n_channels)
model = eff_new_model(window_dimen, expected_n_channels)

print("Model Summary....")
print(model.summary())

print('Compile Model...')
model.compile(loss = 'mae', optimizer = 'nadam', metrics = ['mae'])


print('')
print('Train..')
training_epochs = 8

for file in os.listdir("."):
    if file == file_weights:
        print('Loading weights...')
        model.load_weights(file_weights)

history = model.fit_generator(generator = train,
    validation_data = validate,
    callbacks = [ModelCheckpoint(filepath = file_weights, monitor = 'val_loss', save_best_only = True, save_weights_only = True, verbose = 1)],
    verbose = 1,
    max_queue_size = 32,
    workers = 1,
    use_multiprocessing = False,
    shuffle = True ,
    epochs = training_epochs)

# Evaluate on validation set
Y = np.zeros(batch_size * len(validate))
for i in range(len(validate)):
    xx, yy = validate[i]
    for j in range(batch_size):
        Y[batch_size * i + j] = yy[j]

model.load_weights(file_weights)
P = model.predict_generator(validate, max_queue_size=10, verbose=1)
P = P.flatten()

np.set_printoptions(formatter = {'float': '{: 0.3f}'.format})
print(P[:10])
print(Y[:10])


P[P > 1.0] = 1.0
print ('validation correlation coefficient:', np.corrcoef(P, Y)[0, 1])

Y_train = np.zeros(batch_size * len(train))
for i in range(len(train)):
    xx, yy = train[i]
    for j in range(batch_size):
        Y_train[batch_size * i + j] = yy[j]

P_train = model.predict_generator(train, max_queue_size=10, verbose=1)
P_train = P_train.flatten()

print(P_train[:10])
print(Y_train[:10])

P_train[P_train > 1.0] = 1.0
print ('training correlation coefficient:', np.corrcoef(P_train, Y_train)[0, 1])
