Research Validation
Four domain-specific validation gates for ML/robotics code: shape, gradient, overfit, and regression. Not TDD — research-specific validation.
Overview
Research Validation enforces four validation gates specific to ML and robotics code. Unlike traditional test-driven development, these gates target the failure modes that actually matter in research: wrong shapes, blocked gradients, inability to overfit small data, and silent regressions. The gates run in order — if an earlier gate fails, later gates are skipped until the failure is fixed.
| Property | Details |
|---|---|
| Trigger | "validate this", "test the implementation", "check if this works", or auto-triggered after implementation tasks |
| Active Modes | Engineer |
| Output | Validation scripts in scratch/{investigation}/tests/ |
| Checkpoint | G3 presentation after validation |
These are validation scripts that live in scratch/ (gitignored). They verify this specific implementation, not permanent tests. For infrastructure code that benefits from TDD, use a standard testing framework.
The Four Validation Gates
Run these in order. If an earlier gate fails, don't proceed to later gates — fix the failure first.
Gate 1: Shape Validation
Purpose: Verify that a forward pass with known input produces expected output shape and dtype.
| Check | What to Verify |
|---|---|
| Input | Create test input with known shape and dtype matching expected data format |
| Forward pass | Run end-to-end through the model |
| Output shape | Must match design doc specification |
| Output dtype | Correct dtype (especially float32 vs float64 mismatches) |
| Intermediate shapes | Check at key points in the pipeline |
Written to: scratch/{investigation}/tests/test_shapes.py
def test_forward_shapes():
"""Verify forward pass produces expected output shapes."""
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, shape=(B, T, F))
output = model.apply(params, x)
assert output.shape == (B, T, D), \
f"Expected (B, T, D), got {output.shape}"
assert output.dtype == jnp.float32
Gate 2: Gradient Validation
Purpose: Verify that loss.backward() produces non-zero gradients for all trainable parameters and that stop_gradient boundaries are correct.
| Check | Pass Criteria | Fail Indicates |
|---|---|---|
| Non-zero gradients | Every trainable parameter has non-zero gradient | Dead path, blocked gradient flow |
| Frozen parameters | Parameters that should be frozen have zero gradient | Missing stop_gradient |
| Gradient magnitudes | Within reasonable range | Exploding or vanishing gradients |
| Finite values | All gradients are finite (no NaN/inf) | Numerical instability |
Written to: scratch/{investigation}/tests/test_gradients.py
def test_gradient_flow():
"""Verify gradients flow to all trainable parameters."""
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, shape=(B, T, F))
def loss_fn(params):
output = model.apply(params, x)
return jnp.mean(output ** 2)
grads = jax.grad(loss_fn)(params)
for name, grad in jax.tree_util.tree_leaves_with_path(grads):
assert jnp.any(grad != 0), f"Zero gradient for {name}"
assert jnp.all(jnp.isfinite(grad)), \
f"Non-finite gradient for {name}"
Gate 3: Overfit Validation
Purpose: If the model can't overfit 5 samples in 100 steps, nothing else matters. This is the most informative single test for research code.
If a model can't memorize 5 samples, there's a fundamental bug. Don't waste time on downstream validation until this passes.
| Check | Pass Criteria |
|---|---|
| 5 samples, 100 steps | Loss must decrease significantly (>90% reduction) |
| Loss trajectory | Consistent decrease, not just fluctuation |
Written to: scratch/{investigation}/tests/test_overfit.py
def test_overfit_small_batch():
"""Model must be able to overfit 5 samples."""
batch = get_small_batch(n=5)
initial_loss = compute_loss(params, batch)
final_params = train_n_steps(params, batch, n_steps=100)
final_loss = compute_loss(final_params, batch)
assert final_loss < initial_loss * 0.1, (
f"Failed to overfit: "
f"initial={initial_loss:.4f}, "
f"final={final_loss:.4f}"
)
Gate 4: Regression Validation
Purpose: Verify that existing configs produce identical output before and after the change.
| Check | Pass Criteria |
|---|---|
| Existing configs | Run inference with fixed seed, compare output to baseline |
| Output comparison | Any difference in output for existing configs is a regression |
Implementation: Delegated to the regression-guard subagent with the list of changed files and existing configs.
"I only changed one file" — but that file is imported by 15 others. Always run regression-guard.
Gate 3: Presentation Format
After running validation gates, results are presented to the user.
If all gates pass:
"All 4 validation gates passed for [component].
Shape ✓, Gradient ✓, Overfit ✓, Regression ✓.
Moving to next component?"
If any gate fails:
## Validation Results: [Component]
### Passed
- ✓ Shape validation: output (B, T, D) as expected
- ✓ Gradient flow: all parameters receive non-zero gradients
### Failed
- ✗ Overfit test: loss decreased from 2.3 to 1.8
(only 22% reduction in 100 steps)
- Expected: >90% reduction
- Possible causes: learning rate too low,
loss term dominating, gradient blocked
### My Assessment
[Interpretation of results and suggested next step]
### Question
[Ask user how to proceed —
fix the failure or investigate further?]
When to Run Each Gate
| Gate | When |
|---|---|
| Shape | After any model architecture change or new component |
| Gradient | After any change to loss function, model, or training loop |
| Overfit | After implementation of any trainable component |
| Regression | After any code change, no exceptions |