Showing posts with label MUTAG. Show all posts
Showing posts with label MUTAG. Show all posts

Graph Classification using Graph Neural Networks

Graph classification focuses on assigning labels or categories to entire graphs or networks. Unlike traditional classification tasks that deal with individual data instances, graph classification considers the entire graph structure, including nodes, edges and their properties. A graph classifier uses a mapping function that can accurately predict the class or label of an unseen graph based on its structural properties. The mapping function is learned during training using supervised learning.

Why Do We Need Graph Classification?

The importance of graph classification lies in graph data being ubiquitous in today's interconnected world. Graph based methods including graph classification have emerged as methodology of  choice in numerous applications across various domains, including:

1. Bioinformatics: Classifying protein-protein interaction networks or gene regulatory networks can provide insights into disease mechanisms and aid in drug discovery. In fact, the most well-known success story of graph neural networks is the discovery of antibiotics to treat drug-resistant diseases, widely reported in early 2020.

2. Social Network Analysis: Categorizing social networks can help identify communities, detect anomalies (e.g., fake accounts), and understand information diffusion patterns.

3. Cybersecurity: Classifying computer networks can assist in detecting malicious activities, identifying vulnerabilities, and preventing cyber attacks.

4. Chemistry: Classifying molecular graphs can aid in predicting chemical properties, synthesizing new compounds, and understanding chemical reactions.

How Do We Build a Graph Classifier? 

There are two main approaches that we can use to build graph classifiers: kernel-based methods and neural network-based methods.

1. Kernel-based Methods:

These methods rely on defining similarity measures (kernels) between pairs of graphs, which capture their structural and topological properties. Popular kernel-based methods include the random walk kernel, the shortest-path kernel, and the Weisfeiler-Lehman kernel. Once the kernel is defined, traditional kernel-based machine learning algorithms, such as Support Vector Machines (SVMs), can be applied for classification.

2. Neural Network- based Methods:

These methods typically involve learning low-dimensional representations (embeddings) of the graphs through specialized neural network architectures, such as Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs). The learned embeddings capture the structural information of the graphs and can be used as input to standard classifiers, like feed-forward neural networks or logistic regression models. For details on GCNs and node embeddings, please visit my earlier post on graph convolutional networks

Both kernel-based and neural network-based methods have their strengths and weaknesses, and the choice depends on factors such as the size and complexity of the graphs, the availability of labeled data, and computational resources. Given that graph neural networks are getting more mileage, we will complete this blog post by going over steps needed for building a GNN classifier.

Steps for Building a GNN Classifier


We are going to use the MUTAG dataset which is part of the TUDatasets, an extensive collection of graph datasets and easily accessible in PyTorch Geometric library for building graph applications. The MUTAG dataset is a small dataset of 188 graphs representing two classes of graphs. Each graph node is characterized by seven features. Two of the example graphs from this dataset are shown below.

Two example graphs from MUTAG dataset

We will use 150 graphs for training and the remaining 38 for testing. The division into the training and test sets is done using the available utilities in the PyTorch Geometric. 


Due to the smaller graph sizes in the dataset, mini-batching of graphs is a desirable step in graph classification for better utilization of GPU resources. The mini-batching is done by diagonally stacking adjacency matrices of the graphs in a batch to create a giant graph that acts as an input to the GNN for learning. The node features of the graphs are concatenated to form the corresponding giant node feature vector. The idea of mini-batching is illustrated below.

Illustration of mini-batching

Graph Classifier

We are going to use a three-stage classifier for this task. The first stage will consists of generating node embeddings using a message-passing graph convolutional network (GCN). The second stage is an embeddings aggregation stage. This stage is also known as the readout layer. The function of this stage is to aggregate node embeddings into a single vector. We will simply take the average of node embeddings to create readout vectors. PyTorch Geometric has a built-in function for this purpose operating at the mini-batch level that we will use. The final stage is the actual classifier that looks at mapped/readout vectors to learn classification rule. In the present case, we will simply use a linear thresholding classifier to perform binary classification. We need to specify a suitable loss function so that the network can be trained to learn the proper weights.

A complete implementation of all of the steps described above is available at this Colab notebook. The results show that training a GCN with three convolutional layers results in test accuracy of 76% in just a few epochs.

Takeaways Before You Leave

Graph Neural Networks (GNNs) offer ways to perform various graph-related prediction/classification tasks. We can use GNNs make predictions about nodes, for example designate a node as a friend or fraud. We can use them to predict whether a link should exist between a pair of nodes or not in social networks to make friends' suggestions. And of course, we can use GNNs to perform classification of entire graphs with applications mentioned earlier.

Another useful library for deep learning with graphs is Deep Graph Library (DGL). You can find a graph classifier implementation using DGL at the following link, if you like.