3. Graph Neural Networks

The biggest difficulty for deep learning with molecules is the choice and computation of “descriptors”. Graph neural networks (GNNs) are a category of deep neural networks whose inputs are graphs. As usual, they are composed of specific layers that input a graph and those layers are what we’re interested in. You can find reviews of GNNs in Dwivedi et al.[DJL+20], Bronstein et al.[BBL+17], and Wu et al.[WPC+20]. GNNs can be used for everything from coarse-grained molecular dynamics [LWC+20] to predicting NMR chemical shifts [YCW20] to modeling dynamics of solids [XFLW+19]. Before we dive too deep into them, we must first understand how a graph is represented and how molecules are converted into graphs.

3.1. Representing a Graph

A graph \(\mathbf{G}\) is a set of nodes \(\mathbf{V}\) and edges \(\mathbf{E}\). In our setting, node \(i\) is defined by a vector \(\vec{v}_i\), so that the set of nodes can be written as a rank 2 tensor. The edges can be represented as an adjacency matrix \(\mathbf{E}\), where if \(e_{ij} = 1\) then nodes \(i\) and \(j\) are connected by an edge. In many fields, graphs are often immediately simplified to be directed and acyclic, which simplifies things. Molecules are instead undirected and have cycles (rings). Thus, our adjacency matrices are always symmetric \(e_{ij} = e_{ji}\). Often our edges themselves have features, so that \(e_{ij}\) is itself a vector. Then the adjacency matrix becomes a rank 3 tensor. Examples of edge features might be covalent bond order or distance between two nodes.

../_images/methanol.jpg

Fig. 3.2 Methanol with atoms numbered so that we can convert it to a graph.

Let’s see how a graph can be constructed from a molecule. Consider methanol, shown in Figure Fig. 3.2. I’ve numbered the atoms so that we have an order for defining the nodes/edges. First, the node features. You can use anything for node features, but often we’ll begin with one-hot encoded feature vectors:

Node

C

H

O

1

0

1

0

2

0

1

0

3

0

1

0

4

1

0

0

5

0

0

1

6

0

1

0

\(\mathbf{V}\) will be the combined feature vectors of these nodes. The adjacency matrix \(\mathbf{E}\) will look like:

1

2

3

4

5

6

1

0

0

0

1

0

0

2

0

0

0

1

0

0

3

0

0

0

1

0

0

4

1

1

1

0

1

0

5

0

0

0

1

0

1

6

0

0

0

0

1

0

Take a moment to understand these two. For example, notice that rows 1, 2, and 3 only have the 4th column as non-zero. That’s because atoms 1-3 are bonded only to carbon (atom 4). Also, the diagonal is always 0 because atoms cannot be bonded with themselves.

You can find a similar process for converting crystals into graphs in Xie et al. [XG18]. We’ll now begin with a function which can convert a smiles string into this representation.

3.2. Running This Notebook

Click the    above to launch this page as an interactive Google Colab. See details below on installing packages, either on your own environment or on Google Colab

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import numpy as np
import tensorflow as tf
import warnings
import pandas as pd
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw
import networkx as nx
warnings.filterwarnings('ignore')
sns.set_context('notebook')
sns.set_style('dark',  {'xtick.bottom':True, 'ytick.left':True, 'xtick.color': '#666666', 'ytick.color': '#666666',
                        'axes.edgecolor': '#666666', 'axes.linewidth':     0.8 , 'figure.dpi': 300})
color_cycle = ['#1BBC9B', '#F06060', '#5C4B51', '#F3B562', '#6e5687']
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=color_cycle) 
soldata = pd.read_csv('https://dataverse.harvard.edu/api/access/datafile/3407241?format=original&gbrecs=true')
np.random.seed(0)
my_elements = {6: 'C', 8: 'O', 1: 'H'}

The hidden cell below defines our function smiles2graph. This creates one-hot node feature vectors for the element C, H, and O. It also creates an adjacency tensor with one-hot bond order being the feature vector.

def smiles2graph(sml):
    '''Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    '''
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {rdkit.Chem.rdchem.BondType.SINGLE: 1,
                    rdkit.Chem.rdchem.BondType.DOUBLE: 2,
                    rdkit.Chem.rdchem.BondType.TRIPLE: 3,
                    rdkit.Chem.rdchem.BondType.AROMATIC: 4}
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N,len(my_elements)))
    lookup = list(my_elements.keys())
    for i in m.GetAtoms():
        nodes[i.GetIdx(), lookup.index(i.GetAtomicNum())] = 1
    
    adj = np.zeros((N,N,5))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(),j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(),j.GetEndAtomIdx())        
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning('Ignoring bond order' + order)
        adj[u, v, order] = 1        
        adj[v, u, order] = 1        
    return nodes, adj
nodes, adj = smiles2graph('CO')
nodes
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.]])

3.3. A Graph Neural Network

A graph neural network (GNN) is a neural network with two defining attributes:

  1. It’s input is a graph

  2. It’s output is permutation invariant

We can understand clearly the first point. Here, a graph permutation means re-ordering our nodes. In our methanol example above, we could have easily made the carbon be atom 1 instead of atom 4. Our new adjacency matrix would then be:

1

2

3

4

5

6

1

0

1

1

1

1

0

2

1

0

0

0

0

0

3

1

0

0

0

0

0

4

1

0

0

0

1

0

5

1

0

0

0

0

1

6

0

0

0

0

1

0

A GNN is permutation invariant if the output is insensitive to these kind of exchanges. Of course, there may exist GNNs out there which are not permutation invariant, especially if they are for trees where it is possible to deterministically order all nodes. Yet all the GNNs used in chemistry and most of the deep learning work is concerned with GNNs that are permutation invariant.

3.3.1. A simple GNN

We will often mention a GNN when we really mean a layer from a GNN. Most GNNs implement a specific layer that can deal with graphs, and so usually we are only concerned with this layer. Let’s see an example of a simple layer for a GNN:

(3.9)\[\begin{equation} f_k = \sigma\left( \sum_i \sum_jn_{ij}w_{jk} \right) \end{equation}\]

This equation shows that we first multiply every node feature by trainable weights \(w_{jk}\), sum over all node features, and then apply an activation. This will yield a single feature vector for the graph. Is this equation permutation invariant? Yes, because the node index in our expression is index \(i\) which can be re-ordered without affecting the output.

Let’s see an example that is similar, but not permutation invariant:

(3.10)\[\begin{equation} f_k = \sigma\left( \sum_i n_{ij}w_{ik} \right) \end{equation}\]

This is a small change. We have one weight vector per node now. This makes the trainable weights depend on the ordering of the nodes. Then if we swap the node ordering, our weights will no longer align. So if we were to input two methanol molecules, which should have the same output, but we switched two atom numbers, we would get different answers. These simple examples differ from real GNNs in two important ways: (i) they give a single feature vector output, which throws away per-node information, and (ii) they do not use the adjacency matrix. Let’s see a real GNN that has these properties while maintaining permutation invariance.

3.4. Kipf & Welling GCN

One of the first popular GNNs is the Kipf & Welling graph convolutional network (GCN) [KW16]. Although some people consider GCNs to be a broad class of GNNs, we’ll use GCNs to refer specifically the Kipf & Welling GCN. Thomas Kipf has written an excellent article introducing the GCN. I will not repeat this article, so please take a look at it.

The input to a GCN layer is \(\mathbf{V}\), \(\mathbf{E}\) and it outputs an updated \(\mathbf{V}'\). Each node feature vector is updated. The way it updates a node feature vector is by averaging the feature vectors of its neigbhors, as determined by \(\mathbf{E}\). The choice of averaging over neigbhors is what makes a GCN layer permutation invariant. Averaging over neighbors is not trainable, so we must add trainable parameters. We multiply the neighbor features by a trainable matrix before the averaging, which gives the GCN the ability to learn. In Einstein notation, this process is:

(3.11)\[\begin{equation} v_{il} = \sigma\left(\frac{1}{d_i}e_{ij}v_{jk}w_{lk}\right) \end{equation}\]

where \(i\) is the node we’re considering, \(j\) is the neighbor index, \(k\) is the node input feature, \(l\) is the ouput node feature, \(d_i\) is the degree of node i (which makes it an average instead of sum), \(e_{ij}\) isolates neighbors so that all non-neighbor \(v_{jk}\)s are zero, \(\sigma\) is our activation, and \(w_{lk}\) is the trainable weights. This equation is a mouthful, but it truly just is the average over neighbors with a trainable matrix thrown in. One common modification is to make all nodes neighbors of themselves. This is so that the output node features \(v_{il}\) depends on the input features \(v_{ik}\). We do not need to change our equation, just make the adjacency matrix have \(1\)s on the diagonal instead of \(0\) by adding the identity matrix during pre-processing.

Building understanding about the GCN is important for understanding other GNNs. You can view the GCN layer as a way to “communicate” between a node and its neigbhors. The output for node \(i\) will depend only on its immediate neigbhors. For chemistry, this is not satisfactory. So you can stack multiple layers. If you have two layers, then the output for node \(i\) will include information about node \(i\)’s neighbors’ neigbhors. Another important detail to understand in GCNs is that the averaging procedure accomplishes two goals: (i) it gives permutation invariance by removing the effect of neighbor order and (ii) it prevents a change in magnitude in node features. A sum would accomplish (i) but would cause the magnitude of the node features to grow after each layer. Of course, you could ad-hoc put a batch normalization layer after each GCN layer to keep output magnitudes stable but averaging is easy.

../_images/gnn_10_0.png

Fig. 3.3 Intermediate step of the graph convolution layer. The center node is being updated by averaging its neighbors features.

../_images/gcn.gif

Fig. 3.4 Animation of the graph convolution layer. The left is input, right is output node features. Note that two layers are shown (see title change).

To help understand the GCN layer, look at Fig. 3.3. It shows an intermediate step of the GCN layer. Each node feature is represented here as a one-hot encoded vector at input. The animation in Fig. 3.4 shows the averaging process over neighbor features. To make this animation easy to follow, the trainable weights and activation functions are not considered. Note that the animation repeats for a second layer. Watch how the “information” about there being an oxygen atom in the molecule is propogated only after two layers to each atom. All GNNs operate with similair approaches, so try to understand how this animation works.

3.4.1. GCN Implementation

Let’s now create a tensor implementation of the GCN. We’ll skip the activation and trainable weights for now. We must first compute our rank 2 adjacency matrix. The smiles2graph code above computes an adjacency tensor with feature vectors. We can fix that with a simple reduction and add the identity at the same time

nodes, adj = smiles2graph('CO')
adj_mat = np.sum(adj, axis=-1) + np.eye(adj.shape[0])
adj_mat
array([[1., 1., 1., 1., 1., 0.],
       [1., 1., 0., 0., 0., 1.],
       [1., 0., 1., 0., 0., 0.],
       [1., 0., 0., 1., 0., 0.],
       [1., 0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0., 1.]])

To compute degree of each node, we can do another reduction:

degree = np.sum(adj_mat, axis=-1)
degree
array([5., 3., 2., 2., 2., 2.])

Now we can put all these pieces together into the Einstein equation

print(nodes[0])
# note to divide by degree, make the input 1 / degree
new_nodes = np.einsum('i,ij,jk->ik', 1 / degree, adj_mat, nodes)
print(new_nodes[0])
[1. 0. 0.]
[0.2 0.2 0.6]

To now implement this as a layer in Keras, we must put this code above into a new Layer subclass. The code is relatively straightforward, but you can read-up on the function names and Layer class in this tutorial. The three main changes are that we create trainable parameters self.w and use them in the einsum, we use an activation self.activation, and we output both our new node features and the adjacency matrix. The reason to output the adjacency matrix is so that we can stack multiple GCN layers without having to pass the adjacency matrix each time.

class GCNLayer(tf.keras.layers.Layer):
    '''Implementation of GCN as layer'''
    def __init__(self, activation=None,**kwargs):
        # constructor, which just calls super constructor
        # and turns requested activation into a callable function
        super(GCNLayer, self).__init__(**kwargs)
        self.activation = tf.keras.activations.get(activation)
    
    def build(self, input_shape):
        # create trainable weights
        node_shape, adj_shape = input_shape
        self.w = self.add_weight(shape=(node_shape[2], node_shape[2]),
                                name='w')
        
    def call(self, inputs):
        # split input into nodes, adj
        nodes, adj = inputs 
        # compute degree
        degree = tf.reduce_sum(adj, axis=-1)
        # GCN equation
        new_nodes = tf.einsum('bi,bij,bjk,kl->bil', 1 / degree, adj, nodes, self.w)
        out = self.activation(new_nodes)
        return out, adj

We can now try our our layer:

gcnlayer = GCNLayer('relu')
# we insert a batch axis here
gcnlayer((nodes[np.newaxis,...], adj_mat[np.newaxis,...]))
WARNING:tensorflow:Layer gcn_layer is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
(<tf.Tensor: shape=(1, 6, 3), dtype=float32, numpy=
 array([[[0.        , 0.        , 0.03632121],
         [0.        , 0.        , 0.        ],
         [0.        , 0.16119003, 0.        ],
         [0.        , 0.16119003, 0.        ],
         [0.        , 0.16119003, 0.        ],
         [0.        , 0.        , 0.3373425 ]]], dtype=float32)>,
 <tf.Tensor: shape=(1, 6, 6), dtype=float32, numpy=
 array([[[1., 1., 1., 1., 1., 0.],
         [1., 1., 0., 0., 0., 1.],
         [1., 0., 1., 0., 0., 0.],
         [1., 0., 0., 1., 0., 0.],
         [1., 0., 0., 0., 1., 0.],
         [0., 1., 0., 0., 0., 1.]]], dtype=float32)>)

It outputs (1) the new node features and (2) the adjacency matrix. Let’s make sure we can stack these and apply the GCN multiple times

x = (nodes[np.newaxis,...], adj_mat[np.newaxis,...])
for i in range(2):
    x = gcnlayer(x)

It works! Why do we see zeros though? Probably because we had negative numbers that were removed by our ReLU activation. This will be solved by training and increasing our dimension number.

3.5. Solubility Example

We’ll now revisit predicting solubility with GCNs. Remember before that we used the features included with the dataset. Now we can use the molecular structures directly. Our GCN layer outputs node-level features. To predict solubility, we need to get a graph-level feature. We’ll see later how to be more sophisticated in this process, but for now let’s just take the average over all node features after our GCN layers. This is simple, permutation invariant, and gets us from node-level to graph level. Here’s an implementation of this

class GRLayer(tf.keras.layers.Layer):
    '''A GNN layer that computes average over all node features'''
    def __init__(self, name='GRLayer', **kwargs):
        super(GRLayer, self).__init__(name=name, **kwargs)
    
    def call(self, inputs):
        nodes, adj = inputs
        reduction = tf.reduce_mean(nodes, axis=1)
        return reduction
    

To complete our deep solubility predictor, we can add some dense layers and make sure we have a single-output without activation since we’re doing regression. Note this model is defined using the Keras functional API which is necessary when you have multiple inputs.

ninput = tf.keras.Input((None,100,))
ainput = tf.keras.Input((None,None,))
# GCN block
x = GCNLayer('relu')([ninput, ainput])
x = GCNLayer('relu')(x)
x = GCNLayer('relu')(x)
x = GCNLayer('relu')(x)
# reduce to graph features
x = GRLayer()(x)
# standard layers
x = tf.keras.layers.Dense(16, 'tanh')(x)
x = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs=(ninput, ainput), outputs=x)

where does the 100 come from? Well, this dataset has lots of elements so we cannot use our size 3 one-hot encodings because we’ll have more than 3 unique elements. We previously only had C, H and O. This is a good time to update our smiles2graph function to deal with this

def gen_smiles2graph(sml):
    '''Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    '''
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {rdkit.Chem.rdchem.BondType.SINGLE: 1,
                    rdkit.Chem.rdchem.BondType.DOUBLE: 2,
                    rdkit.Chem.rdchem.BondType.TRIPLE: 3,
                    rdkit.Chem.rdchem.BondType.AROMATIC: 4}
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N,100))
    for i in m.GetAtoms():
        nodes[i.GetIdx(), i.GetAtomicNum()] = 1
    
    adj = np.zeros((N,N))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(),j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(),j.GetEndAtomIdx())        
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning('Ignoring bond order' + order)
        adj[u, v] = 1        
        adj[v, u] = 1
    adj += np.eye(N)
    return nodes, adj
nodes, adj = gen_smiles2graph('CO')
model((nodes[np.newaxis,...], adj_mat[np.newaxis,...]))
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-0.0141629]], dtype=float32)>

It outputs one number! That’s always nice to have. Now we need to do some work to get a trainable dataset. Our dataset is a little bit complex because our features are tuples of tensors(\(\mathbf{V}, \mathbf{E}\)) so that our dataset is a tuple of tuples: \(\left((\mathbf{V}, \mathbf{E}), y\right)\). We use a generator, which is just a python function that can return multiple times. Our function returns once for every training example. Then we have to pass it to the from_generator dataset constructor which requires explicit declaration of the shapes of these examples.

def example():
    for i in range(len(soldata)):
        graph = gen_smiles2graph(soldata.SMILES[i])        
        sol = soldata.Solubility[i]
        yield graph, sol
data = tf.data.Dataset.from_generator(example, output_types=((tf.float32, tf.float32), tf.float32), 
                                      output_shapes=((tf.TensorShape([None, 100]), tf.TensorShape([None, None])), tf.TensorShape([])))

Whew, that’s a lot. Now we can do our usual splitting of the dataset.

test_data = data.take(200)
val_data = data.skip(200).take(200)
train_data = data.skip(400)

And finally, time to train.

model.compile('adam', loss='mean_squared_error')
result = model.fit(train_data.batch(1), validation_data=val_data.batch(1),  epochs=20, verbose=0)
plt.plot(result.history['loss'], label='training')
plt.plot(result.history['val_loss'], label='validation')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
../_images/gnn_38_0.png

This model is definitely underfit. One reason is that our batch size is 1. This is a side-effect of making the number of atoms variable and then Keras/tensorflow has trouble batching together our data if there are two unknown dimensions. You can fix this by manually batching or padding all molecules to have as many atoms as the one with the max. In any case, this example shows how to use GCN layers in a complete model.

3.6. Message Passing Viewpoint

One way to more broadly view a GCN layer is that it is a kind of “message-passing” layer. You first compute a message coming from each neighboring node:

(3.12)\[\begin{equation} \vec{e}_{{s_i}j} = \vec{n}_{{s_i}j} \mathbf{W} \end{equation}\]

where \(n_{{s_i}j}\) means the \(j\)th neigbhor of node \(i\). The \(s_i\) means senders to \(i\). This is how a GCN computes the messages, it’s just a weight matrix times each neighbor node features. After getting the messages that will go to \(\vec{e}_{{s_i}j}\), we aggregate them using a function which is permutation invariant to the order of neigbhors:

(3.13)\[\begin{equation} \vec{e}_{i} = \frac{1}{|\vec{e}_{{s_i}j}|}\sum \vec{e}_{{s_i}j} \end{equation}\]

In the GCN this aggregation is just a mean. Finally, we update our node using the aggregated message in the GCN:

(3.14)\[\begin{equation} \vec{n}^{'}_{i} = \sigma(\vec{e}_i) \end{equation}\]

where \(n^{'}\) indicates the new node features. This is simply the activated aggregated message. Writing it out this way, you can see how it is possible to make small changes. One important paper by Gilmer et al. explored some of these choices and described how this general idea of message passing layers does well in learning to predict molecular energies from quantum mechanics [GSR+17]. Examples of changes to the above GCN equations are to include edge information when computing the neighbor messages or use a dense neural network layer in place of \(\sigma\).

3.7. Gated Graph Neural Network

One common variant of the message passing layer is the gated graph neural network (GGN) [LTBZ15]. It replaces the last equation, the node update, with

(3.15)\[\begin{equation} \vec{n}^{'}_{i} = \textrm{GRU}(\vec{n}_i, \vec{e}_i) \end{equation}\]

where the \(\textrm{GRU}(\cdot, \cdot)\) is a gated recurrent unit[CGCB14]. The interesting property of a GRU is that it has trainable parameters, giving the model a bit more flexibility, but the GRU parameters do not change as you stack more layers. A GRU is usually used for modeling sequences of undetermined length, like a sentence. What’s nice about this is that you can stack infinite GGN layers without increasing the number of trainable parameters (assuming you make \(\mathbf{W}\) the same at each layer). Thus GGNs are suited for large graphs, like a large protein or large unit cell.

3.8. Pooling

Within the message passing viewpoint, and in general for GNNS, the way that messages from neighbors are combined is a key step. This is sometimes called pooling, since it’s similar to the pooling layer used in convolutional neural networks. Just like in pooling for convolutional neural networsk, there are multiple reduction operations you can use. Typically you see sum or mean in GNNs, but you can be quite sophisticated like in Graph Isomorphism Networks [XHLJ18]. We’ll see an example in our attention chapter of using self-attention, which is a complicated reduction that could be used for pooling. It can be tempting to focus on this step, but it’s been empirically found that the choice of pooling is not so important[LDLio19,MSK20].

3.9. Battaglia General Equations

As you can see, message passing layers is a general way to view GNN layers. Battaglia et al. [BHB+18] went further and created a general set of equations which captures nearly all GNNs. They broke the GNN layer equations down into 3 update equations, like the node update equation we saw in the message passing layer equations, and 3 aggregation equations (6 total equations). There is a new concept in these equations: graph feature vectors. A graph feature vector is a set of features which represent the whole graph or molecule. For example, when computing solubility it may have been useful to build-up a per-molecule feature vector that is eventually used to compute solubility. Any knid of per-molecule quantity like energy can be expressed as a graph-level feature vector.

The first step in these equations is updating the edge feature vectors, written as \(\vec{e}_k\), which we haven’t seen yet but is certainly possible:

(3.16)\[\begin{equation} \vec{e}^{'}_k = \phi^e\left( \vec{e}_k, \vec{v}_{rk}, \vec{v}_{sk}, \vec{u}\right) \end{equation}\]

where \(\vec{e}_k\) is the feature vector of edge \(k\), \(\vec{v}_{rk}\) is the receiving node feature vector for edge \(k\), \(\vec{v}_{sk}\) is the sending node feature vector for edge \(k\), \(\vec{u}\) is the global graph feature vector, and \(\phi^e\) is one of the three update functions that the define the GNN layer. Note that these are meant to be general expressions and you define \(\phi^e\) for your specific GNN layer. The output edge updates are then aggregated with the first aggregation function:

(3.17)\[\begin{equation} \bar{e}^{'}_i = \rho^{e\rightarrow v}\left( E_i^{'}\right) \end{equation}\]

where \(\rho^{e\rightarrow v}\) is our defined function and \(E_i^{'}\) represents all edges in or out of node i. Having our aggregated edge, which is equivalent to our message previously, we can compute the node update:

(3.18)\[\begin{equation} \vec{v}^{'}_i = \phi^v\left( \bar{e}^{'}_i, \vec{v}_i, \vec{u}\right) \end{equation}\]

This concludes the usual steps of a GNN layer. If you are updating global attributes or aggregating nodes or edges, the following additional steps may be defined:

(3.19)\[\begin{equation} \bar{e}^{'} = \rho^{e\rightarrow u}\left( E^{'}\right) \end{equation}\]

This equation aggregates all messages across the whole graph. Then we can aggregate the new nodes across the whole graph:

(3.20)\[\begin{equation} \bar{v}^{'} = \rho^{v\rightarrow u}\left( V^{'}\right) \end{equation}\]

Then, we can compute the update to the global feature vector as:

(3.21)\[\begin{equation} \vec{u}^{'} = \phi^u\left( \bar{e}^{'},\bar{v}^{'}, \vec{u}\right) \end{equation}\]

3.9.1. Reformulating GCN into Battaglia equations

Let’s see how the GCN is presented in this form. We first compute our neighbor messages for all possible neighbors. Since our graph is undirected, we’ll just by convention use the senders as the “neighbor”

(3.22)\[\begin{equation} \vec{e}^{'}_k = \phi^e\left( \vec{e}_k, \vec{v}_{rk}, \vec{v}_{sk}, \vec{u}\right) = \vec{v}_{sk} \mathbf{W} \end{equation}\]

To aggregate our messages, we average them.

(3.23)\[\begin{equation} \bar{e}^{'}_i = \rho^{e\rightarrow v}\left( E_i^{'}\right) = \frac{1}{|E_i^{'}|}\sum E_i^{'} \end{equation}\]

Our node update is then the activation:

(3.24)\[\begin{equation} \vec{v}^{'}_i = \phi^v\left( \bar{e}^{'}_i, \vec{v}_i, \vec{u}\right) = \sigma(\bar{e}^{'}_i) \end{equation}\]

we could include the self-loop above using \(\sigma(\bar{e}^{'}_i + \vec{v}_i)\). The other functions are not used so that is the complete set.

3.10. Nodes vs Edges

You’ll find that most GNNs use the node-update equation in the Battaglia equations but do not update edges. For example, the GCN will update nodes at each layer but the edges are constant. Some recent work has shown that updating edges can be important for learning when the edges have geometric information, like if the input graph is a molecule and the edges are distance between the atoms [KGrossGunnemann19]. As we’ll see in the chapter on equivariances (Input Data & Equivariances), one of the key properties of neural networks with geometric data (i.e., Cartesian xyz coordinates) is to have rotation equivariance. [KGrossGunnemann19] showed that you can achieve this if you do edge updates and encode the edge vectors using a rotation equivariant basis set with spherical harmonics and Bessel functions. These kind of edge updating GNNs can be used to predict protein structure [JES+20].

3.11. Common Architecture Motifs and Comparisons

We’ve now seen message passing layer GNNs, GCNs, GGNs, and the generalized Battaglia equations. You’ll find common motifs in the architectures, like gating, attention, and pooling strategies. For example, Gated GNNS (GGNs) can be combined with attention pooling to create Gated Attention GNNs (GAANs)[ZSX+18]. GraphSAGE is a similar to a GCN but it samples when pooling, making the neighbor-updates of fixed dimension[HYL17]. So you’ll see the suffix “sage” when you sample over neighbors while pooling. These can all be represented in the Battaglia equations, but you should be aware of these names.

The enormous variety of architectures has led to work on identifying the “best” or most general GNN architecture [DJL+20,EPBM19,SMBGunnemann18]. Unfortunately, the question of which GNN architecture is best is as difficult as “what benchmark problems are best?” Thus there are no agreed upon conclusions on the best architecture. However, those papers are great resources on training, hyperparameters, and reasonable starting guesses and I highly recommend reading them before designing your own GNN. There has been some theoretical work to show that simple architectures, like GCNs, cannot distinguish between certain simple graphs [XHLJ18]. How much this practically matters depends on your data. Ultimately, there is so much variety in hyperparameters, data equivariances, and training decisions that you should think carefully about how much the GNN architecture matters before exploring it with too much depth.

3.12. Do we need graphs?

It is possible to convert a graph into a string under certain constraints on the graph. Molecules specifically can be converted into a string. This means you can use sequence deep learning layers (i.e., transformers) and avoid the complexities of graph neural network. SMILES is one common approach to converting molecular graphs into strings. However, SMILES is not surjective meaning there are strings in the SMILES alphabet. It also not injective, meaning there are multiple valid SMILES string for each molecule. Some of the early work on using SMILES focused on teaching generative models (e.g., VAEs) to learn to make valid SMILES. Then, hilariously, someone realized you could just create a new way of converting molecules into strings that was surjective leading to SELFIES [KHaseN+20]. In SELFIES then you can trivially generate molecules. I don’t think there is a paper on it yet, but there is a “superstition” that working in molecular graphs is more robust because SELFIES and SMILES can vary quite a bit when a molecular graph undergoes a small change (e.g., ring-opening). It is an unresolved question if GNNs are necessary or if we can do everything with SELFIES.

3.13. Relevant Videos

3.13.1. Intro to GNNs

3.13.2. Overview of GNN with Molecule, Compiler Examples

3.14. Chapter Summary

  • Molecules can be represented by graphs by using one-hot encoded feature vectors that show the elemental identity of each node (atom) and an adjacency matrix that show immediate neighbors (bonded atoms).

  • Graph neural networks are a category of deep neural networks that have graphs as inputs.

  • One of the early GNNs is the Kipf & Welling GCN. The input to the GCN is the node feature vector and the adjacency matrix, and returns the updated node feature vector. The GCN is permutation invariant because it averages over the neighbors.

  • A GCN can be viewed as a message-passing layer, in which we have senders and receivers. Messages are computed from neighboring nodes, which when aggregated update that node.

  • A gated graph neural network is a variant of the message passing layer, for which the nodes are updated according to a gated recurrent unit function.

  • The aggregation of messages is sometimes called pooling, to which there are multiple reduction operations.

  • The Battaglia equations encompasses almost all GNNs into a set of 6 update and aggregation equations.

3.15. Cited References

DJL+20(1,2)

Vijay Prakash Dwivedi, Chaitanya K Joshi, Thomas Laurent, Yoshua Bengio, and Xavier Bresson. Benchmarking graph neural networks. arXiv preprint arXiv:2003.00982, 2020.

BBL+17

Michael M Bronstein, Joan Bruna, Yann LeCun, Arthur Szlam, and Pierre Vandergheynst. Geometric deep learning: going beyond euclidean data. IEEE Signal Processing Magazine, 34(4):18–42, 2017.

WPC+20

Zonghan Wu, Shirui Pan, Fengwen Chen, Guodong Long, Chengqi Zhang, and S Yu Philip. A comprehensive survey on graph neural networks. IEEE Transactions on Neural Networks and Learning Systems, 2020.

LWC+20

Zhiheng Li, Geemi P Wellawatte, Maghesree Chakraborty, Heta A Gandhi, Chenliang Xu, and Andrew D White. Graph neural network based coarse-grained mapping prediction. Chemical Science, 11(35):9524–9531, 2020.

YCW20

Ziyue Yang, Maghesree Chakraborty, and Andrew D White. Predicting chemical shifts with graph neural networks. bioRxiv, 2020.

XFLW+19

Tian Xie, Arthur France-Lanord, Yanming Wang, Yang Shao-Horn, and Jeffrey C Grossman. Graph dynamical networks for unsupervised learning of atomic scale dynamics in materials. Nature communications, 10(1):1–9, 2019.

XG18

Tian Xie and Jeffrey C. Grossman. Crystal graph convolutional neural networks for an accurate and interpretable prediction of material properties. Phys. Rev. Lett., 120:145301, Apr 2018. URL: https://link.aps.org/doi/10.1103/PhysRevLett.120.145301, doi:10.1103/PhysRevLett.120.145301.

KW16

Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907, 2016.

GSR+17

Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural message passing for quantum chemistry. arXiv preprint arXiv:1704.01212, 2017.

LTBZ15

Yujia Li, Daniel Tarlow, Marc Brockschmidt, and Richard Zemel. Gated graph sequence neural networks. arXiv preprint arXiv:1511.05493, 2015.

CGCB14

Junyoung Chung, Caglar Gulcehre, KyungHyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555, 2014.

XHLJ18(1,2)

Keyulu Xu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. How powerful are graph neural networks? In International Conference on Learning Representations. 2018.

LDLio19

Enxhell Luzhnica, Ben Day, and Pietro Liò. On graph classification networks, datasets and baselines. arXiv preprint arXiv:1905.04682, 2019.

MSK20

Diego Mesquita, Amauri Souza, and Samuel Kaski. Rethinking pooling in graph neural networks. Advances in Neural Information Processing Systems, 2020.

BHB+18

Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, and others. Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261, 2018.

KGrossGunnemann19(1,2)

Johannes Klicpera, Janek Groß, and Stephan Günnemann. Directional message passing for molecular graphs. In International Conference on Learning Representations. 2019.

JES+20

Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael JL Townshend, and Ron Dror. Learning from protein structure with geometric vector perceptrons. arXiv preprint arXiv:2009.01411, 2020.

ZSX+18

Jiani Zhang, Xingjian Shi, Junyuan Xie, Hao Ma, Irwin King, and Dit-Yan Yeung. Gaan: gated attention networks for learning on large and spatiotemporal graphs. arXiv preprint arXiv:1803.07294, 2018.

HYL17

Will Hamilton, Zhitao Ying, and Jure Leskovec. Inductive representation learning on large graphs. In Advances in neural information processing systems, 1024–1034. 2017.

EPBM19

Federico Errica, Marco Podda, Davide Bacciu, and Alessio Micheli. A fair comparison of graph neural networks for graph classification. In International Conference on Learning Representations. 2019.

SMBGunnemann18

Oleksandr Shchur, Maximilian Mumme, Aleksandar Bojchevski, and Stephan Günnemann. Pitfalls of graph neural network evaluation. arXiv preprint arXiv:1811.05868, 2018.

KHaseN+20

Mario Krenn, Florian Häse, AkshatKumar Nigam, Pascal Friederich, and Alan Aspuru-Guzik. Self-referencing embedded strings (selfies): a 100% robust molecular string representation. Machine Learning: Science and Technology, 1(4):045024, 2020.