{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "false-finder",
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"try:\n",
" import tinygp\n",
"except ImportError:\n",
" %pip install -q tinygp"
]
},
{
"cell_type": "markdown",
"id": "polish-inquiry",
"metadata": {},
"source": [
"(intro)=\n",
"\n",
"# An Introduction to tinygp\n",
"\n",
"This tutorial provides a brief introduction to how Gaussian Processes (GPs) are implemented in `tinygp`.\n",
"We're not going to provide much of an introduction to GPs themselves, because there are already a lot of excellent resources for that, including [this text book](http://www.gaussianprocess.org/gpml/chapters/), [this blog post](https://distill.pub/2019/visual-exploration-gaussian-processes/), and many others that I'm sure you can find by Googling.\n",
"\n",
"Before we get started, it's pretty much always a good idea to [enable double precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) before doing anything with `tinygp` and GPs in `jax`, so that we end up with fewer numerical precision issues:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "450a06ab",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)"
]
},
{
"cell_type": "markdown",
"id": "251f35bf",
"metadata": {},
"source": [
"## Kernel building\n",
"\n",
"In `tinygp`, we primarily construct GP models, by specifying a \"kernel\" function defined using the building blocks in the {ref}`api-kernels`.\n",
"For example, we can define an \"exponential squared\" or \"radial basis function\" kernel using:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73df1f88",
"metadata": {},
"outputs": [],
"source": [
"from tinygp import kernels\n",
"\n",
"kernel = kernels.ExpSquared(scale=1.5)"
]
},
{
"cell_type": "markdown",
"id": "ed1ed903",
"metadata": {},
"source": [
"And we can plot its value (don't worry too much about the syntax here):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8865c55a",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def plot_kernel(kernel, **kwargs):\n",
" dx = np.linspace(0, 5, 100)\n",
" plt.plot(dx, kernel(dx, dx[:1]), **kwargs)\n",
" plt.xlabel(\"dx\")\n",
" plt.ylabel(\"k(dx)\")\n",
"\n",
"\n",
"plot_kernel(kernel)"
]
},
{
"cell_type": "markdown",
"id": "d0e13fbe",
"metadata": {},
"source": [
"This kernel on its own is not terribly expressive, so we'll usually end up adding and multiplying kernels to build the function we want.\n",
"For example, we can:\n",
"\n",
"- scale our kernel by a scalar,\n",
"- add multiple different kernels, or\n",
"- multiply kernels together\n",
"\n",
"as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4761b07e",
"metadata": {},
"outputs": [],
"source": [
"plot_kernel(kernel, label=\"original\", ls=\"dashed\")\n",
"\n",
"kernel_scaled = 4.5 * kernels.ExpSquared(scale=1.5)\n",
"plot_kernel(kernel_scaled, label=\"scaled\")\n",
"\n",
"kernel_sum = kernels.ExpSquared(scale=1.5) + 2 * kernels.Matern32(scale=2.5)\n",
"plot_kernel(kernel_sum, label=\"sum\")\n",
"\n",
"kernel_prod = 2 * kernels.ExpSquared(scale=1.5) * kernels.Cosine(scale=2.5)\n",
"plot_kernel(kernel_prod, label=\"product\")\n",
"\n",
"_ = plt.legend()"
]
},
{
"cell_type": "markdown",
"id": "062a4089",
"metadata": {},
"source": [
"For a lot of use cases, these operations will be sufficient to build the models that you need, but if not, check out the following tutorials for some more expressive examples: {ref}`kernels`, {ref}`transforms`, {ref}`geometry`, and {ref}`derivative`.\n",
"\n",
"## Sampling\n",
"\n",
"Once you have a kernel in hand, you can pass it to a {class}`tinygp.GaussianProcess` to handle most of the computations you need.\n",
"The {class}`tinygp.GaussianProcess` will also need to know the input coordinates of your data `X` (let's just make some up for now) and it takes a little parameter `diag`, which specifies extra variance to add on the diagonal.\n",
"When modeling real data, this can often be thought of as per-observation measurement uncertainty, but it may not always be obvious what to put there.\n",
"That being said, you'll probably find that if you don't use the `diag` parameter you'll get a lot of `nan`s in your results, so it's generally good to at least provide some small value for `diag`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca544e6a",
"metadata": {},
"outputs": [],
"source": [
"from tinygp import GaussianProcess\n",
"\n",
"# Let's make up some input coordinates (sorted for plotting purposes)\n",
"X = np.sort(np.random.default_rng(1).uniform(0, 10, 100))\n",
"\n",
"gp = GaussianProcess(kernel, X)"
]
},
{
"cell_type": "markdown",
"id": "4c1d803e",
"metadata": {},
"source": [
"This `gp` object now specifies a multivariate normal distribution over data points observed at `X`.\n",
"It's sometimes useful to generate samples from this distribution to see what our prior looks like.\n",
"We can do that using the {func}`GaussianProcess.sample` function, and this will be the first time we're going to need to do anything `jax`-specific (because of [how random numbers work in `jax`](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers)):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a60d78d",
"metadata": {},
"outputs": [],
"source": [
"y = gp.sample(jax.random.PRNGKey(4), shape=(5,))\n",
"plt.plot(X, y.T, color=\"k\", lw=0.5)\n",
"plt.xlabel(\"x\")\n",
"plt.ylabel(\"sampled observations\")\n",
"_ = plt.title(\"exponential squared kernel\")"
]
},
{
"cell_type": "markdown",
"id": "14cca5db",
"metadata": {},
"source": [
"We can also generate samples for different kernel functions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92fe5086",
"metadata": {},
"outputs": [],
"source": [
"# Here we're using the product of kernels defined above\n",
"kernel_prod = 2 * kernels.ExpSquared(scale=1.5) * kernels.Cosine(scale=2.5)\n",
"gp = GaussianProcess(kernel_prod, X, diag=1e-5)\n",
"y = gp.sample(jax.random.PRNGKey(4), shape=(5,))\n",
"\n",
"plt.plot(X, y.T, color=\"k\", lw=0.5)\n",
"plt.xlabel(\"x\")\n",
"plt.ylabel(\"sampled observations\")\n",
"_ = plt.title(\"product of kernels\")"
]
},
{
"cell_type": "markdown",
"id": "cfa44844",
"metadata": {},
"source": [
"It is quite common in the GP literature to set the mean function for our process to zero, but that isn't always what you want.\n",
"Instead, you can set the mean to a different constant, or to a function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d56ff273",
"metadata": {},
"outputs": [],
"source": [
"# A GP with a non-zero constant mean\n",
"gp = GaussianProcess(kernel, X, diag=1e-5, mean=2.0)\n",
"y_const = gp.sample(jax.random.PRNGKey(4), shape=(5,))\n",
"\n",
"\n",
"# And a GP with a general mean function\n",
"def mean_function(x):\n",
" return 5 * jax.numpy.sin(x)\n",
"\n",
"\n",
"gp = GaussianProcess(kernel, X, diag=1e-5, mean=mean_function)\n",
"y_func = gp.sample(jax.random.PRNGKey(4), shape=(5,))\n",
"\n",
"# Plotting these samples\n",
"_, axes = plt.subplots(2, 1, sharex=True)\n",
"ax = axes[0]\n",
"ax.plot(X, y_const.T, color=\"k\", lw=0.5)\n",
"ax.axhline(2.0)\n",
"ax.set_ylabel(\"constant mean\")\n",
"\n",
"ax = axes[1]\n",
"ax.plot(X, y_func.T, color=\"k\", lw=0.5)\n",
"ax.plot(X, jax.vmap(mean_function)(X), label=\"mean\")\n",
"ax.legend()\n",
"ax.set_xlabel(\"x\")\n",
"_ = ax.set_ylabel(\"mean function\")"
]
},
{
"cell_type": "markdown",
"id": "848ecc20",
"metadata": {},
"source": [
"## Conditioning & marginalization\n",
"\n",
"When it comes time to fit data using a GP model, the key operations that `tinygp` provides are _conditioning_ and _marginalization_.\n",
"For example, you may want to fit for the parameters of your kernel model (the length scale and amplitude, for example), and a good objective to use for that process is the marginal likelihood of the process evaluated for the observed data.\n",
"In `tinygp`, this is accessed via the {func}`tinygp.GaussianProcess.log_probability` method, which takes the observed data `y` as input.\n",
"(_Aside:_ The nomenclature for this method is a little tricky to get right, and we've settled on `log_probability` in `tinygp` since it is the multivariate normal probability density, but it's important to remember that it is a function of the data, making it a \"sampling distribution\" or \"likelihood\".)\n",
"We won't actually go into details about how to use this method for fitting here (check out pretty much any of the other tutorials for examples!), but to compute the log probability for a given dataset and kernel, we would do something like the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5db9aad8",
"metadata": {},
"outputs": [],
"source": [
"# Simulate a made up dataset, as an example\n",
"random = np.random.default_rng(1)\n",
"X = np.sort(random.uniform(0, 10, 10))\n",
"y = np.sin(X) + 1e-4 * random.normal(size=X.shape)\n",
"\n",
"# Compute the log probability\n",
"kernel = 0.5 * kernels.ExpSquared(scale=1.0)\n",
"gp = GaussianProcess(kernel, X, diag=1e-4)\n",
"print(gp.log_probability(y))"
]
},
{
"cell_type": "markdown",
"id": "c7d9480b",
"metadata": {},
"source": [
"But we do want to go into more details about how `tinygp` implements conditioning because it might be a little counterintuitive at first (and it's also different in `v0.2` of `tinygp`).\n",
"To condition a GP on observed data, we use the {func}`tinygp.GaussianProcess.condition` method, and that produces a named tuple with two elements as described in {class}`tinygp.gp.ConditionResult`.\n",
"The first element `log_probability` is the same log probability as we calculated above:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7a4ee73",
"metadata": {},
"outputs": [],
"source": [
"cond = gp.condition(y)\n",
"print(cond.log_probability)"
]
},
{
"cell_type": "markdown",
"id": "a191f8a6",
"metadata": {},
"source": [
"Then the second element `gp` is a new {class}`tinygp.GaussianProcess` describing the distribution at some test points (by default, the test points are the same as our inputs).\n",
"This conditioned GP operates just like our `gp` above, but its `kernel` and `mean_function` have these strange types:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2d51731",
"metadata": {},
"outputs": [],
"source": [
"type(cond.gp.kernel), type(cond.gp.mean_function)"
]
},
{
"cell_type": "markdown",
"id": "c02784ae",
"metadata": {},
"source": [
"This is cool because that means that you can use this conditioned {class}`tinygp.GaussianProcess` to do all the things you would usually do with a GP (e.g. sample from it, condition it further, etc.).\n",
"It is common to make plots like the following using these conditioned GPs (note that we're now using different test points):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0421532b",
"metadata": {},
"outputs": [],
"source": [
"X_test = np.linspace(0, 10, 100)\n",
"_, cond_gp = gp.condition(y, X_test)\n",
"\n",
"# The GP object keeps track of its mean and variance, which we can use for\n",
"# plotting confidence intervals\n",
"mu = cond_gp.mean\n",
"std = np.sqrt(cond_gp.variance)\n",
"plt.plot(X_test, mu, \"C1\", label=\"mean\")\n",
"plt.plot(X_test, mu + std, \"--C1\", label=\"1-sigma region\")\n",
"plt.plot(X_test, mu - std, \"--C1\")\n",
"\n",
"# We can also plot samples from the conditional\n",
"y_samp = cond_gp.sample(jax.random.PRNGKey(1), shape=(12,))\n",
"plt.plot(X_test, y_samp[0], \"C0\", lw=0.5, alpha=0.5, label=\"samples\")\n",
"plt.plot(X_test, y_samp[1:].T, \"C0\", lw=0.5, alpha=0.5)\n",
"\n",
"plt.plot(X, y, \".k\", label=\"data\")\n",
"plt.legend(fontsize=10)\n",
"plt.xlim(X_test.min(), X_test.max())\n",
"plt.xlabel(\"x\")\n",
"_ = plt.ylabel(\"y\")"
]
},
{
"cell_type": "markdown",
"id": "55f6fc1a",
"metadata": {},
"source": [
"## Tips\n",
"\n",
"Given the information we've covered so far, you may have just about everything you need to go on to the other tutorials, and start using `tinygp` for real.\n",
"But there were a few last things that are worth mentioning first.\n",
"\n",
"First, since `tinygp` is built on top of `jax`, it can be very useful to spend some time learning about `jax`, and in particular the [How to Think in JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html) tutorial is a great place to start.\n",
"One way that this plays out in `tinygp`, is that all the operations described in this tutorial are designed to be `jit` compiled, rather than executed directly like we've done here.\n",
"Specifically, a very common pattern that you'll see is a functional model setup like the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fcb2164",
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"\n",
"\n",
"def build_gp(params):\n",
" kernel = jnp.exp(params[\"log_amp\"]) * kernels.ExpSquared(\n",
" jnp.exp(params[\"log_scale\"])\n",
" )\n",
" return GaussianProcess(kernel, X, diag=jnp.exp(params[\"log_diag\"]))\n",
"\n",
"\n",
"@jax.jit\n",
"def loss(params):\n",
" gp = build_gp(params)\n",
" return -gp.log_probability(y)\n",
"\n",
"\n",
"params = {\n",
" \"log_amp\": -0.1,\n",
" \"log_scale\": 0.0,\n",
" \"log_diag\": -1.0,\n",
"}\n",
"loss(params)"
]
},
{
"cell_type": "markdown",
"id": "2350805d",
"metadata": {},
"source": [
"This is a good setup because the `build_gp` function can be reused for conditioning the GP, but the `@jax.jit` decorator on `loss` means that `jax` can optimize out all the Python overhead introduced by instantiating the `kernel` and `gp` objects.\n",
"\n",
"Finally, the {ref}`troubleshooting` page includes some other tips for dealing with issues that you might run into, but you should also be encouraged to [open issues](https://github.com/dfm/tinygp/issues) on the `tinygp` GitHub repo if you run into other problems."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b5e0fbe",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}