The mathematical objects: tensors, shapes, broadcasting
"Eighty percent of ML bugs are shape bugs. The other twenty percent are also shape bugs but they only manifest after the model has trained for six hours"
Before we get to attention, transformers, training loops, or any of the things you actually want to learn, we have to spend a chapter on the data structure that holds all of it: the tensor. This is not optional. Every concept in this book will eventually be expressed as operations on tensors, and the bugs you will spend the most time on — by a wide margin — are shape mismatches, broadcasting surprises, contiguous-vs-strided footguns, and dtype precision losses. Master this chapter and you will save yourself thousands of hours over your career.
We will build the picture in layers:
- What a tensor actually is (and why “tensor” is a slight lie).
- The three properties every tensor carries: shape, dtype, device.
- The shape tuple and the conventions you have to memorize.
- Indexing, slicing, and the small set of operators that do the heavy lifting.
- Broadcasting — the rules, the intuition, and the traps.
- Memory layout: strides, contiguity, and the view-vs-reshape distinction.
- dtype: the precision / range / throughput frontier.
- Devices and the cost of host–device transfers.
- The “shape bugs are most bugs” thesis and the discipline that prevents them.
By the end you will be able to read any tensor expression in any ML codebase and predict — without running it — what shape the result will have, where the memory lives, and where the next bug will come from.
1.1 What a tensor actually is
In mathematics, a tensor is a multilinear map between vector spaces. It has all sorts of structure: covariance, contravariance, invariance under coordinate transforms. None of that survives the trip into machine learning.
In machine learning, a tensor is just an n-dimensional array of numbers, plus a small amount of metadata. That is the entire definition. The mathematicians are mildly offended by the abuse of terminology, but the abuse is now universal and you should not fight it.
The dimensionality of the array is called the rank (also “order,” also “ndim”):
- Rank 0: a scalar. A single number. Shape
(). - Rank 1: a vector. A 1-D array. Shape
(n,). - Rank 2: a matrix. A 2-D array. Shape
(m, n). - Rank 3: sometimes called a “cube.” Shape
(d0, d1, d2). - Rank 4 and up: no friendly name; just “rank-N tensor.”
For everything in this book — and most of practical ML — you will work with tensors of rank 1 through 5 or 6. Anything higher and you have probably done something wrong.
import torch
scalar = torch.tensor(3.14) # shape ()
vector = torch.tensor([1.0, 2.0, 3.0]) # shape (3,)
matrix = torch.tensor([[1, 2], [3, 4]]) # shape (2, 2)
cube = torch.zeros(2, 3, 4) # shape (2, 3, 4)
batch = torch.zeros(8, 3, 224, 224) # a batch of 8 RGB 224×224 images
That last example is the canonical “batch of images” shape. Memorize the convention. It will haunt you.
1.2 The three properties of every tensor
Every tensor — in PyTorch, JAX, TensorFlow, NumPy, all of them — carries three things:
- Shape — a tuple of nonnegative integers giving the size in each dimension.
- Dtype — the numeric type of every element (
float32,bfloat16,int8, etc.). - Device — where the memory lives (CPU, CUDA:0, MPS, TPU).
x = torch.zeros(8, 3, 224, 224, dtype=torch.float16, device='cuda:0')
x.shape # torch.Size([8, 3, 224, 224])
x.dtype # torch.float16
x.device # device(type='cuda', index=0)
The mistake every beginner makes is thinking only about shape. Half the production bugs you will hit are dtype bugs (you computed in fp16 and overflowed) or device bugs (you tried to add a CPU tensor to a CUDA tensor and got a runtime error). Always think about all three.
1.3 The shape tuple and the conventions you must memorize
Shape is a tuple, and the order of dimensions matters. The conventions are not laws of nature; they are choices that different libraries and frameworks make differently, and they are the source of more confusion than anything else in ML.
The conventions you absolutely must know:
Images
PyTorch uses (N, C, H, W) — batch, channels, height, width. TensorFlow defaults to (N, H, W, C). The PyTorch order is sometimes called “channels-first” or NCHW; the TF order is “channels-last” or NHWC.
The reason PyTorch picked NCHW is that NVIDIA’s cuDNN kernels were faster on NCHW for years. The reason it stopped mattering is that modern Tensor Cores actually prefer NHWC for many ops, which is why torch.channels_last exists as a memory-format hint. Confusing? Yes. Welcome.
Sequences
Two competing conventions:
- Batch-first:
(N, S, D)— batch, sequence, hidden dim. Default innn.Transformer(batch_first=True)and most modern code. - Time-first:
(S, N, D)— sequence, batch, hidden dim. PyTorch’s RNN modules historically defaulted to this.
Always check batch_first when you instantiate any sequence layer. If you mix the conventions you will get a model that runs without errors and learns nothing.
Attention tensors inside a transformer
When you are inside an attention block, the canonical shape is (N, H, S, D_h) — batch, heads, sequence, head dim. The total hidden dim D = H * D_h gets reshaped into (H, D_h) so each attention head can operate independently. We will spend a lot of time on this in Chapter 6.
A discipline that pays for itself in a week
When you write or read tensor code, annotate every tensor with its expected shape in a comment:
# x: (N, S, D)
x = embedding(input_ids) # (N, S, D)
q = q_proj(x) # (N, S, D)
q = q.view(N, S, H, D_h) # (N, S, H, D_h)
q = q.transpose(1, 2) # (N, H, S, D_h)
When (not if) the shape comment disagrees with the actual code, you have found your bug.
1.4 Indexing and slicing
Indexing tensors uses the same syntax as NumPy arrays. The fundamentals:
x = torch.arange(24).reshape(2, 3, 4)
# tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
x[0] # shape (3, 4) — first "page"
x[0, 1] # shape (4,) — first page, second row
x[0, 1, 2] # shape () — scalar 6
x[:, 1, :] # shape (2, 4) — middle row of every page
x[..., -1] # shape (2, 3) — last column of every page
x[None] # shape (1, 2, 3, 4) — adds a leading singleton dim
x[:, None] # shape (2, 1, 3, 4) — singleton inserted at axis 1
Two underused tricks:
...(ellipsis) stands for “as many:as needed to fill the remaining dims.”x[..., -1]is “the last element of the last dim, regardless of how many leading dims there are.” This makes code rank-polymorphic — the same line works for(C, H, W)and(N, C, H, W).None(ornp.newaxis) inserts a singleton dimension. This is the bread and butter of broadcasting, which is the next section.
Advanced (fancy) indexing
You can index with a tensor of integers to gather arbitrary elements:
indices = torch.tensor([0, 2, 1])
x[:, indices, :] # shape (2, 3, 4) — picks rows 0, 2, 1 from each page
Fancy indexing always copies. Plain slice indexing returns a view (more on views in §1.6). This distinction will bite you when you assign into a sliced expression and the assignment silently doesn’t propagate back to the original tensor.
1.5 Broadcasting
Broadcasting is the mechanism by which tensors of different shapes get combined in element-wise operations without explicit looping or copying. It is one of the two or three most important ideas in this chapter, and getting it wrong is responsible for an enormous fraction of subtle bugs.
The rule, formally
When you apply an element-wise operation to two tensors A and B:
- Pad the shorter shape on the left with
1s until both shapes have the same length. - For each dimension, the sizes must either be equal, or one of them must be 1. If neither is true, broadcasting fails and you get an error.
- The output shape is the element-wise maximum of the two padded shapes.
- Wherever a dimension has size 1, that tensor is conceptually replicated along that dimension to match the other.
That’s the entire rule. Read it three times. Every broadcasting question reduces to applying it mechanically.
Worked example: per-channel mean subtraction
A = torch.zeros(8, 3, 224, 224) # batch of images
B = torch.zeros(3, 1, 1) # per-channel mean (RGB)
(A - B).shape # (8, 3, 224, 224)
How? Pad B on the left: (1, 3, 1, 1). Compare to (8, 3, 224, 224):
| dim | A | B (padded) | result |
|---|---|---|---|
| 0 | 8 | 1 | 8 |
| 1 | 3 | 3 | 3 |
| 2 | 224 | 1 | 224 |
| 3 | 224 | 1 | 224 |
Each dim is either equal or 1. Output shape is the elementwise max: (8, 3, 224, 224). The B tensor is conceptually replicated 8× along the batch axis and 224× along each spatial axis — but no actual memory is allocated for those copies. The kernel just reads B[c, 0, 0] for every spatial position.
Worked example: the masking bug
x = torch.randn(N, S, D) # (N, S, D)
mask = torch.zeros(N, S) # (N, S) — 1 where padding, 0 elsewhere
x_masked = x * (1 - mask) # FAILS — shape mismatch
Why does it fail? Pad mask on the left to (1, N, S). Compare to (N, S, D):
| dim | x | mask (padded) | result |
|---|---|---|---|
| 0 | N | 1 | N |
| 1 | S | N | mismatch unless N == S |
| 2 | D | S | mismatch unless D == S |
The dim sizes don’t line up. The fix is to add a trailing dimension to mask explicitly:
x_masked = x * (1 - mask[..., None]) # mask becomes (N, S, 1) — broadcasts cleanly
This is the single most common broadcasting error in transformer code. Memorize the fix.
The trap
Broadcasting will silently do the wrong thing if you have the right shapes by accident:
predictions = torch.randn(32) # (32,) — one prediction per sample in a batch of 32
targets = torch.randn(32, 1) # (32, 1) — accidentally a column vector
loss = (predictions - targets) ** 2 # what shape?
Padding predictions on the left gives (1, 32). Comparing to (32, 1):
| dim | predictions | targets | result |
|---|---|---|---|
| 0 | 1 | 32 | 32 |
| 1 | 32 | 1 | 32 |
Both broadcast! The output is (32, 32) — a full pairwise difference matrix. Your loss is now the mean over 1024 elements instead of 32, your gradients are wrong by a factor that’s hard to see, your model trains, and you spend a week wondering why it’s not converging.
This is one of the most expensive bugs in ML. The defense is to always check shapes explicitly when you build a loss, or use a library like jaxtyping or einops that enforces named dimensions (more in §1.9).
1.6 Memory layout: strides, contiguity, and the view-vs-reshape distinction
A tensor’s shape is a logical fiction. Underneath, memory is a flat 1-D buffer. The mapping from a multi-dimensional index to a flat offset is determined by the strides — one integer per dimension, telling you how many elements to step in the flat buffer when you increment that dimension by one.
x = torch.arange(12).reshape(3, 4)
# shape: (3, 4)
# strides: (4, 1)
# memory: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
To find element x[i, j], you compute i * 4 + j * 1. That is “row-major” or “C order” — rows are contiguous in memory, so iterating along the last axis is the fastest direction to walk.
When you transpose a tensor, PyTorch does not copy memory. It just swaps the strides:
y = x.transpose(0, 1)
# shape: (4, 3)
# strides: (1, 4) <-- swapped
Same memory, but now reading along the last axis means stepping by 4 in the flat buffer. This is non-contiguous. Most operations still work on non-contiguous tensors, but some kernels demand contiguous input and will either copy automatically (slow) or refuse (error).
You can ask:
y.is_contiguous() # False
y.contiguous() # returns a new tensor with the data physically rearranged
.view() vs .reshape() — the distinction every Python ML engineer must understand
.view(*shape)asks for a new shape without copying memory. It only works if the new shape is compatible with the existing strides — i.e. the tensor is contiguous (or contiguous “enough” in the relevant dims). If it isn’t,.view()raises an error..reshape(*shape)does the same thing if it can — and otherwise silently makes a contiguous copy and reshapes that.
So .reshape() always succeeds; .view() is a stricter version that errors instead of silently copying. The reason to ever prefer .view() is that when it succeeds, you know there was no copy. In a hot training loop, a silent reshape-induced copy can cost you significant throughput.
A canonical example from a transformer attention block:
# x: (N, S, D)
q = q_proj(x) # (N, S, D) — contiguous
q = q.view(N, S, H, D_h) # (N, S, H, D_h) — succeeds, still contiguous
q = q.transpose(1, 2) # (N, H, S, D_h) — strides shuffled, NON-contiguous
q = q.contiguous() # force a copy so the next .view() works
q = q.view(N * H, S, D_h) # collapse batch and heads for the matmul
If you skipped the .contiguous() call, the final .view() would error. Alternatively, you could replace the chain with .reshape(N * H, S, D_h) and let it copy silently. The first form is more explicit about cost; the second is more concise. Production code uses both.
Why this matters for performance
GPU kernels — especially the fused attention kernels in FlashAttention, the matmul kernels in CUTLASS, and the optimized convs in cuDNN — generally require their inputs to be contiguous in specific layouts. If they aren’t, the framework either inserts a copy (extra HBM bandwidth, extra latency) or falls back to a slower kernel.
We will revisit this in Chapter 25 on FlashAttention, where the entire optimization is about exploiting the fast SRAM by tiling carefully across the contiguous memory layout. For now, just remember: strides are not free, and “view” is not a synonym for “reshape.”
1.7 dtype: precision, range, and what each costs you
The dtype of a tensor controls how each element is stored. The choices that matter in modern ML, in roughly the order you’ll meet them:
| dtype | bits | mantissa | exponent | max finite value | typical use |
|---|---|---|---|---|---|
float64 | 64 | 52 | 11 | ~10³⁰⁸ | almost never in ML — too slow, too big |
float32 | 32 | 23 | 8 | ~10³⁸ | the default, but rarely optimal for big models |
float16 | 16 | 10 | 5 | 65 504 | inference; risky for training (underflows) |
bfloat16 | 16 | 7 | 8 | ~10³⁸ | the modern training default |
float8 (e4m3) | 8 | 3 | 4 | 448 | H100/H200 inference; tight calibration |
float8 (e5m2) | 8 | 2 | 5 | 57 344 | H100 gradient storage |
int8 | 8 | — | — | 127 | post-training quantization |
int4 | 4 | — | — | 7 | weight-only quantization (AWQ, GPTQ) |
The three things to remember:
-
bfloat16 has the same exponent range as float32, just with less mantissa precision. That is why it became the training default — it doesn’t underflow gradients the way float16 does, so you don’t need loss scaling. Less precision, same range. We will revisit this in Chapter 13 (mixed-precision training).
-
Tensor Core throughput roughly doubles for every halving of bit width. An H100 does about 990 TFLOPs in fp16/bf16 and about 1980 TFLOPs in fp8 (dense, no sparsity). This is why quantization is not just a memory optimization — it’s the largest single throughput lever you have. Chapter 26 covers the full quantization landscape.
-
Mixed precision — store weights and most activations in bf16/fp16, but keep the master copy of weights and the optimizer state in fp32. This is what
torch.cuda.amp.autocastdoes, and why it works.
You will get casting bugs. They look like this:
x = torch.randn(10, dtype=torch.float16)
y = torch.randn(10, dtype=torch.float32)
z = x + y # what dtype?
The answer is float32 — PyTorch promotes to the wider type. This is usually what you want, except when it isn’t (silently upcasting an entire activation map can blow your memory budget). Always check .dtype after a binary op if you are not 100% sure.
1.8 Devices and host–device transfers
Every tensor lives in some memory: CPU RAM, GPU HBM, MPS (Apple Silicon), TPU HBM. PyTorch represents this with the .device attribute.
x = torch.randn(1024, 1024) # device('cpu')
x_gpu = x.to('cuda:0') # copy to first GPU
x_back = x_gpu.cpu() # copy back
The non-obvious cost: host–device transfers go over PCIe, which is one to two orders of magnitude slower than GPU HBM.
The discipline: load the data on the GPU once, keep it there, and only move scalars and tiny metadata back to the CPU. The DataLoader’s pin_memory=True flag enables faster CPU→GPU transfers by allocating pinned host memory; the non_blocking=True argument to .to() makes the transfer asynchronous so you can overlap it with compute.
for batch in loader:
x = batch['input_ids'].to('cuda', non_blocking=True)
y = batch['labels'].to('cuda', non_blocking=True)
# ... compute on GPU ...
This is mundane but it is the difference between GPU utilization at 40% and at 95%, and you will be asked about it in interviews.
A subtler version: implicit transfers. If you write x.cpu().numpy() to log a metric inside your training loop, you are forcing a synchronization between the GPU stream and the CPU. The transfer itself might be fast (a single scalar) but the synchronization drains the entire GPU pipeline. The fix is to keep the value on the GPU until you actually need to print it, and to log asynchronously. We will revisit this in Chapter 31 on tail latency.
1.9 Why shape errors are 80% of ML bugs
The reason shape bugs dominate isn’t that programmers are bad at counting dimensions. It is that the type system can’t help you. A Tensor is just a Tensor; whether it’s (N, S, D) or (N, D, S) is invisible to the compiler. Your function signature says def attention(q, k, v) -> Tensor and gives you no way to enforce the shape contract.
There are three disciplines that, together, eliminate most of the bugs.
1. Shape comments on every line
The single most effective habit. Write the expected shape next to every tensor expression:
# x: (N, S, D)
q = q_proj(x) # (N, S, D)
q = q.view(N, S, H, D_h) # (N, S, H, D_h)
q = q.transpose(1, 2) # (N, H, S, D_h)
attn = q @ k.transpose(-2, -1) # (N, H, S, S)
attn = attn / (D_h ** 0.5)
attn = attn.softmax(dim=-1)
out = attn @ v # (N, H, S, D_h)
When the comment and the code disagree, you have found the bug before runtime.
2. einops
The einops library lets you write shape transformations as labeled equations. They are self-documenting and they fail loudly when shapes don’t match expectations:
from einops import rearrange
# x: (n, s, d) with d = h * d_h
q = rearrange(q, 'n s (h d) -> n h s d', h=H)
attn = q @ k.transpose(-2, -1)
out = attn @ v
out = rearrange(out, 'n h s d -> n s (h d)')
There are no .view() / .transpose() chains to misorder. The dimension names are right there in the string. Once you start writing transformer code with einops, you stop writing it without.
3. jaxtyping
A typing library that lets you annotate function signatures with shape strings, and a runtime checker that enforces them:
from jaxtyping import Float
from torch import Tensor
def attention(
q: Float[Tensor, "n h s d"],
k: Float[Tensor, "n h s d"],
v: Float[Tensor, "n h s d"],
) -> Float[Tensor, "n h s d"]:
...
When the function is called with the wrong shape, you get an immediate, readable error instead of a silent broadcast in the middle of a backward pass.
You don’t need all three. You should adopt at least one. I recommend einops for reshape-heavy code (transformer guts) and shape comments for everything else.
1.10 The mental model
If you remember nothing else from this chapter:
- A tensor is a multi-dimensional array plus
(shape, dtype, device). - Every shape is a tuple. The order of dimensions is a convention you must check.
- Broadcasting pads on the left, requires equal-or-1 in every dim, and will silently do the wrong thing if your shapes are wrong by accident.
- Memory is a flat buffer; shape is a fiction implemented by strides.
viewis the fast no-copy version ofreshape;transposemakes tensors non-contiguous; some kernels demand contiguity. - dtype is a throughput lever, not just a precision choice.
bfloat16is the modern training default;fp8is the modern inference frontier. - Host↔device transfers go over PCIe and are about 100× slower than HBM. Move data once.
- The type system cannot save you from shape bugs. Comments, einops, and jaxtyping can.
Carry these with you into Chapter 2, where we use them to build a forward pass from scratch.
Read it yourself
- The PyTorch docs page on broadcasting semantics — the formal rules in 200 words.
- The PyTorch docs page on tensor views — short and clarifying.
- The einops tutorial at einops.rocks — 15 minutes, then read all your transformer code differently for the rest of your life.
- Patrick Kidger’s jaxtyping README on GitHub — the named-shape discipline in practice.
- The NumPy broadcasting docs — same rules, sometimes clearer prose.
- Deep Learning with PyTorch, Stevens et al., chapters 3 and 4 — the long-form treatment of everything in this chapter, with more diagrams.
Practice
- Without running it, predict the output shape of
(torch.zeros(8, 1, 5) + torch.zeros(3, 5)).shape. (Answer:(8, 3, 5).) - Why does
torch.randn(32) - torch.randn(32, 1)produce a(32, 32)tensor? Rewrite the expression to actually compute element-wise differences. - Write a function
to_heads(x, h)that takes a tensor of shape(n, s, d)and returns(n, h, s, d/h), usingeinops.rearrange. Then write the inversefrom_heads(x). - Given a tensor
y = x.transpose(0, 1), what doesy.view(-1)do? What doesy.reshape(-1)do? Which one is fast and which one allocates? (Hint:yis non-contiguous.) - Estimate (in milliseconds) the cost of moving a 1 GB tensor from pinned CPU memory to a GPU over PCIe Gen4 at 25 GB/s. Now estimate the cost of reading the same tensor from HBM at 3 TB/s. The ratio is the entire reason you should pre-load data.
- Stretch: Write a 30-line PyTorch function that takes a batch of variable-length sequences as a list of 1-D tensors and returns a single padded tensor of shape
(N, S_max)plus a boolean mask of the same shape. Use no Pythonforloop in the body — only tensor ops and broadcasting.
Concept check
4 questions. Click a choice to check. Your score is saved locally.
- 1. What is the rank of a tensor with shape (3, 4, 5)?
- 2. When broadcasting two tensors of shapes (5, 1) and (1, 4), what is the output shape?
- 3. A tensor of shape (4, 6) is stored with strides (6, 1). After calling .T (transpose), what are the new strides?
- 4. Which dtype change increases memory footprint while providing the widest dynamic range?