JAX Logic Auditor
Traces data flow through JAX transformations, catches axis/dimension mismatches, and identifies where new code may break existing JAX constraints.
Overview
JAX's functional transformation model (jit, vmap, scan, pmap) introduces subtle constraints that are easy to violate — especially when modifying existing code. The JAX Logic Auditor traces data flow through these transformations and catches issues before they surface as cryptic runtime errors or silent correctness bugs.
| Property | Details |
|---|---|
| Tools | Read, Grep, Glob (read-only) |
| Auto-Dispatch | Yes — when implementing or modifying JAX-based components |
| Trigger | JAX transformation changes (vmap, scan, jit), model architecture, training loop |
End-to-End Data Flow Tracking
This is the auditor's primary responsibility. Before checking any JAX-specific concern, it builds a complete picture of how data enters the system, transforms step by step, and exits.
- Full input-to-output trace — follows data from raw input (observation, sensor data, dataset batch) through every function call and transformation to the final output (action, prediction, loss value)
- Annotate every intermediate tensor — shape, dtype, semantic axis labels, and value range at each step
- Track consumption, duplication, merging — where an input feeds multiple branches, where branches merge, where information is discarded
- Identify semantic changes — a tensor might keep the same shape but mean something different after a transformation (raw observation vs normalized observation vs embedded representation)
- Track every axis across the full path — if input has axes
(batch, time, features), trace exactly where batch gets vmapped away, time gets scanned over, features get projected - Flag silent broadcasting — when two tensors interact, verify broadcast semantics are intentional, not accidental shape compatibility
Axis and Dimension Tracking
Within JAX transforms, axis tracking becomes critical:
- Verify
vmapin_axes/out_axesmap to the correct semantic dimensions - Check that
scancarry and input/output shapes are consistent across iterations - Flag any
reshape,squeeze, orexpand_dimsthat could silently swap semantic axes - Verify that batch and time dimensions don't get mixed up — the #1 source of silent bugs in sequential RL/robotics code
Scan Auditing
jax.lax.scan is powerful but has strict constraints:
- Carry shape constancy — carry state shape must remain constant across iterations (JAX requirement)
- Purity — the scan body function must be pure: no side effects, no external state mutation
- Carry vs xs — traces what enters via carry vs what enters via
xs(scanned inputs), a common source of confusion - Scan length — verifies the scan length matches the expected time/sequence dimension
- Initial carry — checks that initial carry values have correct shapes and dtypes
- Missing scans — flags accumulation patterns that should use scan but don't (manual for-loops inside jit)
vmap / pmap Correctness
- Verify vmapped functions don't contain operations that implicitly assume batch position
- Check for hardcoded axis indices that break under vmap (e.g.,
x[0]when 0 is now the vmapped axis) - Verify
in_axesandout_axes—Nonefor broadcasted args, integer for batched args - For nested vmap (envs then agents), verify axis ordering is consistent
- For pmap: verify shapes account for the device axis, check
axis_nameusage in collective operations (psum,pmean)
JIT Compatibility
- Flag Python control flow that depends on traced values (
if/elseon JAX arrays withoutjax.lax.cond) - Check for side effects inside jitted functions (
print,list.append, global mutation) - Verify all function inputs are valid JAX types (no Python lists of arrays, no dicts with dynamic keys)
- Check for shape-dependent recompilation triggers — dynamic shapes cause retracing
- Flag
jax.debug.printvs regularprint(regular print only executes during tracing)
Pytree Handling
- Verify custom classes used as JAX inputs are properly registered as pytrees
- Check that
tree_map,tree_leaves, andtree_unflattenpreserve expected structure - Verify flax/equinox module state is handled correctly through transformations
- Flag namedtuples or dataclasses that might not be recognized as pytrees
Gradient and Autodiff Correctness
- Verify
jax.gradis applied to scalar-output functions (orjax.jacobian/jax.value_and_gradis used appropriately) - Check
stop_gradientplacement — especially in actor-critic, VQ-VAE codebook, or target network patterns - Verify
custom_vjp/custom_jvprules are mathematically correct and shape-consistent - Flag non-differentiable operations (argmax, integer indexing) in gradient paths
- Verify straight-through estimator implementations when used
Random Key Management
- Verify PRNG keys are split correctly and never reused
- Check that
jax.random.splitproduces enough subkeys for all random operations - Flag any random operation using the same key more than once (produces correlated samples)
- In scan loops: verify keys are either split per-iteration or passed via
xs - In vmap: verify each batch element gets a unique key (common bug: same key vmapped produces identical samples)
Parallel Environment Patterns
- Trace env step function: obs, action, reward, done, info shapes through vectorized execution
- Verify environment reset logic handles per-env done flags correctly (auto-reset patterns)
- Check observation normalization statistics are computed across the correct axes
- Verify advantage estimation (GAE) scans over the time axis, not the env axis
- Flag episode boundary handling that could leak information across episodes
Output Format
The auditor produces a structured report containing:
- End-to-End Data Flow — complete input-to-output trace with every intermediate step annotated
- Transformation Stack — diagram of jit/vmap/scan/pmap nesting overlaid on the data flow
- Shape Trace at Boundaries — full shape with semantic axis labels at every transformation boundary
- Issues Found — each with location, current vs expected behavior, severity, and suggested fix
- Verified Correct — components confirmed correct with reasoning
- Warnings — patterns that aren't bugs now but could break if modified
Shapes are semantic — don't just check numerical compatibility; verify that a dimension of size 64 actually represents what the code assumes. Transformation boundaries are where bugs live — most JAX bugs happen at the interface between transformed and untransformed code. Statefulness is the enemy — any pattern that smuggles state (closures over mutable objects, global counters) is suspect inside transformed code.