try:
    import tinygp
except ImportError:
    !pip install -q tinygp

try:
    import flax
except ImportError:
    !pip install -q flax
    
try:
    import optax
except ImportError:
    !pip install -q optax

Using flax

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

t = np.sort(
    np.append(
        np.random.uniform(0, 3.8, 28),
        np.random.uniform(5.5, 10, 18),
    )
)
yerr = np.random.uniform(0.08, 0.22, len(t))
y = (
    0.2 * (t - 5)
    + np.sin(3 * t + 0.1 * (t - 5) ** 2)
    + yerr * np.random.randn(len(t))
)

true_t = np.linspace(0, 10, 100)
true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2)

plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3)
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.xlabel("x [day]")
plt.ylabel("y [ppm]")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5)
_ = plt.title("simulated data")
Matplotlib is building the font cache; this may take a moment.
../_images/flax_2_1.png
from jax.config import config

config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import zeros

import optax

from tinygp import kernels, GaussianProcess


class GPModule(nn.Module):
    @nn.compact
    def __call__(self, x, yerr, y, t):
        mean = self.param("mean", zeros, ())
        log_jitter = self.param("log_jitter", zeros, ())

        log_sigma1 = self.param("log_sigma1", zeros, ())
        log_rho1 = self.param("log_rho1", zeros, ())
        log_tau = self.param("log_tau", zeros, ())
        kernel1 = (
            jnp.exp(2 * log_sigma1)
            * kernels.ExpSquared(jnp.exp(log_tau))
            * kernels.Cosine(jnp.exp(log_rho1))
        )

        log_sigma2 = self.param("log_sigma2", zeros, ())
        log_rho2 = self.param("log_rho2", zeros, ())
        kernel2 = jnp.exp(2 * log_sigma2) * kernels.Matern32(jnp.exp(log_rho2))

        kernel = kernel1 + kernel2
        gp = GaussianProcess(
            kernel, x, diag=yerr ** 2 + jnp.exp(log_jitter), mean=mean
        )

        loss = -gp.condition(y)
        pred = gp.predict(y, t)

        return loss, pred


model = GPModule()


def loss(params):
    return model.apply(params, t, yerr, y, true_t)[0]


params = model.init(jax.random.PRNGKey(0), t, yerr, y, true_t)
tx = optax.sgd(learning_rate=3e-3)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

losses = []
for i in range(1001):
    loss_val, grads = loss_grad_fn(params)
    losses.append(loss_val)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0:
        print("Loss step {}: ".format(i), loss_val)

plt.plot(losses)
plt.ylabel("negative log likelihood")
_ = plt.xlabel("step number")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Loss step 0:  64.37859795274213
Loss step 100:  16.314078068928787
Loss step 200:  9.606068522548224
Loss step 300:  7.9649624466212465
Loss step 400:  7.272120333536339
Loss step 500:  6.904669025553208
Loss step 600:  6.683517806967696
Loss step 700:  6.5390558465441195
Loss step 800:  6.439063632618968
Loss step 900:  6.3667711740680595
Loss step 1000:  6.312679622437152
../_images/flax_3_8.png
pred = model.apply(params, t, yerr, y, true_t)[1]

plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3, label="truth")
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.plot(true_t, pred, label="max. like. model")
plt.xlabel("x [day]")
plt.ylabel("y [ppm]")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5)
plt.legend()
_ = plt.title("simulated data")
../_images/flax_4_0.png

Deep kernel learning with flax

import numpy as np
import matplotlib.pyplot as plt

random = np.random.default_rng(567)

noise = 0.1

x = np.sort(random.uniform(-1, 1, 200))
y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
t = np.linspace(-1.5, 1.5, 500)

plt.plot(x, y, ".k")
plt.plot(t, 2 * (t > 0) - 1)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3);
../_images/flax_6_0.png
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import zeros

from tinygp import kernels, GaussianProcess


class FeatureExtractor(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=100)(x)
        x = nn.relu(x)
        x = nn.Dense(features=20)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x


class GPLoss(nn.Module):
    @nn.compact
    def __call__(self, x, y, t):
        extr = FeatureExtractor()
        x = extr(x[:, None])
        t = extr(t[:, None])

        xmin = jnp.min(x, axis=0, keepdims=True)
        xmax = jnp.max(x, axis=0, keepdims=True)
        x = (x - xmin) / (xmax - xmin)
        t = (t - xmin) / (xmax - xmin)

        mean = self.param("mean", zeros, ())
        log_sigma = self.param("log_sigma", zeros, ())
        log_rho = self.param("log_rho", zeros, (x.shape[1],))
        log_jitter = self.param("log_jitter", zeros, ())
        kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(
            jnp.exp(2 * log_rho)
        )

        gp = GaussianProcess(
            kernel, x, diag=noise ** 2 + jnp.exp(2 * log_jitter), mean=mean
        )
        return -gp.condition(y), gp.predict(y, t, return_var=True), (x, t)
import optax

model = GPLoss()


def loss(params):
    return model.apply(params, x, y, t)[0]


params = model.init(jax.random.PRNGKey(0), x, y, t)
tx = optax.sgd(learning_rate=1e-4)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

for i in range(1001):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0:
        print("Loss step {}: ".format(i), loss_val)
Loss step 0:  209.87165558654118
Loss step 100:  -83.38566874678
Loss step 200:  -118.3053240725421
Loss step 300:  -140.77867793725787
Loss step 400:  -155.54681194000665
Loss step 500:  -157.31810065364692
Loss step 600:  -159.0145064761331
Loss step 700:  -159.83063787359794
Loss step 800:  -160.11920800737445
Loss step 900:  -160.1523040035471
Loss step 1000:  -160.17908016407202
mu, var = model.apply(params, x, y, t)[1]
plt.plot(x, y, ".k")
plt.plot(t, mu)
plt.fill_between(t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3);
../_images/flax_9_0.png
xp, tp = model.apply(params, x, y, t)[2]

plt.plot(t, tp)
plt.xlabel("x")
plt.ylabel("warped x")
plt.xlim(-1.5, 1.5);
../_images/flax_10_0.png