PyTorch 2.0 Takes a Leap Forward in Performance and Innovation

A release packed with exciting new features — from torch.compile to improved performance and support for Dynamic Shapes and Distributed systems.

Eduardo Alvarez
9 min readMar 22, 2023
Image by Author

So what is PyTorch 2.0 all about?

PyTorch 2.0 builds on the success of the 1.x series, which has seen numerous iterations and innovations, including the move to the newly formed PyTorch Foundation, part of the Linux Foundation. PyTorch’s strength lies in its first-class Python integration, imperative style, simplicity of the API and options, and, most importantly, a highly collaborative community.

The latest version of PyTorch continues to offer the same eager mode development and user experience while fundamentally changing how PyTorch operates at the compiler level. PyTorch 2.0 provides faster performance and support for Dynamic Shapes and Distributed systems.

One of the significant features of PyTorch 2.0 is torch.compile, a new feature that improves the scalability and performance of DL PyTorch models. torch.compile is a fully additive and optional feature, which means that PyTorch 2.0 is 100% backward compatible by definition. It is built on new technologies, including:

  • TorchDynamo — captures PyTorch programs safely using Python Frame Evaluation Hooks.
  • AOTAutograd — overloads PyTorch’s autograd engine as a tracing autodiff for generating ahead-of-time backward traces.
  • PrimTorch — canonicalizes ~2000+ PyTorch operators down to a closed set of ~250 primitive operators, which developers can target to build a complete PyTorch backend.
  • TorchInductor — a deep-learning compiler that generates fast code for multiple accelerators and backends.

All of these technologies are written in Python and support dynamic shapes, making them flexible and easily hackable and lowering the barrier of entry for developers.

To validate these technologies, PyTorch used a diverse set of 163 open-source models across various machine learning domains, including image classification, object detection, image generation, NLP tasks such as language modeling, Q&A, sequence classification, recommender systems, and reinforcement learning.

Let’s dig a little deeper into the technical updates

As PyTorch models grow more complex, optimizing them for deployment can be challenging. One solution is to use a compiler to translate high-level PyTorch code into lower-level, hardware-specific instructions, resulting in faster and more efficient code.

Over the years, PyTorch has developed several compiler projects to tackle this challenge. In this article, we’ll dive into the PyTorch compiler’s three main parts and explore some of PyTorch’s latest developments in this field.

Figure 1. The PyTorch compilation process — Image Source

Graph Acquisition

Graph acquisition is the process of capturing a PyTorch model’s computation graph. In the past, PyTorch has built several graph acquisition tools like torch.jit.trace, Torchscript, FX tracing, and Lazy Tensors. However, none of these tools provided the right balance between flexibility and speed. TorchScript, for example, required significant code changes and was a non-starter for many PyTorch users.

Earlier this year, PyTorch introduced TorchDynamo, a new approach that uses a CPython feature called the Frame Evaluation API. This approach acquired the graph 99% of the time, correctly and safely, with negligible overhead. It also did not require changes to the original code, providing flexibility and speed in one tool.

Graph Lowering

Graph lowering is the process of converting high-level code into lower-level code, typically in the form of an intermediate representation (IR). PyTorch’s TorchInductor project takes inspiration from how PyTorch users write high-performance custom kernels, increasingly using the Triton language. TorchInductor uses a pythonic define-by-run loop level IR to map PyTorch models automatically into generated Triton code on GPUs and C++/OpenMP on CPUs. It contains only ~50 operators and is implemented in Python, making it easily hackable and extensible.

Graph Compilation

Graph compilation is the process of translating an IR into hardware-specific instructions. For PyTorch 2.0, PyTorch introduced AOTAutograd, a tool that captures user-level code and backpropagation. AOTAutograd leverages PyTorch’s torch_dispatch extensibility mechanism to trace through the Autograd engine, allowing for the capture of the backward pass “ahead of time.” This tool accelerates the forward and backward pass using TorchInductor, resulting in faster and more efficient training.

Stable Primitive Operators

Writing a backend for PyTorch is challenging, given that PyTorch has 1200+ operators and 2000+ if various overloads for each operator are considered. The PrimTorch project is working on defining smaller and more stable operator sets. PyTorch programs can consistently be lowered to these operator sets, which can be used for compilers or exported as-is. PrimTorch aims to define two operator sets: Prim ops with ~250 operators and ATen ops with ~750 canonical operators. These sets are designed to improve developer experience while ensuring efficient code generation.

torch.compile might be your new best friend. . .

Training and inference on large models can be slow due to overhead in the framework. To address this, PyTorch offers a simple function torch.compile that wraps a PyTorch model and returns a compiled version. Let’s briefly explore how to use torch.compile to optimize the performance of PyTorch models.

The code snippet below is a naïve showcasing the use of torch.compile with a resnet50 model.

import torch
import torchvision.models as models

model = models.resnet50()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
compiled_model = torch.compile(model)

x = torch.randn(8, 3, 112, 112)
optimizer.zero_grad()
out = compiled_model(x)
out.sum().backward()
optimizer.step()

The torch.compile function holds a reference to the input model and compiles its forward function to a more optimized version. You can adjust various options for the compiler using the mode, dynamic, fullgraph, and backend arguments.

def torch.compile(model: Callable,
*,
mode: Optional[str] = "default",
dynamic: bool = False,
fullgraph:bool = False,
backend: Union[str, Callable] = "inductor",
# advanced backend options go here as kwargs
**kwargs
) -> torch._dynamo.NNOptimizedModule

The snippet above shows the various arguments used by the function. Let’s unpack some of these arguments:

  • The mode argument specifies what the compiler should optimize while compiling. The default mode tries to compile efficiently without using extra memory or taking too long to compile. Other modes include reduce-overhead, which reduces framework overhead and helps speed up small models, and max-autotune, which tries to generate the fastest code but takes a very long time to compile.
  • The dynamic argument specifies whether to enable the code path for dynamic shapes. Specific compiler optimizations cannot be applied to dynamic shaped programs, so making it explicit whether you want a compiled program with dynamic or static shapes will help the compiler give you better-optimized code.
  • The fullgraph argument is similar to Numba's nopython. It compiles the entire program into a single graph or gives an error explaining why it could not do so. Most users don't need to use this mode, but if you are very performance-conscious, you can try it.
  • The backend argument specifies which compiler backend to use. By default, TorchInductor is used, but a few others are available.

When you run the compiled model for the first time, it compiles the model, so it takes longer to run. Subsequent runs are fast. Depending on your need, you should use a different mode.

What is new for distributed systems?

The PyTorch library offers two distributed training wrappers: DistributedDataParallel (DDP) and FullyShardedDataParallel (FSDP). While both wrappers have been proven effective in compiled mode, FSDP (currently in Beta) allows users to tune which submodules are wrapped and offers more configuration options. However, due to its higher level of system complexity, some compatibility issues may arise with specific models or configurations.

DDP (Figure 2) relies on overlapping AllReduce communications with backward computation and grouping smaller per-layer AllReduce operations into ‘buckets’ for greater efficiency. However, when combined with AOTAutograd functions compiled by TorchDynamo, communication overlap is prevented. This issue can be resolved by compiling separate subgraphs for each ‘bucket’ and allowing communication ops to happen outside and between the subgraphs.

Figure 2. DDP — Image Source

On the other hand, FSDP (Figure 3) allows users to specify an auto_wrap_policy argument to indicate which submodules of their model to wrap together in an FSDP instance used for state sharding. By wrapping each “transformer block” in a separate FSDP instance, only the full state of one transformer block must be materialized at one time. Dynamo will insert graph breaks at the boundary of each FSDP instance, allowing forward (and backward) communication ops to happen outside the graphs and in parallel to computation.

Figure 3. FSDP — Image Source

While FSDP may fall back to operating similarly to DDP without wrapping submodules in separate instances, this configuration has only been tested with TorchDynamo for functionality but not performance.

How can you start transitioning your 1.x code to 2.0?

In order to take advantage of the new Compiled mode feature in PyTorch 2.0, you can optimize your model with just one line of code: model = torch.compile(model). It's worth noting that if your code is working correctly, you shouldn't need to make any migrations.

Although the main benefits of this feature are seen during training, it can also be used for inference if your model runs faster than eager mode.

import torch

def train(model, dataloader):
model = torch.compile(model)
for batch in dataloader:
run_epoch(model, batch)

def infer(model, input):
model = torch.compile(model)
return model(\*\*input)

PyTorch 2.0 operates in eager mode by default, just like PyTorch 1.x. This means that each line of Python is executed sequentially. However, if you want to take advantage of PyTorch 2.0’s new features, you can wrap your model with model = torch.compile(model).

When you do this, your model goes through three necessary steps before execution. First, it undergoes graph acquisition, where it is rewritten as blocks of subgraphs. Next, it goes through graph lowering, where all PyTorch operations are decomposed into their constituent kernels specific to the chosen backend. Finally, the graph is compiled, and the kernels call their corresponding low-level device-specific operations.

Partnering PyTorch 2.0 with Intel Extension for PyTorch

Intel® Extension for PyTorch* 2.0.0-cpu has been released along with PyTorch 2.0 and brings a host of exciting new features and optimizations. One of the highlights of this release is the Fast BERT optimization (Experimental), which uses a new technique from Intel to speed up BERT workloads. This implementation is integrated into Intel® Extension for PyTorch and particularly benefits BERT model training. A new API called ipex.fast_bert has been provided to try out this new optimization. For more detailed information, please refer to the Fast Bert Feature.

Another optimization introduced in this release is the MHA optimization with Flash Attention. The Intel team optimized the MHA module with the Flash Attention technique, inspired by a Stanford paper. This reduces LLM’s memory consumption and provides better inference performance for models like BERT and Stable Diffusion.

Intel® Extension for PyTorch is enabled as a backend of torch.compile, which can leverage this new PyTorch API’s power of graph capture and provide additional optimization based on these graphs. Using this new feature is quite simple, as shown in the code snippet below:

import torch
import intel_extension_for_pytorch as ipex
...
model = ipex.optimize(model)
model = torch.compile(model, backend='ipex')

You don’t have to use this extension to benefit from running on Intel hardware — Intel contributes optimizations to the main stock package on a regular basis. However, the extension offers early access to functionality prior to its inclusion in stock PyTorch releases.

Discussion and Summary

PyTorch 2.0 marks a significant step forward for the platform, delivering faster performance and powerful new features while retaining the simplicity and ease of use that has made it a favorite among developers.

Undoubtedly the feature that will impact most developers is the torch.compile(), as it helps address scalability concerns with rapidly growing model sizes and brings a healthy performance boost out of the box with little to no code changes. More advanced features like FSDP, which is in Beta, will bring deeper customizability to distributed training workloads.

To access an additional level of performance optimizations on intel hardware, don’t forget to try the 2.0.0-cpu release of the Intel Extension for PyTorch, which already supports torch.compile and some language model-specific optimizations.

I hope this gave you a taste of some of the features you can expect in PyTorch 2.0 — this is likely just the tip of the iceberg, and many more exciting additions will be made through minor releases.

One of the most exciting aspects of PyTorch 2.0 is TorchMultiModal — which will be the topic of one of my future articles.

Additional Resources on PyTorch 2.0

--

--

Eduardo Alvarez

AI Performance Optimization Lead @ AMD | Working on Operational AI, Performance Optimization, Scalable Deployments, and Applied ML | ex-Intel Corp.