JAX Logic Auditor

Traces data flow through JAX transformations, catches axis/dimension mismatches, and identifies where new code may break existing JAX constraints.

Overview

JAX's functional transformation model (jit, vmap, scan, pmap) introduces subtle constraints that are easy to violate — especially when modifying existing code. The JAX Logic Auditor traces data flow through these transformations and catches issues before they surface as cryptic runtime errors or silent correctness bugs.

PropertyDetails
ToolsRead, Grep, Glob (read-only)
Auto-DispatchYes — when implementing or modifying JAX-based components
TriggerJAX transformation changes (vmap, scan, jit), model architecture, training loop

End-to-End Data Flow Tracking

This is the auditor's primary responsibility. Before checking any JAX-specific concern, it builds a complete picture of how data enters the system, transforms step by step, and exits.

Axis and Dimension Tracking

Within JAX transforms, axis tracking becomes critical:

Scan Auditing

jax.lax.scan is powerful but has strict constraints:

vmap / pmap Correctness

JIT Compatibility

Pytree Handling

Gradient and Autodiff Correctness

Random Key Management

Parallel Environment Patterns

Output Format

The auditor produces a structured report containing:

Key Principles

Shapes are semantic — don't just check numerical compatibility; verify that a dimension of size 64 actually represents what the code assumes. Transformation boundaries are where bugs live — most JAX bugs happen at the interface between transformed and untransformed code. Statefulness is the enemy — any pattern that smuggles state (closures over mutable objects, global counters) is suspect inside transformed code.