Introduction to pypomp

This document showcases basic functions in pypomp. We demonstrate by building the linear Gaussian model (available as pypomp.LG()) from scratch.

First, we import the necessary packages.

import jax.numpy as jnp
import jax.random
import pypomp as po
from pypomp.LG import Generate_data

Because the model parameters will all be passed in one JAX array, it can be helpful to define functions that pack the values into the array or unpack them into a more natural form.

def get_thetas(theta):
    """
    Cast a theta vector into A, C, Q, and R matrices as if casting iron.
    """
    A = theta[0:4].reshape(2, 2)
    C = theta[4:8].reshape(2, 2)
    Q = theta[8:12].reshape(2, 2)
    R = theta[12:16].reshape(2, 2)
    return A, C, Q, R


def transform_thetas(A, C, Q, R):
    """
    Take A, C, Q, and R matrices and melt them into a single 1D array.
    """
    return jnp.concatenate([A.flatten(), C.flatten(), Q.flatten(), R.flatten()])

Next, we define the components of the model. The decorators are used to ensure that the arguments are in the order required for jax.vmap() to work internally.

@po.RInit
def rinit(params, J, covars=None):
    """Initial state process simulator for the linear Gaussian model"""
    return jnp.ones((J, 2))

@po.RProc
def rproc(state, params, key, covars=None):
    """Process simulator for the linear Gaussian model"""
    A, C, Q, R = get_thetas(params)
    key, subkey = jax.random.split(key)
    return jax.random.multivariate_normal(
        key=subkey,
        mean=A @ state, cov=Q
    )

@po.DMeas
def dmeas(y, state, params):
    """Measurement model distribution for the linear Gaussian model"""
    A, C, Q, R = get_thetas(params)
    return jax.scipy.stats.multivariate_normal.logpdf(y, state, R)

Now we define the parameter values as JAX arrays.

# Transition matrix
A = jnp.array([
    [jnp.cos(0.2), -jnp.sin(0.2)],
    [jnp.sin(0.2),  jnp.cos(0.2)]
])
# Measurement matrix
C = jnp.eye(2)
# Covariance matrix of state noise
Q = jnp.array([
    [1, 1e-4],
    [1e-4, 1]
]) / 100
# Covariance matrix of measurement noise
R = jnp.array([
    [1, .1],
    [.1, 1]
]) / 10

We do not have real-world data, so we generate our own data.

Y = Generate_data(T = 100)
  0%|          | 0/100 [00:00<?, ?it/s]  1%|          | 1/100 [00:00<00:44,  2.22it/s]100%|██████████| 100/100 [00:00<00:00, 215.12it/s]

We finally construct the POMP model object.

LG_mod = po.Pomp(
    rinit = rinit,
    rproc = rproc, 
    dmeas = dmeas,
    theta = transform_thetas(A, C, Q, R),
    ys=Y
)

Now we can run the particle filter on our model object to obtain an estimate of the negative log likelihood of the data given the model.

pfilter_results = LG_mod.pfilter(J=200, key=jax.random.key(123))
float(pfilter_results)
549.2139892578125