Introduction to pypomp

This document showcases basic functions in pypomp. We demonstrate by building the linear Gaussian model (available as pypomp.LG()) from scratch.

First, we import the necessary packages.

import jax.numpy as jnp
import jax.random
import pypomp as pp
import pandas as pd
from functools import partial

Because the model parameters will all be passed in one JAX array, it can be helpful to define functions that pack the values into the array or unpack them into a more natural form. We also generate a pseudorandom number generation (PRNG) key to be used with JAX.

def get_thetas(theta):
    """
    Cast a theta vector into A, C, Q, and R matrices as if casting iron.
    """
    A = theta[0:4].reshape(2, 2)
    C = theta[4:8].reshape(2, 2)
    Q = theta[8:12].reshape(2, 2)
    R = theta[12:16].reshape(2, 2)
    return A, C, Q, R


def transform_thetas(A, C, Q, R):
    """
    Take A, C, Q, and R matrices and melt them into a single 1D array.
    """
    return jnp.concatenate([A.flatten(), C.flatten(), Q.flatten(), R.flatten()])

key = jax.random.key(1)

Next, we define the model mechanics. We require users to provide the function code to an object constructor that, among other things, checks that all of the necessary function arguments are included and are in the expected order. This is required because, internally, pypomp uses jax.vmap() to vectorize the component functions to efficiently run them for thousands of particles. jax.vmap() maps the function over input arrays based on position rather than by keyword, so the order needs to be enforced on the user. Although every expected argument must be included, it is not necessary for the function to use all of them.

@partial(pp.RInit, t0=0.0)
def rinit(theta_, key, covars=None, t0=None):
    """Initial state process simulator for the linear Gaussian model"""
    A, C, Q, R = get_thetas(theta_)
    return jax.random.multivariate_normal(key=key, mean=jnp.array([0, 0]), cov=Q)


@partial(pp.RProc, step_type="fixedstep", nstep=1)
def rproc(X_, theta_, key, covars=None, t=None, dt=None):
    """Process simulator for the linear Gaussian model"""
    A, C, Q, R = get_thetas(theta_)
    return jax.random.multivariate_normal(key=key, mean=A @ X_, cov=Q)


@pp.DMeas
def dmeas(Y_, X_, theta_, covars=None, t=None):
    """Measurement model distribution for the linear Gaussian model"""
    A, C, Q, R = get_thetas(theta_)
    return jax.scipy.stats.multivariate_normal.logpdf(Y_, X_, R)


@partial(pp.RMeas, ydim=2)
def rmeas(X_, theta_, key, covars=None, t=None):
    """Measurement simulator for the linear Gaussian model"""
    A, C, Q, R = get_thetas(theta_)
    return jax.random.multivariate_normal(key=key, mean=C @ X_, cov=R)

The Pomp constructor also requires model parameters. These can be provided either as a dictionary or as a list of dictionaries. Each item in a dictionary should include the parameter name as the key and the parameter value as the dictionary value. If the parameter sets are provided as a list of dictionaries, methods such as pfilter() run on each set of parameters. Here, we use Pomp.sample_params() to sample sets of parameters from uniform distributions with bounds passed as a dictionary of length-2 tuples. Pomp.sample_params() returns a ready-to-use list of dictionaries with the sampled parameters.

theta = {
    "A11": jnp.cos(0.2),
    "A12": -jnp.sin(0.2),
    "A21": jnp.sin(0.2),
    "A22": jnp.cos(0.2),
    "C11": 1.0,
    "C12": 0.0,
    "C21": 0.0,
    "C22": 1.0,
    "Q11": 0.01,
    "Q12": 1e-6,
    "Q21": 1e-6,
    "Q22": 0.01,
    "R11": 0.1,
    "R12": 0.01,
    "R21": 0.01,
    "R22": 0.1,
}
param_bounds = {k: (v * 0.9, v * 1.1) for k, v in theta.items()}
n = 5  # Number of parameter sets to sample
key, subkey = jax.random.split(key)
theta_list = pp.Pomp.sample_params(param_bounds, n, subkey)

We do not have real data in this example, so we generate our own. To make this example cleaner, we cheat a little bit by using the function LG() to construct the completed linear Gaussian model object and then generate the data using simulate() (a future update will make it easier to populate a data-less Pomp object with simulated data).

key, subkey = jax.random.split(key)
T = 100
sims = pp.LG(T=T).simulate(key=subkey)
ys = pd.DataFrame(
    sims[0]["Y_sims"].squeeze(),
    index=range(1, T + 1),
    columns=pd.Index(["Y1", "Y2"]),
)

We finally construct the POMP model object. Observation times are provided to the Pomp constructor via the pandas.DataFrame row index. If covariates were provided, the times at which the covariates were observed would also be provided by the pandas.DataFrame row index. Each argument to Pomp is accessible from the object as an attribute.

LG_obj = pp.Pomp(
    rinit=rinit,
    rproc=rproc,
    dmeas=dmeas,
    rmeas=rmeas,
    ys=ys,
    theta=theta_list,
    covars=None,
)

Unlike the R family of POMP packages, some Pomp methods including pfilter(), mif(), and train() yield results by modifying the object in place instead of returning new objects. Results are stored in a list under LG_obj.results_history, where each element corresponds to one method call. Each element includes results such as the log-likelihood and parameter estimates when applicable as well as the inputs used for the function call, so it is easy to keep track of how the results were calculated. If multiple parameter sets are supplied in a list as an argument, the method evaluates at each set and the results for each are stored.

We implement IF2 as the Pomp method mif(). The random walk standard deviation is controlled by sigmas and sigmas_init. sigmas controls the random walk standard deviation after time \(t_0\) whereas sigmas_init controls it for time \(t_0\), allowing the user to perturb initial value parameters just once per iteration. These arguments can be provided as arrays where each entry dictates the random walk standard deviation of the parameter with the corresponding index, or they can be provided as a float which gets broadcasted to all parameters.

Notably, Pomp methods such as mif() can update attributes other than results_history. For example, LG_obj.mif() and LG_obj.train() replace LG_obj.theta with the parameter estimate from the end of the last iteration. Furthermore, each method that takes a key as an argument stores an unused child key under LG_obj.fresh_key that later method calls can use by default when a key argument is not given. Consequently, the following sequence of method calls is valid.

key, subkey = jax.random.split(key)
LG_obj.mif(J=200, M=100, sigmas=0.02, sigmas_init=0.02, a=0.5, key=subkey)
LG_obj.pfilter(J=200, reps=20)

One way to access the results of a method call is by using the method results(). This returns a tidy data frame of results stored in LG_obj.results_history at a given index passed to results() (defaults to -1, the last method call). The data frame includes the log-likelihood and parameter estimates from the method call.

LG_obj.results()
logLik se A11 A12 A21 A22 C11 C12 C21 C22 Q11 Q12 Q21 Q22 R11 R12 R21 R22
0 -94.006348 0.258740 0.781668 -0.134588 0.415441 0.673535 0.439900 1.386680 3.840841 0.503787 0.044602 -0.656884 0.689944 0.074715 0.108552 -0.084329 0.076245 0.078228
1 -96.566437 0.269658 0.905786 -0.200530 0.210226 0.524645 1.469434 -0.694713 0.237822 1.521901 0.045307 -0.564952 0.564663 0.029756 0.075840 -0.688339 0.700091 0.066871
2 -100.281670 0.222478 0.645964 -0.070233 0.265264 0.502032 0.556673 -1.024783 -0.760580 0.127822 0.113779 -1.149185 1.155959 0.028191 0.085995 -0.409641 0.387792 0.103658
3 -99.041809 0.272080 0.626091 -0.107587 0.186608 0.607237 0.618223 -0.157061 2.176058 1.794874 0.051851 -0.526904 0.554509 0.088692 0.158336 0.478811 -0.541392 0.048753
4 -100.141853 0.442012 0.462065 0.045177 0.229735 0.673238 0.595209 0.888843 -0.085658 -1.078547 0.077865 -1.099316 1.038090 0.112328 0.106226 -0.169931 0.153630 0.043418

We can also use LG_obj.traces() to access the traces of the parameters and log-likelihood as a data frame. This includes log-likelihood estimates from pfilter() as well.

LG_obj.traces()
replication iteration method loglik A11 A12 A21 A22 C11 C12 C21 C22 Q11 Q12 Q21 Q22 R11 R12 R21 R22
0 0 1 mif NaN 0.963512 -0.178802 0.203705 1.069749 0.999111 0.000000 0.000000 1.054883 0.010472 9.363539e-07 0.000001 0.010878 0.102497 0.010768 0.009182 0.096302
1 1 1 mif NaN 0.966490 -0.178802 0.214531 0.982279 0.941170 0.000000 0.000000 1.077328 0.009964 1.079856e-06 0.000001 0.009380 0.109075 0.009379 0.010691 0.093214
2 2 1 mif NaN 0.967444 -0.178802 0.185848 0.989572 1.003363 0.000000 0.000000 1.003444 0.009594 9.344108e-07 0.000001 0.010617 0.100243 0.010192 0.010793 0.098955
3 3 1 mif NaN 1.015698 -0.178802 0.192220 1.073113 0.973013 0.000000 0.000000 0.916460 0.009097 9.765389e-07 0.000001 0.009699 0.101522 0.009060 0.010467 0.095675
4 4 1 mif NaN 0.977909 -0.178802 0.195180 0.928942 1.055157 0.000000 0.000000 0.939434 0.010325 1.084500e-06 0.000001 0.010019 0.108067 0.010244 0.010489 0.096365
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
505 2 101 pfilter -100.281670 0.645964 -0.070233 0.265264 0.502032 0.556673 -1.024783 -0.760580 0.127822 0.113779 -1.149185e+00 1.155959 0.028191 0.085995 -0.409641 0.387792 0.103658
506 3 101 mif -99.112450 0.626091 -0.107587 0.186608 0.607237 0.618223 -0.157061 2.176058 1.794874 0.051851 -5.269038e-01 0.554509 0.088692 0.158336 0.478811 -0.541392 0.048753
507 3 101 pfilter -99.041809 0.626091 -0.107587 0.186608 0.607237 0.618223 -0.157061 2.176058 1.794874 0.051851 -5.269038e-01 0.554509 0.088692 0.158336 0.478811 -0.541392 0.048753
508 4 101 mif -103.194794 0.462065 0.045177 0.229735 0.673238 0.595209 0.888843 -0.085658 -1.078547 0.077865 -1.099316e+00 1.038090 0.112328 0.106226 -0.169931 0.153630 0.043418
509 4 101 pfilter -100.141853 0.462065 0.045177 0.229735 0.673238 0.595209 0.888843 -0.085658 -1.078547 0.077865 -1.099316e+00 1.038090 0.112328 0.106226 -0.169931 0.153630 0.043418

510 rows × 20 columns