import numpy as np from rdkit.Chem import MolFromSmiles from features import atom_features, bond_features import numpy mol_degrees = [0, 1, 2, 3, 4, 5] class MolGraph(object): def __init__(self): self.nodes = {} # dict of lists of nodes, keyed by node type def new_node(self, ntype, features=None, rdkit_ix=None): new_node = Node(ntype, features, rdkit_ix) self.nodes.setdefault(ntype, []).append(new_node) return new_node def add_subgraph(self, subgraph): old_nodes = self.nodes new_nodes = subgraph.nodes for ntype in set(old_nodes.keys()) | set(new_nodes.keys()): old_nodes.setdefault(ntype, []).extend(new_nodes.get(ntype, [])) def sort_nodes_by_degree(self, ntype): nodes_by_degree = {i : [] for i in mol_degrees} for node in self.nodes[ntype]: nodes_by_degree[len(node.get_neighbors(ntype))].append(node) new_nodes = [] for degree in mol_degrees: cur_nodes = nodes_by_degree[degree] self.nodes[(ntype, degree)] = cur_nodes new_nodes.extend(cur_nodes) self.nodes[ntype] = new_nodes def feature_array(self, ntype): assert ntype in self.nodes return np.array([node.features for node in self.nodes[ntype]]) def rdkit_ix_array(self): return np.array([node.rdkit_ix for node in self.nodes['atom']]) def neighbor_list(self, self_ntype, neighbor_ntype): assert self_ntype in self.nodes and neighbor_ntype in self.nodes neighbor_idxs = {n : i for i, n in enumerate(self.nodes[neighbor_ntype])} return [[neighbor_idxs[neighbor] for neighbor in self_node.get_neighbors(neighbor_ntype)] for self_node in self.nodes[self_ntype]] def bond_neighbor_feature(self): atom_idxs = {n : i for i, n in enumerate(self.nodes['atom'])} bond_idxs = {n : i for i, n in enumerate(self.nodes['bond'])} bond_feature_matrix = numpy.zeros((len(self.nodes['atom']),len(self.nodes['atom']),6)) for b in self.nodes['bond']: atom_pairs = b._neighbors atom1 = atom_idxs[atom_pairs[0]] atom2 = atom_idxs[atom_pairs[1]] bond_feature_matrix[atom1,atom2]=b.features bond_feature_matrix[atom2,atom1]=b.features return bond_feature_matrix def get_degree(self, ntype): all_node_degree=[] for node in self.nodes[ntype]: all_node_degree.append(len(node.get_neighbors(ntype))) return np.array(all_node_degree) class Node(object): __slots__ = ['ntype', 'features', '_neighbors', 'rdkit_ix'] def __init__(self, ntype, features, rdkit_ix): self.ntype = ntype self.features = features self._neighbors = [] self.rdkit_ix = rdkit_ix def add_neighbors(self, neighbor_list): for neighbor in neighbor_list: self._neighbors.append(neighbor) neighbor._neighbors.append(self) def get_neighbors(self, ntype): return [n for n in self._neighbors if n.ntype == ntype] def graph_from_smiles_tuple(smiles_tuple): graph_list = [graph_from_smiles(s) for s in smiles_tuple] big_graph = MolGraph() for subgraph in graph_list: big_graph.add_subgraph(subgraph) # This sorting allows an efficient (but brittle!) indexing later on. big_graph.sort_nodes_by_degree('atom') return big_graph def graph_from_smiles(smiles): graph = MolGraph() mol = MolFromSmiles(smiles) if not mol: raise ValueError("Could not parse SMILES string:", smiles) atoms_by_rd_idx = {} for atom in mol.GetAtoms(): new_atom_node = graph.new_node('atom', features=atom_features(atom), rdkit_ix=atom.GetIdx()) atoms_by_rd_idx[atom.GetIdx()] = new_atom_node for bond in mol.GetBonds(): atom1_node = atoms_by_rd_idx[bond.GetBeginAtom().GetIdx()] atom2_node = atoms_by_rd_idx[bond.GetEndAtom().GetIdx()] new_bond_node = graph.new_node('bond', features=bond_features(bond)) new_bond_node.add_neighbors((atom1_node, atom2_node)) atom1_node.add_neighbors((atom2_node,)) mol_node = graph.new_node('molecule') mol_node.add_neighbors(graph.nodes['atom']) return graph