Training and Inference of LLMs with PyTorch Fully Sharded Data Parallel and Better Transformer

In this blog we show perform efficient and optimized distributed training and inference of large language models using PyTorch's Fully Sharded Data Parallel and Better Transformer implementations, on the Spark platform.

In this implementation, we combine Microsoft Fabric for data preparation and model inference, and Azure Databricks for model training, having all our data under Microsoft Fabric's OneLake.

The code for this blog is available at this GitHub repository, as a series of PySpark notebooks for Microsoft Fabric and Azure Databricks.


With the popularization of large pre-trained language models and abundance of business data, many enterprises are looking into customizing those models with their own data and use them as part of their analytics workflows and applications.

In many instances, that customization implies fine-tuning a pre-trained language model in an efficient way, minimizing training time and computational resources, and then using the fine-tuned model, also in an efficient way, as part of a data enrichment process, an analytics workflow, or a business intelligence application.

Here we show a recipe for implementing that workflow using PyTorch's recent optimizations for model training and inference. Specifically, we show train PyTorch models at scale using the Fully Sharded Data Parallel approach, and run model inference at scale using the Better Transformer optimizations, both on the Apache Spark environment.

We also show here how an organization can implement that workflow, by splitting the work between Microsoft Fabric and Azure Databricks, benefiting from both platforms.

Microsoft Fabric is the all-in-one analytics solution for enterprises seeking everything from data movement to data engineering, data science, real-time analytics, data warehousing, and business intelligence in an easy-to-use unified platform.

Azure Databricks is an Apache Spark-centric analytics platform, with strong capabilities for data science and workloads.

By implementing data preparation and model inference on Microsoft Fabric, organizations can benefit from the simplicity of a unified data analytics platform, using familiar languages, tools, and file formats in a single platform, and having their data managed in a single place, minimizing the need for data movement and different data formats across different platforms and tools.

By implementing model training on Azure Databricks, organizations can benefit from having access to modern GPU-based infrastructure to train large language models at scale, also using the same familiar languages and file formats, and seamlessly accessing the data layer managed by Microsoft Fabric.

In Figure 1 we see a high-level conceptual view of how that workflow can be implemented. We highlight Microsoft Fabric's OneLake as the central data layer, seamlessly accessed by both Microsoft Fabric and Azure Databricks. We also highlight the usage of open standards for data processing and model training and inference, using PySpark, PyTorch, and Parquet files.


Figure 1: high-level conceptual view of the proposed workflow

The Use Case

To exemplify a typical use case, we consider the task of fine tuning a pre-trained Transformer model from Hugging Face, for text classification, and then use the fine-tuned model to perform batch inference.

We want to do this in an efficient manner, optimizing computational resource consumption for both model training and inference. For that, we use PyTorch's Fully Sharded Data Parallel and Better Transformer, respectively.

The dataset we use is the Rotten Tomatoes movie review dataset. It is a simple dataset with only two columns: text, which is the movie review, and label, which is either 1 (positive review) or 0 (negative review).

We will get the pre-trained model and fine tune it for text classification using that dataset. In this way, we will leverage the knowledge the pre-trained model already has about natural language and augment it with the specific knowledge about how to classify movie reviews into positive and negative reviews. Once we have the fine-tuned model, we can use it to classify new reviews.

Notice that this is a toy task, with a relatively small model and dataset. The goal is not the task itself, but rather to show how the technologies presented here work, what benefits we can get from using them, and make the code to be easily reproducible. Nevertheless, the same building blocks used here can be applied to real-world tasks on a much larger scale.

Data Preparation on Spark

Before we start fine tuning the model, we need to extract the numeric features from the text, which are used as inputs to the model. For Hugging Face models this is facilitated by the Transformers library using its Tokenizer class.

This is an embarrassingly parallel task, which we can perform on Spark through Pandas UDF (User Defined Functions) over Spark DataFrames.

Please refer to data_preparation.ipynb for implementation details.

Model Training with PyTorch's Fully Sharded Data Parallel on Spark

The model training is performed through PyTorch's distributed training on Spark, using PySpark's TorchDistributor on Azure Databricks.

We also optimize the model training with PyTorch's Fully Sharded Data Parallel (FSDP). FSDP provides several benefits for model training, resulting in faster training with quicker and better convergence, and optimized GPU memory utilization, allowing us to work with larger models and larger batch sizes.

In our sample code we noticed a speedup of 3.6X when using FSDP, compared to PyTorch's Distributed Data Parallel (DDP), and we were able to double the batch size for training.

Please refer to model_training_fsdp.ipynb for implementation details.

Batch Inference with PyTorch's Better Transformer on Spark

Like the data preparation task, batch inference is also an embarrassingly parallel task, which we can perform on Spark through Pandas UDF over Spark DataFrames.

Here we optimize this task using PyTorch's Better Transformer. It optimizes the execution of certain operations in the model building blocks implemented in PyTorch, such as TransformerEncoder, TransformerEncoderLayer, and MultiHeadAttention, making the overall model inference process faster on both CPUs and GPUs. We use the Better Transformer implementation which is available through the integration with the Hugging Face Optimum library.

In our sample code we observed a speedup of 6X when using Better Transformer, compared to the standard model implementation.

Please refer to model_inference_optim.ipynb for implementation details.

In Conclusion

In this post we show how to use PyTorch's FSDP and Better Transformer, on Spark clusters, to accelerate and optimize model training and inference.

We also show how easy it is to combine processing on Microsoft Fabric and Azure Databricks, having all data stored in a single location, and using the same open standards for data processing, model training, and inference.


This article was originally published by Microsoft's AI - Customer Engineering Team Blog. You can find the original article here.