Scaling TF Training to Large Amounts of Data Part 1: Dealing with Large Datasets

Scaling TF Training to Large Amounts of Data Part 1: Dealing with Large Datasets

Introduction

At Reverie Labs, we focus on building state-of-the-art deep learning models to predict important molecular properties. One method we use in building such models is self-supervised pre-training, which we run over tens of millions of molecules.

This requires overcoming two major hurdles: Working around peculiarities in how Tensorflow handles large datasets, and dealing with long training jobs that may need to be paused and resumed. In this blog post, we will address the former issue and discuss building Tensorflow data pipelines that can effectively deal with data at this scale. The latter issue will be addressed in a follow-up post.

Setup

To solidify the concepts that will be discussed in the blog, let’s start with the following setup. Note that the numbers and cloud platforms given below are chosen to exemplify the problem, but the findings apply generally.

  • We have 10 million samples, stored in TFRecord files. Each file contains 1,000 samples, for 10,000 files total. If you are unfamiliar with TFRecord data, check out this example.
  • The data is also hosted on a cloud service such as AWS S3 and is on the order of hundreds of GB.
  • We are training models through a cloud service such as AWS EC2, and each EC2 instance downloads a local copy of the data.

Main Difficulties

Under the above constraints, a couple of issues arise.

  • The data is large enough to cause difficulties if stored on disk on an EC2 instance.
  • Dynamically creating training/validation splits of the data can be computationally inefficient.
  • Tensorflow reads data serially, so properly shuffling the data is not straightforward.

We will examine solutions to each of these problems.

Reading Data Directly From S3

In a typical setting where models are trained using EC2 instances, the TFRecord files are first pulled from S3 cloud storage and downloaded onto the instance. We can then use this data to construct our Tensorflow Dataset.

cloud_path = <S3 URI containing TFRecord files>
local_path = <local path to store TFRecord files>
download_from_s3(cloud_path, local_path) # helper function that downloads s3 files to local
filepaths = [os.path.join(local_path, filename) for filename in os.listdir(local_path)]
tf_dataset = tf.data.TFRecordDataset(filepaths)
tf_dataset = tf_dataset.map(_parse_function)

However, the cost of on-demand storage adds up when the dataset is on the order of hundreds of gigabytes. This can be exacerbated by hyperparameter searches across multiple machines, each one requiring a local copy.

Luckily, Tensorflow supports reading in TFRecord files directly from S3. Doing so is as simple as passing in the cloud paths rather than the local paths.

cloud_path = <S3 URI containing TFRecord files>
fs = s3fs.S3FileSystem()
filepaths = [f"s3://{path}" for path in fs.ls(cloud_path, refresh=True)]
tf_dataset = tf.data.TFRecordDataset(filepaths)
tf_dataset = tf_dataset.map(_parse_function)

However, this comes with a problem: I/O bottleneck. Intuitively, we have to deal with latency overhead when loading the data from S3 (cloud) rather than directly on the EC2 instance (local), which can slow down model training. There’s a few ways to check how the latency is impacting your workflow. Let’s first create a local and cloud version of the dataset, each with 50K examples.

cloud_path = <S3 URI containing TFRecord files>
local_path = <local path to store TFRecord files>
fs = s3fs.S3FileSystem()

# Only grab the first 50K samples
cloud_files = [f"s3://{path}" for path in fs.ls(cloud_path, refresh=True)][:50]
for cloud_file in cloud_files:
    download_from_s3(cloud_file, local_path)
local_files = [os.path.join(local_path, os.path.basename(cloud_file) for cloud_file in cloud_files)]

# Create local dataset
local_data = tf.data.TFRecordDataset(local_files)
local_data = local_data.map(_parse_function)

# Create cloud version
cloud_data = tf.data.TFRecordDataset(cloud_files)
cloud_data = cloud_data.map(_parse_function)

We can benchmark latency in the following ways

  • Benchmark dataset loading: An easy way to check the throughput of your data pipeline is by timing the amount of time it takes to iterate through the dataset. For example, you can use the tqdm module to do this. The below code snippet will print out throughput as well as overall time it takes to iterate through each dataset. You can then compare these metrics between cloud and local copies. Of course, more sophisticated ways to time this exist, but this is a quick way to get started. The goal would be to get the throughput of the cloud dataset to match that of the local dataset. More on that later.
for data in [local_data, cloud_data]:
    for sample in tqdm(iter(data)):
        pass
  • However, it’s not strictly necessary to have the throughputs match. What is important is that during training, we are not bottlenecked by I/O. Therefore, it’s also useful to measure how training speed decreases when switching to cloud. Again, the goal would be to get the training speed on the cloud dataset to match that of the local version.
for data in [local_data, cloud_data]:
    data = data.shuffle(10000).batch(32)
    model = # Code to initialize a Keras Model
    model.fit(data, epochs=1, verbose=True)

There are multiple techniques that can be used to reduce I/O bottlenecks. The first is to set the num_parallel_reads argument in the TFRecordDataset constructor. This will read multiple TFRecord files simultaneously through multiprocessing, interleaving results. Note that this will affect the ordering of the dataset. Adding a prefetch step at the end of the pipeline can also help. All in all, we can set up the pipeline as follows to efficiently read files from cloud.

cloud_path = <S3 URI containing TFRecord files>
fs = s3fs.S3FileSystem()
filepaths = [f"s3://{path}" for path in fs.ls(cloud_path, refresh=True)]
tf_dataset = tf.data.TFRecordDataset(filepaths, num_parallel_reads=4)
tf_dataset = tf_dataset.map(_parse_function)
tf_dataset = tf_dataset.shuffle(10000).batch(32)
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)

The above may not entirely alleviate I/O bottlenecks during training. The end result will differ from a case-by-case basis, and it’s worth assessing the tradeoff between the increase in training time and decrease in disk space cost. In our case, we observed little change in training time when loading directly from S3.

Split Data Earlier Rather than Later

Oftentimes, it is convenient to first load in all of the data under one TFRecordDataset, and then dynamically split the dataset afterwards. This way, we can choose the proportion of data in the train/validation sets on the fly, as well as decide the dataset splitting method. For example, when splitting up molecules into train/validation sets, we can split it randomly, or we can split such that the validation set contains a particular subset of substructures not seen in training. However, doing this naively can cause wasted computations.

This can be illustrated with the following example:

tf_dataset = tf.data.TFRecordDataset(filepaths)
tf_dataset = tf_dataset.map(_parse_function)
train_inds, val_inds = scaffold_split(tf_dataset) # Dataset splitting method
tf_dataset = tf_dataset.enumerate()

# Create new train TFDataset with only data from train_inds
train_data = tf_dataset.filter(lambda i, data: tf.math.reduce_any(i == train_inds))
train_data = train_data.map(lambda i, data: data)

# Repeat for val_inds
val_data = tf_dataset.filter(lambda i, data: tf.math.reduce_any(i == val_inds))
val_data = val_data.map(lambda i, data: data)

Let’s say we split the data such that 20% of the data goes into our validation set. The current dataset pipeline for val_data looks like this:

At first glance: This is what we want! If we call val_data.as_numpy_iterator().next(), we will get data2. However, to get to that datapoint, we actually load/parse/index data0 then ignore it since it doesn’t pass the filter. The same is repeated for data1. Only once we reach data2 do we pass the filtering step and return that datapoint.

This is inefficient behavior. Even though we’re only yielding a fraction of the entire dataset (20%), we are actually loading and parsing the entire dataset, which are the computationally expensive parts of this pipeline! This problem is exacerbated in large-scale pre-training, where oftentimes it is viable for the validation split to contain only 1% or even 0.1% of the data.

There are two solutions: The first is to put the filtering step at the very beginning of the pipeline before the parsing step, so as to catch unwanted samples early and minimize the amount of wasted operations. This still gives us a light amount of wasted operations, but we maintain flexibility in dataset splitting. This gives us the following code:

tf_dataset = tf.data.TFRecordDataset(filepaths)
train_inds, val_inds = scaffold_split(tf_dataset) # Dataset splitting method
tf_dataset = tf_dataset.enumerate()

# Create new train TFDataset with only data from train_inds
train_data = tf_dataset.filter(lambda i, data: tf.math.reduce_any(i == train_inds))
train_data = train_data.map(lambda i, data: data)
train_data = train_data.map(_parse_function)


# Repeat for val_inds
val_data = tf_dataset.filter(lambda i, data: tf.math.reduce_any(i == val_inds))
val_data = val_data.map(lambda i, data: data)
val_data = val_data.map(_parse_function)

The second solution is to bypass the problem altogether and split the train and validation sets into distinct TFRecord files BEFORE loading them in as TFRecordDatasets. This removes all redundant operations, but we also lose the ability to flexibly split our data. However, for large-scale pre-training, this flexibility is not as important, and can be worth sacrificing for computational efficiency. Therefore, our dataset loading code now looks like the following:

# Load in training data
train_data = tf.data.TFRecordDataset(train_filepaths, num_parallel_reads=4)
train_data = train_data.map(_parse_function)
# Load in validation data
val_data = tf.data.TFRecordDataset(val_filepaths, num_parallel_reads=4)
val_data = val_data.map(_parse_function)

If we’re using a split of 99%/1%, then we can reserve 100 out of the 10,000 TFRecord files for the validation split. Assuming that the data was shuffled before being written out to these files, reserving these files is the same as using a random train/val split method.

Shuffle the Files as well as the Data

The last difficulty that will be addressed revolves around dataset shuffling. Tensorflow does shuffling through the tf.data.Dataset.shuffle function (see doc here). Because Tensorflow accesses data serially, perfect shuffling can only be obtained when the buffer size is greater than or equal to the number of samples. This is not tractable when we have such a large dataset. Currently, shuffling for the training dataset looks like this:

# Load in training data
train_data = tf.data.TFRecordDataset(train_filepaths, num_parallel_reads=4)
train_data = train_data.map(_parse_function)
train_data = train_data.shuffle(buffer_size)

In the large-scale pre-training regime, where num_samples >> buffer_size, the relative ordering between two samples found on different TFRecord files will not change between epochs in most cases. This example is illustrated below, where no sample in File 3 will ever come before File 0 since the buffer can never overlap both files simultaneously.

The solution around this is to add an additional shuffle step to the TFRecord files. Since the number of TFRecord files is small compared to the total number of samples, and file names take up little space in memory, we can fit the entire list of TFRecord file names in memory.

# Shuffle filenames
train_files = tf.data.Dataset.from_tensor_slices(filenames)
train_files = train_data.shuffle(len(filenames))
# Load in data
train_data = tf.data.TFRecordDataset(train_files, num_parallel_reads=4)
train_data = train_data.map(_parse_function)
train_data = train_data.shuffle(buffer_size)

This works because TFRecordDataset can take in either a list of file names (which is what we had been doing), or a TFDataset that outputs file names (which is what is done here). By initializing a TFDataset using the file names and calling the .shuffle operation, we ensure that the file name order will re-shuffle after each pass through the dataset by default.

As we can see in the figure below, the relative ordering of samples in different files can now change between epochs.

Although this still does not give us perfect shuffling (samples within a file are more likely to occur near one another), the result is much closer than before!

Further Thoughts

Although dealing with large datasets for large-scale pre-training can be cumbersome, Tensorflow provides the tools necessary to get around such hurdles. Before diving into creating your own dataset pipelines, it is useful to understand the various tools at your disposal from the API. Additionally, there is a guide on how to optimize the performance of your dataset within the Tensorflow website as well, and include techniques that we at Reverie Labs use on a daily basis.

We are Hiring!

If this type of work excites you, check out our careers page! Our team includes a mix of industry-experienced engineering and biotech professionals. We're actively hiring engineers across our tech stack, including Machine Learning Engineers, Senior Data Scientists, and Full Stack Engineers to work on exciting challenges critical to our approach to developing life-saving cancer drugs. You will work with a YC-backed team that is growing in size and scope. You can read more about us at www.reverielabs.com, and please reach out if you're interested in learning more.