from __future__ import print_function import networkx as nx import os, pickle import sys from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem.Draw import DrawingOptions import os import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt import numpy from mol_graph import * from extract_pts import * import json from process_poc_pretrain import * from io_utils_DUDE_ROC import * from rdkit.Chem import AllChem from cairosvg import svg2png from cairosvg import svg2ps from cairosvg import svg2pdf from rdkit.Chem import rdDepictor from rdkit.Chem.Draw import rdMolDraw2D max_poc_degrees = 20 max_nodes_in_poc = 50 max_mol_degrees = 6 max_nodes_in_mol = 60 figsize = (120, 120) highlight_color = (30.0/255.0, 100.0/255.0, 255.0/255.0) # A nice light blue. def get_FF(target,PDB_ID,DUDE_ctr): print (target, PDB_ID, DUDE_ctr) pro_name=target ff_name = None ligs=cut_ligand_all_atoms(PDB_ID,DUDE_ctr,False) if ligs == []: print('cant find ligs'+'\n') for l in ligs: [lig_ID,lig_chain,lig_no, ctr]=l #ff_name=PDB_ID+'_'+str(lig_ID)+'_'+str(lig_chain)+'_'+str(lig_no)+'.ff' ff_name=PDB_ID+'_'+str(lig_ID)+'.ff' return ff_name def load_DUDE_dataset(): filename = "../DATA_GRAPH_CNN/DUDE/DUDE_PDBID.csv" data=read_csv_DUDE(filename,"Target Name","PDB") ofile=open('../DATA_GRAPH_CNN/DUDE/DUDE_PDB_list.txt','w') target_lig=[] target_FF_dict={} for i in range(0,len(data[0])): target = data[0][i] PDB_ID = data[1][i] if PDB_ID.lower() in current_PDBs: ofile.write(PDB_ID+',') target_actives_file=open('../DATA_GRAPH_CNN/DUDE/'+target.lower()+'/'+'actives_final.ism') target_decoys_file=open('../DATA_GRAPH_CNN/DUDE/'+target.lower()+'/'+'decoys_final.ism') crystal_lig = open('../DATA_GRAPH_CNN/DUDE/'+target.lower()+'/'+'crystal_ligand.mol2') receptor=open('../DATA_GRAPH_CNN/DUDE/'+target.lower()+'/'+'receptor.pdb') lig_ctr=parse_crystal_lig(list(crystal_lig)) target_FF = get_FF(target,PDB_ID,lig_ctr) # Target FF if target_FF!=None: target_FF_dict[target]=target_FF target_actives_list=list(target_actives_file) target_decoys_list=list(target_decoys_file) tar_lig_pos=[] tar_lig_neg=[] # print "actives" for active in target_actives_list: eles=active.split() # print eles smiles=eles[0] # print smiles tar_lig_pos.append(smiles) # print "decoys" for decoy in target_decoys_list: eles=decoy.split() # print eles smiles=eles[0] # print smiles tar_lig_neg.append(smiles) target_lig.append([data[0][i],data[1][i],lig_ctr,tar_lig_pos,tar_lig_neg]) ofile.write('\n') with open('target_FF_dict','w') as myfile: json.dump(target_FF_dict,myfile) return target_lig, target_FF_dict def save(path, ext='png', close=True, verbose=True): # Extract the directory and filename from the given path directory = os.path.split(path)[0] filename = "%s.%s" % (os.path.split(path)[1], ext) if directory == '': directory = '.' if not os.path.exists(directory): os.makedirs(directory) savepath = os.path.join(directory, filename) plt.savefig(savepath+'.eps', format='eps',dpi=900) #plt.savefig(savepath) if close: plt.close() if verbose: print("Done") class PocketGraph(object): def __init__(self): self.nodes = {} # dict of lists of nodes, keyed by node type def new_node(self, ntype, pos, features=None, env_ix=None): new_node = Env(ntype, pos, features, env_ix) self.nodes.setdefault(ntype, []).append(new_node) return new_node def add_subgraph(self, subgraph): old_nodes = self.nodes new_nodes = subgraph.nodes for ntype in set(old_nodes.keys()) | set(new_nodes.keys()): old_nodes.setdefault(ntype, []).extend(new_nodes.get(ntype, [])) def get_degree(self, ntype): all_node_degree=[] for node in self.nodes[ntype]: all_node_degree.append(len(node.get_neighbors(ntype))) return numpy.array(all_node_degree) def feature_array(self, ntype): assert ntype in self.nodes return np.array([node.features for node in self.nodes[ntype]]) def pos_array(self, ntype): assert ntype in self.nodes return np.array([node.pos for node in self.nodes[ntype]]) def neighbor_list(self, self_ntype, neighbor_ntype): assert self_ntype in self.nodes and neighbor_ntype in self.nodes neighbor_idxs = {n : i for i, n in enumerate(self.nodes[neighbor_ntype])} for self_node in self.nodes[self_ntype]: # print "indi number of env:" neighbors=self_node.get_neighbors(neighbor_ntype) return [[neighbor_idxs[neighbor] for neighbor in self_node.get_neighbors(neighbor_ntype)] for self_node in self.nodes[self_ntype]] def env_ix_array(self): return np.array([node.env_ix for node in self.nodes['env']]) class Env(object): __slots__ = ['ntype', 'features', '_neighbors', 'pos', 'env_ix'] def __init__(self, ntype, pos, features, env_ix): self.ntype = ntype self.features = features self._neighbors = [] self.pos = pos self.env_ix = env_ix def add_neighbors(self, neighbor_list): for neighbor in neighbor_list: self._neighbors.append(neighbor) neighbor._neighbors.append(self) def get_neighbors(self, ntype): return [n for n in self._neighbors if n.ntype == ntype] def dist(env_1,env_2): return np.sqrt(np.sum((env_1.pos-env_2.pos)**2)) #files = ['1a00.ff'] def pocket_ff_to_numpy(ff_list,max_ff, min_ff, mean_ff): input_dir = '../DATA_GRAPH_CNN/ALL_ff/' input_ext = '.ff' files = [ f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir,f))] print ("ff_list") all_pocket=[] dat_num=0 for fn in ff_list: FV=[] #print f site_ID=fn.strip('.ff') # print "processing..." ele = fn.split('_') PDB = fn[0:4] lig = ele[1].split('.')[0] correct_fn = PDB.lower()+'_'+lig+'.ff' f = open(os.path.join(input_dir,correct_fn)) infile=list(f) for line in infile: S=line.split() if S!=[]: if len(S[0])>3: if S[0][0:3]=="Env": # print line feature_vec=numpy.zeros((480,)) for i in range (0,480): #S[1]-S[480] if max_ff[i]-min_ff[i]!=0: feature_vec[i]=(float(S[i+1])-min_ff[i])/(max_ff[i]-min_ff[i]) else: feature_vec[i]=float(S[i+1]) # print S[481:] x=float(S[482]) y=float(S[483]) z=float(S[484]) pos=[x,y,z] pos=numpy.array(pos) # feature_vec=feature_vec-min_ff # feature_vec/=(max_ff-min_ff) # feature_vec=feature_vec*2 -1 T=[feature_vec,pos] T=numpy.array(T) FV.append(T) f.close() vectors = numpy.array(FV) all_pocket.append([fn,vectors]) return all_pocket def get_pocket_FEATURE_attributes(ff_list,max_ff, min_ff, mean_ff): pockets = pocket_ff_to_numpy(ff_list,max_ff, min_ff, mean_ff) big_graph, mask = graph_from_pocket_tuple(pockets) env_features = big_graph.feature_array('env') return env_features, big_graph, mask def sum_and_stack(features, mask, num_envs, fp_length, max_nodes): # features: 250, 512 # mask: 250 features = features * mask.dimshuffle(0,'x') stacked_features = T.reshape(features, (int(num_envs/max_nodes), int(max_nodes), int(fp_length)), ndim=3) return T.sum(stacked_features,axis=1) def sum_and_stack_atoms(features, mask, num_input, max_nodes_in_mol, fp_length): # features: 250, 512 # mask: 250 features = features * mask.dimshuffle(0,'x') stacked_features = T.reshape(features, (int(num_input), int(max_nodes_in_mol), int(fp_length)), ndim=3) return T.sum(stacked_features,axis=1) def relu(X): """Rectified linear units (relu)""" return T.maximum(0,X) class PocketGraph(object): def __init__(self): self.nodes = {} # dict of lists of nodes, keyed by node type def new_node(self, ntype, pos, features=None, env_ix=None): new_node = Env(ntype, pos, features, env_ix) self.nodes.setdefault(ntype, []).append(new_node) return new_node def add_subgraph(self, subgraph): old_nodes = self.nodes new_nodes = subgraph.nodes for ntype in set(old_nodes.keys()) | set(new_nodes.keys()): old_nodes.setdefault(ntype, []).extend(new_nodes.get(ntype, [])) def get_degree(self, ntype): all_node_degree=[] for node in self.nodes[ntype]: all_node_degree.append(len(node.get_neighbors(ntype))) return numpy.array(all_node_degree) def feature_array(self, ntype): assert ntype in self.nodes return np.array([node.features for node in self.nodes[ntype]]) def pos_array(self, ntype): assert ntype in self.nodes return np.array([node.pos for node in self.nodes[ntype]]) def neighbor_list(self, self_ntype, neighbor_ntype): assert self_ntype in self.nodes and neighbor_ntype in self.nodes neighbor_idxs = {n : i for i, n in enumerate(self.nodes[neighbor_ntype])} for self_node in self.nodes[self_ntype]: # print "indi number of env:" neighbors=self_node.get_neighbors(neighbor_ntype) return [[neighbor_idxs[neighbor] for neighbor in self_node.get_neighbors(neighbor_ntype)] for self_node in self.nodes[self_ntype]] def env_ix_array(self): return np.array([node.env_ix for node in self.nodes['env']]) class Env(object): __slots__ = ['ntype', 'features', '_neighbors', 'pos', 'env_ix'] def __init__(self, ntype, pos, features, env_ix): self.ntype = ntype self.features = features self._neighbors = [] self.pos = pos self.env_ix = env_ix def add_neighbors(self, neighbor_list): for neighbor in neighbor_list: self._neighbors.append(neighbor) neighbor._neighbors.append(self) def get_neighbors(self, ntype): return [n for n in self._neighbors if n.ntype == ntype] def dist(env_1,env_2): return np.sqrt(np.sum((env_1.pos-env_2.pos)**2)) def pad_neighbors(neighbors): pad_neighbors = numpy.zeros((len(neighbors),len(neighbors))) for i in range (len(neighbors)): entry = neighbors[i] for e in range (len(entry)): pad_neighbors[i][entry[e]]=1 return pad_neighbors def pad_neighbors_bond(neighbors,num_atoms,num_bonds): pad_neighbors = numpy.zeros((num_atoms,num_bonds)) for i in range (len(neighbors)): entry = neighbors[i] for e in range (len(entry)): pad_neighbors[i][entry[e]]=1 return pad_neighbors def pad_degree(degrees,max_degrees): pad_degrees = numpy.zeros((len(degrees),max_degrees)) for i in range (len(degrees)): degree = degrees[i] pad_degrees[i][degree]=1 return pad_degrees def get_pocket_attributes(ff_list,max_ff, min_ff, mean_ff): pockets = pocket_ff_to_numpy(ff_list,max_ff, min_ff, mean_ff) big_graph, mask = graph_from_pocket_tuple(pockets) env_features = big_graph.feature_array('env') env_neighbors = big_graph.neighbor_list('env','env') env_neighbors = pad_neighbors(env_neighbors) env_degrees = big_graph.get_degree('env') env_degrees = pad_degree(env_degrees, max_poc_degrees) return env_features, env_neighbors, env_degrees, mask def get_atom_bond_dim(smiles_tuple): big_graph, mask = graph_from_smiles_tuple(smiles_tuple) mol_atom_features = big_graph.feature_array('atom') mol_bond_features = big_graph.feature_array('bond') num_atom_features = mol_atom_features.shape[1] num_bond_features = mol_bond_features.shape[1] return num_atom_features, num_bond_features def get_mol_attributes(smiles_tuple): big_graph, mask, rd_idx = graph_from_smiles_tuple(smiles_tuple) num_atoms = len(big_graph.nodes['atom']) num_bonds = len(big_graph.nodes['bond']) mol_atom_features = big_graph.feature_array('atom') mol_bond_features = big_graph.feature_array('bond') mol_atom_neighbors = big_graph.neighbor_list('atom','atom') mol_bond_neighbors = big_graph.neighbor_list('atom','bond') mol_atom_neighbors = pad_neighbors(mol_atom_neighbors) mol_bond_neighbors = pad_neighbors_bond(mol_bond_neighbors,num_atoms,num_bonds) mol_degrees = big_graph.get_degree('atom') mol_degrees = pad_degree(mol_degrees, max_mol_degrees) return mol_atom_features, mol_bond_features, mol_atom_neighbors, mol_bond_neighbors, mol_degrees, mask, rd_idx def graph_from_pocket_tuple(pocket_Env_list): graph_list = [graph_from_Env(s) for s in pocket_Env_list] big_graph = PocketGraph() for i in range (len(graph_list)): subgraph = graph_list[i] graph, mask = subgraph big_graph.add_subgraph(graph) if i ==0: big_graph_mask = mask else: big_graph_mask = numpy.concatenate((big_graph_mask, mask), axis=0) return big_graph, big_graph_mask def graph_from_smiles_tuple(smiles_tuple): graph_list = [graph_from_smiles(s) for s in smiles_tuple] big_graph = MolGraph() count=0 for i in range (len(graph_list)): subgraph = graph_list[i] graph, mask = subgraph if graph is None: continue else: big_graph.add_subgraph(graph) if count ==0: big_graph_mask = mask else: big_graph_mask = numpy.concatenate((big_graph_mask, mask), axis=0) count=count+1 big_graph_rd_idx = big_graph.rdkit_ix_array() return big_graph, big_graph_mask, big_graph_rd_idx def graph_from_smiles(smiles): graph = MolGraph() mol = MolFromSmiles(smiles) if not mol: print ("Could not parse SMILES string:") print (smiles) return None, None else: atoms_by_rd_idx = {} # for atom in mol.GetAtoms(): # new_atom_node = graph.new_node('atom', features=atom_features(atom), rdkit_ix=atom.GetIdx()) # atoms_by_rd_idx[atom.GetIdx()] = new_atom_node # dummpy_atom_shape=atom_features(atom).shape for atom in mol.GetAtoms(): features = atom_features(atom) if features[0]==False: return None, None new_atom_node = graph.new_node('atom', features=features[1], rdkit_ix=atom.GetIdx()) atoms_by_rd_idx[atom.GetIdx()] = new_atom_node dummpy_atom_shape=features[1].shape for bond in mol.GetBonds(): atom1_node = atoms_by_rd_idx[bond.GetBeginAtom().GetIdx()] atom2_node = atoms_by_rd_idx[bond.GetEndAtom().GetIdx()] new_bond_node = graph.new_node('bond', features=bond_features(bond)) new_bond_node.add_neighbors((atom1_node, atom2_node)) atom1_node.add_neighbors((atom2_node,)) num_of_atoms = len(mol.GetAtoms()) mask = numpy.zeros((max_nodes_in_mol,)) # print "num_of_atoms" # print num_of_atoms for i in range(num_of_atoms): mask[i]=1 if num_of_atoms0: for i in range (0,len(graph.nodes())): for j in range (i+1,len(graph.nodes())): env_1=graph.nodes()[i] env_2=graph.nodes()[j] pos_1 = numpy.array(graph.node[env_1]['pos']) pos_2 = numpy.array(graph.node[env_2]['pos']) if np.sqrt(np.sum((pos_1-pos_2)**2))<7: graph.add_edge(env_1,env_2) return [site_ID, graph] def draw_pockets_with_highlights(pocket_ff, sali_score, max_ff, min_ff, mean_ff, pend): [pocket_name,subgraph] = nx_graph_from_Env(pocket_ff, sali_score,max_ff, min_ff, mean_ff) pos=nx.get_node_attributes(subgraph,'pos') score=nx.get_node_attributes(subgraph,'score') labels=nx.get_node_attributes(subgraph,'res') float_score=[] for i in range(len(score)): s = score[i] float_score.append(s) float_score=numpy.array(float_score) Dxy_pos={} Dyz_pos={} Dzx_pos={} for i in range(len(pos)): p = pos[i] Dxy_pos[i]=p[0:2] Dyz_pos[i]=p[1:3] Dzx_pos[i]=[p[2],p[0]] nx.draw_networkx_edges(subgraph,Dxy_pos) nx.draw_networkx_nodes(subgraph,pos=Dxy_pos,node_color=float_score,cmap=plt.cm.jet) nx.draw_networkx_labels(subgraph,Dxy_pos,labels,font_size=12) vmin = 0#min(float_score) vmax = 1#max(float_score) sm = plt.cm.ScalarMappable(cmap=plt.cm.jet, norm=plt.Normalize(vmin=vmin, vmax=vmax)) sm._A = [] plt.colorbar(sm) save("../SALI_figs_sda_both_paper/"+"poc_xy_"+pocket_name[0:4]+'_'+pend+".png") nx.draw_networkx_edges(subgraph,Dyz_pos) nx.draw_networkx_nodes(subgraph,pos=Dyz_pos,node_color=float_score,cmap=plt.cm.jet) nx.draw_networkx_labels(subgraph,Dyz_pos,labels,font_size=12) vmin = 0#min(float_score) vmax = 1#max(float_score) sm = plt.cm.ScalarMappable(cmap=plt.cm.jet, norm=plt.Normalize(vmin=vmin, vmax=vmax)) sm._A = [] plt.colorbar(sm) save("../SALI_figs_sda_both_paper/"+"poc_yz_"+pocket_name[0:4]+'_'+pend+".png") nx.draw_networkx_edges(subgraph,Dzx_pos) nx.draw_networkx_nodes(subgraph,pos=Dzx_pos,node_color=float_score,cmap=plt.cm.jet) nx.draw_networkx_labels(subgraph,Dzx_pos,labels,font_size=12) vmin = 0#min(float_score) vmax = 1#max(float_score) sm = plt.cm.ScalarMappable(cmap=plt.cm.jet, norm=plt.Normalize(vmin=vmin, vmax=vmax)) sm._A = [] plt.colorbar(sm) save("../SALI_figs_sda_both_paper/"+"poc_zx_"+pocket_name[0:4]+'_'+pend+".png") def moltosvg(mol,highlight, molSize=(500,500),kekulize=True): mc = Chem.Mol(mol.ToBinary()) if kekulize: try: Chem.Kekulize(mc) except: mc = Chem.Mol(mol.ToBinary()) if not mc.GetNumConformers(): rdDepictor.Compute2DCoords(mc) drawer = rdMolDraw2D.MolDraw2DSVG(molSize[0],molSize[1]) drawer.DrawMolecule(mc,highlightAtoms=highlight) # drawer.DrawMolecule(mc) drawer.FinishDrawing() svg = drawer.GetDrawingText() # It seems that the svg renderer used doesn't quite hit the spec. # Here are some fixes to make it work in the notebook, although I think # the underlying issue needs to be resolved at the generation step return svg#.replace('svg:','') def draw_molecule_with_highlights(smiles, sali_score, thres, rd_idx,pend): mol = Chem.MolFromSmiles(smiles) num_atom = len(mol.GetAtoms()) print ("smiles") print (smiles) drawoptions = DrawingOptions() drawoptions.selectColor = highlight_color drawoptions.elemDict = {} # Don't color nodes based on their element. drawoptions.bgColor=None print ("numpy.sum(sali_score,axis=1)") print (numpy.sum(sali_score,axis=1)) num_atom = max(numpy.where(sali_score)[0])+1 print ("num_atom") print (num_atom) sali_score = sali_score[0:num_atom,:] sum_s = numpy.sum(sali_score,axis=1) print ("sum_s") print (sum_s) arg_sort_sum = numpy.argsort(sum_s) print ("arg_sort_sum") print (arg_sort_sum) sorted_sum = sum_s[arg_sort_sum] print ("sorted_sum") print (sorted_sum) num_positive_atoms = len(numpy.where(sorted_sum>0)[0]) print ("num_positive_atoms") print (num_positive_atoms) positive_sum = sorted_sum[-num_positive_atoms:] highlight_list_our_ixs = arg_sort_sum[-num_positive_atoms:] print ("positive_sum") print (positive_sum) if len(positive_sum)>10: positive_sum=positive_sum[-10:] highlight_list_our_ixs=highlight_list_our_ixs[-10:] print ("highlight_list_our_ixs") print (highlight_list_our_ixs) print ("rd_idx") print (rd_idx) max_s = numpy.max(sum_s) min_s = numpy.min(sum_s) diff_s = max_s-min_s norm_s = (sum_s-min_s)/diff_s highlight = [int(rd_idx[our_ix]) for our_ix in highlight_list_our_ixs] print ("highlight") print (highlight) print ("highlight_list_our_ixs") print (highlight_list_our_ixs) print ("rd_idx") print (rd_idx) svg2png(moltosvg(mol,highlight),write_to="../SALI_FP_MOL_paper/"+pend+'.png',dpi=600) svg2ps(moltosvg(mol,highlight),write_to="../SALI_FP_MOL_paper/"+pend+'.ps',dpi=600) svg2pdf(moltosvg(mol,highlight),write_to="../SALI_FP_MOL_paper/"+pend+'.pdf',dpi=600) def plot(file_name,Weights_ID): all_poc, max_ff, min_ff, mean_ff = get_all_pocs() print ("I'm here!!!") pickle_keys = numpy.load(file_name).keys() W_hid_=numpy.load(file_name)[pickle_keys[6]] # pro_name_set = load_test_pro(fold) # with open('target_FF_dict','r') as infile: # target_FF_dict = json.load(infile) batch_size = 5 # pro_name_set = ['SRC']#,]#'ESR1','BACE1']#, 'BACE1', # M_index = [3] pro_name_set = ['HDAC2'] M_index = [0] # pro_name_set = ['ESR1'] # M_index = [0] if os.path.isfile('../DATA_GRAPH_CNN/all_targets.json'): with open('../DATA_GRAPH_CNN/all_targets.json','r') as infile: all_targets = json.load(infile) else: all_targets = gene_target_smiles_for_ROC_dict() with open('../DATA_GRAPH_CNN/all_targets.json','w') as myfile: json.dump(all_targets,myfile) for pro_n in all_targets.keys(): print (pro_n) if pro_n in pro_name_set: # print ("all_targets") # print (all_targets[pro_n]) test_pockets=[] test_mols=[] labels_test=[] test_target_mol = all_targets[pro_n] print (len(test_target_mol)) for entry in test_target_mol: pocket_name=entry[0] sm_name=entry[1] label_name=entry[2] p_name=entry[3] test_pockets.append(pocket_name) test_mols.append(sm_name) labels_test.append(label_name) # ###### # #689: 8, b0 # #466: 24, b0 # test_pockets=test_pockets[24:] # test_mols=test_mols[24:] # labels_test=labels_test[24:] # ###### # for i in range(50): # smiles=test_mols[i] # mol = Chem.MolFromSmiles(smiles) # num_atom = len(mol.GetAtoms()) # print ("i, smiles") # print (i, smiles) # drawoptions = DrawingOptions() # drawoptions.selectColor = highlight_color # drawoptions.elemDict = {} # Don't color nodes based on their element. # drawoptions.bgColor=None # svg2png(moltosvg(mol,None),write_to="../SALI_FP_MOL_paper/"+pro_n+"_mol_"+str(i)+'.png') # this_poc = target_FF_dict[pro_n] input_dir='../SALI_GRAPH_CNN/'+pro_n+'/' pockets = test_pockets smiles = test_mols labels = labels_test #(n_batches-2),(n_batches-1),1,2, for minibatch_index in M_index: print ("minibatch_index") print (minibatch_index) if os.path.isfile('../SALI_GRAPH_CNN/'+pro_n+"/hid_out_batch_"+str(minibatch_index)+'_'+Weights_ID+'_'+pro_n+'.npy'): hid_out = numpy.load('../SALI_GRAPH_CNN/'+pro_n+"/hid_out_batch_"+str(minibatch_index)+'_'+Weights_ID+'_'+pro_n+'.npy') grad_hid = numpy.load('../SALI_GRAPH_CNN/'+pro_n+"/grad_hid_out_batch_"+str(minibatch_index)+'_'+Weights_ID+'_'+pro_n+'.npy') poc_fp = numpy.load('../SALI_GRAPH_CNN/'+pro_n+'/sum_poc_fp_layer1_batch_'+str(minibatch_index)+"_"+Weights_ID+'_'+pro_n+'.npy') mol_fp = numpy.load('../SALI_GRAPH_CNN/'+pro_n+'/sum_mol_fp_layer1_batch_'+str(minibatch_index)+"_"+Weights_ID+'_'+pro_n+'.npy') sali_hid = hid_out[0,:]*grad_hid[0,:] # this 0 is for the 0th example hid_ind = numpy.where(sali_hid>0)[0] poc_fp = poc_fp[0,:] # this 0 is for the 0th example mol_fp = mol_fp[0,:] # hid_ind=[20,73] print ("there are files!") for h_idx in hid_ind: # hid that has positive gradient hid_score = sali_hid[h_idx] print (pro_n,"h_idx",h_idx) print (pro_n,"hid_score",hid_score) poc_w = W_hid_[0:512,h_idx] mol_w = W_hid_[512:,h_idx] x_f=poc_fp*poc_w y_f=mol_fp*mol_w poc_final_idx=numpy.where(x_f>0)[0] # this 0 is just to index it out of np.where array mol_final_idx=numpy.where(y_f>0)[0] # this 0 is just to index it out of np.where array # ######################### # ###### Plot POCKET ###### # ######################### files = [ f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir,f))] files = [ f for f in files if ('poc_layer1') in f ] files = [ f for f in files if ('hid_'+str(h_idx)+'_') in f ] files = [ f for f in files if ('batch_'+str(minibatch_index)+'_') in f ] files = [ f for f in files if '.npy' in f ] files = [ f for f in files if pro_n in f ] files = [ f for f in files if Weights_ID in f] all_poc_sali_score = numpy.zeros((250,480)) for fn in files: poc_sali_score = numpy.load('../SALI_GRAPH_CNN/'+pro_n+'/'+fn) # poc_sali_score = numpy.load('../SALI_GRAPH_CNN/'+fn) all_poc_sali_score = all_poc_sali_score + poc_sali_score this_poc = pockets[minibatch_index * batch_size: (minibatch_index + 1) * batch_size][0] draw_pockets_with_highlights(this_poc,all_poc_sali_score,max_ff, min_ff, mean_ff,pend=pro_n+'_poc_layer1_hid_'+str(h_idx)+'_batch_'+str(minibatch_index)) for fp in poc_final_idx: files = [ f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir,f))] files = [ f for f in files if ('poc_layer1') in f ] files = [ f for f in files if ('hid_'+str(h_idx)+'_') in f ] files = [ f for f in files if ('_f_'+str(fp)+'_') in f ] files = [ f for f in files if ('batch_'+str(minibatch_index)+'_') in f ] files = [ f for f in files if '.npy' in f ] files = [ f for f in files if pro_n in f ] files = [ f for f in files if Weights_ID in f] all_poc_sali_score = numpy.zeros((250,480)) for fn in files: poc_sali_score = numpy.load('../SALI_GRAPH_CNN/'+pro_n+'/'+fn) # poc_sali_score = numpy.load('../SALI_GRAPH_CNN/'+fn) all_poc_sali_score = all_poc_sali_score + poc_sali_score this_poc = pockets[minibatch_index * batch_size: (minibatch_index + 1) * batch_size][0] draw_pockets_with_highlights(this_poc,all_poc_sali_score,max_ff, min_ff, mean_ff, pend=pro_n+'_poc_layer1_hid_'+str(h_idx)+'_fp_'+str(fp)+'_batch_'+str(minibatch_index)) ######################### ####### Plot MOL ######## ######################### files = [ f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir,f))] files = [ f for f in files if ('mol_atom_layer1') in f ] files = [ f for f in files if ('hid_'+str(h_idx)) in f ] files = [ f for f in files if ('batch_'+str(minibatch_index)+'_') in f ] files = [ f for f in files if '.npy' in f ] files = [ f for f in files if 'dat' not in f ] files = [ f for f in files if Weights_ID in f] # files = [ f for f in files if pro_n in f ] batch_mol = smiles[minibatch_index * batch_size: (minibatch_index + 1) * batch_size] mol_atom_features, mol_bond_features, mol_atom_neighbors, mol_bond_neighbors, mol_degrees, mol_mask, rd_idx = get_mol_attributes(batch_mol) all_mol_sali_score = numpy.zeros((300,62)) for fn in files: mol_sali_score = numpy.load('../SALI_GRAPH_CNN/'+pro_n+'/'+fn) # mol_sali_score = numpy.load('../sali_graph_0410/'+pro_n+'/'+fn) all_mol_sali_score = all_mol_sali_score + mol_sali_score this_mol = batch_mol[0] draw_molecule_with_highlights(this_mol, all_mol_sali_score[0:60]*mol_atom_features[0:60], thres=0.0, rd_idx=rd_idx, pend=pro_n+'_mol_layer1_hid_'+str(h_idx)+'_batch_'+str(minibatch_index)) for fp in mol_final_idx: files = [ f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir,f))] files = [ f for f in files if ('mol_atom_layer1') in f ] files = [ f for f in files if ('hid_'+str(h_idx)) in f ] files = [ f for f in files if ('_f_'+str(fp)+'_') in f ] files = [ f for f in files if ('batch_'+str(minibatch_index)+'_') in f ] files = [ f for f in files if '.npy' in f ] files = [ f for f in files if 'dat' not in f ] files = [ f for f in files if Weights_ID in f] # files = [ f for f in files if pro_n in f ] batch_mol = smiles[minibatch_index * batch_size: (minibatch_index + 1) * batch_size] mol_atom_features, mol_bond_features, mol_atom_neighbors, mol_bond_neighbors, mol_degrees, mol_mask, rd_idx = get_mol_attributes(batch_mol) all_mol_sali_score = numpy.zeros((300,62)) for fn in files: mol_sali_score = numpy.load('../SALI_GRAPH_CNN/'+pro_n+'/'+fn) all_mol_sali_score = all_mol_sali_score + mol_sali_score this_mol = batch_mol[0] print (this_mol) print ("mol_atom_features[0:60]") print (mol_atom_features[0:60]) print ("all_mol_sali_score[0:60]") print (all_mol_sali_score[0:60]) draw_molecule_with_highlights(this_mol, all_mol_sali_score[0:60]*mol_atom_features[0:60], thres=0.0, rd_idx=rd_idx, pend=pro_n+'_mol_layer1_hid_'+str(h_idx)+'_fp_'+str(fp)+'_batch_'+str(minibatch_index)) if __name__ == '__main__': Weights_ID = 'sda_fix0_assay_poc_5_nodec_e_1_190000' Weights_file_name='../weights_Graph/weight_SDA_both_fix_0_assay_poc_5_all_folds_nodec_e_1_190000.zip' plot(Weights_file_name,Weights_ID)