Advanced Features Tutorial¶
This notebook demonstrates advanced ODE solving capabilities:
- First-Order ODEs with Hidden States (Joint state-parameter estimation)
- Second-Order ODEs with Hidden States
- Conservation Constraints (Algebraic constraints)
- Linear Measurements (Time-varying observations)
- JIT Compilation and Differentiation
- Scan-Based Sequential Filtering (Efficient
jax.lax.scanwith decoupled observations)
import jax
import jax.numpy as np
import jax.random as jrandom
import matplotlib.pyplot as plt
from ode_filters.filters import ekf1_sqr_loop, rts_sqr_smoother_loop
from ode_filters.measurement import (
BlackBoxMeasurement,
Conservation,
Measurement,
ODEInformation,
ODEInformationWithHidden,
SecondOrderODEInformation,
SecondOrderODEInformationWithHidden,
TransformedMeasurement,
)
from ode_filters.priors import IWP, JointPrior, taylor_mode_initialization
jit_filter = jax.jit(ekf1_sqr_loop, static_argnums=(2, 3, 4, 5))
jit_smoother = jax.jit(rts_sqr_smoother_loop, static_argnums=(5,))
1. First-Order ODEs with Hidden States¶
Problem: Exponential Decay with Unknown Rate¶
Consider a radioactive decay problem where we observe the amount of material but don't know the decay rate:
$$\frac{dx}{dt} = -\lambda x, \quad x(0) = 1.0$$
where $\lambda$ is an unknown parameter we want to infer from (noisy) observations.
We model $\lambda$ as a hidden state with its own prior (e.g., constant or slowly varying).
# True decay rate (unknown to the solver)
lambda_true = 0.5
# Vector field with hidden parameter
def vf_decay(x, lam, *, t):
"""dx/dt = -lambda * x, where lambda is the hidden state"""
return -lam * x
# Initial conditions
x0 = np.array([1.0]) # Initial amount
lambda0 = np.array([0.3]) # Initial guess for decay rate
tspan = (0, 10)
N = 50
Setup Joint Prior for State and Hidden Parameter¶
We use JointPrior to combine:
- Prior for state
x(IWP with q=2) - Prior for parameter
λ(IWP with q=1, since it's slowly varying)
# Prior for the state x (d=1, q=2)
prior_x = IWP(q=2, d=1, Xi=0.5 * np.eye(1))
# Prior for the hidden parameter lambda (d=1, q=1, smaller diffusion)
prior_lambda = IWP(q=1, d=1, Xi=0.01 * np.eye(1))
# Combine into joint prior
joint_prior = JointPrior(prior_x, prior_lambda)
print(f"State extraction matrix E0_x shape: {joint_prior.E0_x.shape}")
print(f"Hidden extraction matrix E0_hidden shape: {joint_prior.E0_hidden.shape}")
print(f"Derivative extraction matrix E1 shape: {joint_prior.E1.shape}")
State extraction matrix E0_x shape: (1, 5) Hidden extraction matrix E0_hidden shape: (1, 5) Derivative extraction matrix E1 shape: (1, 5)
Initialize Joint State¶
For first-order ODEs with hidden states, we need to initialize both the state and the hidden parameter.
# We need to wrap the vector field for initialization
def vf_for_init(x, *, t):
"""For initialization, we use our best guess for lambda"""
return -lambda0[0] * x
# Initialize state coefficients
mu_x, _ = taylor_mode_initialization(vf_for_init, x0, q=2)
# Initialize hidden parameter (constant, so higher derivatives are zero)
mu_lambda = np.concatenate([lambda0, np.zeros(1)]) # [lambda, d_lambda/dt]
# Combine into joint initialization
mu_0 = np.concatenate([mu_x, mu_lambda])
D_total = mu_0.shape[0]
Sigma_0_sqr = np.zeros((D_total, D_total))
print(f"Joint initial state dimension: {D_total}")
print(f"Initial state values: x={mu_0[:3]}, lambda={mu_0[3:]}")
Joint initial state dimension: 5 Initial state values: x=[ 1. -0.3 0.09], lambda=[0.3 0. ]
Generate Synthetic Data and Setup Measurement¶
# Generate true solution
ts = np.linspace(tspan[0], tspan[1], N + 1)
x_true = np.exp(-lambda_true * ts)
# Add noise to observations (observe x, not lambda)
key = jrandom.PRNGKey(42)
noise_std = 0.05
z = x_true[1:] + noise_std * jrandom.normal(key, shape=(N,))
z = z.reshape(-1, 1)
z_t = ts[1:]
plt.figure(figsize=(10, 3))
plt.scatter(z_t, z, s=10, alpha=0.5, label="Noisy observations", color="orange")
plt.plot(ts, x_true, "k--", label=f"True (λ={lambda_true})")
plt.xlabel("t"), plt.ylabel("x(t)")
plt.legend(), plt.title("Exponential Decay with Unknown Rate")
plt.show()
# Measurement matrix: observe only x (not lambda)
A = np.array([[1.0]]) # Extract first state component
measurement = Measurement(A, z, z_t, noise=noise_std**2)
# Create ODE measurement model with hidden states
measure = ODEInformationWithHidden(
vf=vf_decay,
E0=joint_prior.E0_x, # Extract x from joint state
E1=joint_prior.E1, # Extract dx/dt
E0_hidden=joint_prior.E0_hidden, # Extract lambda
constraints=[measurement],
)
Run Filter and Smoother¶
# Run EKF
m_seq, P_sqr, _, _, G_back, d_back, P_back_sqr, *_ = ekf1_sqr_loop(
mu_0, Sigma_0_sqr, joint_prior, measure, tspan, N
)
# Run RTS smoother
m_smooth, P_smooth_sqr = rts_sqr_smoother_loop(
m_seq[-1], P_sqr[-1], np.array(G_back), np.array(d_back), np.array(P_back_sqr), N
)
m_smooth = np.array(m_smooth)
P_smooth_sqr = np.array(P_smooth_sqr)
Visualize Results: State and Parameter Estimation¶
# Extract state x and parameter lambda
x_est = m_smooth[:, 0] # First component is x
lambda_est = m_smooth[:, 3] # Fourth component is lambda (after x, dx, d2x)
P_smooth = np.einsum("ijk,ijl->ikl", P_smooth_sqr, P_smooth_sqr)
x_std = np.sqrt(P_smooth[:, 0, 0])
lambda_std = np.sqrt(P_smooth[:, 3, 3])
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))
# Plot state x
ax1.scatter(z_t, z, s=10, alpha=0.5, color="orange", label="Observations")
ax1.plot(ts, x_true, "k--", linewidth=2, label=f"True")
ax1.plot(ts, x_est, "b-", linewidth=2, label="Estimated")
ax1.fill_between(ts, x_est - 2 * x_std, x_est + 2 * x_std, alpha=0.3)
ax1.set_xlabel("t"), ax1.set_ylabel("x(t)")
ax1.set_title("State Estimation")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot parameter lambda
ax2.axhline(
lambda_true, color="k", linestyle="--", linewidth=2, label=f"True λ={lambda_true}"
)
ax2.plot(ts, lambda_est, "r-", linewidth=2, label="Estimated λ")
ax2.fill_between(
ts, lambda_est - 2 * lambda_std, lambda_est + 2 * lambda_std, alpha=0.3, color="red"
)
ax2.set_xlabel("t"), ax2.set_ylabel("λ(t)")
ax2.set_title(f"Parameter Estimation (init: λ₀={lambda0[0]})")
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"\nFinal estimate: λ = {lambda_est[-1]:.3f} ± {2 * lambda_std[-1]:.3f}")
print(f"True value: λ = {lambda_true}")
print(f"Initial guess: λ₀ = {lambda0[0]}")
Final estimate: λ = 0.288 ± 1.340 True value: λ = 0.5 Initial guess: λ₀ = 0.30000001192092896
2. Second-Order ODEs with Hidden States¶
Problem: Damped Harmonic Oscillator with Unknown Damping¶
Consider a spring-mass-damper system where the damping coefficient is unknown:
$$\frac{d^2x}{dt^2} = -\omega^2 x - \gamma \frac{dx}{dt}$$
where $\gamma$ is the unknown damping coefficient we want to infer from position observations.
# System parameters
omega = 1.0 # Natural frequency (known)
gamma_true = 0.3 # True damping coefficient (unknown to solver)
# Vector field: d²x/dt² = f(x, dx/dt, gamma, t)
def vf_damped(x, dx, gamma, *, t):
"""Second-order ODE with hidden damping parameter"""
return -(omega**2) * x - gamma * dx
# Initial conditions
x0_2nd = np.array([1.0]) # Initial position
dx0_2nd = np.array([0.0]) # Initial velocity
gamma0 = np.array([0.1]) # Initial guess for damping
tspan_2nd = (0, 30)
N_2nd = 100
# Prior for state x (q=2 for second-order)
prior_x_2nd = IWP(q=2, d=1, Xi=1.0 * np.eye(1))
# Prior for hidden parameter gamma (q=1, slowly varying)
prior_gamma = IWP(q=1, d=1, Xi=0.01 * np.eye(1))
# Joint prior
joint_prior_2nd = JointPrior(prior_x_2nd, prior_gamma)
# Vector field for initialization
def vf_for_init_2nd(x, dx, *, t):
return -(omega**2) * x - gamma0[0] * dx
# Initialize state (x, dx, d²x)
mu_x_2nd, _ = taylor_mode_initialization(
vf_for_init_2nd, (x0_2nd, dx0_2nd), q=2, order=2
)
# Initialize hidden parameter
mu_gamma = np.concatenate([gamma0, np.zeros(1)])
# Joint initialization
mu_0_2nd = np.concatenate([mu_x_2nd, mu_gamma])
Sigma_0_sqr_2nd = np.zeros((mu_0_2nd.shape[0], mu_0_2nd.shape[0]))
print(f"Initial state: x={mu_0_2nd[0]:.2f}, dx={mu_0_2nd[1]:.2f}, γ={mu_0_2nd[3]:.2f}")
Initial state: x=1.00, dx=0.00, γ=0.10
# True solution (damped oscillator)
ts_2nd = np.linspace(tspan_2nd[0], tspan_2nd[1], N_2nd + 1)
omega_d = np.sqrt(omega**2 - (gamma_true / 2) ** 2) # Damped frequency
x_true_2nd = np.exp(-gamma_true * ts_2nd / 2) * np.cos(omega_d * ts_2nd)
# Noisy observations of position
key = jrandom.PRNGKey(43)
noise_std_2nd = 0.05
z_2nd = x_true_2nd[1:] + noise_std_2nd * jrandom.normal(key, shape=(N_2nd,))
z_2nd = z_2nd.reshape(-1, 1)
z_t_2nd = ts_2nd[1:]
# Measurement: observe position only
A_2nd = np.array([[1.0]])
measurement_2nd = Measurement(A_2nd, z_2nd, z_t_2nd, noise=noise_std_2nd**2)
# Second-order ODE with hidden states
measure_2nd = SecondOrderODEInformationWithHidden(
vf=vf_damped,
E0=joint_prior_2nd.E0_x,
E1=joint_prior_2nd.E1,
E2=joint_prior_2nd.E2,
E0_hidden=joint_prior_2nd.E0_hidden,
constraints=[measurement_2nd],
)
# Run EKF and Smoother
m_seq_2nd, P_sqr_2nd, _, _, G_back_2nd, d_back_2nd, P_back_sqr_2nd, *_ = ekf1_sqr_loop(
mu_0_2nd, Sigma_0_sqr_2nd, joint_prior_2nd, measure_2nd, tspan_2nd, N_2nd
)
m_smooth_2nd, P_smooth_sqr_2nd = rts_sqr_smoother_loop(
m_seq_2nd[-1],
P_sqr_2nd[-1],
np.array(G_back_2nd),
np.array(d_back_2nd),
np.array(P_back_sqr_2nd),
N_2nd,
)
m_smooth_2nd = np.array(m_smooth_2nd)
P_smooth_sqr_2nd = np.array(P_smooth_sqr_2nd)
# Extract estimates
x_est_2nd = m_smooth_2nd[:, 0]
gamma_est = m_smooth_2nd[:, 3]
P_smooth_2nd = np.einsum("ijk,ijl->ikl", P_smooth_sqr_2nd, P_smooth_sqr_2nd)
x_std_2nd = np.sqrt(P_smooth_2nd[:, 0, 0])
gamma_std = np.sqrt(P_smooth_2nd[:, 3, 3])
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))
# Plot position
ax1.scatter(z_t_2nd, z_2nd, s=10, alpha=0.5, color="orange", label="Observations")
ax1.plot(ts_2nd, x_true_2nd, "k--", linewidth=2, label="True")
ax1.plot(ts_2nd, x_est_2nd, "b-", linewidth=2, label="Estimated")
ax1.fill_between(
ts_2nd, x_est_2nd - 2 * x_std_2nd, x_est_2nd + 2 * x_std_2nd, alpha=0.3
)
ax1.set_xlabel("t"), ax1.set_ylabel("x(t)")
ax1.set_title("Position Estimation")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot damping parameter
ax2.axhline(
gamma_true, color="k", linestyle="--", linewidth=2, label=f"True γ={gamma_true}"
)
ax2.plot(ts_2nd, gamma_est, "r-", linewidth=2, label="Estimated γ")
ax2.fill_between(
ts_2nd, gamma_est - 2 * gamma_std, gamma_est + 2 * gamma_std, alpha=0.3, color="red"
)
ax2.set_xlabel("t"), ax2.set_ylabel("γ(t)")
ax2.set_title(f"Damping Coefficient (init: γ₀={gamma0[0]})")
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"\nFinal estimate: γ = {gamma_est[-1]:.3f} ± {2 * gamma_std[-1]:.3f}")
print(f"True value: γ = {gamma_true}")
print(f"Initial guess: γ₀ = {gamma0[0]}")
Final estimate: γ = 1.506 ± 6.249 True value: γ = 0.3 Initial guess: γ₀ = 0.10000000149011612
3. Conservation Constraints¶
The SIR epidemiological model has a natural conservation law: $S + I + R = N_{\text{pop}}$ (constant total population).
We can enforce this as a hard constraint using the Conservation class.
# SIR model parameters
beta_sir = 0.5
gamma_sir = 0.1
def vf_sir(x, *, t):
"""SIR model: x = [S, I, R]"""
return np.array(
[
-beta_sir * x[0] * x[1],
beta_sir * x[0] * x[1] - gamma_sir * x[1],
gamma_sir * x[1],
]
)
x0_sir = np.array([0.99, 0.01, 0.0])
tspan_sir = (0, 100)
N_sir = 100
# Prior
prior_sir = IWP(q=2, d=3, Xi=1.0 * np.eye(3))
mu_0_sir, Sigma_0_sqr_sir = taylor_mode_initialization(vf_sir, x0_sir, q=2)
# Conservation constraint: S + I + R = 1
A_conservation = np.array([[1.0, 1.0, 1.0]])
p_conservation = np.array([1.0])
conservation = Conservation(A_conservation, p_conservation)
# Create measurement model with conservation
measure_sir = ODEInformation(
vf=vf_sir, E0=prior_sir.E0, E1=prior_sir.E1, constraints=[conservation]
)
# Run filter and smoother
m_seq_sir, P_sqr_sir, _, _, G_back_sir, d_back_sir, P_back_sqr_sir, *_ = jit_filter(
mu_0_sir, Sigma_0_sqr_sir, prior_sir, measure_sir, tspan_sir, N_sir
)
m_smooth_sir, _ = jit_smoother(
m_seq_sir[-1],
P_sqr_sir[-1],
np.array(G_back_sir),
np.array(d_back_sir),
np.array(P_back_sqr_sir),
N_sir,
)
m_smooth_sir = np.array(m_smooth_sir)
ts_sir = np.linspace(tspan_sir[0], tspan_sir[1], N_sir + 1)
S_est = m_smooth_sir[:, 0]
I_est = m_smooth_sir[:, 1]
R_est = m_smooth_sir[:, 2]
total = S_est + I_est + R_est
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))
# Plot SIR trajectories
ax1.plot(ts_sir, S_est, label="S (Susceptible)", linewidth=2)
ax1.plot(ts_sir, I_est, label="I (Infected)", linewidth=2)
ax1.plot(ts_sir, R_est, label="R (Recovered)", linewidth=2)
ax1.set_xlabel("Time (days)"), ax1.set_ylabel("Population fraction")
ax1.set_title("SIR Model with Conservation Constraint")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Check conservation
ax2.plot(ts_sir, total, "b-", linewidth=2, label="S + I + R")
ax2.axhline(1.0, color="k", linestyle="--", linewidth=2, label="Expected (1.0)")
ax2.set_xlabel("Time (days)"), ax2.set_ylabel("Total population")
ax2.set_title("Conservation Law Verification")
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0.999, 1.001])
plt.tight_layout()
plt.show()
print(f"\nConservation error: max = {np.max(np.abs(total - 1.0)):.2e}")
print(f"Conservation error: mean = {np.mean(np.abs(total - 1.0)):.2e}")
Conservation error: max = 2.38e-07 Conservation error: mean = 4.66e-08
# Lotka-Volterra parameters
alpha, beta, delta, gamma = 2 / 3, 4 / 3, 1.0, 1.0
def vf_lv(x, *, t):
"""Lotka-Volterra: x = [prey, predator]"""
return np.array(
[alpha * x[0] - beta * x[0] * x[1], delta * x[0] * x[1] - gamma * x[1]]
)
x0_lv = np.array([1.0, 1.0])
tspan_lv = (0, 30)
N_lv = 60
# Get reference solution (without measurements)
prior_lv_ref = IWP(q=2, d=2, Xi=0.5 * np.eye(2))
mu_0_lv_ref, Sigma_0_sqr_lv_ref = taylor_mode_initialization(vf_lv, x0_lv, q=2)
measure_lv_ref = ODEInformation(vf_lv, prior_lv_ref.E0, prior_lv_ref.E1)
m_ref, *_ = ekf1_sqr_loop(
mu_0_lv_ref, Sigma_0_sqr_lv_ref, prior_lv_ref, measure_lv_ref, tspan_lv, N_lv
)
m_ref = np.array(m_ref)
ts_lv = np.linspace(tspan_lv[0], tspan_lv[1], N_lv + 1)
# Create sparse noisy observations of predator (every 5th timestep)
obs_indices = np.arange(5, N_lv + 1, 5)
z_lv = m_ref[obs_indices, 1] + 0.1 * jrandom.normal(
jrandom.PRNGKey(44), shape=(len(obs_indices),)
)
z_lv = z_lv.reshape(-1, 1)
z_t_lv = ts_lv[obs_indices]
# Solve with measurements
prior_lv = IWP(q=2, d=2, Xi=1.0 * np.eye(2))
mu_0_lv, Sigma_0_sqr_lv = taylor_mode_initialization(vf_lv, x0_lv, q=2)
# Measurement matrix: observe only predator (second component)
A_lv = np.array([[0.0, 1.0]])
measurement_lv = Measurement(A_lv, z_lv, z_t_lv, noise=0.01)
measure_lv = ODEInformation(
vf=vf_lv, E0=prior_lv.E0, E1=prior_lv.E1, constraints=[measurement_lv]
)
m_seq_lv, P_sqr_lv, _, _, G_back_lv, d_back_lv, P_back_sqr_lv, *_ = ekf1_sqr_loop(
mu_0_lv, Sigma_0_sqr_lv, prior_lv, measure_lv, tspan_lv, N_lv
)
m_smooth_lv, _ = rts_sqr_smoother_loop(
m_seq_lv[-1],
P_sqr_lv[-1],
np.array(G_back_lv),
np.array(d_back_lv),
np.array(P_back_sqr_lv),
N_lv,
)
m_smooth_lv = np.array(m_smooth_lv)
# Plot
plt.figure(figsize=(10, 4))
plt.plot(ts_lv, m_ref[:, 0], "b--", label="Prey (reference)", linewidth=2, alpha=0.5)
plt.plot(ts_lv, m_smooth_lv[:, 0], "b-", label="Prey (with measurements)", linewidth=2)
plt.plot(
ts_lv, m_ref[:, 1], "r--", label="Predator (reference)", linewidth=2, alpha=0.5
)
plt.plot(
ts_lv, m_smooth_lv[:, 1], "r-", label="Predator (with measurements)", linewidth=2
)
plt.scatter(z_t_lv, z_lv, s=50, color="orange", zorder=3, label="Observations")
plt.xlabel("t"), plt.ylabel("Population")
plt.title("Lotka-Volterra with Sparse Predator Observations")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
5. JIT Compilation and Differentiation¶
The filter and smoother loops compose naturally with jax.jit and jax.grad,
enabling:
- JIT compilation for faster execution
- Automatic differentiation through the entire filter (e.g., for hyperparameter optimization via the log marginal likelihood)
Mark prior, measure, tspan, and N as static arguments so that
JAX can trace through the filter body.
def vf_logistic(x, *, t):
return x * (1 - x)
x0_jit = np.array([0.01])
tspan_jit = (0, 5)
N_jit = 50
prior_jit = IWP(q=2, d=1, Xi=0.5 * np.eye(1))
mu_0_jit, Sigma_0_sqr_jit = taylor_mode_initialization(vf_logistic, x0_jit, q=2)
# Add sparse measurements
ts_jit = np.linspace(tspan_jit[0], tspan_jit[1], N_jit + 1)
z_t_jit = ts_jit[10:40]
z_jit = np.array([[0.05 + 0.02 * i] for i in range(30)])
A_jit = np.array([[1.0]])
measurement_jit = Measurement(A_jit, z_jit, z_t_jit, noise=0.01)
measure_jit = ODEInformation(
vf_logistic, prior_jit.E0, prior_jit.E1, constraints=[measurement_jit]
)
# JIT-compiled filter
result = jit_filter(
mu_0_jit, Sigma_0_sqr_jit, prior_jit, measure_jit, tspan_jit, N_jit
)
m_seq_jit = np.array(result[0])
ll = float(result[-1])
print(f"Filtered state shape: {m_seq_jit.shape}")
print(f"Log-likelihood: {ll:.4f}")
Filtered state shape: (51, 3) Log-likelihood: 123.8705
Gradient of the Log Marginal Likelihood¶
The filter is fully differentiable via jax.grad. This enables gradient-based
optimization of hyperparameters (e.g., the initial covariance).
def log_likelihood_fn(initial_cov_sqr):
res = jit_filter(
mu_0_jit, initial_cov_sqr, prior_jit, measure_jit, tspan_jit, N_jit
)
return res[-1] # log_likelihood is the last element
grad_fn = jax.grad(log_likelihood_fn)
grad_val = grad_fn(Sigma_0_sqr_jit)
print(f"Gradient shape: {grad_val.shape}")
print(f"Gradient values:\n{grad_val}")
print(f"\nGradient is finite: {bool(np.all(np.isfinite(grad_val)))}")
Gradient shape: (3, 3) Gradient values: [[0. 0. 0.] [0. 0. 0.] [0. 0. 0.]] Gradient is finite: True
6. Scan-Based Sequential Filtering¶
The sequential update splits each filter step into two stages:
- ODE update -- incorporate ODE residual (linearized at the predicted state)
- Observation update -- incorporate measurements (linearized at the ODE-updated state)
This is beneficial because the observation update sees a tighter prior
(m_ode, P_ode) rather than the diffuse prediction (m_pred, P_pred).
The _scan variant uses jax.lax.scan for O(1) compilation time and
efficient memory usage. Observations are decoupled from the ODE model:
pre-compute them into an ObsModel with prepare_observations, then pass
the ObsModel as a fixed-shape array structure to the scan loop.
from ode_filters.filters import (
ekf1_sqr_loop_sequential,
ekf1_sqr_loop_sequential_scan,
)
from ode_filters.measurement import prepare_observations
# Logistic ODE with sparse observations
def vf_scan(x, *, t):
return x * (1 - x)
x0_scan = np.array([0.01])
tspan_scan = (0, 10)
N_scan = 200
prior_scan = IWP(q=2, d=1, Xi=0.5 * np.eye(1))
mu_0_scan, Sigma_0_sqr_scan = taylor_mode_initialization(vf_scan, x0_scan, q=2)
# ODE-only measurement model (no observations bundled in)
measure_scan = ODEInformation(vf_scan, prior_scan.E0, prior_scan.E1)
# Create observations separately
ts_scan = np.linspace(tspan_scan[0], tspan_scan[1], N_scan + 1)
z_t_scan = ts_scan[20:180:10]
x_true_scan = 1.0 / (1.0 + 99.0 * np.exp(-z_t_scan))
z_scan = x_true_scan[:, None] + 0.02 * jrandom.normal(
jrandom.PRNGKey(99), shape=(len(z_t_scan), 1)
)
A_scan = np.array([[1.0]])
obs_scan = Measurement(A_scan, z_scan, z_t_scan, noise=0.01)
# Pre-compute observations into fixed-shape ObsModel
obs_model = prepare_observations([obs_scan], measure_scan._E0, ts_scan)
print(f"ObsModel shapes:")
print(f" H: {obs_model.H.shape}")
print(f" R_sqr: {obs_model.R_sqr.shape}")
print(f" c_seq: {obs_model.c_seq.shape} (one offset per time step)")
print(f" mask: {obs_model.mask.shape} (boolean mask per step)")
ObsModel shapes: H: (1, 3) R_sqr: (1, 1) c_seq: (200, 1) (one offset per time step) mask: (200, 1) (boolean mask per step)
# Run scan-based sequential filter (JIT-compatible out of the box)
jit_scan_filter = jax.jit(
ekf1_sqr_loop_sequential_scan, static_argnums=(2, 3, 4, 5)
)
scan_result = jit_scan_filter(
mu_0_scan, Sigma_0_sqr_scan, prior_scan, measure_scan, tspan_scan, N_scan,
obs_model=obs_model,
)
# Compare with the for-loop sequential filter
loop_result = ekf1_sqr_loop_sequential(
mu_0_scan, Sigma_0_sqr_scan, prior_scan, measure_scan, tspan_scan, N_scan,
observations=[obs_scan],
)
# Results match closely
m_scan = scan_result[0]
m_loop = np.array(loop_result[0])
max_diff = float(np.max(np.abs(m_scan - m_loop)))
print(f"Max difference between scan and for-loop means: {max_diff:.2e}")
print(f"Scan log-likelihood (ODE): {float(scan_result[-2]):.4f}")
print(f"Loop log-likelihood (ODE): {float(loop_result[-2]):.4f}")
print(f"Scan log-likelihood (obs): {float(scan_result[-1]):.4f}")
print(f"Loop log-likelihood (obs): {float(loop_result[-1]):.4f}")
Max difference between scan and for-loop means: 4.35e-06 Scan log-likelihood (ODE): 831.6470 Loop log-likelihood (ODE): 831.6470 Scan log-likelihood (obs): 21.7964 Loop log-likelihood (obs): 21.7964
# The scan filter is fully differentiable -- compute gradient of log-likelihood
def scan_ll(initial_cov_sqr):
res = jit_scan_filter(
mu_0_scan, initial_cov_sqr, prior_scan, measure_scan, tspan_scan, N_scan,
obs_model=obs_model,
)
return res[-2] + res[-1] # total log-likelihood
grad_scan = jax.grad(scan_ll)(Sigma_0_sqr_scan)
print(f"Gradient of scan log-likelihood w.r.t. initial covariance:")
print(f" shape: {grad_scan.shape}")
print(f" finite: {bool(np.all(np.isfinite(grad_scan)))}")
# Plot the scan filter result
fig, ax = plt.subplots(figsize=(10, 4))
x_analytic = 1.0 / (1.0 + 99.0 * np.exp(-ts_scan))
ax.plot(ts_scan, x_analytic, "k--", linewidth=2, label="Analytic solution")
ax.plot(ts_scan, m_scan[:, 0], "b-", linewidth=2, label="Scan filter estimate")
ax.scatter(z_t_scan, z_scan, s=20, color="orange", zorder=3, label="Observations")
ax.set_xlabel("t"), ax.set_ylabel("x(t)")
ax.set_title("Scan-Based Sequential Filter (logistic ODE, N=200)")
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()
Gradient of scan log-likelihood w.r.t. initial covariance: shape: (3, 3) finite: True
Summary¶
This notebook demonstrated advanced ODE solving capabilities:
- First-Order ODEs with Hidden States: Joint inference of state and parameters (decay rate)
- Second-Order ODEs with Hidden States: Estimating damping coefficient from observations
- Conservation Constraints: Enforcing algebraic constraints (SIR population conservation)
- Linear Measurements: Sparse, noisy observations at discrete times (Lotka-Volterra)
- JIT Compilation and Differentiation: Gradient-based hyperparameter optimization
- Scan-Based Sequential Filtering: Efficient
jax.lax.scanwith decoupled observations
Key Classes and Functions¶
For Hidden States:
JointPrior(prior_x, prior_u): Combines independent priors for state and hidden parametersODEInformationWithHidden(vf, E0, E1, E0_hidden, ...): First-order ODE with vector fieldvf(x, u, *, t)SecondOrderODEInformationWithHidden(vf, E0, E1, E2, E0_hidden, ...): Second-order ODE withvf(x, dx, u, *, t)
For Constraints:
Conservation(A, p): Hard algebraic constraintA @ x = p(always active)Measurement(A, z, z_t, noise): Time-varying linear observationsA @ x = z[t](active at specified times)
For Scan-Based Filtering:
prepare_observations(obs_list, E0, ts): Pre-computes observations into fixed-shapeObsModelekf1_sqr_loop_sequential_scan(...): Scan-based filter with O(1) compilation timeekf1_sqr_loop_preconditioned_sequential_scan(...): Preconditioned variant
Other Advanced Features (see documentation):
BlackBoxMeasurement: Custom observation models with autodiff JacobiansTransformedMeasurement: Nonlinear state transformations before measurementMaternPrior: Gaussian process priors with specified length scalesPrecondIWPand preconditioned filters: For numerical stability
All these features can be combined for complex real-world problems!