use Python and JAX to efficiently build infinitely wide networks for deeper network insights on your finite machine.

One notorious problem with deep learning and deep neural networks (DNNs) is, that they can become black boxes. Lets say, that we have fitted a network with good test performance on a given classification problem. However we are now stuck. We cannot make sense of the final weights that have been learned or adequately visualize the problem space!
Another issue arises in the real world. In practical applications of applying neural networks we often fall back to train ensembles of networks. We use the averaged output of many models. This can be more powerful than the output of one single network. When we have multiple models we can also better analyze overall performance over the given problem.
Neural Tangents aims to solve these challenges elegantly.
In this post we will explore the background and application of neural networks of infinite width. We take a look at Neural Tangents as a framework and high level API, which allows us to build such networks efficiently. We will use Neural tangents on a simple regression problem and train both a finite and infinite network using JAX. From this we can gain more insights from our model and can provide better analytics. We can use both gradient descent and performing inference, when building networks with Neural Tangents and we will exploit both features when looking at our model.
Firstly, how do we get to infinity and back?
A short primer on Neural Networks
Generally speaking, we build a neural network from assumptions about the underlying problem-space ; for example proximity and values of pixels, to detect edges and coherent elements in pictures. Usually a network’s topology fits to solve a class of problems.
We then initialize the network weights. This can be done from a Gaussian distribution. We train the model by adapting the model-parameters in order to minimize a loss function for the problem that we have previously defined. After the training is done we can test the model, with a selection of metrics on an unseen test-set of the data.
Now we want to bring some light into the dark: We would like easier to explain analytical properties. One essential question is, why (deep) neural networks generalize so well even though they tend to be overparameterized [4].
What happens when we stretch Deep Neural Networks?
It turns out when we increase the width of a regular DNN we can gain some nicer insights and overall better understanding. Very, very, (very) wide Networks show convergence to Gaussian Processes [9]. It had been established for some time, that this is true for single layer neural networks (Ch. 4 in [8]). It has recently been shown that it can also be applicable for deeper networks [9].
What do we get from transforming a Neural Network into a Gaussian Process?
This allows for better analysis and uncertainties over our estimates. In addition to the previous NN metrics we can now perform Bayesian inference, assess uncertainties over prior and posterior, and analyze them.
One essential cornerstone is, that a Gaussian Process is associated with a kernel for its covariance function. These kernels allow for better analytical properties. The mathematics behind kernel functions is in parts better understood and more developed than when working with regular neural networks.
Figure 1 gives a graphical overview of what we are doing. We start with a regular DNN of finite width and take the width of the hidden layer to its limits.

Two essential kernels – our gates to infinity
What connects the neural network with the fabled Gaussian Processes?
One essential assumption is, that at initialization (given infinite width) a neural network is equivalent to a Gaussian Process [4]. The evolution that occurs when training the network can then be described by a kernel as has been shown by researchers at the Ecole Polytechnique Federale de Lausanne [4] .
Context on kernels
For the sake of a simple introduction: we say that a kernel is a function that relates two points from your data-set () to each other. For this fact to work, you need a measurable space that has some underlying properties (looking at you RKHS).
How your pairs of points relate to each other is dependent on a kernel function.
We can use a kernel on our dataset to introduce relations between the data-points, measure similarities, and other properties.
This unnecessarily short paragraph does not do the field and exciting topic of kernel methods justice and I recommend checking out more insightful resources on the subject [5][6].
Now we already have a network, with its architecture, activations and also pre-activations (like convolutions, pooling – you name it) we denote that as .
The NNGP Kernel is simply defined as:
and the NTK Kernel:
In layman’s terms, the last kernel describes the change in the network parameters up until z.
The relationship between kernels and tensors
Roughly speaking, there is a direct correspondence between tensor operations and kernel operations. We can therefore express relationships that come from a neural network in terms of kernel methods.
That means, for each (tensor) operation in the neural network we can find an equivalent kernel operation.
There exists a pointwise translational rule, given our two kernels, such that:
and from this we can transform dense layers through our both kernels as:
For a rigorous explanation see Eq.(2)-(7) in [3].
A simple convolutional example
For an convolutional neural network that could translate into:
- Layer 0 – input:
and NTK:
This means that the NNGP Kernel operation is simply the given dataset in a NxN matrix and the NTK kernel is zero, since nothing has happened yet in your network. - Layer 0 – convolution:
= Conv(
)
and NTK:
Here we see the convolution operation by applying the NNGP kernel over the convolutions with theoperation (see below) and the NTK kernel operates on the computed NNGP Kernel taking previous
into account.
- Layer 1 – activations:
and NTK:
At this stage activation is performed, which is denoted asfor NNGP. From the NTK perspective the stage of the network is the dot-product of the activation with the state of the
thusfar.
- Layer 1 – …
And so on, all the way through the network.
Above is a dedicated operation which describes the summation over the convolutional filter. (You can find a complete exemplary overview in Fig. 4 and Table 1 in [3].)
And T is defined as:

Using Neural Tangents – hands-on
Now that we are familiar with some of the underlying concepts, we are going to use JAX and the Neural Tangents library. We build an exemplary network and evaluate it in detail.
Enter some regression problem
In order to get going, lets create an arbitrary regression problem from scratch, using sklearn. We create a dataset with two features and the function value
, like this:
n_samples = 125
regression = datasets.make_regression(n_samples=n_samples, n_features=2, n_informative=15, noise=0.25)
data = regression[0]
f = regression[1]
Sidenote: Above is only a small number of samples. You are encouraged to try out a bigger problem. Using Google Colab I ran into memory issues when using larger samples (n >750).
From this we get data, that should look something like this:

We split our data-set in training and testing for later processing. Be aware that we need an extra axis for our y-values for later computation and that we take a test set of 10% of our total data-set size.
X_train, X_test, y_train, y_test = train_test_split(data, f[:, np.newaxis], test_size=0.1, shuffle=True)
train = (X_train, y_train)
test = (X_test, y_test)
Now we create a simple hidden layer neural network, using JAX. Very simple in fact, since the hidden layer barely holds any neurons (see stax.Dense(5)) together with a relu activation function.
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(5), stax.Relu(),
stax.Dense(1)
)
Neural tangents returns three objects from this. The kernel_fn
part will serve for our analysis over infinity. It is the kernel representation of the architecture that we encountered previously. The init_fn
and apply_fn
correspond to networks of finite width.
We can visualize before even training a network?
The first interesting feature that neural tangents allows are insights from the kernels.
We compute the kernels for our test-data as such:
kernel = kernel_fn(X_test, X_test, ('nngp', 'ntk')) # get both NNGP and NTK kernel from the test-data

This gives us two things.:
One is the performance of Bayesian inference represented through the NNGP kernel at infinite time-steps. We achieve this by looking at our NNGP kernel.
Secondly, the NTK kernel corresponds to how you network would behave after having been trained for an infinite amount of time.
In theory we can sample from the prior distribution. However for this specific problem it is not very informative, so we’ll skip this step. You can find a nice introduction to how prior samples look in Google’s Colab Cookbook.
Actual Inference
After this first insight, we perform inference:
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, X_train,
y_train)
ntk_mean, ntk_covariance = predict_fn(x_test=X_test, get='ntk',
compute_cov=True)
ntk_mean = np.reshape(ntk_mean, (-1,))
ntk_std = np.sqrt(np.diag(ntk_covariance))
Where we solve for infinite runtime training over the network, given the kernel that we have constructed. We do this by utilizing the NTK kernel in predict_fn(x_test=X_test, get='ntk', ...
This is called: solving in closed form for gradient descent.
From this we can compute a mean and a covariance matrix. The variables above are ntk_mean
and ntk_std
. Note that for this specific problem the standard deviation is very small. Though we can still visualize how the means perform:

Inference can be done the Bayesian way as well. We just have to substitute ntk
with nngp
. For the performance of NNGP see the attachment, later in this post.
Computing loss
What we can do with Neural Tangents is compute the loss of the network over timesteps for both testing and training.
The loss is defined as the mean squared error or in Python: 1/n * np.mean(ys**2 - 2*mean*ys + var + mean**2, axis=1)
also we provide the time-steps of interest from 0 to 1000 in steps of 0.1 ( np.arange(0, 10 ** 3, 10 ** -1)
):
ts = np.arange(0, 10 ** 3, 10 ** -1)
ntk_train_loss_mean = loss_fn(predict_fn, y_train, ts)
ntk_test_loss_mean = loss_fn(predict_fn, y_test, ts, X_test)

The Actual Training
To have a finite comparison point we have to perform actual training of our network. As hyperparameters we set the learning rate to 0.1 and run training for 10000 steps. We use JAX’s stochastic gradient descent for our optimization.:
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_update = jit(opt_update)
We also need to compute loss and gradient loss in the process of training. As a loss function we use the mean-squared error. Note that we also use the JAX specific jit-compilation feature to make loss and grad-loss more performant in the process.
loss = jit(lambda params, x, y: 1/(len(x))*np.mean((apply_fn(params, x) - y) ** 2))
grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y))
Now we run the training:
opt_state = opt_init(params)
for i in range(training_steps):
opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state)
train_losses += [loss(get_params(opt_state), *train)]
test_losses += [loss(get_params(opt_state), *test)]
Our Outcome

Fig
6.) Predictions over test data with underlying tr
uth of the test-data. Good correlation
of test data
against predictions w.r.t. feature 1,
worse performance for feature 2.
Correlation (Pearson) Training feature x_1 r: 0.96
Correlation Training feature x_2 r: -0.03
and
Correlation (Pearson) Testing feature x_1 r: 0.97
Correlation Testing feature x_2 r: 0.58

From the last figure, we can see that the finite network performs better after 10 steps. Though please keep in mind that this is a very simple exercise. For more complex tasks [3] shows that Neural Tangents performs at least as well as finite networks.
What makes Neural Tangents performant?
Working with covariances and kernels in a large problem domain can be very expensive. Some of the key challenges are: (1) Computing covariances when (a) inverting matrices, (b) constructing them, (c) updates to them and (2) computing kernels (a) for multiple kernels (b) across all the data.
Neural Tangents offers solutions to make these problems more feasible:
- leveraging the internal structures to reduce covariance inversion by orders of magnitude (for classification problems),
- tracking only the covariance matrices necessary and optimizing for the properties of the convolutions allowing for a reduced runtime and minimal memory footprint,
- treating covariance computations as 2D convolutions to allow for hardware accelerators,
- computing the NNGP and NT kernel together, since NTK requires NNGP,
- allowing for batching in training such that the dataset does not have to be treated all at once.
This all combined makes computational intense tasks more feasible on standard hardware.
Conclusion
We have seen what benefits can come from using infinitely wide neural nets:
- We gained additional insights from the kernel, right after specifying our network architecture,
- We computed the infinite width network in closed form,
- We gain a mean function and standard deviations from our dataset by doing inference.
Taking NNs to the their limits, ties them well together with Gaussian Processes; a classical machine learning approach. This allows for performing inference, observe how the data-set correlates through the given kernel(s). We have also seen, how Neural Tangents makes this accessible and can give insights and improvements to conventional network structures when building artificial neural networks.
Neural Tangents is an exciting module to allow better insights into deep neural network structures and for more analytical work on the outputs.
Neural Tangents is a high level API that takes a way a lot of the heavy lifting and problems that arise when one takes his network to the limits – with Python and JAX it is even more accessible.
The makers of Neural Tangents have already announced that they are looking into more network structures for future work to come.
For the math and computer science enthusiast I highly recommend checking out the reference publications.
The Code to produce the plots
https://colab.research.google.com/drive/1s2QdQyS9YndXpUoG-0-EtDQdXbFYyxT9?usp=sharing
References
- Fast and Easy Infinitely Wide Networks with Neural Tangents Friday, March 13, 2020 by Samuel S. Schoenholz on the Google AI Blog
- Neural Tangents Cookbook in Colab
- NEURAL TANGENTS FAST AND EASY INFINITE NEURAL NETWORKS IN PYTHON by R. Novak, L. Xiao, S. Schoenholz et al.
- Neural Tangent Kernel: Convergence and Generalization in Neural Networks by A.Jacot, F.Gabriel, C. Hongler
- Kernel Methods in Machine Learning
- Kernel Cookbook Advice on Covariance Functions by David Duvenaud
- Wide Residual Networks by Sergey Zagoruyko, Nikos Komodakis
- Gaussian Processes for Machine Learning by Carl Edward Rasmussen and Christopher K. I. Williams 2006.
- Deep Neural Networks as Gaussian Processes by J. Lee, Y. Bahri, et al. 2018.