Transformers documentation
XLA
XLA
Accelerated Linear Algebra (XLA) is a linear algebra compiler that optimizes model runtime across different hardware and frameworks.
This guide will look specifically at how to accelerate TensorFlow models with XLA.
TensorFlow
XLA can potentially accelerate a TensorFlow model without making any source code changes. It is already packaged with the TensorFlow library, and it is triggered with jit_compile
in any graph creating function such as tf.function.
If you’re using Keras methods like fit and predict, enable XLA by passing jit_compile=True
to compile.
model.compile(jit_compile=True)
XLA can be used to accelerate any arbitrary tf.function.
Models with a TensorFlow implementation like GPT2, T5, OPT, and Whisper are XLA compatible. The speed up depends on a model, but in general, TensorFlow models in Transformers get a ~100x speed up.
Functions
A typical forward pass in a TensorFlow model is shown below. To run a forward pass with XLA, wrap the model with tf.function and set jit_compile=True
.
import tensorflow as tf
model = tf.keras.Sequential(
[tf.keras.layers.Dense(10, input_shape=(10,), activation="relu"), tf.keras.layers.Dense(5, activation="softmax")]
)
# Generate random inputs for the model.
batch_size = 16
input_vector_dim = 10
random_inputs = tf.random.normal((batch_size, input_vector_dim))
# Run a forward pass.
- _ = model(random_inputs)
+ xla_fn = tf.function(model, jit_compile=True)
+ _ = xla_fn(random_inputs)
The default call
function of the model is used to compile the XLA graph. But if there’s any other model function you want to compile with XLA, wrap them with tf.function.
my_xla_fn = tf.function(model.my_xla_fn, jit_compile=True)
Text generation
You could also compile other model functions with XLA. For example, enable XLA for text generation by wrapping generate() with tf.function.
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
# Will error if the minimal version of Transformers is not installed.
from transformers.utils import check_min_version
check_min_version("4.21.0")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
input_string = ["TensorFlow is"]
xla_generate = tf.function(model.generate, jit_compile=True)
tokenized_input = tokenizer(input_string, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
"Generated -- TensorFlow is an open-source, open-source, distributed-source application framework for the"
Tracing
When executing an XLA-enabled function for the first time, it tries to infer the computation graph in a process known as tracing. This is a time-consuming step, but any consecutive calls to the function will be much faster because it won’t have to trace the computation graph again.
To ensure a function is only traced once, the inputs must have the same shape as when the graph was built. This usually isn’t an issue for fixed input shapes like images, but it can be an issue for inputs with variable shapes like text.
One way to handle this is to pad your text so it always has the same shape. Configure padding options such as pad_to_multiple_of in the tokenizer.
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
input_string = ["TensorFlow is"]
xla_generate = tf.function(model.generate, jit_compile=True)
# Call tokenizer with padding options.
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
In addition to the input shape, any changes to the generation options at any point also triggers tracing.
Resources
Learn more about XLA with the following resources.
- A notebook demonstrating XLA-compatible encoder-decoder and decoder-only text generation models.
- The Faster Text Generation with TensorFlow and XLA blog post compares benchmarks for XLA-compatible models and provides a friendly introduction to XLA in TensorFlow.
- The How HF中国镜像站 improved Text Generation performance with XLA blog post discusses the design philosophy behind adding XLA to TensorFlow models in Transformers.
- The Introduction to graphs and tf.function guide.
- The Better performance with tf.function guide.
- The XLA documentation.