Tutorial: Kernel Transforms
Contents
try:
import tinygp
except ImportError:
!pip install -q tinygp
try:
import flax
except ImportError:
!pip install -q flax
try:
import optax
except ImportError:
!pip install -q optax
from jax.config import config
config.update("jax_enable_x64", True)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Tutorial: Kernel Transforms¶
tinygp
is designed to make it easy to implement new kernels (see Tutorial: Custom Kernels & Pytree Data for an example), but a particular set of customizations that tinygp
supports with a high-level interface are coordinate transforms.
The basic idea here is that you may want to pass your input coordinates through a linear or non-linear transformation before evaluating one of the standard kernels in that transformed space.
This is particularly useful for multivariate inputs where, for example, you may want to capture the different units, or prior covariances between dimensions.
tinygp
has two types of built in transforms for multivariate inputs (tinygp.transforms.Affine
and tinygp.transforms.Subspace
), and support for much more flexible transforms via the tinygp.transforms.Transform
base class.
Take a look at the docstrings for tinygp.transforms.Affine
and tinygp.transforms.Subspace
to get more info about how to use them, and continue to the example below for a more detailed example of constructing a custom tinygp.transforms.Transform
.
Example: Deep kernel lerning¶
The Deep Kernel Learning model is an example of a more complicated kernel transform, and since tinygp
integrates well with libraries like flax
(see Tutorial: Modeling Frameworks) the implementation of such a model is fairly straightforward.
To demonstrate, let’s start by sampling a simulated dataset from a step function, a model that a GP would typically struggle to model:
import numpy as np
import matplotlib.pyplot as plt
random = np.random.default_rng(567)
noise = 0.1
x = np.sort(random.uniform(-1, 1, 100))
y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
t = np.linspace(-1.5, 1.5, 500)
plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
plt.plot(x, y, ".k", label="data")
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3)
plt.xlabel("x")
plt.ylabel("y")
_ = plt.legend()
Then we will fit this model using a model similar to the one described in Optimization with flax & optax, except our kernel will include a custom tinygp.kernels.Transform
that will pass the input coordinates through a (small) neural network before passing them into a tinygp.kernels.Matern32
kernel.
Otherwise, the model and optimization procedure are similar to the ones used in Optimization with flax & optax.
import jax
import optax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import zeros
from tinygp import kernels, transforms, GaussianProcess
# Define a small neural network used to non-linearly transform the input data in our model
class Transformer(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=15)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
class GPLoss(nn.Module):
@nn.compact
def __call__(self, x, y, t):
# Set up a typical Matern-3/2 kernel
log_sigma = self.param("log_sigma", zeros, ())
log_rho = self.param("log_rho", zeros, ())
log_jitter = self.param("log_jitter", zeros, ())
base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(
jnp.exp(log_rho)
)
# Define a custom transform to pass the input coordinates through our `Transformer`
# network from above
transform = Transformer()
kernel = transforms.Transform(transform, base_kernel)
# Evaluate and return the GP negative log likelihood as usual
gp = GaussianProcess(
kernel, x[:, None], diag=noise ** 2 + jnp.exp(2 * log_jitter)
)
return -gp.condition(y), gp.predict(y, t[:, None], return_var=True)
# Define and train the model
def loss(params):
return model.apply(params, x, y, t)[0]
model = GPLoss()
params = model.init(jax.random.PRNGKey(1234), x, y, t)
tx = optax.sgd(learning_rate=1e-4)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))
for i in range(1000):
loss_val, grads = loss_grad_fn(params)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
# Plot the results and compare to the true model
plt.figure()
mu, var = model.apply(params, x, y, t)[1]
plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
plt.plot(x, y, ".k", label="data")
plt.plot(t, mu)
plt.fill_between(
t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label="model"
)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3)
plt.xlabel("x")
plt.ylabel("y")
_ = plt.legend()