Introduction to pypomp

Published

April 27, 2026

This document showcases basic functions in pypomp. We demonstrate by building the linear Gaussian model (available as pypomp.models.LG()) from scratch.

First, we import the necessary packages.

import jax
import jax.numpy as jnp
import numpy as np
import pypomp as pp
from pypomp.types import *
import pandas as pd
import plotnine as p9

print(f"pypomp version: {pp.__version__}")
print(f"JAX version: {jax.__version__}")
pypomp version: 0.4.4.9
JAX version: 0.9.1

In pypomp, model components receive parameters as a dictionary (or ParamDict). For the linear Gaussian model, it is helpful to define a utility function to unpack these into the matrices used by the model equations.

def get_matrices(theta: dict[str, float]):
    """
    Unpack the theta dictionary into A, C, Q, and R matrices.
    """
    A = jnp.array([[theta["A11"], theta["A12"]], [theta["A21"], theta["A22"]]])
    C = jnp.array([[theta["C11"], theta["C12"]], [theta["C21"], theta["C22"]]])
    Q = jnp.array([[theta["Q11"], theta["Q12"]], [theta["Q21"], theta["Q22"]]])
    R = jnp.array([[theta["R11"], theta["R12"]], [theta["R21"], theta["R22"]]])

    return A, C, Q, R

Next, we define the model mechanics, which need to have a specific set of arguments and a specific output format. This is required so the functions can properly slot in to the internal pypomp code. For example, pypomp needs to know the order in which to pass arguments to the model mechanics. There are two ways pypomp can determine the order of arguments to pass to the model mechanics:

  1. The Pomp constructor checks if the model mechanics functions have particular argument names that appear in a predetermined order, in which case internal functions can pass arguments to the functions in a default order without issue.

  2. Failing that, it checks if they have particular type hints included in pypomp.types, in which case it passes arguments in the order specified by the type hints.

For demonstration, we use both the type hints and the default argument names.

Although every expected argument must be included, it is not necessary for the function to use all of them.

def rinit(
    theta_: ParamDict,
    key: RNGKey,
    covars: CovarDict = None,
    t0: InitialTimeFloat = None,
):
    """Initial state process simulator for the linear Gaussian model"""
    A, C, Q, R = get_matrices(theta_)
    res = jax.random.multivariate_normal(key=key, mean=jnp.array([0, 0]), cov=Q)
    return {"X1": res[0], "X2": res[1]}


def rproc(
    X_: StateDict,
    theta_: ParamDict,
    key: RNGKey,
    covars: CovarDict = None,
    t: TimeFloat = None,
    dt: StepSizeFloat = None,
):
    """Process simulator for the linear Gaussian model"""
    A, C, Q, R = get_matrices(theta_)
    X_arr = jnp.array([X_["X1"], X_["X2"]])
    res = jax.random.multivariate_normal(key=key, mean=A @ X_arr, cov=Q)
    return {"X1": res[0], "X2": res[1]}


def dmeas(
    Y_: ObservationDict,
    X_: StateDict,
    theta_: ParamDict,
    covars: CovarDict = None,
    t: TimeFloat = None,
):
    """Measurement model distribution for the linear Gaussian model (log-likelihood)"""
    A, C, Q, R = get_matrices(theta_)
    X_arr = jnp.array([X_["X1"], X_["X2"]])
    Y_arr = jnp.array([Y_["Y1"], Y_["Y2"]])
    return jax.scipy.stats.multivariate_normal.logpdf(Y_arr, C @ X_arr, R)


def rmeas(
    X_: StateDict,
    theta_: ParamDict,
    key: RNGKey,
    covars: CovarDict = None,
    t: TimeFloat = None,
):
    """Measurement simulator for the linear Gaussian model"""
    A, C, Q, R = get_matrices(theta_)
    X_arr = jnp.array([X_["X1"], X_["X2"]])
    return jax.random.multivariate_normal(key=key, mean=C @ X_arr, cov=R)

The algorithms implemented in pypomp for searching the parameter space assume that the parameters exist on the real line. To handle parameters that are constrained, such as the diagonal elements of \(Q\) and \(R\) which must be positive, we write functions describing how to transform the parameters to the real line and back and store them in a ParTrans object.

def to_est(theta):
    new_theta = {**theta}
    for name in "ACQR":
        new_theta[f"{name}11"] = jnp.log(theta[f"{name}11"])
        new_theta[f"{name}22"] = jnp.log(theta[f"{name}22"])
    return new_theta


def from_est(theta):
    new_theta = {**theta}
    for name in "ACQR":
        new_theta[f"{name}11"] = jnp.exp(theta[f"{name}11"])
        new_theta[f"{name}22"] = jnp.exp(theta[f"{name}22"])
    return new_theta


ptrans = pp.ParTrans(to_est=to_est, from_est=from_est)

The Pomp constructor also requires model parameters. These can be provided either as a dictionary or as a list of dictionaries. Each item in a dictionary should include the parameter name as the key and the parameter value as the dictionary value. If the parameter sets are provided as a list of dictionaries, methods such as pfilter() run on each set of parameters. Here, we use Pomp.sample_params() to sample sets of parameters from uniform distributions with bounds passed as a dictionary of length-2 tuples. Pomp.sample_params() returns a ready-to-use list of dictionaries with the sampled parameters.

We also generate a pseudorandom number generation (PRNG) key to be used with JAX.

theta = {
    "A11": jnp.cos(0.2),
    "A12": -jnp.sin(0.2),
    "A21": jnp.sin(0.2),
    "A22": jnp.cos(0.2),
    "C11": 1.0,
    "C12": 0.0,
    "C21": 0.0,
    "C22": 1.0,
    "Q11": 0.01,
    "Q12": 1e-6,
    "Q21": 1e-6,
    "Q22": 0.01,
    "R11": 0.1,
    "R12": 0.01,
    "R21": 0.01,
    "R22": 0.1,
}
param_bounds = {k: (v * 0.9, v * 1.1) for k, v in theta.items()}
n = 5  # Number of parameter sets to sample
key = jax.random.key(1)
key, subkey = jax.random.split(key)
theta_list = pp.Pomp.sample_params(param_bounds, n, subkey)

We finally construct the POMP model object. Observation times are provided to the Pomp constructor via the pandas.DataFrame row index. If covariates were provided, the times at which the covariates were observed would also be provided by the pandas.DataFrame row index. Each argument to Pomp is accessible from the object as an attribute.

We do not have real data in this example, so we generate our own. We do this by creating a dummy data frame with the desired observation times and column names, and then we generate the observations using Pomp.simulate(), which returns a new Pomp object with simulated data.

key, subkey = jax.random.split(key)
T = 100
ys = pd.DataFrame(
    np.zeros((T, 2)),
    index=range(1, T + 1),
    columns=pd.Index(["Y1", "Y2"]),
)

LG_obj_empty = pp.Pomp(
    ys=ys,
    theta=theta_list[0],
    statenames=["X1", "X2"],
    t0=0.0,
    rinit=rinit,
    rproc=rproc,
    dmeas=dmeas,
    rmeas=rmeas,
    nstep=1,
    par_trans=ptrans,
)
LG_obj = LG_obj_empty.simulate(key=subkey, as_pomp=True)
LG_obj.theta = theta_list

Unlike the R family of POMP packages, some Pomp methods including pfilter(), mif(), and train() yield results by modifying the object in place instead of returning new objects. Results are stored in a list under LG_obj.results_history, where each element corresponds to one method call. Each element includes results such as the log-likelihood and parameter estimates when applicable as well as the inputs used for the function call, so it is easy to keep track of how the results were calculated. If multiple parameter sets are supplied in a list as an argument, the method evaluates at each set and the results for each are stored.

We implement IF2 as the Pomp method mif(). The random walk standard deviations are controlled via an RWSigma object.

The RWSigma constructor takes a dictionary of standard deviations and an optional list of init_names. Parameters in init_names are perturbed only at the initial time \(t_0\), while others are perturbed at every time step. In this case, we perturb all parameters at every time step.

theta_names = LG_obj.canonical_param_names
rw_sd = pp.RWSigma(sigmas={k: 0.01 for k in theta_names})

Notably, Pomp methods such as mif() can update attributes other than results_history. For example, LG_obj.mif() and LG_obj.train() replace LG_obj.theta with the parameter estimate from the end of the last iteration. Furthermore, each method that takes a key as an argument stores an unused child key under LG_obj.fresh_key that later method calls can use by default when a key argument is not given. Consequently, the following sequence of method calls is valid.

key, subkey = jax.random.split(key)
LG_obj.mif(J=200, M=100, rw_sd=rw_sd, a=0.5, key=subkey)
LG_obj.pfilter(J=200, reps=20)

One way to access the results of a method call is by using the method results(). This returns a tidy data frame of results stored in LG_obj.results_history at a given index passed to results() (defaults to -1, the last method call). The data frame includes the log-likelihood and parameter estimates from the method call.

LG_obj.results()
theta_idx logLik se A11 A12 A21 A22 C11 C12 C21 C22 Q11 Q12 Q21 Q22 R11 R12 R21 R22
0 0 -104.156914 0.489811 1.038134 -0.154051 0.227557 0.977598 1.320086 -0.171506 -0.275279 1.204020 0.019986 -0.004560 -0.016448 0.017054 0.115791 0.163259 -0.183533 0.104358
1 1 -98.217696 0.187537 1.093436 -0.180231 0.226643 0.933004 0.650459 0.173875 -0.679727 0.719599 0.018843 -0.405763 0.397696 0.011592 0.130361 -0.100313 0.150536 0.161968
2 2 -99.826312 0.545677 1.148332 -0.200436 0.247030 0.883687 0.768660 -0.375512 -0.277474 0.777397 0.011457 -0.331688 0.334211 0.013466 0.104730 -0.667116 0.699338 0.096994
3 3 -95.854480 0.156273 1.045159 -0.201832 0.174963 0.980609 0.643130 0.004079 -0.282392 0.746395 0.019933 0.400345 -0.393962 0.012683 0.112911 0.841764 -0.845988 0.162781
4 4 -115.756517 0.314190 0.857582 -0.325349 0.188333 1.170569 1.144829 1.187590 -0.500412 0.797229 0.036627 1.172909 -1.232881 0.082105 0.105251 0.142052 -0.206588 0.130150

We can also use LG_obj.traces() to access the traces of the parameters and log-likelihood as a data frame. This includes log-likelihood estimates from pfilter() as well. We can easily plot these traces to visualize the convergence of the algorithm.

LG_obj.traces()
theta_idx iteration method logLik A11 A12 A21 A22 C11 C12 C21 C22 Q11 Q12 Q21 Q22 R11 R12 R21 R22
0 0 0 mif NaN 0.963512 -0.178802 0.203705 1.069749 0.999112 0.000000 0.000000 1.054883 0.010472 9.363539e-07 0.000001 0.010878 0.102497 0.010768 0.009182 0.096302
1 0 1 mif -158.750687 0.979602 -0.181179 0.177056 1.042791 0.940792 0.148876 -0.218030 0.942802 0.010468 6.731675e-02 -0.069008 0.009978 0.118210 -0.018776 0.044527 0.114382
2 0 2 mif -180.775375 0.961939 -0.184962 0.193551 1.042404 0.878036 0.078811 -0.067227 0.903754 0.011677 6.131419e-02 -0.062619 0.009367 0.100722 0.041346 -0.013567 0.124873
3 0 3 mif -147.994080 0.938148 -0.201689 0.199320 1.066983 0.885734 0.106574 0.031965 0.918693 0.012200 3.579927e-03 -0.003693 0.009328 0.098911 -0.005288 0.027408 0.117492
4 0 4 mif -143.895569 0.915658 -0.181497 0.220157 1.076241 0.907372 0.214076 -0.017906 0.845472 0.015432 1.067293e-01 -0.105180 0.009258 0.104234 0.079302 -0.136381 0.119326
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
505 4 97 mif -128.210876 0.840428 -0.361168 0.204942 1.194526 1.111526 1.131103 -0.485405 0.812234 0.034200 1.220573e+00 -1.299572 0.081826 0.111896 0.145000 -0.157156 0.120908
506 4 98 mif -120.703064 0.840681 -0.336580 0.204312 1.198948 1.157781 1.167410 -0.505694 0.794345 0.035234 1.208649e+00 -1.294106 0.083394 0.107634 0.153087 -0.170070 0.127514
507 4 99 mif -121.283913 0.855288 -0.333067 0.194163 1.173801 1.146324 1.187370 -0.527522 0.811167 0.037217 1.193427e+00 -1.274379 0.084293 0.103649 0.119751 -0.208256 0.129763
508 4 100 mif -123.410843 0.857582 -0.325349 0.188333 1.170569 1.144829 1.187590 -0.500412 0.797229 0.036627 1.172908e+00 -1.232882 0.082105 0.105251 0.142052 -0.206588 0.130150
509 4 100 pfilter -115.756517 0.857582 -0.325349 0.188333 1.170569 1.144829 1.187590 -0.500412 0.797229 0.036627 1.172909e+00 -1.232881 0.082105 0.105251 0.142052 -0.206588 0.130150

510 rows × 20 columns

(
    p9.ggplot(
        LG_obj.traces().reset_index(),
        p9.aes(x="iteration", y="logLik", color="factor(theta_idx)"),
    )
    + p9.geom_line()
    + p9.theme_minimal()
    + p9.labs(x="Iteration", y="Log-Likelihood", title="IF2 Convergence", color="Chain")
)
/Users/aaronabkemeier/Projects/pypomp_suite/tutorials/.venv/lib/python3.12/site-packages/plotnine/geoms/geom_path.py:100: PlotnineWarning: geom_path: Removed 1 rows containing missing values.

traces_long = (
    LG_obj.traces()
    .reset_index()
    .melt(id_vars=["iteration", "theta_idx"], value_vars=theta_names)
)

(
    p9.ggplot(traces_long, p9.aes(x="iteration", y="value", color="factor(theta_idx)"))
    + p9.geom_line()
    + p9.facet_wrap("~variable", scales="free_y")
    + p9.theme_minimal()
    + p9.theme(figure_size=(10, 8))
    + p9.labs(x="Iteration", y="Value", title="Parameter Traces", color="Chain")
)