import os
import subprocess
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Libraries used for hierarchical clustering and creating the dendrogram
# This is up to user which to choose which algorithm to do use for the clustering
from scipy.sparse import tril
from scipy.sparse import triu

# Set $SCHRODINGER environment variable, change as needed
SCHRODINGER = '/opt/schrodinger/suites2019-4'

# Location of the python scripts
script_dir = "./backend_scripts"

# Location of the Fab maestro files
fab_dir = './Fabs_Nov_2019/Fab-models-BioClustering-2019'
# Location of the faux epitopes
faux_epi_dir = './faux_epitopes'
# Location of the interaction fields
if_dir = './interaction_fields'

def safe_mkdir(dir, overwrite=False):
    if not os.path.isdir(dir):
        os.system("mkdir %s" % dir)
    else:
        if overwrite:
            print("Warning: '%s' already exists, will overwrite its content" %
                  dir)
        else:
            print("''%s' already exists, set overwrite=True to overwrite" %
                  dir)
            sys.exit()
def gen_file_list(dir):
    return [dir + '/' + item for item in os.listdir(dir)]
def gen_names(dir):
    return [item.split('.')[0] for item in os.listdir(dir)]
def gen_faux_epitopes(fab_dir, faux_epi_dir, overwrite=False):
    """
    Wrapper function for step 1 - Generate faux epitope
    """
    safe_mkdir(faux_epi_dir, overwrite=overwrite)
    faux_epitope_list = []
    ts = time.time()
    fab_mae_list = gen_file_list(fab_dir)
    fab_names = gen_names(fab_dir)
    print(fab_names)
    N = len(fab_names)
    for i in range(N):
        print("\rGenerating faux epitopes %d/%d" % (i + 1, N), end="\r\r")
        faux_epit_mae = "faux_epitope_%s.mae" % fab_names[i]
        run_faux_epitope = "$SCHRODINGER/run %s/ %s %s/%s" % (
            script_dir, fab_mae_list[i], faux_epi_dir, faux_epit_mae)
        os.system(run_faux_epitope)
        #print(run_faux_epitope)
        faux_epitope_list.append(faux_epit_mae)
    te = time.time()
    duration = (te - ts) / 60
    print("\n>>> Faux epitope generation took %.2f minutes" % duration)
def extract_binding_sites(fab_dir, faux_epi_dir, if_dir, overwrite=False):
    """
    Wrapper function for step 2 - Generate MIFs
    """
    safe_mkdir(if_dir, overwrite=overwrite)
    if_list = []
    ts = time.time()
    fab_mae_list = gen_file_list(fab_dir)
    fab_names = gen_names(fab_dir)
    # It's important to make sure the faux epitope is matched with the right Fab, otherwise
    # this step will result in traceback, so generate faux epitope file list in the exact
    # order as the Fab file list.
    faux_epi_list = [
        faux_epi_dir + "/faux_epitope_%s.mae" % item for item in fab_names
    ]
    N = len(fab_names)
    for i in range(N):
        te = time.time()
        duration = (te - ts) / 60
        print(
            "\rGenerating interaction field %d/%d, current running time ~ %.2f min. (~ %.2f min. per calculation)"
            % (i + 1, N, duration, duration / float(i + 1)),
            end="\r\r")
        if_name = "if_%s.json" % fab_names[i]
        run_extract_bsite = "$SCHRODINGER/run %s/ -L %s %s -i %s/%s" % (
            script_dir, faux_epi_list[i], fab_mae_list[i], if_dir, if_name)
        #print(run_extract_bsite)
        os.system(run_extract_bsite)
        if_list.append(if_name)
    print("\n>>> MIFs generation took %.2f minutes" % duration)
def gen_sim_mat(SCHRODINGER, fab_dir, if_dir):
    """
    Compute the pairwise similarity matrix between binding sites using Phase Shape approach.
    Compare all i,j pairs of MIFs in the if_dir.
    """
    ts = time.time()
    if_list = gen_file_list(if_dir)
    if_names = gen_names(if_dir)
    N = len(if_list)
    #N = 3
    # Initiate a matrix filled with -1
    #sim_mat = np.zeros((N,N))
    sim_mat = np.empty((N, N))
    sim_mat[:] = -1.0
    Ntot = N * N - N
    counter = 1
    IDs = [name.split('_')[1] for name in if_names]
    for i in range(N):
        for j in range(N):
            if (i != j):
                te = time.time()
                duration = (te - ts) / 60
                print(
                    "\rComparing pair: %d-%d (%d/%d pairs), current running time ~ %.2f min. (~ %.2f min. per pair)"
                    % (i + 1, j + 1, counter, Ntot, duration, duration / counter),
                    end="\r\r")
                result =
                    [
                        '%s/run' % SCHRODINGER, '',
                        if_list[i], if_list[j],
                        #'--force-gpu',
                        #'--maxMapPerAtom','4',
                        #'--maxAtomsToMap','32'
                    ],
                    stdout=subprocess.PIPE)
                sim_mat[i, j] = float(
                    str(result).split(' ')[-1].split('\\')[0])
                counter += 1
            else:
                sim_mat[i, j] = 1.0
    return sim_mat, N, IDs
def plot_sim_mat(data, IDs, plot_title, cmap, plot_width, plot_height, dpi):
    fig, ax = plt.subplots(figsize=(plot_width, plot_height), dpi=dpi)
    nx, ny = data.shape
    ax.set_xticks(np.arange(0, nx + 1, 1), minor=False)
    ax.set_yticks(np.arange(0, ny + 1, 1), minor=False)
    ax.set_xticklabels(IDs)
    ax.set_yticklabels(IDs)
    color_map = ax.imshow(
        data,
        cmap=cmap,
        aspect='equal',
        #interpolation='bicubic',alpha=1.0,
        vmin=0,
        vmax=1)
    ax.xaxis.set_ticks_position('bottom')
    plt.colorbar(color_map, fraction=0.046, pad=0.04)
    color_map.cmap.set_under('snow')
    plt.xticks(rotation=90)
    plt.title('%s' % plot_title)
    plt.tight_layout()
    plt.savefig('%s.png' % plot_title.replace(' ', '_'))
def save_sym_mat_csv(mat, IDs, name):
    df = pd.DataFrame(mat, columns=IDs)
    df.insert(0, '', IDs, True)
    n = df.shape[0]
    csv_file = '%s_%dx%d.csv' % (name, n, n)
    df.to_csv(csv_file, index=False)
    print("Matrix saved as %s" % csv_file)
def sim2dist(S):
    """
    Convert the simularity matrix S in to a symmetric distance matrix D
    """
    n, m = S.shape
    if (n != m):
        print("Error, input matrix is not symmetrical: %d x %d" % (n, m))
        sys.exit(1)
    Su = np.copy(S)
    Sl = np.copy(S)
    # Create the transposed upper triangle
    ind_tril = np.tril_indices(n, -1)
    Sl[ind_tril] = 0
    plt.imshow(Sl)
    Sut = Sl.transpose()
    # Create the transposed lower triangle
    ind_triu = np.triu_indices(n, 1)
    Su[ind_triu] = 0
    Slt = Su.transpose()
    plt.imshow(Su)
    S_sym = (S + Slt + Sut) / 2
    S_sym /= np.amax(S_sym)
    cm = plt.imshow(S_sym)
    plt.colorbar(cm, fraction=0.046, pad=0.04)
    D = 1 - S_sym
    D = D / np.amax(D)
    cm = plt.imshow(D)
    plt.colorbar(cm, fraction=0.046, pad=0.04)
    return D
def cond_1d_dist_mat(D, triag):
    """
    Convert the fully symmetric distance matrix into a condensed 1D distance matrix
    triag = 'upper' or 'lower', specifying whether the upper or lower triangle to use to compute the
    """
    if triag == 'upper':
        Dt = triu(D, k=0).toarray()
    elif triag == 'lower':
        Dt = tril(D, k=0).toarray().transpose()
    else:
        print('Error: triag can be either "upper" or "lower"')
        sys.exit(1)
    D_cond = []
    for item in Dt:
        trimmed = np.delete(item, np.where(item == 0.0))
        if trimmed.size != 0:
            D_cond.append(trimmed)
    return np.hstack(D_cond)
