DEV Community

Cover image for Bridging numerical relativity and automatic differentiation using JAX
Baalateja Kataru
Baalateja Kataru

Posted on

Bridging numerical relativity and automatic differentiation using JAX

Introduction

Albert Einstein's general theory of relativity is the most successful picture of how gravity works that we have as of today. It underlies our current understanding of the physics of stars, galaxies, and the universe itself. The theory has made sound, testable predictions such as gravitational waves, black holes, the big bang, and gravitational time dilation, predictions which have stood the test of time and been experimentally verified time and again.

These predictions have also driven the development of practical, real-world applications such as improved GPS capabilities due to being able to understand the effects of the Earth's gravity on timekeeping and synchronization for satellites operating in the vicinity of near-Earth space.

At the heart of the mathematics of general relativity is the language of tensor calculus and differential geometry, which provides a prescription for arbitrary differentiation of multilinear structures on smooth, semi-Riemannian manifolds.

Coincidentally, tensors are also the way that data is represented and modeled in AI and deep learning, providing a convenient structure to handle data of all kinds of shapes and sizes. The need to solve optimization problems efficiently and up to maximum machine precision in order to learn features and patterns from data to the best extent possible, resulted in the field of AI largely pioneering the technique of automatic differentiation, a more robust way to compute derivatives numerically and the theory from which backpropagation is derived as a particular case. Automatic differentiation for tensors of arbitrary shapes and sizes has since been refined and implemented in day-to-day deep learning libraries such as PyTorch and JAX, to be utilized primarily for building neural network architectures.

In this piece, I use JAX and the technique of automatic differentiation to explore this bridge between the mathematics of general relativity and deep learning, in an attempt to drive this synergy forward and bring the excellent methods and tools of building neural networks and making machines learn to numerical relativity and scientific computing.

What is automatic differentiation?

Automatic differentiation is a technique for computing the derivative of a function extremely efficiently and with exact numerical precision. Two claims that cannot be made for numeric and symbolic differentiation, which are as of now the predominant methods to compute derivatives in scientific computing.

Symbolic differentiation involves automated manipulation of symbols while respecting the rules of algebra and mathematics to derive exact expressions for derivatives. This form of differentiation usually requires a CAS (Computer Algebra System) that has knowledge of how to perform differentiation by hand using the rules of calculus and algebra. While symbolic differentiation does provide exact results by virtue of producing exact expressions of derivatives, it is the most computationally expensive form of differentiation, leaving it lacking in the domains of speed and efficiency. Examples of symbolic differentiation software include Mathematica, SageMath and SymPy.

Numerical differentiation calculates derivatives by using the method of finite differences to approximate the limit definition of differentiation by first principles. While it is faster than symbolic differentiation, it is plagued by accuracy and stability issues stemming from floating point arithmetic. Common pain points include numerical instability due to dividing by extremely small numbers which causes expressions to explode, intermediate floating point round off errors which accumulate across expressions leading to divergence from the true derivative, and the limits of floating point precision and storage in modern day classical computing.

Automatic differentiation intends to circumvent all these issues by tracing out all the operations defined and constructing a DAG (directed acyclic graph) for the target function, and computing gradients for each variable by tracing backwards in the graph via the chain rule. This is essentially how optimization of deep neural networks is done, via something called backpropagation - a form of automatic differentiation. Autodiff is at the heart of modern machine learning and deep learning libraries like tensorflow, pytorch and jax. PyTorch's torch.autograd module provides methods to run autodiff on PyTorch tensors, which is used to train neural networks in order to minimize predicion loss (error) and improve accuracy of outputs.

For a visual explanation of automatic differentiation, check out: https://www.youtube.com/watch?v=wG_nF1awSSY

To implement automatic differentiation yourself, check out this tutorial on building micrograd, a tiny autodiff engine, by the brilliant Andrej Karpathy, The spelled-out intro to neural networks and backpropagation: building micrograd. This tutorial also doubles as an excellent hands-on introduction to the world of modern day machine/deep learning and the frameworks involved like tensorflow and pytorch because you're essentially building a smaller version of them.

What is JAX?

jax is a high performance machine learning and scientific computing library for multilinear algebra, automatic differentiation, and for writing performant numerical code that can run on CPUs as well as accelerators such as GPUs and TPUs for profitable speedups.

In a nutshell, jax is numpy with autodiff (automatic differentiation) support. I say this because jax's API has a 1-1 correspondance to numpy, which makes jax a drop-in replacement for numpy.

Compared to other machine learning libraries like tensorflow and pytorch that come packed with practically everything you need to do machine learning of any sorts at scale, jax embraces the "simple is beautiful" philosophy by being leaner and simpler in its design, represented by its smaller size (jax is 9MB while other libraries like pytorch and tensorflow exceed 1GB when installing with CUDA support). jax can be further augmented by libraries like flox and equinox, allowing one to do more things like construct and train neural networks, nominally putting it on par with bigger libraries like tensorflow and pytorch.

What does this have to do with General Relativity?

Firstly, GR is formulated in the language of tensors and multilinear algebra. A tensor is a multidimensional array of numbers that acts as a generalization of vectors and matrices to higher dimensions. Like vectors and matrices, we can do calculus on tensors, which is a large part of the calculations we do in GR (computing Christoffel symbols, etc).

Tensors are also the language of neural networks and deep learning. Input and output data, and parameters such as a neural network's weights and biases, are stored as tensors to allow flexibility when it comes to representing data of many different shapes and sizes. We also need to be able to compute derivatives of tensors and do calculus on them in order to train neural networks. As one might gather, this means that there lies a natural intersection between GR and deep learning when it comes to the language they use to represent and manipulate data. Of course, there are conceptual differences that abound and need to be kept in mind when trying to make sense of this synergy. For instance, the mathematics of general relativity (and tensor calculus in general) defines an indexed quantity as a tensor based on its transformation properties, whereas any indexed quantity (eg. the Christoffel symbols) qualifies as a tensor in deep learning. However, such details are technicalities that can be kept in mind and accounted for by exercising adequate oversight.

The rapid developments in the field of AI in the past few decades have led to creation of efficient and accurate methods and tools for tensor algebra and calculus. Automatic differentiation and jax are just some of many such methods and tools, but prominent ones due to their potential applications to the field of physics. However, I have noticed that there is a gaping lack of utilization of these tools and methods that have been built and refined for AI/ML to solve computational problems and improve computations done in modern day physics. I partly attribute this to the novel nature of the tools themselves, and I'm sure that given enough time and efforts, there will be widespread adoption of them to the many different subfields of physics.

In fact, such adoption has already begun with automatic differentiation finding applications in fluid dynamics and differential equation solving. From my search of the interweb and arxiv, however, the field of general relativity, and specifically numerical relativity, which deals with numerically computing quantities and solving equations of Einstein's general theory of relativity, has had little to no such adoption, which I believe it is time to change.

What are we doing?

In this exercise, I will use jax to demonstrate the power of autodiff by exploring its potential and applicability to the field of numerical relativity.

Given a metric and some coordinates, we will explore how to compute derivatives of the metric tensor (Christoffel symbols) and other relevant tensors and quantities, such as the Riemann and Ricci tensors, the Ricci scalar curvature and the Kretschmann invariant, in order to finally compute the Einstein tensor and the stress-energy-momentum tensor.

We will make use of modern Python features such as type hinting, which was introduced in recent versions of Python (3.11, 3.12, ...) and allows us to explicitly specify the data types of the variables we're using in our code to enable better readability and type safety, and decorators, which are higher-order functions that modify/augment the behavior of the functions given to them as inputs to transform said functions' inputs/outputs for a specific purpose, among other things.

Note: we will be working exclusively in SI units and not natural units of any sort for ease of interpreting calculations when deriving numerical values of quantities.

Setup

We begin by importing dependencies from the typing library for type hinting support and importing jax + its numpy variant jax.numpy.

We configure jax to force all floating point numbers to be in 64-bit precision for higher accuracy in our results.

We also define a utility decorator function close_to_zero to help round off values close to zero and below a certain arbitrarily chosen tolerance in our tensors in order to reduce floating point errors arising from numerical precision issues compounding in intermediary arithmetic calculation steps.

from typing import Callable

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

TOLERANCE = 1e-8

def close_to_zero(func):
    def wrapper(*args, **kwargs) -> jnp.ndarray:
        result: jnp.ndarray = func(*args, **kwargs)

        return jnp.where(jnp.abs(result) < TOLERANCE, 0.0, result)

    return wrapper
Enter fullscreen mode Exit fullscreen mode

Defining common metrics

We define some common metrics that are used as trivial examples for the rest of the tutorial.

# while the minkowski metric is a constant tensor that is not coordinate dependent,
# we still need to take in some "dummy" coordinates in order to make this function play nice with JAX's autodiff mechanism
def minkowski_metric(coordinates: jnp.ndarray) -> jnp.ndarray:
    """Returns the Minkowski metric in float64 precision with the (-1, 1, 1, 1) metric signature"""
    return jnp.diag(jnp.array([-1, 1, 1, 1], dtype=jnp.float64))

# this is the standard metric for a 2-sphere.
@close_to_zero
def spherical_polar_metric(coordinates: jnp.ndarray) -> jnp.ndarray:
    r, theta, phi = coordinates
    return jnp.diag(jnp.array([1, r**2, r**2 * jnp.sin(theta)**2], dtype=jnp.float64))
Enter fullscreen mode Exit fullscreen mode

The Christoffel symbols

Given a metric gijg_{ij} , the Christoffel symbols of the second kind (also called the affine connection, or the connection coefficients) are defined as derivatives of the metric contracted with the inverse metric tensor gijg^{ij} :

Γklj=12gjm(gmkxl+glmxkgklxm)\Gamma_{kl}^{j} = \frac{1}{2} g^{jm} \left( \frac{\partial g_{mk} }{\partial x^l } + \frac{\partial g_{lm} }{\partial x^k } - \frac{\partial g_{kl} }{\partial x^m } \right)

Note: This definition and subsequent ones are taken from Mathematical Methods for Students of Physics and Related Fields by Sadri Hassani and uses all Roman indices. We will be doing the same and using all Roman indices throughout this tutorial, even when dealing with spacetime quantities and not just Cartesian quantities, because of their convenience when it comes to specifying them in the code. Greek symbols are non-trivial when it comes to their underlying Unicode representation, and that's a complexity I want to avoid for now.

The Christoffel symbols represent the gravitational field in a given spacetime, being derivatives of the metric, which itself is analogous to the gravitational potential in Newtonian gravitation.

Now, we define a Python function to compute the Christoffel symbols. The function takes in the coordinates to compute the Christoffel symbols at, and the metric function to compute the Christoffel symbols for, and its implementation is broadly as follows:

  1. We evaluate the metric function to get the metric at the given coordinates and calculate the inverse metric tensor along the way using jnp.linalg.inv.

  2. We use jax.jacfwd to compute the "Jacobian" of the metric, i.e., its partial derivatives with respect to the given coordinates, using forward-mode automatic differentiation. This is gklxm\frac{\partial g_{kl} }{\partial x^m } , also denoted as gkl;mg_{kl;m} . I write "Jacobian" in quotes because rigorously speaking, the Jacobian is defined as the matrix of partial derivatives of a vector-valued function with respect to its inputs, however, what we have is a tensor-valued function in the form of the metric. I am not aware of any mathematical quantity that is used to describe the higher-rank tensor associated with the derivatives of a tensor-valued function, hence my loose usage of the word "Jacobian" here.

  3. We use jnp.einsum, an extremely convenient and performant subroutine that manipulates computational tensors and other indexed objects using standard index notation, to compute the Christoffel symbols according to the equation given above.

@close_to_zero # this ensures that any values of our Christoffel symbols are rounded off if they're close to 0
def christoffel_symbols(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    # evaluate the metric tensor at the coordinates
    g = metric(coordinates)
    # compute the inverse metric tensor
    g_inv = jnp.linalg.inv(g)
    # obtain and evaluate the "jacobian" of the metric tensor at the coordinates
    jacobian = jax.jacfwd(metric)(coordinates) # this is kl;m

    return 0.5 * jnp.einsum('jm, klm -> jkl', g_inv, jnp.einsum('klm -> mkl', jacobian) + jnp.einsum('klm -> lmk', jacobian) - jacobian)
Enter fullscreen mode Exit fullscreen mode

The Torsion Tensor

The torsion tensor is defined as the antisymmetric part of a general affine connection Γhkl\Gamma^l_{hk} :

ΓhklΓkhl\Gamma^l_{hk} - \Gamma^l_{kh}

If the torsion tensor vanishes in one coordinate system, then it vanishes in all coordinate systems (the zero tensor is zero in all coordinate systems). Therefore, the torsion tensor of a general affine connection is zero if and only if the connection is symmetric, i.e.,

Γhkl=Γkhl\Gamma^l_{hk} = \Gamma^l_{kh}

Since we are dealing only with Christoffel symbols of the second kind, which is a unique and symmetric affine connection derived from the metric tensor, all the torsion tensors that we will be computing in this exploration should be zero, and the following subroutine will be used to verify that.

@close_to_zero
def torsion_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    christoffels = christoffel_symbols(coordinates, metric)

    return christoffels - jnp.einsum('ijk -> ikj', christoffels)
Enter fullscreen mode Exit fullscreen mode

The Riemann Curvature Tensor and the Kretschmann Invariant

The Riemann tensor RklmjR^j_{klm} encodes the intrinsic curvature of the Riemannian (or semi-Riemannian) manifold produced by any given metric. It is defined as derivatives of the Christoffel symbols Γklj\Gamma^j_{kl} as:

Rklmj=mΓkljlΓkmj+ΓrmjΓklrΓrljΓkmrR^j_{klm} = \partial_m \Gamma^j_{kl} - \partial_l \Gamma^j_{km} + \Gamma^j_{rm} \Gamma^r_{kl} - \Gamma^j_{rl} \Gamma^r_{km}

We define a Python function that:

  1. Uses the previous christoffel_symbols function to obtain the Christoffel symbols for a given set of metric and coordinates.
  2. Computes the "jacobian" of the Christoffel symbols to obtain Γkl;mj\Gamma^j_{kl;m} .
  3. Manipulates this "jacobian" tensor and products of the Christoffel symbols using jnp.einsum to obtain the Riemann tensor

We also define a Python function to compute the Kretschmann invariant from the Riemann tensor, a scalar that is used to look for true physical singularities (gravitational singularities) in certain manifolds independent of the choice of coordinates:

RjklmRjklmR^{jklm} R_{jklm}

We will use the Kretschmann invariant later on in this tutorial to verify whether the Riemann tensor implementation we have below is correct or not when using the Schwarzchild metric as a case study to verify the correctness of these subroutines.

@close_to_zero
def riemann_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    christoffels = christoffel_symbols(coordinates, metric)
    jacobian = jax.jacfwd(christoffel_symbols)(coordinates, metric) # computes jkl;m

    return jacobian - jnp.einsum('jklm -> jkml', jacobian) + jnp.einsum('jrm, rkl -> jklm', christoffels, christoffels) - jnp.einsum('jrl, rkm -> jklm', christoffels, christoffels)

@close_to_zero
def kretschmann_invariant(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    riemann = riemann_tensor(coordinates, metric)

    g = metric(coordinates)
    g_inv = jnp.linalg.inv(g)

    riemann_upper = jnp.einsum('pj, qk, rl, ijkl -> ipqr', g_inv, g_inv, g_inv, riemann) # computes R^{jklm} by contracting with three inverse metric tensors
    riemann_lower = jnp.einsum('pi, ijkl -> pjkl', g, riemann) # computes R_{jklm} by contracting with one metric tensor

    return jnp.einsum('ijkl, ijkl ->', riemann_upper, riemann_lower)
Enter fullscreen mode Exit fullscreen mode

The Ricci tensor and the Ricci scalar curvature

The Ricci tensor RklR_{kl} is another curvature-related quantity which is defined as the trace component of the Riemann tensor RklmjR^j_{klm} . The Ricci tensor is obtained by contracting the only contravariant index of the Riemann tensor with its last covariant index:

Rkl=RkljjR_{kl} = R^j_{klj}

Physically, the Ricci tensor encodes information about how volumes change in the presence of tidal forces.

By raising one of the Ricci tensor's indices and contracting, we obtain the Ricci scalar curvature

R=Rll=gklRklR = R^l_l = g^{kl}R_{kl}

We implement Python subroutines to do this computationally using jnp.einsum again to manipulate indices and perform contractions.

@close_to_zero
def ricci_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    riemann = riemann_tensor(coordinates, metric)

    return jnp.einsum('jklj -> kl', riemann) # contracting the first and last indices

@close_to_zero
def ricci_scalar(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.float32:
    g = metric(coordinates)
    g_inv = jnp.linalg.inv(g)
    ricci = ricci_tensor(coordinates, metric)

    return jnp.einsum('kl, kl -> ', g_inv, ricci) # trace of the ricci tensor
Enter fullscreen mode Exit fullscreen mode

The Einstein tensor and the stress-energy-momentum tensor

The Einstein tensor, the crown jewel of the general theory of relativity, encodes all information about the curvature of a spacetime manifold, and is defined in terms of the Ricci tensor, the metric tensor, and the Ricci scalar curvature as:

GijRij12gijRG_{ij} \equiv R_{ij} - \frac{1}{2} g_{ij} R

It forms the left-hand-side of the Einstein Field Equations (EFEs), a set of 16 coupled partial differential equations that relate the curvature of a spacetime manifold to the mass-energy content in it:

Gij=8πGc4TijG_{ij} = \frac{8 \pi G}{c^4} T_{ij}

The right hand side TijT_{ij} is the stress-energy-momentum tensor, that encodes information about all the mass-energy present in a spacetime manifold.

We write Python functions to call the subroutines implemented before to trivially compute the Einstein tensor and the stress-energy-momentum tensor using the equations we just described.

@close_to_zero
def einstein_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    g = metric(coordinates)
    rt = ricci_tensor(coordinates, metric)
    rs = ricci_scalar(coordinates, metric)

    return rt - 0.5 * g * rs

@close_to_zero
def stress_energy_momentum_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    G = einstein_tensor(coordinates, metric)

    kappa = (8 * jnp.pi * 6.67e-11) / ((299792458)**4)

    return G / kappa
Enter fullscreen mode Exit fullscreen mode

Case study 1: the 2-sphere metric

Using all of the Python functions we have defined before, let's perform calculations in float64 precision for the 2-sphere metric given in spherical polar coordinates (r,θ,ϕ)(r, \theta, \phi) as:

gij=diag(1,r2,r2sin2(θ))g_{ij} = \text{diag}(1, r^2, r^2 \sin^2(\theta))

We use the following coordinate values for the calculations, which were arbitrarily chosen

r=5r = 5
θ=π/3\theta = \pi/3
ϕ=π/2\phi = \pi/2

coordinates = jnp.array([5, jnp.pi/3, jnp.pi/2], dtype=jnp.float64)
metric = spherical_polar_metric

print(f"Christoffel symbols: {christoffel_symbols(coordinates, metric)}")
print(f"Torsion tensor: {torsion_tensor(coordinates, metric)}")
print(f"Riemann tensor: {riemann_tensor(coordinates, metric)}")
print(f"Ricci tensor: {ricci_tensor(coordinates, metric)}")
print(f"Ricci scalar: {ricci_scalar(coordinates, metric)}")
print(f"Einstein tensor: {einstein_tensor(coordinates, metric)}")
print(f"Stress-energy-momentum tensor: {stress_energy_momentum_tensor(coordinates, metric)}")
print(f"Kretschmann invariant: {kretschmann_invariant(coordinates, metric)}")
Enter fullscreen mode Exit fullscreen mode

Running this, we get

Christoffel symbols: [[[ 0.          0.          0.        ]
  [ 0.         -5.          0.        ]
  [ 0.          0.         -3.75      ]]

 [[ 0.          0.2         0.        ]
  [ 0.2         0.          0.        ]
  [ 0.          0.         -0.4330127 ]]

 [[ 0.          0.          0.2       ]
  [ 0.          0.          0.57735027]
  [ 0.2         0.57735027  0.        ]]]
Torsion tensor: [[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
Riemann tensor: [[[[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]


 [[[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]


 [[[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]]
Ricci tensor: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Ricci scalar: 0.0
Einstein tensor: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Stress-energy-momentum tensor: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Kretschmann invariant: 0.0
Enter fullscreen mode Exit fullscreen mode

We obtain all of the Christoffel symbols for this metric. Furthermore, as expected, the Riemann tensor, Ricci tensor, Ricci scalar, the Einstein tensor, and the stress-energy-momentum tensor all vanish since this metric describes a flat spacetime with no curvature and no mass-energy content. The torsion tensor is also zero as expected by the symmetric nature of the Christoffel symbols.

Case study 2: the Schwarzschild metric

Now we come to a more interesting case study, the Schwarzschild metric, which describes the spacetime of an uncharged, unrotating, spherically symmetric body in vacuum. The Schwarzschild metric in traditional spherical polar spacetime coordinates (t,r,θ,ϕ)(t, r, \theta, \phi) is given as:

gij=diag((1rsr)c2,1(1rsr),r2,r2sin2(θ))g_{ij} = \text{diag}\left(- \left( 1 - \frac{r_s}{r} \right) c^2, \frac{1}{\left(1 - \frac{r_s}{r} \right)}, r^2, r^2 \sin^2(\theta) \right)

Where rsr_s is the Schwarzschild radius of the massive body, a scale factor which is related to its mass MM by:

rs=2GMc2r_s = \frac{2 G M}{c^2}

We do exactly the same as the previous case study, computing all the quantities in float64 precision for a body with ~4.3 million solar masses using the following arbitrary chosen spacetime spherical polar coordinates

t=3600t = 3600

r=3000r = 3000

θ=π/3\theta = \pi/3

ϕ=π/2\phi = \pi/2

G = 6.67e-11
c = 299792458.0

M = 4.297e+6 * 1.989e+30 # 4.3 million solar masses, mass of Sgr A*

# schwarzschild radius
rs = (2 * G * M) / c**2

@close_to_zero
def schwarzschild_metric(coordinates: jnp.ndarray) -> jnp.ndarray:
    t, r, theta, phi = coordinates

    return jnp.diag(jnp.array([-(1 - (rs / r)) * c**2, 1/(1 - (rs/r)), r**2, r**2 * jnp.sin(theta)**2], dtype=jnp.float64))

coordinates = jnp.array([3600, 3000, jnp.pi/3, jnp.pi/2], dtype=jnp.float64)
metric = schwarzschild_metric

print(f"Christoffel symbols: {christoffel_symbols(coordinates, metric)}")
print(f"Torsion tensor: {torsion_tensor(coordinates, metric)}")
print(f"Riemann tensor: {riemann_tensor(coordinates, metric)}")
print(f"Ricci tensor: {ricci_tensor(coordinates, metric)}")
print(f"Ricci scalar: {ricci_scalar(coordinates, metric)}")
print(f"Einstein tensor: {einstein_tensor(coordinates, metric)}")
print(f"Stress-energy-momentum tensor: {stress_energy_momentum_tensor(coordinates, metric)}")
print(f"Kretschmann invariant: {kretschmann_invariant(coordinates, metric)}")
Enter fullscreen mode Exit fullscreen mode

Running this outputs

Christoffel symbols: [[[ 0.00000000e+00 -1.66666706e-04  0.00000000e+00  0.00000000e+00]
  [-1.66666706e-04  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

 [[-2.67840757e+26  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  1.66666706e-04  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  1.26857006e+10  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  9.51427546e+09]]

 [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  3.33333333e-04  0.00000000e+00]
  [ 0.00000000e+00  3.33333333e-04  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -4.33012702e-01]]

 [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  3.33333333e-04]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  5.77350269e-01]
  [ 0.00000000e+00  3.33333333e-04  5.77350269e-01  0.00000000e+00]]]
Torsion tensor: [[[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]]
Riemann tensor: [[[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  1.11111137e-07  0.00000000e+00  0.00000000e+00]
   [-1.11111137e-07  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  2.11428394e+06  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [-2.11428394e+06  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.58571295e+06]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [-1.58571295e+06  0.00000000e+00  0.00000000e+00  0.00000000e+00]]]


 [[[ 0.00000000e+00  1.78560505e+23  0.00000000e+00  0.00000000e+00]
   [-1.78560505e+23  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  2.11428394e+06  0.00000000e+00]
   [ 0.00000000e+00 -2.11428394e+06  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.58571295e+06]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -1.90734863e-06]
   [ 0.00000000e+00 -1.58571295e+06  1.90734863e-06  0.00000000e+00]]]


 [[[ 0.00000000e+00  0.00000000e+00 -8.92802525e+22  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 8.92802525e+22  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  5.55555687e-08  0.00000000e+00]
   [ 0.00000000e+00 -5.55555687e-08  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -3.17142590e+06]
   [ 0.00000000e+00  0.00000000e+00  3.17142590e+06  0.00000000e+00]]]


 [[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -8.92802525e+22]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 8.92802525e+22  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  5.55555687e-08]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00 -5.55555687e-08  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  4.22856787e+06]
   [ 0.00000000e+00  0.00000000e+00 -4.22856787e+06  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]]]
Ricci tensor: [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
Ricci scalar: 0.0
Einstein tensor: [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
Stress-energy-momentum tensor: [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
Kretschmann invariant: 2.649005370647907
Enter fullscreen mode Exit fullscreen mode

Like before, we've obtained our Christoffel symbols for the Schwarzschild metric and the torsion tensor is zero as expected.

Furthermore, since the Schwarzschild metric is a vacuum solution to the Einstein Field Equations, the Ricci tensor, Ricci scalar, Einstein tensor, and Stress-energy-momentum tensor are all zero.

However, observe that the Riemann tensor is not zero. In fact, some of its components are quite large in magnitude. We also obtain a value for the Kretschmann invariant from the Riemann tensor's contraction. To verify this and thereby verify whether the components of our calculated Riemann tensor are correct or not, we can compute the Krietschmann invariant directly using a relatively trivial scalar formula:

RjklmRjklm=48G2M2c4r6R^{jklm} R_{jklm} = \frac{48 G^2 M^2}{c^4 r^6}
G = 6.67e-11
M = 4.297e+6 * 1.989e+30 # 4.3 million solar masses, mass of Sgr A*
c = 299792458.0

r = 3000

kr = (48 * G**2 * M**2) / (c**4 * r**6)

print("Kretschmann invariant is", kr)
Enter fullscreen mode Exit fullscreen mode

Running this, we get the output

Kretschmann invariant is 2.649005370647906
Enter fullscreen mode Exit fullscreen mode

Which matches the result we obtained by contracting all indices of the Riemann tensor up to 15 decimal places!

Conclusion

In this exploration, we've witnessed first hand the power of automatic differentiation in enabling derivative computation of all quantities to maximum machine precision in terms of floating point accuracy.

We saw the effectiveness and simplicity of JAX's jax.jacfwd to compute Jacobians/derivatives of tensors in forward-mode automatic differentiation.

We've also seen the benefits of JAX's jnp.einsum with its ability to do index manipulations and perform calculations at a high level, allowing one to do tensor calculus operations in a convenient and efficient manner. The language of tensors and tensor calculus is pervasive in other topics and fields of theoretical physics, such as covariant electromagnetism and quantum field theory. At the very least, it would be an interesting exercise to explore a similar approach leveraging modern AI/ML libraries and frameworks in order to perform tensor computations for calculations arising in these fields.

JAX's capabilities further allows us to parallelize all of this code and run it on accelerators such as GPUs and TPUs with no changes to the original code, in order to speed up calculations for heavy numerical computations and simulations requiring parallelism, speed, and efficiency. We have not explored that aspect specifically in this demonstration, and all computations have been done on the CPU since none of the computations or use cases required usage of parallelism and accelerators, but nonetheless the capabilities exist and can be called upon if a specific use case demands it.

The explanations, scripts, and results of this exploration are also conveniently documented in this Jupyter notebook hosted on Google Colab, which you can use to reproduce and build upon the ideas presented here.

Future work

Some ideas and extensions in this general direction that are worth exploring:

  • Perform the same calculations for other prominent metrics such as the Kerr and Kerr-Newmann metrics
  • Write functions to compute the Weyl tensor and the Weyl invariant, other important curvature-related quantities.
  • Provide alternative but equivalent PyTorch and Tensorflow implementations to the JAX implementation here.
  • Use the @jax.jit decorator on relevant subroutines for using JIT compilation in order to speed up computations.
  • Look into jax.jacrev for reverse-mode automatic differentiation and try to combine it with jax.jacfwd for optimal speed and accuracy.
  • Configure this code to run on GPUs and other accelerators.
  • Train a neural network to parameterize the metric tensor and solve an optimization problem in order to find the metric tensor components from data.
  • Compute the Christoffel symbols for a metric and solve the geodesic equation to find the EOM using automatic differentiation powered differential equation solvers, such as diffrax
  • Extend this methodology to other aspects and calculations of numerical relativity, specifically those tackled by popular general relativity Python libraries such as EinsteinPy

Top comments (0)