Training Transformers for Practical Drug Discovery with Tensor2Tensor
We're releasing a Colab notebook for training Transformer networks on a wide range of drug discovery tasks using Tensor2Tensor.
The introduction of attention-based language models like BERT and GPT-2 has transformed the field of machine learning overnight. Yet, despite the overwhelming success of these models in NLP, adoption of the Transformer architecture within cheminformatics has so far been limited to sequence-to-sequence based tasks like reaction prediction – just one class of tasks in the rapidly-growing problem space of molecular machine learning.
Though the Transformer can learn effective representations of heterogeneous data, it's also a fairly complex model that requires non-trivial infrastructure to configure and train. That's why we're releasing a Colab notebook containing a premade pipeline for training Transformers on SMILES datasets. Our code is built on Tensor2Tensor, a robust framework developed by Google that spawned much of the original wave of Transformer research. By leveraging Tensor2Tensor's flexible dataset and model infrastructure, we're able to train the Transformer directly on a broad set of regression- and classification-based property prediction problems (e.g., ADME, potency, solubility, etc.) that are highly relevant to drug discovery. We're also releasing a Python-based SMILES tokenizer that we developed internally at Reverie for the purpose of training NLP-style models on SMILES.
This Colab is designed to take you end-to-end with the Transformer on a toy problem from the QM9 dataset. We'll walk through everything from data preprocessing and tokenization, to training a Transformer model, to decoding predictions on new SMILES strings. Our hope is that this work will lower the activation energy for working with the Transformer and catalyze further research using this powerful model.
Transformer Colab Notebook
To run the code in this preview interactively, open the notebook in Colab.
What is a Transformer?
The Transformer is a neural network architecture that uses attention layers as its primary building block. When the original Transformer paper, Attention is All You Need (Vaswani et al., 2017), first came out, many researchers were surprised that a model that used no convolution or recurrence outperformed Google's existing seq2seq neural machine translation models, which relied heavily on recurrent layers.
Since then, Transformer architectures have become a staple of machine learning, especially in NLP. Currently, attention-based models like Google's BERT and OpenAI's GPT-2 perform at state-of-the-art on most NLP benchmarks. Variants of the Transformer architecture have also been developed for images, graphs, and other modalities.
The Illustrated Transformer provides a nice high-level introduction to core Transformer concepts like self-attention and positional encoding. For a more technical walkthrough, the Harvard NLP group has put together The Annotated Transformer, which provides a section-by-section implementation of the paper in PyTorch.
What is Tensor2Tensor?
Tensor2Tensor is a deep learning framework developed by Google that spawned much of the original wave of Transformer research. From the Tensor2Tensor README:
Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research. T2T was developed by researchers and engineers in the Google Brain team and a community of users.
One great feature of Tensor2Tensor is that it's designed to easily integrate new ML tasks from the open-source community via subclassing of its flexible
Problem class (discussed in detail in the Colab). In fact, one of the key subclasses we use in the Colab,
Text2RegressionProblem, was originally written by us at Reverie to facilitate training models on molecular regression problems, and is now part of the T2T codebase.
Why aren't you using Trax?
Tensor2Tensor is now in what Google calls "maintenance mode" as the Brain team shifts to a more streamlined successor called Trax. Though Trax has been developing steadily, it was in its infancy when we started this project last year, and continues to be in an experimental state. In the future, we may port our SMILES Transformer implementation over to Trax. In the meantime, T2T continues to be a mature and feature-complete framework that works quite well for our application.
Why train on SMILES? Aren't they a bad representation of molecules?
In short — yes. SMILES are a poor representation of molecules, especially compared to graph-based representations or more sophisticated featurizations (e.g., an ensemble of 3D conformers). However, SMILES are also cheap, plentiful, and memory-efficient, which make them easy to work with when training a model like the Transformer on millions of molecules. Having next-to-zero featurization cost also carries practical benefits at inference time, making it possible to use models like the Transformer to filter huge libraries for particular properties of interest.
On a more theoretical level, even though SMILES don't carry as much information as a full molecular graph, it's an open question whether that deficit actually matters for powerful ML architectures. There's actually a direct analog to NLP here: even though you could argue that language is inherently tree-structured, we can still obtain useful language models by training on text directly without explicitly passing in parse trees as inputs.
Interestingly, a lot of recent research in NLP has focused on investigating the structure of learned representations in large-scale language models. Work in the so-called field of "BERTology" has found that these models implicitly recapitulate elements of traditional pre-processing pipelines and that the internal representations of these models may mimic classical, tree-like structures. In a similar vein, it's possible that a Transformer trained on enough SMILES may learn to reconstruct latent graph representations of molecules — though to thoroughly investigate that hypothesis will take time and research funding!
Does this code perform at state-of-the-art on QM9?
In the Colab, we use QM9 as a toy problem to showcase the mechanics of the training and evaluation pipeline. The goal of this work isn't to beat SOTA on any particular benchmark using the Transformer, but rather to highlight how you might go about training a model to do so using only SMILES. That said, graphs provide an incredibly useful inductive bias for molecular ML tasks, so it's not surprising that the top models on QM9 tend to use graph featurizations. Recent work on Graph Transformer Networks may provide the best of both worlds by uniting the Transformer architecture with graph featurizations.
How big of a Transformer do I need to get good performance?
Transformer models get their competitive advantage from sheer scale, both of the model and also of the data. However, unlike in NLP, where hundreds of GBs of text are readily available for training models, good chemistry data is comparatively scarce.
That said, we've found that scaling up does lead to better task performance. Currently, the largest Transformer that we've trained at Reverie has 44M parameters, and was trained on 77M+ unique SMILES using an adapted version of the token-masking technique from BERT. While these results were obtained on internal data, in a future post, we may go into more detail about our own experiments and results with Transformers on practical drug discovery tasks.
Are there any other SMILES-based Transformer implementations available?
Yes! The Molecular Transformer from Philippe Schwaller and colleagues is a similar project with a public codebase. Their model is based on OpenNMT-py, a machine translation framework written in PyTorch. While the underlying Transformer architecture is similar, their pipeline is specifically designed to do reaction prediction as a sequence-to-sequence translation task. Since one of the goals of our project was to open the Transformer up to a broader range of regression- and classification-based property prediction tasks, we built our implementation on Tensor2Tensor, which offers greater flexibility in terms of downstream task applications.
Where can I learn more about Reverie Labs?
You can read more about us at reverielabs.com. Please reach out if you’re interested in learning more! And stay tuned for more Chemistry + ML content on our engineering blog: blog.reverielabs.com.