Comparison With george
try:
import tinygp
except ImportError:
!pip install -q tinygp
try:
import george
except ImportError:
!pip install -q george
from jax.config import config
config.update("jax_enable_x64", True)
Comparison With george¶
One of the tinygp
design decisions was to provide a high-level API similar to the one provided by the george GP library.
This was partly because I (as the lead developer of george
) wanted to ease users’ transitions away from george
to something more modern (like tinygp
).
I also quite like the george
API and don’t think that there exist other similar tools.
The defining feature is tinygp
does not include built-in implementations of inference algorithms.
Instead, it provides an expressive model-building interface that makes it easy to experiement with different kernels while still integrating with your favorite inference engine.
In this document, we compare the interfaces and computational performance of george
and tinygp
for constructing kernel models and evaluating the GP marginalized likelihood.
Since tinygp
supports GPU-acceleration, we have executed this notebook on a machine with the following GPU:
!nvidia-smi --query-gpu=gpu_name --format=csv
name
NVIDIA A100-PCIE-40GB
By default, the CPU versions of both george
and tinygp
will also use parellelized linear algebra libraries to take advantage of multiple CPU threads, however to make the benchmarks more replicable, we’ll disable this parallelization for the remainder of this notebook:
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["XLA_FLAGS"] = (
os.environ.get("XLA_FLAGS", "")
+ " --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
)
Then we generate some simulated data and define functions for computing the GP log likelihood using george
and tinygp
(with separate CPU and GPU version).
As mentioned above, the syntax of these functions is quite similar, but there are a few differences.
Most notably, the units of the “metric” or “length scale” parameter in the kernel is different (length-squared in george
and not squared in tinygp
).
Also, the gp.compute
method no longer exists in tinygp
since this would be less compatible with jax
’s preference for pure functional programming.
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import george
import tinygp
sigma = 1.5
rho = 2.5
jitter = 0.1
random = np.random.default_rng(49382)
x = np.sort(random.uniform(0, 10, 20_000))
y = np.sin(x) + jitter * random.normal(0, 1, len(x))
def george_loglike(x, y, **kwargs):
kernel = sigma ** 2 * george.kernels.Matern32Kernel(rho ** 2)
gp = george.GP(kernel, **kwargs)
gp.compute(x, jitter)
return gp.log_likelihood(y)
def tinygp_loglike(x, y):
kernel = sigma ** 2 * tinygp.kernels.Matern32(rho)
gp = tinygp.GaussianProcess(kernel, x, diag=jitter ** 2)
return gp.condition(y)
hodlr_loglike = partial(
george_loglike, solver=george.solvers.HODLRSolver, tol=0.5
)
tinygp_loglike_cpu = jax.jit(tinygp_loglike, backend="cpu")
tinygp_loglike_gpu = jax.jit(tinygp_loglike, backend="gpu")
Now we benchmark the computational cost of computing the log likelihood using each of these methods:
ns = [10, 20, 100, 200, 1_000, 2_000, 10_000, len(x)]
george_time = []
hodlr_time = []
cpu_time = []
gpu_time = []
for n in ns:
print(f"\nN = {n}:")
args = x[:n], y[:n]
gpu_args = jax.device_put(x[:n]), jax.device_put(y[:n])
if n < 10_000:
results = %timeit -o george_loglike(*args)
george_time.append(results.average)
tinygp_loglike_cpu(*args).block_until_ready()
results = %timeit -o tinygp_loglike_cpu(*args).block_until_ready()
cpu_time.append(results.average)
results = %timeit -o hodlr_loglike(*args)
hodlr_time.append(results.average)
tinygp_loglike_gpu(*gpu_args).block_until_ready()
results = %timeit -o tinygp_loglike_gpu(*gpu_args).block_until_ready()
gpu_time.append(results.average)
N = 10:
387 µs ± 6.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
14.2 µs ± 14 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
304 µs ± 6.28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
73.5 µs ± 368 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
N = 20:
403 µs ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
23.9 µs ± 78.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
318 µs ± 6.21 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
75.6 µs ± 86.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
N = 100:
781 µs ± 757 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
398 µs ± 701 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
592 µs ± 3.78 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
180 µs ± 2.02 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
N = 200:
2.42 ms ± 5.15 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.01 ms ± 8.43 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
919 µs ± 674 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
266 µs ± 222 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
N = 1000:
198 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
155 ms ± 429 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.42 ms ± 7.32 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.54 ms ± 1.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
N = 2000:
1.56 s ± 1.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.18 s ± 2.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8.81 ms ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.65 ms ± 1.25 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
N = 10000:
59.5 ms ± 228 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
49.1 ms ± 55.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
N = 20000:
126 ms ± 519 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
263 ms ± 898 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In the plot of this benchmark, you’ll notice several features:
For very small datasets, the
tinygp
CPU implementation is significantly faster than any of the other implementations. This is becausejax.jit
removes a lot of the Python overhead that is encountered when chainingnumpy
functions.For medium to large datasets,
tinygp
is generally faster thangeorge
, with the GPU version seeing a significant advantage.The CPU implementations approach the expected asymptotic complexity of \(\mathcal{O}(N^3)\) only for the largest values of \(N\). This is probably caused by memory allocation overhead or other operations with better scaling than the Cholesky factorization.
The approximate “HODLR” solver from
george
outperforms the GPU-enabledtinygp
exact solver, but only for very large datasets, and it’s important to note that the HODLR method does not scale well to larger input dimensions. Any existing or future approximate solvers like this that are implemented injax
could be easily used in conjunction withtinygp
, but such things have not yet been implemented.
plt.loglog(ns[: len(george_time)], george_time, "o-", label="george (basic)")
plt.loglog(ns, hodlr_time, "o-", label="george (HODLR)")
plt.loglog(ns[: len(cpu_time)], cpu_time, "o-", label="tinygp (CPU)")
plt.loglog(ns, gpu_time, "o-", label="tinygp (GPU)")
ylim = plt.ylim()
plt.loglog(
ns,
0.5 * np.array(ns) ** 3 / ns[len(cpu_time) - 1] ** 3 * cpu_time[-1],
":k",
label="O($N^3$)",
)
plt.ylim(ylim)
plt.legend()
plt.xlabel("number of data points")
plt.ylabel("runtime [s]");