Tutorial: Kernel Transforms

try:
    import tinygp
except ImportError:
    !pip install -q tinygp

try:
    import flax
except ImportError:
    !pip install -q flax
    
try:
    import optax
except ImportError:
    !pip install -q optax
    
from jax.config import config

config.update("jax_enable_x64", True)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Tutorial: Kernel Transforms

tinygp is designed to make it easy to implement new kernels (see Tutorial: Custom Kernels & Pytree Data for an example), but a particular set of customizations that tinygp supports with a high-level interface are coordinate transforms. The basic idea here is that you may want to pass your input coordinates through a linear or non-linear transformation before evaluating one of the standard kernels in that transformed space. This is particularly useful for multivariate inputs where, for example, you may want to capture the different units, or prior covariances between dimensions.

tinygp has two types of built in transforms for multivariate inputs (tinygp.transforms.Affine and tinygp.transforms.Subspace), and support for much more flexible transforms via the tinygp.transforms.Transform base class. Take a look at the docstrings for tinygp.transforms.Affine and tinygp.transforms.Subspace to get more info about how to use them, and continue to the example below for a more detailed example of constructing a custom tinygp.transforms.Transform.

Example: Deep kernel lerning

The Deep Kernel Learning model is an example of a more complicated kernel transform, and since tinygp integrates well with libraries like flax (see Tutorial: Modeling Frameworks) the implementation of such a model is fairly straightforward. To demonstrate, let’s start by sampling a simulated dataset from a step function, a model that a GP would typically struggle to model:

import numpy as np
import matplotlib.pyplot as plt

random = np.random.default_rng(567)

noise = 0.1

x = np.sort(random.uniform(-1, 1, 100))
y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
t = np.linspace(-1.5, 1.5, 500)

plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
plt.plot(x, y, ".k", label="data")
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3)
plt.xlabel("x")
plt.ylabel("y")
_ = plt.legend()
../_images/transforms_3_0.png

Then we will fit this model using a model similar to the one described in Optimization with flax & optax, except our kernel will include a custom tinygp.kernels.Transform that will pass the input coordinates through a (small) neural network before passing them into a tinygp.kernels.Matern32 kernel. Otherwise, the model and optimization procedure are similar to the ones used in Optimization with flax & optax.

import jax
import optax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import zeros
from tinygp import kernels, transforms, GaussianProcess


# Define a small neural network used to non-linearly transform the input data in our model
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=15)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x


class GPLoss(nn.Module):
    @nn.compact
    def __call__(self, x, y, t):
        # Set up a typical Matern-3/2 kernel
        log_sigma = self.param("log_sigma", zeros, ())
        log_rho = self.param("log_rho", zeros, ())
        log_jitter = self.param("log_jitter", zeros, ())
        base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(
            jnp.exp(log_rho)
        )

        # Define a custom transform to pass the input coordinates through our `Transformer`
        # network from above
        transform = Transformer()
        kernel = transforms.Transform(transform, base_kernel)

        # Evaluate and return the GP negative log likelihood as usual
        gp = GaussianProcess(
            kernel, x[:, None], diag=noise ** 2 + jnp.exp(2 * log_jitter)
        )
        return -gp.condition(y), gp.predict(y, t[:, None], return_var=True)


# Define and train the model
def loss(params):
    return model.apply(params, x, y, t)[0]


model = GPLoss()
params = model.init(jax.random.PRNGKey(1234), x, y, t)
tx = optax.sgd(learning_rate=1e-4)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))
for i in range(1000):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

# Plot the results and compare to the true model
plt.figure()
mu, var = model.apply(params, x, y, t)[1]
plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
plt.plot(x, y, ".k", label="data")
plt.plot(t, mu)
plt.fill_between(
    t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label="model"
)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3)
plt.xlabel("x")
plt.ylabel("y")
_ = plt.legend()
../_images/transforms_5_0.png