Transformers documentation

XLA

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.49.0).
HF中国镜像站's logo
Join the HF中国镜像站 community

and get access to the augmented documentation experience

to get started

XLA

Accelerated Linear Algebra (XLA) is a linear algebra compiler that optimizes model runtime across different hardware and frameworks.

This guide will look specifically at how to accelerate TensorFlow models with XLA.

TensorFlow

XLA can potentially accelerate a TensorFlow model without making any source code changes. It is already packaged with the TensorFlow library, and it is triggered with jit_compile in any graph creating function such as tf.function.

If you’re using Keras methods like fit and predict, enable XLA by passing jit_compile=True to compile.

model.compile(jit_compile=True)

XLA can be used to accelerate any arbitrary tf.function.

Models with a TensorFlow implementation like GPT2, T5, OPT, and Whisper are XLA compatible. The speed up depends on a model, but in general, TensorFlow models in Transformers get a ~100x speed up.

Functions

A typical forward pass in a TensorFlow model is shown below. To run a forward pass with XLA, wrap the model with tf.function and set jit_compile=True.

import tensorflow as tf

model = tf.keras.Sequential(
    [tf.keras.layers.Dense(10, input_shape=(10,), activation="relu"), tf.keras.layers.Dense(5, activation="softmax")]
)
# Generate random inputs for the model.
batch_size = 16
input_vector_dim = 10
random_inputs = tf.random.normal((batch_size, input_vector_dim))

# Run a forward pass.
- _ = model(random_inputs)
+ xla_fn = tf.function(model, jit_compile=True)
+ _ = xla_fn(random_inputs)

The default call function of the model is used to compile the XLA graph. But if there’s any other model function you want to compile with XLA, wrap them with tf.function.

my_xla_fn = tf.function(model.my_xla_fn, jit_compile=True)

Text generation

You could also compile other model functions with XLA. For example, enable XLA for text generation by wrapping generate() with tf.function.

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
# Will error if the minimal version of Transformers is not installed.
from transformers.utils import check_min_version

check_min_version("4.21.0")

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
input_string = ["TensorFlow is"]

xla_generate = tf.function(model.generate, jit_compile=True)

tokenized_input = tokenizer(input_string, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)

decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
"Generated -- TensorFlow is an open-source, open-source, distributed-source application framework for the"

Tracing

When executing an XLA-enabled function for the first time, it tries to infer the computation graph in a process known as tracing. This is a time-consuming step, but any consecutive calls to the function will be much faster because it won’t have to trace the computation graph again.

To ensure a function is only traced once, the inputs must have the same shape as when the graph was built. This usually isn’t an issue for fixed input shapes like images, but it can be an issue for inputs with variable shapes like text.

One way to handle this is to pad your text so it always has the same shape. Configure padding options such as pad_to_multiple_of in the tokenizer.

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
input_string = ["TensorFlow is"]

xla_generate = tf.function(model.generate, jit_compile=True)

# Call tokenizer with padding options.
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")

generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")

In addition to the input shape, any changes to the generation options at any point also triggers tracing.

Resources

Learn more about XLA with the following resources.

< > Update on GitHub