Research Validation

Four domain-specific validation gates for ML/robotics code: shape, gradient, overfit, and regression. Not TDD — research-specific validation.

Overview

Research Validation enforces four validation gates specific to ML and robotics code. Unlike traditional test-driven development, these gates target the failure modes that actually matter in research: wrong shapes, blocked gradients, inability to overfit small data, and silent regressions. The gates run in order — if an earlier gate fails, later gates are skipped until the failure is fixed.

PropertyDetails
Trigger"validate this", "test the implementation", "check if this works", or auto-triggered after implementation tasks
Active ModesEngineer
OutputValidation scripts in scratch/{investigation}/tests/
CheckpointG3 presentation after validation
Validation Scripts, Not a Test Suite

These are validation scripts that live in scratch/ (gitignored). They verify this specific implementation, not permanent tests. For infrastructure code that benefits from TDD, use a standard testing framework.

The Four Validation Gates

Run these in order. If an earlier gate fails, don't proceed to later gates — fix the failure first.

Gate 1: Shape Validation

Purpose: Verify that a forward pass with known input produces expected output shape and dtype.

CheckWhat to Verify
InputCreate test input with known shape and dtype matching expected data format
Forward passRun end-to-end through the model
Output shapeMust match design doc specification
Output dtypeCorrect dtype (especially float32 vs float64 mismatches)
Intermediate shapesCheck at key points in the pipeline

Written to: scratch/{investigation}/tests/test_shapes.py

def test_forward_shapes():
    """Verify forward pass produces expected output shapes."""
    key = jax.random.PRNGKey(0)
    x = jax.random.normal(key, shape=(B, T, F))

    output = model.apply(params, x)

    assert output.shape == (B, T, D), \
        f"Expected (B, T, D), got {output.shape}"
    assert output.dtype == jnp.float32

Gate 2: Gradient Validation

Purpose: Verify that loss.backward() produces non-zero gradients for all trainable parameters and that stop_gradient boundaries are correct.

CheckPass CriteriaFail Indicates
Non-zero gradientsEvery trainable parameter has non-zero gradientDead path, blocked gradient flow
Frozen parametersParameters that should be frozen have zero gradientMissing stop_gradient
Gradient magnitudesWithin reasonable rangeExploding or vanishing gradients
Finite valuesAll gradients are finite (no NaN/inf)Numerical instability

Written to: scratch/{investigation}/tests/test_gradients.py

def test_gradient_flow():
    """Verify gradients flow to all trainable parameters."""
    key = jax.random.PRNGKey(0)
    x = jax.random.normal(key, shape=(B, T, F))

    def loss_fn(params):
        output = model.apply(params, x)
        return jnp.mean(output ** 2)

    grads = jax.grad(loss_fn)(params)

    for name, grad in jax.tree_util.tree_leaves_with_path(grads):
        assert jnp.any(grad != 0), f"Zero gradient for {name}"
        assert jnp.all(jnp.isfinite(grad)), \
            f"Non-finite gradient for {name}"

Gate 3: Overfit Validation

Purpose: If the model can't overfit 5 samples in 100 steps, nothing else matters. This is the most informative single test for research code.

The Most Important Gate

If a model can't memorize 5 samples, there's a fundamental bug. Don't waste time on downstream validation until this passes.

CheckPass Criteria
5 samples, 100 stepsLoss must decrease significantly (>90% reduction)
Loss trajectoryConsistent decrease, not just fluctuation

Written to: scratch/{investigation}/tests/test_overfit.py

def test_overfit_small_batch():
    """Model must be able to overfit 5 samples."""
    batch = get_small_batch(n=5)

    initial_loss = compute_loss(params, batch)
    final_params = train_n_steps(params, batch, n_steps=100)
    final_loss = compute_loss(final_params, batch)

    assert final_loss < initial_loss * 0.1, (
        f"Failed to overfit: "
        f"initial={initial_loss:.4f}, "
        f"final={final_loss:.4f}"
    )

Gate 4: Regression Validation

Purpose: Verify that existing configs produce identical output before and after the change.

CheckPass Criteria
Existing configsRun inference with fixed seed, compare output to baseline
Output comparisonAny difference in output for existing configs is a regression

Implementation: Delegated to the regression-guard subagent with the list of changed files and existing configs.

Don't Skip Regression Testing

"I only changed one file" — but that file is imported by 15 others. Always run regression-guard.

Gate 3: Presentation Format

After running validation gates, results are presented to the user.

If all gates pass:

"All 4 validation gates passed for [component].
Shape ✓, Gradient ✓, Overfit ✓, Regression ✓.
Moving to next component?"

If any gate fails:

## Validation Results: [Component]

### Passed
- ✓ Shape validation: output (B, T, D) as expected
- ✓ Gradient flow: all parameters receive non-zero gradients

### Failed
- ✗ Overfit test: loss decreased from 2.3 to 1.8
  (only 22% reduction in 100 steps)
  - Expected: >90% reduction
  - Possible causes: learning rate too low,
    loss term dominating, gradient blocked

### My Assessment
[Interpretation of results and suggested next step]

### Question
[Ask user how to proceed —
 fix the failure or investigate further?]

When to Run Each Gate

GateWhen
ShapeAfter any model architecture change or new component
GradientAfter any change to loss function, model, or training loop
OverfitAfter implementation of any trainable component
RegressionAfter any code change, no exceptions