Part III · Inference Internals & Production Serving
Chapter 40 ~31 min read

Model optimization and compilation

"Your model is already correct. The question is whether it's fast enough to matter"

In Chapter 26 we reduced weight size with quantization. In Chapter 38 we saw how custom kernels extract hardware performance. This chapter is the layer that sits between them: the compilation and optimization stack that transforms a Python model into something the hardware can run at maximum efficiency. Quantization changes what you compute. Kernels change how primitives run. Compilation changes how those primitives are assembled, fused, scheduled, and dispatched.

The tools in this chapter — torch.compile, ONNX, TensorRT, XLA — are the difference between a model that runs and a model that runs fast. They’re often misunderstood as competing options when they’re really a ladder: each tool closes a different gap between what the math requires and what the hardware actually does.

By the end of this chapter you will be able to:

  • Explain the 4-stage compilation pipeline and where each optimizer fits.
  • Apply torch.compile to get a free 1.5-2× speedup with one line of code.
  • Export a model to ONNX and run it on ONNX Runtime.
  • Decide when TensorRT is worth the engineering cost.
  • Identify graph breaks, fusion failures, and shape problems in compiled models.
  • Choose the right optimizer for your deployment constraint.

Outline:

  1. Why your model is slower than it should be.
  2. Graph-level optimizations.
  3. torch.compile: the modern PyTorch optimizer.
  4. torch.export and AOTAutograd.
  5. ONNX: the interchange format.
  6. TensorRT: NVIDIA’s inference optimizer.
  7. XLA and JAX compilation.
  8. Operator fusion in practice.
  9. Dynamic shapes and the compilation problem.
  10. The optimization decision tree.
  11. Profiling before and after.
  12. Common pitfalls.
  13. The mental model.

40.1 Why your model is slower than it should be

Start with a fact: the GPU is underutilized almost all the time.

A 70B Llama model on H100 has a theoretical compute ceiling of 989 TFLOPs/s for BF16. A typical inference run achieves 15-30% of that. The rest is wasted on:

  • Kernel launch overhead. Each PyTorch op dispatches a separate CUDA kernel. Launch overhead is ~5-10 µs per kernel. A single forward pass of a large model can dispatch thousands of kernels. That’s milliseconds gone before the GPU does any math.
  • Memory round-trips. Elementwise operations (GELU, bias add, residual) that could share a register tile instead write to HBM, wait, and read back. A typical elementwise op achieves ~5% of peak FLOP/s because the memory round-trip dominates. Fusion (Section 130.8) is the fix.
  • Operator sub-optimality. PyTorch’s eager mode dispatches one operation at a time. The dispatcher can’t see that layernorm → linear → gelu → linear could be a single optimized pass. It calls four separate kernels, each with their own launch overhead and HBM traffic.
  • Graph-level redundancy. Constant subexpressions are recomputed. Unused branches execute. Layouts get transposed unnecessarily between ops.

The optimization stack exists to close these gaps. It does so through a 4-stage pipeline:

The 4-stage model compilation pipeline: Python eager code becomes a graph IR, which is optimized, which is lowered to hardware-specific kernels. Each stage closes a different performance gap. Stage 1 Python eager Slow. Dispatches 1 kernel / op. Stage 2 Graph IR (FX / HLO) See the full program. Enable graph passes. Stage 3 Optimized graph Fused ops. Const-fold. Layout-optimized. Stage 4 HW kernels (Triton/TRT) Max occupancy. Low launch overhead. torch.compile takes you from stage 1 to 4 automatically; TensorRT maximizes stage 3 → 4 for NVIDIA.
Each stage closes a different performance gap — the graph IR enables global optimizations that Python dispatch cannot; hardware-specific kernels extract the last bits of throughput.

The numbers matter. On a typical transformer layer in BF16:

OptimizationRepresentative gain
Operator fusion (elementwise)2-4× for the fused ops
torch.compile (overall model)1.5-2× end-to-end
TensorRT (vs PyTorch eager)2-4× end-to-end
FlashAttention (vs naive attn)2-8× (sequence-length dependent)

These stack partially, not fully — you can’t apply all four and get 50×. The actual bottleneck shifts as you optimize. But the first pass through this stack is often the highest-return engineering work in a deployment project.

40.2 Graph-level optimizations

A graph IR (intermediate representation) is a directed acyclic graph where nodes are operations and edges are tensors. Once you have the graph, you can run optimization passes over it. These passes are cheap to apply and often “free” performance.

Constant folding. If a subgraph depends only on constants (not on runtime inputs), compute it at compile time. Example: a mask computed from a fixed sequence length doesn’t need to run at inference time — fold it into a constant tensor. In practice this eliminates initializer computations, repeated arange calls, and any static shape arithmetic.

Dead code elimination (DCE). Remove ops whose outputs are never used. In a model with optional features (e.g., a training-mode dropout that’s disabled at inference), DCE removes the dead branches. PyTorch’s FX graph representation makes this trivial: traverse the graph, mark used nodes, delete unreachable ones.

Common subexpression elimination (CSE). If the same computation appears twice (e.g., x * scale computed in two branches), compute it once and share the result. LLMs often have this in attention variants where Q, K, V share a normalization step.

Operator fusion. The big one. Covered separately in Section 130.8. Instead of running 3 separate kernels (elementwise add, GELU, elementwise mul), emit a single kernel that does all three in registers without touching HBM. Free for the same math; 2-10× faster in practice.

Layout optimization. Convolution frameworks obsess over NCHW vs NHWC (channels first vs channels last). Transformers care less about layout, but the same principle applies: some kernels expect row-major, some expect column-major, some run faster on a particular memory layout. If the graph optimizer can choose layouts once and route ops to match, you avoid expensive transpose ops between layers. In PyTorch, the memory_format=torch.channels_last hint engages this. For LLMs the relevant equivalent is choosing whether to keep the weight matrix in row or column major depending on whether it’s the left or right operand of the GEMM.

Shape propagation and specialization. If you compile for a fixed input shape, the optimizer knows exact tensor dimensions at compile time. It can choose tile sizes, unroll loops, and select specialized kernels. Static shapes are faster; dynamic shapes are more general. The tension is the subject of Section 130.9.

Why don’t frameworks apply all these automatically in eager mode? Because eager mode processes one op at a time. You can’t do CSE across ops you haven’t seen yet. You can’t do layout optimization without a global view of the graph. Graph capture is the prerequisite for graph optimization.

40.3 torch.compile: the modern PyTorch optimizer

torch.compile is PyTorch’s answer to the question: “can we get TensorFlow graph-mode speed without making users write graphs?” The answer, since PyTorch 2.0, is mostly yes.

The one-liner:

model = torch.compile(model)

That’s it. Under the hood, three systems activate:

TorchDynamo is the graph capturer. It hooks into Python’s bytecode evaluation, intercepts PyTorch operations, and traces them into an FX (Functional eXecution) graph. Critically, it handles Python-level control flow: if your model has an if statement, Dynamo either specializes the compiled path (generating separate graphs for each branch) or falls back to eager execution at that point (a “graph break”).

AOTAutograd (Ahead-of-Time Autograd) partitions the graph into forward and backward subgraphs before execution. This allows the backward pass to be compiled and optimized as well. For inference-only use, AOTAutograd is less critical, but for training it’s what makes the backward pass fast.

TorchInductor is the code generator. It takes the FX graph, runs graph-level optimizations (fusion, CSE, loop tiling), and emits Triton kernels for GPU (or C++ for CPU). This is where the speed comes from. Inductor fuses elementwise operations, chooses memory layouts, and generates code specialized to your shapes.

torch.compile architecture: TorchDynamo captures the Python graph via bytecode inspection, AOTAutograd partitions forward and backward, TorchInductor generates Triton or C++ kernels. The compiled kernels are cached per input shape. Python model eager, flexible TorchDynamo bytecode intercept → FX graph AOTAutograd fwd/bwd partition joint graph TorchInductor fusion + tiling → Triton / C++ Kernel cache per shape/dtype Compiled CUDA kernels (run on GPU) graph break → fall back to eager here graph break
torch.compile captures the graph via TorchDynamo's bytecode inspection and lowers it to Triton kernels via TorchInductor — a graph break forces a fallback to Python execution and ends the compiled region.

Modes

torch.compile has three modes that trade compilation time for runtime speed:

model = torch.compile(model)                               # default
model = torch.compile(model, mode="reduce-overhead")       # minimize launch overhead
model = torch.compile(model, mode="max-autotune")          # grid-search tile sizes

default: Fuses elementwise ops and does basic optimizations. Compilation takes 30-60 seconds. Good for most cases.

reduce-overhead: Uses CUDA Graphs (Chapter 39) to eliminate kernel launch overhead. Best for small batch sizes where launch overhead is the dominant cost. Requires static shapes.

max-autotune: Profile-guided tile selection for GEMM and attention. Compilation takes minutes. Best for production serving of a fixed model at fixed shapes.

Graph breaks

A graph break occurs when Dynamo encounters something it can’t trace:

  • Data-dependent control flow: if x.sum() > 0: — the condition depends on a runtime value, so Dynamo can’t trace a single static graph.
  • Unsupported operations: some PyTorch ops, most torch.Tensor-to-Python conversions, print statements.
  • Python-level side effects: writing to external state, calling non-PyTorch functions.

Each break forces Dynamo to restart graph capture from that point, creating multiple small graphs instead of one large one. Small graphs have fewer fusion opportunities. To see breaks:

torch._dynamo.explain(model)(inputs)

To avoid breaks: keep control flow static, avoid .item() calls (which convert tensors to Python scalars), and test with torch._dynamo.config.suppress_errors = False to surface break causes.

When torch.compile helps and when it doesn’t

Helps most:

  • Custom model architectures without optimized kernels.
  • Training loops (both forward and backward benefit).
  • Elementwise-heavy ops (the fusion wins are real).
  • Small models where launch overhead is significant.

Helps less:

  • Models already using FlashAttention and cuBLAS (those kernels are already optimal; Inductor won’t beat them).
  • Batches with highly variable shapes (recompilation kills the benefit).
  • Already-optimized serving stacks like vLLM (they’ve replaced PyTorch’s default dispatch for the hot paths).

40.4 torch.export and AOTAutograd

torch.compile is JIT (just-in-time): the model gets compiled on first run. torch.export is AOT (ahead-of-time): it exports a serializable graph that can be saved, inspected, and passed to downstream tooling.

exported = torch.export.export(model, args=(example_input,))
# exported is a serializable ExportedProgram
torch.export.save(exported, "model.pt2")

The difference matters for deployment:

  • Tracing (how early PyTorch JIT worked) records the computation path of a single input. If control flow produces different paths for different inputs, tracing only captures one.
  • Capturing (how TorchDynamo works) symbolically traces the graph, representing shapes as symbolic variables and preserving branches. The result is a graph that’s correct for a range of inputs.

torch.export uses capturing, not tracing. This means:

  1. You get a graph that faithfully represents the model’s logic (modulo graph breaks, which become explicit guards).
  2. The graph can be shipped to TensorRT, ONNX, or mobile runtimes.
  3. The graph is inspectable — you can read the FX IR, verify what’s being computed, and debug.

Dynamic shapes in export

By default, torch.export specializes for the exact input shapes provided. If you want to export for variable sequence lengths:

from torch.export import Dim

seq_len = Dim("seq_len", min=1, max=2048)
exported = torch.export.export(
    model,
    args=(example_input,),
    dynamic_shapes={"input_ids": {1: seq_len}}
)

This marks dimension 1 as dynamic. The exported graph includes guards that validate the input at runtime and shape-polymorphic kernels that work for any valid length. It’s slightly slower than static but handles variable-length inference correctly.

40.5 ONNX: the interchange format

ONNX (Open Neural Network Exchange) is a standardized intermediate representation for neural network models. The pitch: train in PyTorch, export to ONNX, run on any runtime that supports ONNX — which includes Microsoft’s ONNX Runtime, NVIDIA’s TensorRT (via the ONNX parser), mobile runtimes, and dozens of others.

ONNX is a directed graph of operators (nodes) connected by typed tensors (edges). Each operator has a standardized semantics defined in the ONNX spec. There are ~160 standard operators covering all common neural network operations. Third-party ops get wrapped as “custom ops.”

Exporting from PyTorch

import torch.onnx

torch.onnx.export(
    model,
    (example_input,),
    "model.onnx",
    input_names=["input_ids"],
    output_names=["logits"],
    dynamic_axes={"input_ids": {0: "batch", 1: "seq_len"},
                  "logits": {0: "batch", 1: "seq_len"}},
    opset_version=18,
)

The dynamic_axes argument marks dimensions that will vary at runtime. Omitting this exports a static model that only accepts exactly the shapes used during export.

The opset versioning problem

Every ONNX release defines a new opset version. Opset 18 (2023) differs from opset 17, which differs from opset 16. Each PyTorch operation maps to an ONNX opset-N operator. When PyTorch updates a op or adds a new one, it may not have an ONNX mapping in the current opset. The result: you try to export a model that uses a new PyTorch op (like a custom attention variant), and the ONNX exporter fails with “operator not supported in opset 17.”

The workarounds:

  1. Upgrade to the latest opset (18+).
  2. Register a custom op in ONNX.
  3. Use torch.export + a dedicated ONNX dynamo exporter (the new path, still maturing as of early 2026).

Opset versioning is the single biggest friction in the ONNX ecosystem. Budget time for it.

ONNX Runtime

ONNX Runtime (ORT) is Microsoft’s optimizing inference runtime for ONNX models. It runs graph-level optimization passes (similar to those in Section 130.2), picks hardware-specific execution providers (CUDA, TensorRT, OpenVINO, CoreML, CPU), and dispatches to the best available kernel.

import onnxruntime as ort

sess = ort.InferenceSession("model.onnx",
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
output = sess.run(["logits"], {"input_ids": input_ids_np})

ORT’s optimization levels (set via GraphOptimizationLevel):

  • ORT_ENABLE_BASIC: DCE, constant folding, shape inference.
  • ORT_ENABLE_EXTENDED: adds op fusion and memory layout optimization.
  • ORT_ENABLE_ALL: platform-specific optimizations; the graph may become non-portable.

ORT typically gives 1.5-2× speedup over vanilla PyTorch eager for standard transformer ops, through fusion and the CUDA execution provider. It won’t beat a hand-optimized TensorRT engine but gets you 70-80% of the way there without the TensorRT engineering investment.

When ONNX makes sense

Use ONNX when:

  • You need to run on non-NVIDIA hardware (AMD ROCm, Intel OpenVINO, ARM, Apple CoreML).
  • You need model portability across environments (train on NVIDIA, deploy on CPU, run on mobile).
  • Your model is standard transformer architecture without exotic ops.
  • You want a runtime-agnostic deployment artifact.

Skip ONNX when:

  • You’re targeting cutting-edge LLM features: paged attention, speculative decoding, rotary embeddings with fused kernels. These ops either have no ONNX mapping or map to slow implementations.
  • You’re on NVIDIA hardware and want maximum performance (use TensorRT directly).
  • Your model changes frequently (ONNX export is not a fast iteration loop).

40.6 TensorRT: NVIDIA’s inference optimizer

TensorRT is NVIDIA’s inference optimizer. It takes a model (from ONNX, from PyTorch, or from its own API), runs an aggressive optimization pass, and builds an engine — a compiled, hardware-specific binary that runs as fast as possible on a specific GPU.

The TensorRT workflow has three phases:

Parse. Import the model from ONNX or a native TensorRT network definition API.

Optimize. This is TensorRT’s value. It runs:

  • Layer fusion: fuse Conv → BN → ReLU into one kernel; fuse multi-head attention projections.
  • Precision calibration: convert to FP16 or INT8 automatically, with optional calibration data to minimize accuracy loss.
  • Kernel auto-tuning: for each op and shape, profile multiple kernel implementations and pick the fastest one on the target GPU.
  • Memory layout optimization: choose formats that minimize memory traffic.
  • CUDA Graph integration: wrap the compiled graph in a CUDA Graph for low-latency repeated execution.

Build engine. Serialize the optimized, hardware-specific engine to disk.

import tensorrt as trt

logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
parser.parse_from_file("model.onnx")

config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)  # enable FP16
engine = builder.build_serialized_network(network, config)
with open("model.trt", "wb") as f:
    f.write(engine)

At inference:

runtime = trt.Runtime(logger)
engine = runtime.deserialize_cuda_engine(open("model.trt", "rb").read())
context = engine.create_execution_context()
# set input/output bindings, then:
context.execute_v2(bindings)

Precision calibration

TensorRT’s INT8 and FP8 calibration is one of its most powerful features. You provide a calibration dataset (a few hundred representative inputs), and TensorRT measures activation distributions to choose per-layer scales that minimize quantization error. Compare to Chapter 26: this is automatic post-training quantization done inside the optimizer, rather than requiring a separate AWQ or GPTQ pass. The quality is often comparable to manual calibration if the calibration data is representative.

The engine is hardware-specific

This is the most important caveat about TensorRT. A .trt engine built on an A100 does not run on an H100. It must be rebuilt. The kernel auto-tuning phase picks operations specific to the GPU architecture (SM version, Tensor Core generation). Portability is zero.

Concretely: if you build your TensorRT engine on your development machine (say, a 4090) and then deploy to a datacenter A100, you’ll get an error. You must build the engine in the same environment where it will run, or build separate engines for each target GPU type.

TensorRT-LLM

TensorRT-LLM is NVIDIA’s LLM serving layer built on TensorRT. It adds:

  • LLM-specific ops: PagedAttention, continuous batching, speculative decoding.
  • FP8 support with Hopper-specific kernels.
  • CUDA Graph integration for low-latency decode.
  • In-flight batching support.

TensorRT-LLM is often the fastest available serving stack for NVIDIA hardware. The costs: it’s complex to deploy, the API differs from PyTorch, models must be converted to TensorRT-LLM’s format, and model support lags behind PyTorch. As of early 2026 it supports Llama, Mistral, Mixtral, Qwen, Falcon, and a handful of others — but not every architecture.

When to use TensorRT

Use TensorRT when:

  • You’re on NVIDIA hardware and need maximum throughput or minimum latency.
  • The model is stable (not changing weekly).
  • You can afford the build time (large models take 30-60 minutes to build an engine).
  • You’ve exhausted torch.compile and ONNX Runtime.

Skip TensorRT when:

  • You’re not on NVIDIA hardware.
  • The model changes frequently (every rebuild is 30-60 min).
  • You’re using cutting-edge features (vLLM or SGLang likely have better-tuned kernels for those).
  • You’re in early development — the iteration cycle is too slow.

40.7 XLA and JAX compilation

XLA (Accelerated Linear Algebra) is Google’s compiler for ML workloads. It takes a high-level IR called HLO (High Level Operations) and compiles it to optimized kernels for GPU, TPU, or CPU.

XLA is the compiler underneath:

  • JAX (by default: every JAX operation runs through XLA).
  • TensorFlow 2 (XLA mode via @tf.function).
  • PyTorch/XLA (a bridge that routes PyTorch ops through XLA, targeting TPUs and GPU).

How XLA differs from TorchInductor

Both do graph-level optimization and kernel generation, but the design philosophies differ:

XLATorchInductor
ScopeWhole-program compilationPer-subgraph
IRHLO (high-level)FX graph
FusionAggressive fusion across the entire programFusion within captured regions
TargetGPU, TPU, CPUGPU (Triton), CPU (C++)
Kernel strategyEmit fused LLVM/PTX from scratchCall Triton for GPU; LLVM for CPU
Shape handlingRequires static shapes (XLA) or shape inferenceDynamic shapes via symbolic dim

XLA’s “whole-program” approach is more aggressive: it can fuse operations across transformer layers, hoist loop-invariant computations, and perform rematerialization (recomputing rather than storing activations to save memory). For training on TPUs this is the dominant approach. For GPU inference it tends to be slower than TorchInductor+Triton because XLA doesn’t have the same depth of GPU-specific kernel tuning.

JAX’s jit-by-default

JAX’s value proposition: functions decorated with @jax.jit are traced once (with abstract shape-carrying values) and compiled via XLA. The result is cached and reused. There’s no eager mode performance penalty in production because you’re expected to jit everything.

@jax.jit
def forward(params, x):
    # all ops run through XLA
    return model.apply(params, x)

First call: tracing + compilation (~seconds). Subsequent calls: cached kernel dispatch (~microseconds overhead).

JAX’s compilation model is clean. Its functional API is easy to reason about. The research community has embraced it (Gemini’s internals are JAX-based, as is much of Google Brain’s training infrastructure).

But production LLM serving on NVIDIA GPUs in 2025-26 is dominated by PyTorch-based stacks (vLLM, SGLang, TGI) for a practical reason: the ecosystem of LLM-specific kernels (FlashAttention, PagedAttention, Marlin INT4) is developed for PyTorch/CUDA first. JAX/XLA bindings come later if at all. Running JAX in production serving means either writing your own JAX-compatible implementations of these kernels (non-trivial) or accepting slower fallback implementations.

The exception: TPU serving, where JAX + XLA is the only real option and the ecosystem is correspondingly richer.

40.8 Operator fusion in practice

Fusion is the single most impactful optimization in the compilation stack. It deserves its own section.

The mechanism: instead of 3 separate kernels that each read from and write to HBM, emit 1 fused kernel that loads the input once, applies all three operations in registers, and writes the output once.

Operator fusion: before fusion, three separate kernels each read and write to HBM, incurring 6 memory round-trips; after fusion, one kernel does all three operations in registers with 2 memory round-trips. Before fusion After fusion x (HBM) bias_add tmp1 (HBM) gelu tmp2 (HBM) kernel 1: bias_add write tmp1 to HBM kernel 2: gelu read tmp1, write tmp2 HBM read/write ×4 x (HBM) fused kernel bias_add gelu (all in registers) output (HBM) HBM read/write ×2 2-4× faster
Fusion eliminates intermediate HBM round-trips — the fused kernel keeps all intermediate values in registers, so what was 3 kernel launches and 4 HBM transfers becomes 1 launch and 2 transfers.

Common fusion targets in transformers

Elementwise chains. bias_add → GELU → dropout is three ops in eager mode. Fused, it’s one kernel. TorchInductor does this automatically. So does the Liger-kernel library, with hand-tuned Triton versions.

Attention fusion (FlashAttention). Chapter 25 covers this in depth. The key point here: Q @ K.T → scale → mask → softmax → @ V is fused into a single kernel that tiles over the sequence length, keeping the S × S attention score matrix in SRAM instead of materializing it in HBM. This is the single largest memory saving in transformer inference — the O(s²) activation is never written to HBM.

QKV projection. Instead of three separate GEMMs for Q, K, V projections, use one GEMM with a wider output matrix: X @ [W_Q | W_K | W_V] → split. This reduces kernel launch overhead and improves GEMM efficiency (wider GEMMs hit higher Tensor Core utilization).

LayerNorm fused with matmul input. Some serving stacks fuse the RMSNorm before an attention layer with the first GEMM in that layer, reducing the number of HBM passes over the activation.

How to check if fusion happened

The torch.profiler trace is the ground truth. If you see separate CUDA kernel calls for bias_add_kernel, gelu_kernel, etc., fusion didn’t happen. If you see a single Triton or cuDNN fused kernel covering all three, it did.

with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
    output = model(input)
print(p.key_averages().table(sort_by="cuda_time_total"))

In Nsight Systems (nsys), look at the CUDA kernel timeline. Fused ops appear as single wider bars; unfused ops appear as many thin bars with small gaps (kernel launch overhead) between them.

40.9 Dynamic shapes and the compilation problem

LLM inference is inherently dynamic. Sequence lengths vary. Batch sizes vary. In speculative decoding, the number of accepted tokens per step varies. This creates a fundamental tension with static-shape compilers.

The static-shape world (TensorRT default):

TensorRT builds an engine for specific input shapes. If your model is deployed for variable-length requests, you have two options:

  1. Padding: always pad inputs to a fixed maximum length (e.g., 2048 tokens). Simple, but wasteful — a 10-token request runs at 2048-token cost.
  2. Multiple engine profiles: build separate engine profiles for a set of shape buckets (e.g., [1-128, 129-512, 513-2048]). TensorRT selects the smallest profile that fits the input. More complex, but avoids waste.

TensorRT’s optimization profile API:

profile = builder.create_optimization_profile()
profile.set_shape("input_ids",
    min=(1, 1),    # (batch, seq_len) min
    opt=(1, 512),  # optimal
    max=(8, 2048)  # max
)
config.add_optimization_profile(profile)

The dynamic-shape world (torch.compile, ONNX Runtime):

Both torch.compile and ORT support dynamic shapes natively. The compiled code handles variable lengths without recompilation by generating shape-polymorphic kernels (or by using symbolic shapes that get resolved at runtime).

The cost: shape-polymorphic kernels can’t be as aggressively specialized. A static-shape kernel for seq_len=512 can use exactly 512-iteration loops, unroll them, and choose tile sizes tuned for exactly that shape. A dynamic-shape kernel must handle any length, which forces conservative tiling choices.

In practice for LLM inference:

  • Prefill has variable sequence lengths. Static-shape compilers require padding or multiple profiles. Dynamic-shape handles it natively.
  • Decode is always seq_len=1 (one new token per step). Static shapes are fine; compile once.
  • KV cache growth is dynamic. This is why PagedAttention (Chapter 22) was needed — it breaks the assumption that the KV cache is a contiguous tensor.

The standard serving stack answer (vLLM, SGLang): don’t use TensorRT for the full model. Use carefully hand-picked kernels for the hot ops (FlashAttention, paged attention) that handle dynamism natively, and let the rest run through PyTorch. This sidesteps the entire static-shape problem.

40.10 The optimization decision tree

Don’t stack compilers that do the same thing. Don’t run a model through torch.compile, export it to ONNX, run it through ORT, and then build a TensorRT engine from the ORT output. Each compiler introduces its own representation and optimization passes; stacking them creates conflicts and makes debugging impossible.

Instead, pick one optimizer per deployment context and go deep on it.

graph TD
  Start[Have a trained model to deploy] --> Q1{On NVIDIA GPU?}
  Q1 -->|No| Q2{Cross-platform needed?}
  Q1 -->|Yes| Q3{LLM with paged attention / speculative decoding?}
  Q2 -->|Yes| ONNX[ONNX Runtime + ORT EP]
  Q2 -->|No| XLA[XLA / CoreML / OpenVINO]
  Q3 -->|Yes| VLLM[vLLM or SGLang — they optimize internally]
  Q3 -->|No| Q4{Stable model, fixed shapes, latency-critical?}
  Q4 -->|Yes| TRT[TensorRT or TensorRT-LLM]
  Q4 -->|No| Compile[torch.compile — start here]
  Compile --> Q5{Still not fast enough?}
  Q5 -->|Yes + NVIDIA| TRT
  Q5 -->|Yes + cross-platform| ONNX
  Q5 -->|No| Done[Ship it]
  style Compile fill:var(--fig-accent-soft),stroke:var(--fig-accent)
  style VLLM fill:var(--fig-surface),stroke:var(--fig-border)
  style Done fill:var(--fig-surface),stroke:var(--fig-good)

Start with torch.compile (one line, free) and move right only if you’ve measured and it’s not enough.

Rules of thumb:

  1. torch.compile first. One line. No format conversion. Reverting is easy. Get a baseline.
  2. LLM serving: use a serving framework. vLLM (Chapter 44) and SGLang include their own optimized kernels, continuous batching, and PagedAttention. They outperform naive torch.compile for LLM serving. Don’t reimplement what vLLM already does.
  3. TensorRT for non-LLM NVIDIA inference. Computer vision, classification, detection, embedding generation — these are the sweet spot. Fixed shapes, stable models, NVIDIA hardware. TensorRT will win.
  4. ONNX for portability. If your deployment targets multiple hardware types or your production environment can’t guarantee NVIDIA, ONNX Runtime is the right layer.
  5. Don’t double-optimize. If you’re running vLLM, don’t also run torch.compile on the model — vLLM has already replaced the hot paths. You’ll get graph breaks and confusion.

40.11 Profiling before and after

The optimizer decision tree is useless without measurement. Never assume an optimization helped. Always measure.

The three things to check

(1) Did fusion happen? Use torch.profiler or Nsight Systems. Count the number of distinct CUDA kernel calls during one forward pass. After fusion, the number should drop; fused kernels should appear in the trace. If you still see bias_add_kernel and gelu_kernel as separate calls, fusion failed.

(2) Did precision lower? For TensorRT FP16 or INT8 builds, use Nsight Compute (ncu) to check which tensor cores are engaged. H100 FP16 uses 16-bit Tensor Cores (989 TFLOPs/s ceiling). INT8 doubles that. FP8 on Hopper uses the e4m3 Transformer Engine. If your optimization should have engaged FP16/FP8 but Nsight shows FP32 operations, something in the precision setup failed — look for missing autocast scopes or layers that were excluded from precision lowering.

(3) Did occupancy increase? Nsight Compute reports SM occupancy (the fraction of SMs actively running warps). Low occupancy on a GEMM usually means your batch or matrix dimensions don’t tile efficiently. A fused kernel may have higher occupancy than the equivalent unfused ops because the fused version can schedule more work per SM.

torch.profiler

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA],
    with_stack=True,
    record_shapes=True,
) as prof:
    with torch.profiler.record_function("forward"):
        output = model(input)

# Print summary
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

# Export Chrome trace
prof.export_chrome_trace("trace.json")

Load trace.json in chrome://tracing or Perfetto UI to see the kernel timeline.

ONNX Runtime profiler

ORT has a built-in profiler:

opts = ort.SessionOptions()
opts.enable_profiling = True
sess = ort.InferenceSession("model.onnx", sess_options=opts, providers=["CUDAExecutionProvider"])
sess.run(...)
profile_file = sess.end_profiling()
# profile_file is a JSON, load in chrome://tracing

Before/after comparison template

Track these three metrics before and after each optimization:

MetricToolWhat to record
p50/p95 latencytime.perf_counter or Nsightms/token or ms/request
Throughputtokens/sectotal output tokens / wall time
GPU memorytorch.cuda.max_memory_allocated()peak MiB per forward pass

If latency didn’t improve, check whether the optimizer is actually engaged (compilation errors are often silent). If throughput improved but memory increased, you may have hit a precision regression.

40.12 Common pitfalls

Graph breaks in torch.compile.

The most common cause is data-dependent control flow:

# This causes a graph break:
if x.sum() > threshold:   # scalar comparison forces Python execution
    x = x * 2

Fix by restructuring as a masked operation:

# No graph break:
mask = (x.sum() > threshold).float()
x = x * (1 + mask)

Other common break causes: .item(), print(tensor), in-place operations on views, non-PyTorch library calls inside the model forward.

ONNX export failures.

Export fails when PyTorch ops don’t map to ONNX operators. Common cases:

  • torch.einsum with complex subscripts: rephrase as matmul.
  • Custom attention variants: register a custom ONNX op or rewrite.
  • Dynamic slicing with runtime-determined indices: use dynamic_axes and ensure indices are ONNX-compatible.
  • torch.compile-generated code: export the pre-compiled model.

TensorRT engine not portable.

You build the engine on your dev machine (RTX 4090) and try to load it on the production server (A100). TensorRT refuses with an engine version or device mismatch error. Solution: always build TensorRT engines in the production environment or use TensorRT’s --saveEngine/--loadEngine in the correct environment. Add the GPU SM version as part of your artifact naming.

Calibration data not representative.

TensorRT INT8 calibration uses your calibration dataset to set per-layer scales. If the calibration data has different distribution than production inputs (e.g., you calibrated on short sequences but serve long ones), outlier activations in production will be clipped to the calibration range and you’ll see quality degradation. Fix: calibrate on a random sample from your actual traffic distribution. A few hundred examples is usually enough; the distribution matters more than the count.

Stacking incompatible optimizers.

Running torch.compile and then exporting to ONNX from the compiled model can produce ONNX graphs full of Triton-emitted ops that have no ONNX equivalents. Always export ONNX from the original (uncompiled) model. Similarly, don’t build a TensorRT engine from an already-ORT-optimized ONNX model — the double-optimization can produce worse results than building TensorRT from the clean export.

Warmup and caching costs.

torch.compile adds 30-60 seconds of compilation overhead on first use. TensorRT builds take 30-60 minutes. In production, this means:

  • Don’t call torch.compile inside a request handler. Compile at startup, cache the compiled model.
  • Serialize and reload TensorRT engines across restarts. Build once, run many.
  • ONNX Runtime’s extended optimization can take seconds on large models — also do this at startup.

40.13 The mental model

Your model goes through four stages. Each stage closes a different performance gap. The question is where your bottleneck is and which stage to target.

Stage 1: Python eager. Correct. Flexible. Slow. One kernel per operation, thousands of CUDA kernel launches per forward pass, no global view of the computation.

Stage 2: Graph IR. Captured. Visible. Enables global optimization. The FX graph (PyTorch), HLO (XLA), or ONNX IR is where graph passes run: constant folding, dead code elimination, CSE, layout optimization.

Stage 3: Optimized graph. Fused. Pruned. Layout-optimal. The key transformations happen here: operator fusion (Section 130.8) eliminates HBM round-trips; precision lowering (Chapter 26) reduces bandwidth and enables Tensor Core dispatch; kernel auto-tuning selects the fastest implementation for your shapes.

Stage 4: Hardware-specific kernels. Compiled. Device-bound. Maximum performance. The Triton-generated fused kernels from TorchInductor, or the TensorRT-built engine, or the CUTLASS-backed matmuls. These are the actual GPU instructions that run.

The optimizer choice maps to the stage you’re targeting:

  • torch.compile: automates stages 2 → 4 in one call, with dynamic-shape support.
  • torch.export: gives you a portable stage-2 artifact.
  • ONNX: portable stage-2/3 representation; ORT runs stage 3 → 4.
  • TensorRT: aggressive stage-3 → 4 optimization for NVIDIA hardware.
  • XLA/JAX: whole-program stage-2 → 4 with cross-hardware portability.
  • vLLM/SGLang: hand-coded stage-4 kernels for LLM-specific operations, bypassing the compilation stack entirely for the hottest paths.

The mistake is to apply optimizers without knowing which stage is your bottleneck. If your model is memory-bandwidth-bound (as most LLM decode passes are), optimizing compute fusion (stage 3) won’t help much — you need to reduce model size via quantization (Chapter 26) or reduce memory traffic via better batching (Chapter 44). Profile first. Know your bottleneck. Then pick the optimizer that closes that specific gap.


Read it yourself

  • The torch.compile documentation and tutorial at pytorch.org/tutorials. The “Introduction to torch.compile” tutorial is the fastest path to working code.
  • The TorchInductor design doc on the PyTorch blog. Explains how Dynamo → FX → Inductor → Triton works.
  • Ansel et al., PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation (2024). The official paper.
  • ONNX specification at onnx.ai/onnx/intro/concepts.html. Read the operator semantics section.
  • The ONNX Runtime performance tuning guide at onnxruntime.ai.
  • The TensorRT developer guide at docs.nvidia.com/deeplearning/tensorrt. Read chapters on building engines and calibration.
  • The XLA overview at openxla.org/xla/architecture. Short and worth reading.
  • The PyTorch profiler tutorial for understanding torch.profiler output.

Practice

  1. Apply torch.compile(model) to a small transformer you have access to. Measure latency before and after. Use torch._dynamo.explain(model)(inputs) to find any graph breaks and fix one.
  2. Export a model to ONNX with dynamic_axes for batch and sequence length. Load it in ONNX Runtime with the CUDA execution provider. Measure latency vs PyTorch eager.
  3. What is the difference between torch.compile and torch.export? When would you use each?
  4. A colleague says “I’ll just use TensorRT for my LLM — it’s the fastest NVIDIA option.” Name two specific scenarios where this advice is wrong and what to use instead.
  5. You profile a compiled model and find that gelu_kernel and bias_add_kernel still appear as separate CUDA calls. List three possible reasons fusion failed.
  6. Why does TensorRT require rebuilding the engine when you move from an A100 to an H100? What specifically is hardware-specific about the built engine?
  7. Your INT8 TensorRT engine produces correct outputs on the calibration set but lower quality on production inputs. Diagnose the likely cause and describe the fix.
  8. Stretch: Take a transformer model, export to ONNX (opset 18), and run both the original PyTorch model and the ORT model with profiling enabled. Compare the number of distinct CUDA kernels invoked in each case. Which is lower and why?