Decode is repetitive: why caching primitives and kernels matters

Engineering Team

Performance Optimization Decode Caching

LLM inference feels slow because decode is expensive at scale. Prefill runs once, but decode runs per token—overhead multiplies across the entire output. We address this by optimizing the decode loop, where every token re-executes all transformer layers. Reducing per-token overhead with primitive caching and graph captures delivers significant performance improvements.

The hidden cost of recreating primitives

Consider OneDNN, a popular kernel library for X86 platforms. OneDNN sits above raw kernels, it doesn't just execute operations, it decides how they execute on a particular machine.

The central abstraction in OneDNN is the primitive. A primitive isn't just a kernel, it's a fully specified execution plan for an operation. Given exact tensor shapes, layouts, data types, post-operations, and target CPU, the primitive encodes how that operation will run.

That distinction matters because creating a primitive isn't lightweight.

When your code constructs a primitive descriptor for a matmul, OneDNN performs a planning phase. It resolves memory layouts and strides, determines required reorders, selects an implementation optimized for available ISA (AVX2, AVX-512, AMX), chooses blocking and vectorization strategies, fixes post-ops like bias or activation, computes scratchpad requirements, and may JIT-compile a microkernel on first encounter.

All of that happens before a single multiply-accumulate executes.

Many developers create primitives for every op and call them continuously during token generation. During decode, none of the inputs affecting primitive creation change token-to-token. Weight shapes are fixed. Hidden size is fixed. Post-ops are fixed. Yet if primitives are rebuilt per token, OneDNN repeats the same planning work endlessly.

Primitive creation time varies by operation and CPU, but it's not negligible. For common matmul configurations in transformer layers, creating a primitive_desc and primitive takes 10–100 microseconds, higher on first run due to JIT compilation. Matmul execution itself, especially in decode, can be comparable or cheaper.

Scale that to a real model. A LLaMA-1B-class model executes ~80–100 matmuls per token. Generating 1,000 tokens creates ~100,000 primitives if nothing is cached. Even at a conservative 1 microsecond per creation, that's 0.1 seconds per token spent planning, not computing. This is why decode often appears CPU-bound in profiles despite minimal math. The system isn't slow because kernels are bad—it's slow because it repeats the same work when it could cache once.

What caching a primitive actually buys you

A OneDNN primitive is a fully resolved execution plan. Once created, executing it is cheap.

Caching means OneDNN thinks once. You create the primitive descriptor and primitive ahead of time, store them in a cache keyed by shape, layout, data type, and attributes, then reuse them for every token.

Decode no longer alternates between "plan" and "execute." It becomes a tight loop that only executes.

The effect is dramatic, you're not making execution faster, you're removing work entirely. Decode performance improves not by a few percent, but often by multiples. Experiments with 1B Llama models showed nearly 65% improvement in tokens/s just by caching primitives.For instance with 1B Llama3.1 Bf16 model, the decode speed jumped from ~9 tok/s to 15 tok/s just by enabling primitive caching. This indicates that caching primitives will avoid significant overhead and these overheads are very significant

Why this shows up on GPU too

On CUDA, the same pattern exists under different names.

Instead of OneDNN primitives, you have cuBLASLt matmul descriptors and algorithm selection, or cuDNN frontend operation graphs and execution plans. If you let these libraries run heuristics or rebuild plans per token, you repeat the same mistake.

Algorithm selection and plan construction are expensive—searching kernel variants, deciding epilogues, sizing workspaces. None of that belongs in decode's hot path.

When you cache the chosen cuBLASLt algorithm or cuDNN execution plan, you lock in a specific kernel configuration. Execution becomes cheap and predictable. This is the CUDA equivalent of "compiling to a binary" for inference: the kernel choice is fixed, and decode just runs it.

Where CUDA Graphs naturally fit

Even after caching primitives and plans, decode still launches many small kernels. Kernel launch overhead becomes the next limiting factor.

CUDA Graphs solve this by capturing an entire decode step—all matmuls, normalizations, and elementwise ops—into a single executable graph. Replaying the graph per token replaces dozens or hundreds of launches with one. On our real world tests, we often see a nearly 10% improvement in decode speed by switching to Cuda graphs.

But CUDA Graphs only work after you've done the earlier work. Shapes must be stable. Kernel choices must be stable. Memory must be allocated ahead of time. CUDA Graphs aren't an alternative to caching—they're the payoff after caching makes decode deterministic.

Why inference code looks different from training code

This is why inference-optimized systems diverge from training stacks.

Training tolerates dynamism: changing shapes, flexible graphs, frequent reallocations. Decode doesn't. Decode rewards repetition. Every time you rebuild a primitive, reselect a kernel, or reallocate a workspace, decode makes you pay repeatedly.

Caching primitives in OneDNN, caching plans and algorithms in CUDA, and capturing execution with CUDA Graphs express the same idea:

Do all the thinking once. Then stop thinking and just execute.

Ready to Get Started?

OpenInfer is now available! Sign up today to gain access and experience these performance gains for yourself.