import os import sys import time import numpy import theano import theano.tensor as T import bisect from scipy.io import matlab import math import scipy.ndimage import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt from layers import * import random import argparse import numpy as np import os from io_utils_DUDE_neg_poc import * from mol_graph import * from collections import OrderedDict input_dir = '../DATA_GRAPH_CNN/DUDE_ff_drugFEATURE/' input_ext = '.ff' files = [ f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir,f))] def save(path, ext='png', close=True, verbose=True): # Extract the directory and filename from the given path directory = os.path.split(path)[0] filename = "%s.%s" % (os.path.split(path)[1], ext) if directory == '': directory = '.' # If the directory does not exist, create it if not os.path.exists(directory): os.makedirs(directory) # The final path to save to savepath = os.path.join(directory, filename) if verbose: print("Saving figure to '%s'..." % savepath), #plt.savefig(savepath, dpi=900) plt.savefig(savepath+'.eps', format='eps',dpi=900) # Close it if close: plt.close() if verbose: print("Done") def sum_sqroot(v): sum_square = 0 for i in range (0,len(v)): sum_square = sum_square+(v[i]*v[i]) return numpy.sqrt(sum_square) def get_FF(target,PDB_ID,DUDE_ctr): pro_name=target ff_name = None ligs=cut_ligand_all_atoms(PDB_ID,DUDE_ctr,False) if ligs == []: print('cant find ligs'+'\n') for l in ligs: [lig_ID,lig_chain,lig_no, ctr]=l ff_name=PDB_ID+'_'+str(lig_ID)+'_'+str(lig_chain)+'_'+str(lig_no)+'.ff' return ff_name def load_test_pro(fold): filename = "../DATA_GRAPH_CNN/DUDE/DUDE_PDBID.csv" data=read_csv_DUDE(filename,"Target Name","PDB") DUDE_PDB_pro_dict={} for i in range(0,len(data[0])): target = data[0][i] PDB_ID = data[1][i] DUDE_PDB_pro_dict[PDB_ID]=target test_fold_file = open('../DATA_GRAPH_CNN/DUDE_test_PDB_fold_'+str(fold)+'.txt') test_fold_list = list(test_fold_file) test_set = set() for line in test_fold_list: PDBs = line.strip('\n').split(',') for pdb in PDBs: if pdb != '': pro_name = DUDE_PDB_pro_dict[pdb] test_set.add(pro_name) print ("test set fold "+str(fold)+":") print (test_set) return test_set def load_train_pro(fold): filename = "../DATA_GRAPH_CNN/DUDE/DUDE_PDBID.csv" data=read_csv_DUDE(filename,"Target Name","PDB") DUDE_PDB_pro_dict={} for i in range(0,len(data[0])): target = data[0][i] PDB_ID = data[1][i] DUDE_PDB_pro_dict[PDB_ID]=target train_fold_file = open('../DATA_GRAPH_CNN/DUDE_train_PDB_fold_'+str(fold)+'.txt') train_fold_list = list(train_fold_file) train_set = set() for line in train_fold_list: PDBs = line.strip('\n').split(',') for pdb in PDBs: if pdb != '': pro_name = DUDE_PDB_pro_dict[pdb] train_set.add(pro_name) # pro_name = line.strip('\n') # train_set.add(pro_name) print ("train set fold "+str(fold)+":") print (train_set) return train_set from sklearn.metrics import confusion_matrix fig = plt.figure() target_FF_dict=OrderedDict() tar_ff_dict_5 = {'DRD3':'3pbl_ETQ.ff','KIT':'3g0e_B49.ff','INHA':'4trj_665.ff','FNTA':'3e37_ED5.ff','HIVINT':'3nf7_CIW.ff'} #4trj_NAD.ff filename = "../DATA_GRAPH_CNN/DUDE/DUDE_PDBID.csv" data=read_csv_DUDE(filename,"Target Name","PDB") for i in range(0,len(data[0])): target = data[0][i] PDB_ID = data[1][i] target_actives_file=open('../DATA_GRAPH_CNN/DUDE/'+target.lower()+'/'+'actives_final.ism') target_decoys_file=open('../DATA_GRAPH_CNN/DUDE/'+target.lower()+'/'+'decoys_final.ism') crystal_lig = open('../DATA_GRAPH_CNN/DUDE/'+target.lower()+'/'+'crystal_ligand.mol2') receptor=open('../DATA_GRAPH_CNN/DUDE/'+target.lower()+'/'+'receptor.pdb') lig_ctr=parse_crystal_lig(list(crystal_lig)) target_actives_list=list(target_actives_file) target_decoys_list=list(target_decoys_file) if target in ['DRD3','KIT','INHA','FNTA','HIVINT']: target_FF=tar_ff_dict_5[target] else: target_FF = get_FF(target,PDB_ID,lig_ctr) if target_FF==None: print ("SOMETHING IS VERY WRONG") target_FF_dict[target]=target_FF train_pro = load_train_pro(0) test_pro = load_test_pro(0) split_tar_order=[] for target in target_FF_dict: split_tar_order.append(target) # Weight_fold_0 = 'DUDE_sda_both_fix0_160000' # #Weight_fold_1 = '50_35_f1_160000' # Weight_fold_1 = 'sda_both_fix0_f1_120000' # # Weight_fold_2 = '50_35_f2_220000' # #Weight_fold_2 = 'sda_both_fix0_f2_110000' # #Weight_fold_2 = 'sda_both_fix0_f2_110000' # Weight_fold_2 = 'early_both_fix_0_fold_2_90000' # #Weight_fold_3 = 'sda_both_fix0_f3_120000' # Weight_fold_3 = 'early_both_fix_0_fold_3_130000' # Weight_fold_0 = 'DUDE_pocSDA_poc_50_fold_0_140000' # Weight_fold_1 = 'pocSDA_poc_50_fold_1_140000' # Weight_fold_2 = 'pocSDA_poc_50_fold_2_140000' # Weight_fold_3 = 'pocSDA_poc_50_fold_3_120000' # Weight_fold_0 = 'DUDE_pocSDA_poc_50_ES_fold_0_130000' # Weight_fold_1 = 'DUDE_pocSDA_poc_50_ES_fold_1_120000' # Weight_fold_2 = 'DUDE_pocSDA_poc_50_ES_fold_2_90000' # #Weight_fold_3 = 'DUDE_pocSDA_poc_50_ES_fold_3_100000' # #Weight_fold_3 = 'DUDE_pocSDA_poc_50_ES_fold_3_90000_switch_ff' # Weight_fold_3 = 'DUDE_pocSDA_poc_50_ES_fold_3_90000_ALL_ff' Weight_fold_0 ='DUDE_pocSDA_poc_50_ES_fold_0_130000_all_tars' Weight_fold_1 ='DUDE_pocSDA_poc_50_ES_fold_1_120000_all_tars' Weight_fold_2 ='DUDE_pocSDA_poc_50_ES_fold_2_90000_all_tars' Weight_fold_3 ='DUDE_pocSDA_poc_50_ES_fold_3_100000_switch_ff_all_tars' result_ID = 'test_mols_sda_pocsda_poc50_ES_f0_13_f1_12_f2_9_f3_10_all_tars' # Weight_fold_0 = 'DUDE_pocSDA_ES_div_deg_tanh_95_fold_0_110000' # Weight_fold_1 = 'DUDE_pocSDA_ES_div_deg_tanh_95_fold_1_110000' # Weight_fold_2 = 'DUDE_pocSDA_ES_div_deg_tanh_95_fold_2_90000' # Weight_fold_3 = 'DUDE_pocSDA_ES_div_deg_tanh_95_fold_3_120000' # result_ID = 'test_mols_sda_pocsda_ES_f0_11_f1_11_f2_9_f3_12_95iden' # Weight_fold_0 = 'DUDE_pocSDA_ES_div_deg_tanh_fold_0_100000' # Weight_fold_1 = 'DUDE_pocSDA_ES_div_deg_tanh_fold_1_80000' # Weight_fold_2 = 'DUDE_pocSDA_ES_div_deg_tanh_fold_2_90000' # Weight_fold_3 = 'DUDE_pocSDA_ES_div_deg_tanh_fold_3_120000' # result_ID = 'test_mols_sda_pocsda_ES_f0_10_f1_8_f2_9_f3_12_80iden' #result_ID = 'test_mols_sda_pocsda_poc50_ES_f3_10_f0_13_f2_9' test_tar_0 = load_test_pro(0) test_tar_1 = load_test_pro(1) test_tar_2 = load_test_pro(2) test_tar_3 = load_test_pro(3) test_tar_1.add('KIT') test_tar_1.add('FNTA') test_tar_2.add('ADRB1') test_tar_2.add('DRD3') test_tar_2.add('HIVINT') test_tar_3.add('INHA') all_prob = [] for target in split_tar_order: print (target) if target in test_tar_0: Weight_ID = Weight_fold_0 elif target in test_tar_1: Weight_ID = Weight_fold_1 elif target in test_tar_2: Weight_ID = Weight_fold_2 elif target in test_tar_3: Weight_ID = Weight_fold_3 else: print ("SOMETHING WRONG! -> "+target) prob = numpy.load('../RESULTS_GRAPH_CNN/poc_scores_single_ac/sda/'+target+'_'+Weight_ID+'_y_prob.dat') ave_prob = numpy.mean(prob,axis=0) all_prob.append(ave_prob) all_prob = numpy.array(all_prob) print all_prob.shape num_tar = all_prob.shape[0] print ("len(all_prob)") print (all_prob.shape) from scipy.cluster.hierarchy import dendrogram, linkage #################################################### ############## Hierarchical Clustering ############# #################################################### clus_method = 'ward' # clus_method = 'single' import scipy import pylab import scipy.cluster.hierarchy as sch fig = pylab.figure(figsize=(16,16)) ax1 = fig.add_axes([0.055,0.1,0.2,0.6]) Y = sch.linkage(all_prob, method=clus_method) Z1 = sch.dendrogram(Y, orientation='right') ax1.set_xticks([]) ax1.set_yticks([]) # Compute and plot second dendrogram. ax2 = fig.add_axes([0.3,0.745,0.6,0.2]) Y = sch.linkage(all_prob, method=clus_method) Z2 = sch.dendrogram(Y) ax2.set_xticks([]) ax2.set_yticks([]) # Plot distance matrix. axmatrix = fig.add_axes([0.3,0.1,0.6,0.6]) idx1 = Z1['leaves'] idx2 = Z2['leaves'] n_row_ = all_prob[idx1,:] n_row_ = n_row_[:,idx2] im = axmatrix.matshow(n_row_, vmin = 0, vmax = 1, aspect='auto', origin='lower', cmap=plt.get_cmap('Spectral_r')) axmatrix.set_xticks([]) axmatrix.set_yticks([]) tar_idx1=[] for o in idx1: tar_idx1.append(split_tar_order[o]) tar_idx2=[] for o in idx2: tar_idx2.append(split_tar_order[o]) axcolor = fig.add_axes([0.91,0.1,0.02,0.6]) axmatrix.set_xticks(range(num_tar)) axmatrix.set_xticklabels(tar_idx1,rotation='vertical',fontsize = 10) axmatrix.xaxis.set_label_position('top') axmatrix.xaxis.tick_top() axmatrix.set_yticks(range(num_tar)) axmatrix.set_yticklabels(tar_idx2 ,fontsize = 10) axmatrix.yaxis.set_label_position('left') axmatrix.yaxis.tick_left() save("../RESULTS_GRAPH_CNN/statistics/sda/hierarchy_"+clus_method+'_'+result_ID, ext="png", close=True, verbose=True) ticks=numpy.arange(0,num_tar,1) fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(n_row_, vmin = 0, vmax = 1, cmap=plt.get_cmap('Spectral_r'),interpolation='nearest') fig.colorbar(cax) plt.xticks(ticks, tar_idx1,rotation='vertical',fontsize = 5) plt.yticks(ticks, tar_idx2,fontsize = 5) save("../RESULTS_GRAPH_CNN/statistics/sda/hier_"+clus_method+'_'+result_ID, ext="png", close=True, verbose=True) # #################################################### # ############## Hierarchical Clustering ############# # #################################################### import scipy import pylab import scipy.cluster.hierarchy as sch all_prob_trans = numpy.transpose(all_prob) fig = pylab.figure(figsize=(16,16)) ax1 = fig.add_axes([0.055,0.1,0.2,0.6]) Y = sch.linkage(all_prob_trans, method=clus_method) Z1 = sch.dendrogram(Y, orientation='right') ax1.set_xticks([]) ax1.set_yticks([]) # Compute and plot second dendrogram. ax2 = fig.add_axes([0.3,0.745,0.6,0.2]) Y = sch.linkage(all_prob_trans, method=clus_method) Z2 = sch.dendrogram(Y) ax2.set_xticks([]) ax2.set_yticks([]) # Plot distance matrix. axmatrix = fig.add_axes([0.3,0.1,0.6,0.6]) idx1 = Z1['leaves'] idx2 = Z2['leaves'] print (idx1) print (idx2) print (all_prob_trans.shape) n_row_ = all_prob_trans[idx1,:] n_row_ = n_row_[:,idx2] im = axmatrix.matshow(n_row_, vmin = 0, vmax = 1, aspect='auto', origin='lower', cmap=plt.get_cmap('Spectral_r')) axmatrix.set_xticks([]) axmatrix.set_yticks([]) tar_idx1=[] for o in idx1: tar_idx1.append(split_tar_order[o]) # print tar_idx1 tar_idx2=[] for o in idx2: tar_idx2.append(split_tar_order[o]) axcolor = fig.add_axes([0.91,0.1,0.02,0.6]) axmatrix.set_xticks(range(num_tar)) axmatrix.set_xticklabels(tar_idx1,rotation='vertical',fontsize = 10) axmatrix.xaxis.set_label_position('top') axmatrix.xaxis.tick_top() axmatrix.set_yticks(range(num_tar)) axmatrix.set_yticklabels(tar_idx2 ,fontsize = 10) axmatrix.yaxis.set_label_position('left') axmatrix.yaxis.tick_left() save("../RESULTS_GRAPH_CNN/statistics/sda/hierarchy_col_"+clus_method+'_'+result_ID, ext="png", close=True, verbose=True) ticks=numpy.arange(0,num_tar,1) fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(n_row_, vmin = 0, vmax = 1, cmap=plt.get_cmap('Spectral_r'),interpolation='nearest') fig.colorbar(cax) plt.xticks(ticks, tar_idx1,rotation='vertical',fontsize = 5) plt.yticks(ticks, tar_idx2,fontsize = 5) save("../RESULTS_GRAPH_CNN/statistics/sda/hier_col_"+clus_method+'_'+result_ID, ext="png", close=True, verbose=True) # #### TEST BOTH # #################################################### # ############## Hierarchical Clustering ############# # #################################################### import scipy import pylab import scipy.cluster.hierarchy as sch all_prob_trans = numpy.transpose(all_prob) fig = pylab.figure(figsize=(16,16)) ax2 = fig.add_axes([0.3,0.745,0.6,0.2]) Y = sch.linkage(all_prob_trans, method=clus_method) Z2 = sch.dendrogram(Y) ax2.set_xticks([]) ax2.set_yticks([]) n_col_ = all_prob_trans[idx2,:] n_col_ = n_col_[:,idx2] n_col_trans = numpy.transpose(n_col_) ax1 = fig.add_axes([0.055,0.1,0.2,0.6]) Y = sch.linkage(n_col_trans, method=clus_method) Z1 = sch.dendrogram(Y, orientation='right') ax1.set_xticks([]) ax1.set_yticks([]) axmatrix = fig.add_axes([0.3,0.1,0.6,0.6]) idx2 = Z2['leaves'] idx1 = Z1['leaves'] n_row_ = n_col_trans[idx1,:] im = axmatrix.matshow(n_row_, vmin = 0, vmax = 1, aspect='auto', origin='lower', cmap=plt.get_cmap('Spectral_r')) axmatrix.set_xticks([]) axmatrix.set_yticks([]) tar_idx2=[] for o in idx2: tar_idx2.append(split_tar_order[o]) tar_idx1=[] for o in idx1: tar_idx1.append(tar_idx2[o]) axcolor = fig.add_axes([0.91,0.1,0.02,0.6]) pylab.colorbar(im, cax=axcolor) axmatrix.set_xticks(range(num_tar)) axmatrix.set_xticklabels(tar_idx2,rotation='vertical',fontsize = 10) axmatrix.xaxis.set_label_position('top') axmatrix.xaxis.tick_top() axmatrix.set_yticks(range(num_tar)) axmatrix.set_yticklabels(tar_idx1 ,fontsize = 10) axmatrix.yaxis.set_label_position('left') axmatrix.yaxis.tick_left() save("../RESULTS_GRAPH_CNN/statistics/sda/hierarchy_both_"+clus_method+'_'+result_ID, ext="png", close=True, verbose=True) ticks=numpy.arange(0,num_tar,1) fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(n_row_, vmin = 0, vmax = 1, cmap=plt.get_cmap('Spectral_r'),interpolation='nearest') fig.colorbar(cax) plt.xticks(ticks, tar_idx2,rotation='vertical',fontsize = 5) plt.yticks(ticks, tar_idx1,fontsize = 5) save("../RESULTS_GRAPH_CNN/statistics/sda/hier_both_"+clus_method+'_'+result_ID, ext="png", close=True, verbose=True) # # Group0 = ['ABL1','AKT1','AKT2','BRAF','CDK2','CSF1R','EGFR','FAK1','FGFR1','IGF1R','JAK2','KPCB','LCK','MAPK2','MET','MK01','MK10','MK14','MP2K1','PLK1','ROCK1','SRC','TGFR1','UROK','HXK4'] # # Group1 = ['FPPS','GLCM','HDAC2','HDAC8','GRIK1','DEF','AMPC','NRAM','ACE','PUR2','TYSY','AA2AR','SAHH','HS90A','PNPH','COMT','DPP4','LKHA4','CXCR4','NOS1','ADA'] # # Group2 = ['ALDR','FKB1A','FABP4','ANDR','ESR1','ESR2','GCR','MCR','PRGR','RXRA','THB', 'PGH1' , 'PGH2','PYRD','PDE5A','CP3A4','DHI1','PPARA','PPARD','PPARG'] # # Group3 = ['PA2GA','ADRB1','ADRB2','PTN1','HMDH','HIVPR','RENI','FA10','FA7','TRY1'] # # Group4 = ['ADA17','BACE1','CASP3','MMP13','THRB','TRYB1','KITH'] # # Group5 = ['CP2C9','ITAL','HIVINT','PYGM','KIF11','DYR','HIVRT','PARP1','GRIA2','AOFB','CAH2','ACES'] # Group0 = ['ABL1','AKT1','AKT2','BRAF','CDK2','CSF1R','EGFR','FAK1','FGFR1','IGF1R','JAK2','KPCB','LCK','MAPK2','MET','MK01','MK10','MK14','MP2K1','PLK1','ROCK1','SRC','TGFR1','HXK4'] # Group1 = ['FPPS','GLCM','HDAC2','HDAC8','GRIK1','DEF','AMPC','NRAM','PUR2','TYSY','SAHH','HS90A','PNPH','COMT','NOS1','ADA'] # Group2 = ['ACE','UROK','LKHA4','ADA17','BACE1','CASP3','MMP13','THRB','TRYB1','FA10','DPP4','TRY1','FA7','HIVPR','RENI'] # Group3 = ['HMDH','AA2AR','CXCR4','ADRB1','ADRB2'] # Group4 = ['ANDR','ESR1','ESR2','GCR','MCR','PRGR','RXRA','THB', 'PPARA','PPARD','PPARG','PGH1','PGH2','PYRD','PDE5A','CP3A4','DHI1','ALDR','FKB1A','FABP4','PA2GA'] # Group5 = ['KITH','CP2C9','ITAL','PTN1','HIVINT','PYGM','KIF11','DYR','HIVRT','PARP1','GRIA2','AOFB','CAH2','ACES'] # row_val_dict = {} # for i in range(n_row_.shape[0]): # row_val_dict[tar_idx1[i]]=n_row_[i] # row_order=Group0+Group1+Group2+Group3+Group4+Group5 # new_row=[] # for target in row_order: # new_row.append(row_val_dict[target]) # new_row=numpy.array(new_row) # col_val_dict = {} # trans_new_row = numpy.transpose(new_row) # for i in range(trans_new_row.shape[0]): # col_val_dict[tar_idx2[i]]=trans_new_row[i] # col_order=Group0+Group1+Group2+Group3+Group4+Group5 # new_col=[] # for target in col_order: # new_col.append(col_val_dict[target]) # new_col=numpy.array(new_col) # final = numpy.transpose(new_col) # ticks=numpy.arange(0,num_tar,1) # fig = plt.figure() # ax = fig.add_subplot(111) # cax = ax.matshow(final, vmin = 0, vmax = 1, cmap=plt.get_cmap('Spectral_r'),interpolation='nearest') # fig.colorbar(cax) # plt.xticks(ticks, col_order,rotation='vertical',fontsize = 5) # plt.yticks(ticks, row_order,fontsize = 5) # save("../RESULTS_GRAPH_CNN/statistics/sda/hier_both_by_group_"+clus_method+'_'+result_ID, ext="png", close=True, verbose=True)