# Motivation

In the previous post, I have looked into how a molecular graph is constructed and message can be passed around in a MPNN architecture. In this post, I'll take a look at how the graph neural net can be trained. The details of message passing update will vary by implementation; here we choose what was used in this paper.

Again, many code examples were taken from chemprop repository. The code was initially taken from the chemprop repository and I edited them for the sake of simplicity.

# Recap Let's briefly recap how the molecular graph is constructed and messages are passed around. As shown in the above figure, each atom and bonds are labeled as $x$ and $e$. Here we are discussing D-MPNN, which represents the graph with a directional edges, which means, for each bond between two atoms $v$ and $w$, there are two directional bond, $e_{vw}$ and $e_{wv}$.

The initial hidden message is constructed as $h_{vw}^0 = \tau (W_i \mathrm{cat}(x_v, e_{vw}))$ where $W_i$ is a learned matrix and $\tau$ is an activation function. Above figure shows two messages, $m_{57}$ and $m_{54}$, are constructed. First, the message from atom $v$ to $w$ is the sum of all the hidden state for the incoming bonds to $v$ (excluding the one originating from $w$). Then the learned matrix $W_m$ is multiplied to the message and the initial hidden message is added to form the new hidden message for the depth 1. This is repeated several times for the message to be passed around to multiple depth.

After the messages are passed up to the given number of depth, the hidden states are summed to be a final message per each atom (all incoming hidden state) and the hidden state for each atom is computed as follows:

$$m_v = \sum_{k \in N(v)} h_{kv}^t$$

$$h_v = \tau(W_a \mathrm{cat} (x_v, m_v))$$

Finally, the readout phase uses the sum of all $h_v$ to obtain the feature vector of the molecule and property prediction is carried out using a fully-connected feed forward network.

# Train Data

As an example, I'll use Enamine Real's diversity discovery set composed of 10240 compounds. This dataset contains some molecular properties, such as ClogP and TPSA, so we should be able to train a GCNN that predicts those properties.

For this example, let's train using ClogP values.

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib

from io import BytesIO
import pandas as pd
import numpy as np
from IPython.display import SVG

# RDKit
import rdkit
from rdkit.Chem import PandasTools
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem import rdRGroupDecomposition
from rdkit.Chem.Draw import IPythonConsole #Needed to show molecules
from rdkit.Chem import Draw
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying defaults

DrawingOptions.bondLineWidth=1.8
IPythonConsole.ipython_useSVG=True
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.warning')
print(rdkit.__version__)

# pytorch
import torch
from torch.utils.data import DataLoader, Dataset, Sampler
from torch import nn

# misc
from typing import Dict, Iterator, List, Optional, Union, OrderedDict, Tuple
from tqdm.notebook import tqdm
from functools import reduce

2020.03.2


# we will define a class which holds various parameter for D-MPNN
class TrainArgs:
smiles_column = None
no_cuda = False
gpu = None
num_workers = 8
batch_size = 50
no_cache_mol = False
dataset_type = 'regression'
task_names = []
seed = 0
hidden_size = 300
bias = False
depth = 3
dropout = 0.0
undirected = False
aggregation = 'mean'
aggregation_norm = 100
ffn_num_layers = 2
ffn_hidden_size = 300
init_lr = 1e-4
max_lr = 1e-3
final_lr = 1e-4
num_lrs = 1
warmup_epochs = 2.0
epochs = 30

@property
def device(self) -> torch.device:
"""The :code:torch.device on which to load and process data and models."""
if not self.cuda:
return torch.device('cpu')

return torch.device('cuda', self.gpu)

@device.setter
def device(self, device: torch.device) -> None:
self.cuda = device.type == 'cuda'
self.gpu = device.index

@property
def cuda(self) -> bool:
"""Whether to use CUDA (i.e., GPUs) or not."""
return not self.no_cuda and torch.cuda.is_available()

@cuda.setter
def cuda(self, cuda: bool) -> None:
self.no_cuda = not cuda

args = TrainArgs()
args.data_path = 'files/enamine_discovery_diversity_set_10240.csv'
args.target_column = 'ClogP'
args.smiles_column = 'SMILES'
args.dataset_type = 'regression'
args.task_names = [args.target_column]
args.num_tasks = 1

df = pd.read_csv(args.data_path)
df.head()

Name SMILES Catalog ID PlateID Well MW (desalted) ClogP HBD TPSA RotBonds
0 NaN CN(C(=O)NC1CCOc2ccccc21)C(c1ccccc1)c1ccccn1 Z447596076 1186474-R-001 A02 373.448 2.419 1 54.46 4
1 NaN Cn1cc(C(=O)N2CCC(OC3CCOC3)CC2)c(C2CC2)n1 Z2180753156 1186474-R-001 A03 319.399 -0.570 0 56.59 4
2 NaN CC(=O)N(C)C1CCN(C(=O)c2ccccc2-c2ccccc2C(=O)O)CC1 Z2295858832 1186474-R-001 A04 380.437 0.559 1 77.92 4
3 NaN COCC1(CNc2cnccc2C#N)CCNCC1 Z2030994006 1186474-R-001 A05 260.335 0.902 2 69.97 5
4 NaN CCCCOc1ccc(-c2nnc3n2CCCC3)cc1OC Z273627850 1186474-R-001 A06 301.383 3.227 0 49.17 6

from random import Random

# Cache of graph featurizations
CACHE_GRAPH = True
SMILES_TO_GRAPH = {}

def cache_graph():
return CACHE_GRAPH

def set_cache_graph(cache_graph):
global CACHE_GRAPH
CACHE_GRAPH = cache_graph

# Cache of RDKit molecules
CACHE_MOL = True
SMILES_TO_MOL: Dict[str, Chem.Mol] = {}

def cache_mol() -> bool:
r"""Returns whether RDKit molecules will be cached."""
return CACHE_MOL

def set_cache_mol(cache_mol: bool) -> None:
r"""Sets whether RDKit molecules will be cached."""
global CACHE_MOL
CACHE_MOL = cache_mol

# Atom feature sizes
MAX_ATOMIC_NUM = 100
ATOM_FEATURES = {
'atomic_num': list(range(MAX_ATOMIC_NUM)),
'degree': [0, 1, 2, 3, 4, 5],
'formal_charge': [-1, -2, 1, 2, 0],
'chiral_tag': [0, 1, 2, 3],
'num_Hs': [0, 1, 2, 3, 4],
'hybridization': [
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2
],
}

# Distance feature sizes
PATH_DISTANCE_BINS = list(range(10))
THREE_D_DISTANCE_MAX = 20
THREE_D_DISTANCE_STEP = 1
THREE_D_DISTANCE_BINS = list(range(0, THREE_D_DISTANCE_MAX + 1, THREE_D_DISTANCE_STEP))

# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2
EXTRA_ATOM_FDIM = 0
BOND_FDIM = 14

def get_atom_fdim():
"""Gets the dimensionality of the atom feature vector."""
return ATOM_FDIM + EXTRA_ATOM_FDIM

def get_bond_fdim(atom_messages=False):
"""Gets the dimensionality of the bond feature vector.
"""
return BOND_FDIM + (not atom_messages) * get_atom_fdim()

def onek_encoding_unk(value: int, choices: List[int]):
encoding =  * (len(choices) + 1)
index = choices.index(value) if value in choices else -1
encoding[index] = 1

return encoding

def atom_features(atom: Chem.rdchem.Atom, functional_groups: List[int] = None):
"""Builds a feature vector for an atom.
"""
features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \
onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \
onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \
onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
[1 if atom.GetIsAromatic() else 0] + \
[atom.GetMass() * 0.01]  # scaled to about the same range as other features
if functional_groups is not None:
features += functional_groups
return features

def bond_features(bond: Chem.rdchem.Bond):
"""Builds a feature vector for a bond.
"""
if bond is None:
fbond =  +  * (BOND_FDIM - 1)
else:
bt = bond.GetBondType()
fbond = [
0,  # bond is not None
bt == Chem.rdchem.BondType.SINGLE,
bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
(bond.GetIsConjugated() if bt is not None else 0),
(bond.IsInRing() if bt is not None else 0)
]
fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
return fbond

class MoleculeDatapoint:
def __init__(self,
smiles: str,
targets: List[Optional[float]] = None,
row: OrderedDict = None):

self.smiles = smiles
self.targets = targets
self.features = []
self.row = row

@property
def mol(self) -> Chem.Mol:
"""Gets the corresponding list of RDKit molecules for the corresponding SMILES."""
mol = SMILES_TO_MOL.get(self.smiles, Chem.MolFromSmiles(self.smiles))
if cache_mol():
SMILES_TO_MOL[self.smiles] = mol
return mol

def set_features(self, features: np.ndarray) -> None:
"""Sets the features of the molecule.
"""
self.features = features

def extend_features(self, features: np.ndarray) -> None:
"""Extends the features of the molecule.
"""
self.features = np.append(self.features, features) if self.features is not None else features

def num_tasks(self) -> int:
"""Returns the number of prediction tasks.
"""
return len(self.targets)

def set_targets(self, targets: List[Optional[float]]):
"""Sets the targets of a molecule.
"""
self.targets = targets

def reset_features_and_targets(self) -> None:
"""Resets the features and targets to their raw values."""
self.features, self.targets = self.raw_features, self.raw_targets

class MoleculeDataset(Dataset):
def __init__(self, data: List[MoleculeDatapoint]):
self._data = data
self._scaler = None
self._batch_graph = None
self._random = Random()

def smiles(self) -> List[str]:
return [d.smiles for d in self._data]

def mols(self) -> List[Chem.Mol]:
return [d.mol for d in self._data]

def targets(self) -> List[List[Optional[float]]]:
return [d.targets for d in self._data]

def num_tasks(self) -> int:
return self._data.num_tasks() if len(self._data) > 0 else None

def set_targets(self, targets: List[List[Optional[float]]]) -> None:
assert len(self._data) == len(targets)
for i in range(len(self._data)):
self._data[i].set_targets(targets[i])

def reset_features_and_targets(self) -> None:
for d in self._data:
d.reset_features_and_targets()

def __len__(self) -> int:
return len(self._data)

def __getitem__(self, item) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]:
return self._data[item]

def batch_graph(self):
if self._batch_graph is None:
self._batch_graph = []

mol_graphs = []
for d in self._data:
mol_graphs_list = []
if d.smiles in SMILES_TO_GRAPH:
mol_graph = SMILES_TO_GRAPH[d.smiles]
else:
mol_graph = MolGraph(d.mol)
if cache_graph():
SMILES_TO_GRAPH[d.smiles] = mol_graph
mol_graphs.append([mol_graph])

self._batch_graph = [BatchMolGraph([g[i] for g in mol_graphs]) for i in range(len(mol_graphs))]

return self._batch_graph

def features(self) -> List[np.ndarray]:
"""
Returns the features associated with each molecule (if they exist).

:return: A list of 1D numpy arrays containing the features for each molecule or None if there are no features.
"""
if len(self._data) == 0 or self._data.features is None:
return None

return [d.features for d in self._data]

def index_select_ND(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""Selects the message features from source corresponding to the atom or bond indices in index.
"""
index_size = index.size()             # (num_atoms/num_bonds, max_num_bonds)
suffix_dim = source.size()[1:]        # (hidden_size,)
final_size = index_size + suffix_dim  # (num_atoms/num_bonds, max_num_bonds, hidden_size)

target = source.index_select(dim=0, index=index.view(-1)) # (num_atoms/num_bonds * max_num_bonds, hidden_size)
target = target.view(final_size)                          # (num_atoms/num_bonds, max_num_bonds, hidden_size)
return target

class MolGraph:
def __init__(self, mol, atom_descriptors=None):
# Convert SMILES to RDKit molecule if necessary
if type(mol) == str:
mol = Chem.MolFromSmiles(mol)

self.n_atoms = 0  # number of atoms
self.n_bonds = 0  # number of bonds
self.f_atoms = []  # mapping from atom index to atom features
self.f_bonds = []  # mapping from bond index to concat(in_atom, bond) features
self.a2b = []  # mapping from atom index to incoming bond indices
self.b2a = []  # mapping from bond index to the index of the atom the bond is coming from
self.b2revb = []  # mapping from bond index to the index of the reverse bond

# Get atom features
self.f_atoms = [atom_features(atom) for atom in mol.GetAtoms()]
if atom_descriptors is not None:
self.f_atoms = [f_atoms + descs.tolist() for f_atoms, descs in zip(self.f_atoms, atom_descriptors)]

self.n_atoms = len(self.f_atoms)

# Initialize atom to bond mapping for each atom
for _ in range(self.n_atoms):
self.a2b.append([])

# Get bond features
for a1 in range(self.n_atoms):
for a2 in range(a1 + 1, self.n_atoms):
bond = mol.GetBondBetweenAtoms(a1, a2)

if bond is None:
continue

f_bond = bond_features(bond)
self.f_bonds.append(self.f_atoms[a1] + f_bond)
self.f_bonds.append(self.f_atoms[a2] + f_bond)

# Update index mappings
b1 = self.n_bonds
b2 = b1 + 1
self.a2b[a2].append(b1)  # b1 = a1 --> a2
self.b2a.append(a1)
self.a2b[a1].append(b2)  # b2 = a2 --> a1
self.b2a.append(a2)
self.b2revb.append(b2)
self.b2revb.append(b1)
self.n_bonds += 2

class BatchMolGraph:
"""A BatchMolGraph represents the graph structure and featurization of a batch of molecules.
"""

def __init__(self, mol_graphs: List[MolGraph]):
self.atom_fdim = get_atom_fdim()
self.bond_fdim = get_bond_fdim()

# Start n_atoms and n_bonds at 1 b/c zero padding
self.n_atoms = 1  # number of atoms (start at 1 b/c need index 0 as padding)
self.n_bonds = 1  # number of bonds (start at 1 b/c need index 0 as padding)
self.a_scope = []  # list of tuples indicating (start_atom_index, num_atoms) for each molecule
self.b_scope = []  # list of tuples indicating (start_bond_index, num_bonds) for each molecule

# All start with zero padding so that indexing with zero padding returns zeros
f_atoms = [ * self.atom_fdim]  # atom features
f_bonds = [ * self.bond_fdim]  # combined atom/bond features
a2b = [[]]  # mapping from atom index to incoming bond indices
b2a =   # mapping from bond index to the index of the atom the bond is coming from
b2revb =   # mapping from bond index to the index of the reverse bond
for mol_graph in mol_graphs:
f_atoms.extend(mol_graph.f_atoms)
f_bonds.extend(mol_graph.f_bonds)

for a in range(mol_graph.n_atoms):
a2b.append([b + self.n_bonds for b in mol_graph.a2b[a]])

for b in range(mol_graph.n_bonds):
b2a.append(self.n_atoms + mol_graph.b2a[b])
b2revb.append(self.n_bonds + mol_graph.b2revb[b])

self.a_scope.append((self.n_atoms, mol_graph.n_atoms))
self.b_scope.append((self.n_bonds, mol_graph.n_bonds))
self.n_atoms += mol_graph.n_atoms
self.n_bonds += mol_graph.n_bonds

self.max_num_bonds = max(1, max(len(in_bonds) for in_bonds in a2b))  # max with 1 to fix a crash in rare case of all single-heavy-atom mols

self.f_atoms = torch.FloatTensor(f_atoms)
self.f_bonds = torch.FloatTensor(f_bonds)
self.a2b = torch.LongTensor([a2b[a] +  * (self.max_num_bonds - len(a2b[a])) for a in range(self.n_atoms)])
self.b2a = torch.LongTensor(b2a)
self.b2revb = torch.LongTensor(b2revb)
self.b2b = None  # try to avoid computing b2b b/c O(n_atoms^3)
self.a2a = None  # only needed if using atom messages

def get_components(self, atom_messages: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor,
torch.LongTensor, torch.LongTensor, torch.LongTensor,
List[Tuple[int, int]], List[Tuple[int, int]]]:
return self.f_atoms, self.f_bonds, self.a2b, self.b2a, self.b2revb, self.a_scope, self.b_scope

def get_b2b(self) -> torch.LongTensor:
"""Computes (if necessary) and returns a mapping from each bond index to all the incoming bond indices.
"""
if self.b2b is None:
b2b = self.a2b[self.b2a]  # num_bonds x max_num_bonds
# b2b includes reverse edge for each bond so need to mask out
revmask = (b2b != self.b2revb.unsqueeze(1).repeat(1, b2b.size(1))).long()  # num_bonds x max_num_bonds
self.b2b = b2b * revmask

return self.b2b

def get_a2a(self) -> torch.LongTensor:
"""Computes (if necessary) and returns a mapping from each atom index to all neighboring atom indices.
"""
if self.a2a is None:
# b = a1 --> a2
# a2b maps a2 to all incoming bonds b
# b2a maps each bond b to the atom it comes from a1
# thus b2a[a2b] maps atom a2 to neighboring atoms a1
self.a2a = self.b2a[self.a2b]  # num_atoms x max_num_bonds

return self.a2a

# prepare data set
data = MoleculeDataset([
MoleculeDatapoint(
smiles=row[args.smiles_column],
targets=[row[args.target_column]]
) for i, row in df.iterrows()
])

# split data into train, validation and test set
random = Random()
sizes = [0.8, 0.1, 0.1]

indices = list(range(len(data)))
random.shuffle(indices)

train_size = int(sizes * len(data))
train_val_size = int((sizes + sizes) * len(data))

train = [data[i] for i in indices[:train_size]]
val = [data[i] for i in indices[train_size:train_val_size]]
test = [data[i] for i in indices[train_val_size:]]

train_data = MoleculeDataset(train)
val_data = MoleculeDataset(val)
test_data = MoleculeDataset(test)

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-1-018356360359> in <module>
1 # prepare data set
----> 2 data = MoleculeDataset([
3     MoleculeDatapoint(
4         smiles=row[args.smiles_column],
5         targets=[row[args.target_column]]

NameError: name 'MoleculeDataset' is not defined

# MPNN Model

Let's create a MPNN model. The model is composed of encoder and feed-forward network (FFN). The encoder is same as the one we discussed before and the FFN is defined as a straightforward neural network.

# Atom feature sizes
MAX_ATOMIC_NUM = 100
ATOM_FEATURES = {
'atomic_num': list(range(MAX_ATOMIC_NUM)),
'degree': [0, 1, 2, 3, 4, 5],
'formal_charge': [-1, -2, 1, 2, 0],
'chiral_tag': [0, 1, 2, 3],
'num_Hs': [0, 1, 2, 3, 4],
'hybridization': [
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2
],
}

# Distance feature sizes
PATH_DISTANCE_BINS = list(range(10))
THREE_D_DISTANCE_MAX = 20
THREE_D_DISTANCE_STEP = 1
THREE_D_DISTANCE_BINS = list(range(0, THREE_D_DISTANCE_MAX + 1, THREE_D_DISTANCE_STEP))

# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2
EXTRA_ATOM_FDIM = 0
BOND_FDIM = 14

def get_atom_fdim() -> int:
"""Gets the dimensionality of the atom feature vector."""
return ATOM_FDIM + EXTRA_ATOM_FDIM

def get_bond_fdim() -> int:
"""Gets the dimensionality of the bond feature vector.
"""
return BOND_FDIM + get_atom_fdim()

def onek_encoding_unk(value: int, choices: List[int]) -> List[int]:
encoding =  * (len(choices) + 1)
index = choices.index(value) if value in choices else -1
encoding[index] = 1

return encoding

def atom_features(atom: Chem.rdchem.Atom, functional_groups: List[int] = None) -> List[Union[bool, int, float]]:
"""Builds a feature vector for an atom.
"""
features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \
onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \
onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \
onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
[1 if atom.GetIsAromatic() else 0] + \
[atom.GetMass() * 0.01]  # scaled to about the same range as other features
if functional_groups is not None:
features += functional_groups
return features

def initialize_weights(model: nn.Module) -> None:
"""Initializes the weights of a model in place.
"""
for param in model.parameters():
if param.dim() == 1:
nn.init.constant_(param, 0)
else:
nn.init.xavier_normal_(param)

class MPNEncoder(nn.Module):
def __init__(self, args, atom_fdim, bond_fdim):
super(MPNEncoder, self).__init__()
self.atom_fdim = atom_fdim
self.bond_fdim = bond_fdim
self.hidden_size = args.hidden_size
self.bias = args.bias
self.depth = args.depth
self.dropout = args.dropout
self.layers_per_message = 1
self.undirected = False
self.atom_messages = False
self.device = args.device
self.aggregation = args.aggregation
self.aggregation_norm = args.aggregation_norm

self.dropout_layer = nn.Dropout(p=self.dropout)
self.act_func = nn.ReLU()
self.cached_zero_vector = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False)

# Input
input_dim = self.bond_fdim
self.W_i = nn.Linear(input_dim, self.hidden_size, bias=self.bias)
w_h_input_size = self.hidden_size

# Shared weight matrix across depths (default)
self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias)
self.W_o = nn.Linear(self.atom_fdim + self.hidden_size, self.hidden_size)

def forward(self, mol_graph):
"""Encodes a batch of molecular graphs.
"""
f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components()
f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.to(self.device), f_bonds.to(self.device), a2b.to(self.device), b2a.to(self.device), b2revb.to(self.device)

input = self.W_i(f_bonds)  # num_bonds x hidden_size
message = self.act_func(input)  # num_bonds x hidden_size

# Message passing
for depth in range(self.depth - 1):
# m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
# message      a_message = sum(nei_a_message)      rev_message
nei_a_message = index_select_ND(message, a2b)  # num_atoms x max_num_bonds x hidden
a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
rev_message = message[b2revb]  # num_bonds x hidden
message = a_message[b2a] - rev_message  # num_bonds x hidden

message = self.W_h(message)
message = self.act_func(input + message)  # num_bonds x hidden_size
message = self.dropout_layer(message)  # num_bonds x hidden

a2x = a2b
nei_a_message = index_select_ND(message, a2x)  # num_atoms x max_num_bonds x hidden
a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
a_input = torch.cat([f_atoms, a_message], dim=1)  # num_atoms x (atom_fdim + hidden)
atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden

# Readout
mol_vecs = []
for i, (a_start, a_size) in enumerate(a_scope):
if a_size == 0:
mol_vecs.append(self.cached_zero_vector)
else:
cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
mol_vec = cur_hiddens  # (num_atoms, hidden_size)
if self.aggregation == 'mean':
mol_vec = mol_vec.sum(dim=0) / a_size
elif self.aggregation == 'sum':
mol_vec = mol_vec.sum(dim=0)
elif self.aggregation == 'norm':
mol_vec = mol_vec.sum(dim=0) / self.aggregation_norm
mol_vecs.append(mol_vec)

mol_vecs = torch.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)

return mol_vecs  # num_molecules x hidden

class MPN(nn.Module):
def __init__(self, args, atom_fdim=None, bond_fdim=None):
super(MPN, self).__init__()
self.atom_fdim = atom_fdim or get_atom_fdim()
self.bond_fdim = bond_fdim or get_bond_fdim()
self.device = args.device
self.encoder = MPNEncoder(args, self.atom_fdim, self.bond_fdim)

def forward(self, batch):
"""Encodes a batch of molecules.
"""
if type(batch) != BatchMolGraph:
batch = [mol2graph(b) for b in batch]

encodings = [self.encoder(batch)]
output = reduce(lambda x, y: torch.cat((x, y), dim=1), encodings)
return output

class MoleculeModel(nn.Module):
def __init__(self, args, featurizer=False):
super(MoleculeModel, self).__init__()

self.classification = args.dataset_type == 'classification'
self.featurizer = featurizer

self.output_size = args.num_tasks

if self.classification:
self.sigmoid = nn.Sigmoid()

self.create_encoder(args)
self.create_ffn(args)

initialize_weights(self)

def create_encoder(self, args):
self.encoder = MPN(args)

def create_ffn(self, args):
first_linear_dim = args.hidden_size
dropout = nn.Dropout(args.dropout)
activation = nn.ReLU()

# Create FFN layers
if args.ffn_num_layers == 1:
ffn = [
dropout,
nn.Linear(first_linear_dim, self.output_size)
]
else:
ffn = [
dropout,
nn.Linear(first_linear_dim, args.ffn_hidden_size)
]
for _ in range(args.ffn_num_layers - 2):
ffn.extend([
activation,
dropout,
nn.Linear(args.ffn_hidden_size, args.ffn_hidden_size),
])
ffn.extend([
activation,
dropout,
nn.Linear(args.ffn_hidden_size, self.output_size),
])

# Create FFN model
self.ffn = nn.Sequential(*ffn)

def featurize(self, batch, features_batch=None, atom_descriptors_batch=None):
"""Computes feature vectors of the input by running the model except for the last layer.
"""
return self.ffn[:-1](self.encoder(batch, features_batch, atom_descriptors_batch))

def forward(self, batch):
output = self.ffn(self.encoder(batch))

# Don't apply sigmoid during training b/c using BCEWithLogitsLoss
if self.classification and not self.training:
output = self.sigmoid(output)

return output

model = MoleculeModel(args)
model = model.to(args.device)


As it is shown below, the model is comprised of an encoder and FFN. The encoder has three learned matrices and the FFN has 2 fully-connected layers.

model

MoleculeModel(
(encoder): MPN(
(encoder): MPNEncoder(
(dropout_layer): Dropout(p=0.0, inplace=False)
(act_func): ReLU()
(W_i): Linear(in_features=147, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=433, out_features=300, bias=True)
)
)
(ffn): Sequential(
(0): Dropout(p=0.0, inplace=False)
(1): Linear(in_features=300, out_features=300, bias=True)
(2): ReLU()
(3): Dropout(p=0.0, inplace=False)
(4): Linear(in_features=300, out_features=1, bias=True)
)
)

# Train the MPNN

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Adam, Optimizer

class NoamLR(_LRScheduler):
"""
Noam learning rate scheduler with piecewise linear increase and exponential decay.

The learning rate increases linearly from init_lr to max_lr over the course of
the first warmup_steps (where :code:warmup_steps = warmup_epochs * steps_per_epoch).
Then the learning rate decreases exponentially from :code:max_lr to :code:final_lr over the
course of the remaining :code:total_steps - warmup_steps (where :code:total_steps =
total_epochs * steps_per_epoch). This is roughly based on the learning rate
schedule from Attention is All You Need <https://arxiv.org/abs/1706.03762>_, section 5.3.
"""
def __init__(self,
optimizer: Optimizer,
warmup_epochs: List[Union[float, int]],
total_epochs: List[int],
steps_per_epoch: int,
init_lr: List[float],
max_lr: List[float],
final_lr: List[float]):

assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \
len(max_lr) == len(final_lr)

self.num_lrs = len(optimizer.param_groups)

self.optimizer = optimizer
self.warmup_epochs = np.array(warmup_epochs)
self.total_epochs = np.array(total_epochs)
self.steps_per_epoch = steps_per_epoch
self.init_lr = np.array(init_lr)
self.max_lr = np.array(max_lr)
self.final_lr = np.array(final_lr)

self.current_step = 0
self.lr = init_lr
self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int)
self.total_steps = self.total_epochs * self.steps_per_epoch
self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps

self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps))

super(NoamLR, self).__init__(optimizer)

def get_lr(self) -> List[float]:
return list(self.lr)

def step(self, current_step: int = None):
if current_step is not None:
self.current_step = current_step
else:
self.current_step += 1

for i in range(self.num_lrs):
if self.current_step <= self.warmup_steps[i]:
self.lr[i] = self.init_lr[i] + self.current_step * self.linear_increment[i]
elif self.current_step <= self.total_steps[i]:
self.lr[i] = self.max_lr[i] * (self.exponential_gamma[i] ** (self.current_step - self.warmup_steps[i]))
else:  # theoretically this case should never be reached since training should stop at total_steps
self.lr[i] = self.final_lr[i]

self.optimizer.param_groups[i]['lr'] = self.lr[i]


import threading

def construct_molecule_batch(data):
data = MoleculeDataset(data)
data.batch_graph()  # Forces computation and caching of the BatchMolGraph for the molecules
return data

class MoleculeSampler(Sampler):
def __init__(self, dataset, shuffle=False, seed=0):
super(Sampler, self).__init__()

self.dataset = dataset
self.shuffle = shuffle
self._random = Random(seed)
self.positive_indices = self.negative_indices = None
self.length = len(self.dataset)

def __iter__(self):
indices = list(range(len(self.dataset)))
if self.shuffle:
self._random.shuffle(indices)
return iter(indices)

def __len__(self):
return self.length

class MoleculeDataLoader(DataLoader):
def __init__(self,
dataset: MoleculeDataset,
batch_size: int = 50,
num_workers: int = 8,
shuffle: bool = False,
seed: int = 0):

self._dataset = dataset
self._batch_size = batch_size
self._num_workers = num_workers
self._shuffle = shuffle
self._seed = seed
self._context = None
self._class_balance = False
self._timeout = 0
is_main_thread = threading.current_thread() is threading.main_thread()

if not is_main_thread and self._num_workers > 0:
self._context = 'forkserver'  # In order to prevent a hanging
self._timeout = 3600  # Just for sure that the DataLoader won't hang

self._sampler = MoleculeSampler(
dataset=self._dataset,
shuffle=self._shuffle,
seed=self._seed
)

super(MoleculeDataLoader, self).__init__(
dataset=self._dataset,
batch_size=self._batch_size,
sampler=self._sampler,
num_workers=self._num_workers,
collate_fn=construct_molecule_batch,
multiprocessing_context=self._context,
timeout=self._timeout
)

@property
def targets(self) -> List[List[Optional[float]]]:
if self._class_balance or self._shuffle:
raise ValueError('Cannot safely extract targets when class balance or shuffle are enabled.')

return [self._dataset[index].targets for index in self._sampler]

@property
def iter_size(self) -> int:
return len(self._sampler)

def __iter__(self) -> Iterator[MoleculeDataset]:
return super(MoleculeDataLoader, self).__iter__()

# Create data loaders
train_data_loader = MoleculeDataLoader(
dataset=train_data,
batch_size=args.batch_size,
num_workers=8,
shuffle=True,
seed=args.seed
)
val_data_loader = MoleculeDataLoader(
dataset=val_data,
batch_size=args.batch_size,
num_workers=8
)
test_data_loader = MoleculeDataLoader(
dataset=test_data,
batch_size=args.batch_size,
num_workers=8
)

# optimizer
params = [{'params': model.parameters(), 'lr': args.init_lr, 'weight_decay': 0}]
optimizer = Adam(params)

# scheduler
scheduler = NoamLR(
optimizer=optimizer,
warmup_epochs=[args.warmup_epochs],
total_epochs=[args.epochs] * args.num_lrs,
steps_per_epoch=len(train_data) // args.batch_size,
init_lr=[args.init_lr],
max_lr=[args.max_lr],
final_lr=[args.final_lr]
)

# loss function
loss_func = nn.MSELoss(reduction='none')

# train loop
model.train()
loss_sum = iter_count = 0
n_iter = 0

for batch in tqdm(train_data_loader, total=len(train_data_loader), leave=False):
mol_batch, target_batch = batch.batch_graph(), batch.targets()
mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])

# Run model
model.zero_grad()
preds = model(mol_batch)

# Move tensors to correct device
mask = mask.to(preds.device)
targets = targets.to(preds.device)
class_weights = torch.ones(targets.shape, device=preds.device)

loss = loss_func(preds, targets) * class_weights * mask
loss = loss.sum() / mask.sum()

loss_sum += loss.item()
iter_count += 1

loss.backward()

optimizer.step()

if isinstance(scheduler, NoamLR):
scheduler.step()

n_iter += len(batch)


Training for the first epoch is finished. Let's take a look at how well the model predict the ClogP.

model.eval()

initial_preds = []

for batch in tqdm(val_data_loader, disable=False, leave=False):
# Prepare batch
batch: MoleculeDataset
mol_batch = batch.batch_graph()

# Make predictions
with torch.no_grad():
batch_preds = model(mol_batch)

batch_preds = batch_preds.data.cpu().numpy()

# Collect vectors
batch_preds = batch_preds.tolist()
initial_preds.extend(batch_preds)

from sklearn.metrics import mean_squared_error

# valid_preds and valid_targets have shape (num_tasks, data_size)
targets = val_data_loader.targets
metric_func = mean_squared_error

valid_preds = [[] for _ in range(args.num_tasks)]
valid_targets = [[] for _ in range(args.num_tasks)]
for i in range(args.num_tasks):
for j in range(len(preds)):
if targets[j][i] is not None:  # Skip those without targets
valid_preds[i].append(preds[j][i].detach())
valid_targets[i].append(targets[j][i])

result = metric_func(valid_targets[i], valid_preds[i])
print('MSE:', result)

MSE: 3.0658553721115887

fit = plt.figure(figsize=(4,4))
plt.scatter(valid_targets[i], valid_preds[i])
plt.xlabel('Target ClogP')
plt.ylabel('Predicted ClogP')
plt.show() Almost no correlation as of first epoch. Let's train a few more epochs and see if the prediciton improves.

scheduler = NoamLR(
optimizer=optimizer,
warmup_epochs=[args.warmup_epochs],
total_epochs=[args.epochs] * args.num_lrs,
steps_per_epoch=len(train_data) // args.batch_size,
init_lr=[args.init_lr],
max_lr=[args.max_lr],
final_lr=[args.final_lr]
)

loss_func = nn.MSELoss(reduction='none')

optimizer = Adam(params)

metric_func = mean_squared_error

for epoch in tqdm(range(args.epochs)):

# train
model.train()

loss_sum = iter_count = 0
n_iter = 0

for batch in tqdm(train_data_loader, total=len(train_data_loader), leave=False):
mol_batch, target_batch = batch.batch_graph(), batch.targets()
mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])

# Run model
model.zero_grad()
preds = model(mol_batch)

# Move tensors to correct device
mask = mask.to(preds.device)
targets = targets.to(preds.device)
class_weights = torch.ones(targets.shape, device=preds.device)

loss = loss_func(preds, targets) * class_weights * mask
loss = loss.sum() / mask.sum()

loss_sum += loss.item()
iter_count += 1

loss.backward()

optimizer.step()

if isinstance(scheduler, NoamLR):
scheduler.step()

n_iter += len(batch)

# eval

model.eval()

preds = []

for batch in tqdm(val_data_loader, disable=False, leave=False):
# Prepare batch
batch: MoleculeDataset
mol_batch = batch.batch_graph()

# Make predictions
with torch.no_grad():
batch_preds = model(mol_batch)

batch_preds = batch_preds.data.cpu().numpy()

# Collect vectors
batch_preds = batch_preds.tolist()
preds.extend(batch_preds)

# valid_preds and valid_targets have shape (num_tasks, data_size)
num_tasks = 1
targets = val_data_loader.targets

valid_preds = [[] for _ in range(num_tasks)]
valid_targets = [[] for _ in range(num_tasks)]
for i in range(num_tasks):
for j in range(len(preds)):
if targets[j][i] is not None:  # Skip those without targets
valid_preds[i].append(preds[j][i])
valid_targets[i].append(targets[j][i])

result = metric_func(valid_targets[i], valid_preds[i])
print(epoch, result)

0 0.5059666700564436
1 0.6747574089163861
2 0.29365952162247105
3 0.2443422001168768
4 0.22893787807471938
5 0.24693692581926335
6 0.18068380381421822
7 0.16628313459548472
8 0.15646424430563233
9 0.14714968670153764
10 0.1484021422019028
11 0.1497949169195903
12 0.1332405691696682
13 0.12293908860736892
14 0.12480342381848945
15 0.12231600820223316
16 0.12469534214189118
17 0.12742151396520932
18 0.10947115821963195
19 0.11363330373825292
20 0.10276127977714987
21 0.10135583119599778
22 0.10164295643187546
23 0.09673642613524305
24 0.09645447176601984
25 0.0980465410170695
26 0.09245372955020192
27 0.08906743111716783
28 0.0879800185507161
29 0.09138532813108272


Over the course of training 30 epochs, the MSE quickly improved and reached to less than 0.1 at the end! Let's take a look at the target vs predicted values. So, we got to this level of prediction without ever specifically wrote code for computing ClogP.

fit = plt.figure(figsize=(4,4))
plt.scatter(valid_targets[i], valid_preds[i])
plt.xlabel('Target ClogP')
plt.ylabel('Predicted ClogP')
plt.show() # Conclusion

In this post, I have trained a graph neural network that can predict ClogP property. Within 30 epochs, it was able to predict the property pretty accurately in less than 0.1 MSE. Given only very simple features were used in atom and bond features, it was able to "learn" to predict the property fairly quickly.

Now that we have a trained model, a few things I'd like to try:

• compare this model with other traditional model and compare performance
• try different parameters, such as depth
• try alternative featurization, i.e., add if bond is rotatable in the bond_features and so on.
• add long-range connection;current network is limited to chemical bonds, but longer range interaction may also be important.