#!/usr/bin/env python
"""
Tools for studying cold and dense lattice QCD. This module provides:
* Subtractions (a class of functions guaranteed to integrate to 0)
* Yang-Mills action
* Various Dirac matrices
* Metropolis and HMC implementations
"""
import abc
import argparse
from dataclasses import dataclass
import logging
import sys
from typing import Callable
import equinox as eqx
from equinox import nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
jax.config.update('jax_platform_name', 'cpu')
[docs]class SpecialUnitaryAlgebra:
""" The su(N) Lie algebra.
The basis is orthonormalized (so the normalization differs from that of the
Pauli and Gell-Mann bases).
"""
[docs] def __init__(self, N):
"""
Args:
N: number of colors
"""
self.N = N
self.dimension = N**2-1
self.basis = np.zeros((N**2-1,N,N))+0j
# The first N*(N-1)/2 components specify the real part of the triangle. The
# second N*(N-1)/2 components specify the imaginary part of the same
# triangle. That leaves (N-1) components, which yield the diagonal.
basis_r = self.basis[:(N*(N-1))//2]
basis_i = self.basis[(N*(N-1))//2:N*(N-1)]
basis_d = self.basis[N*(N-1):]
# Triangular parts.
triu = np.triu_indices(N, 1)
for i in range((N*(N-1))//2):
# Real (symmetric)
basis_r[i,triu[0][i],triu[1][i]] = 1./np.sqrt(2)
basis_r[i,triu[1][i],triu[0][i]] = 1./np.sqrt(2)
# Imaginary (antisymmetric)
basis_i[i,triu[0][i],triu[1][i]] = 1.j/np.sqrt(2)
basis_i[i,triu[1][i],triu[0][i]] = -1.j/np.sqrt(2)
# Diagonal parts.
for i in range(N-1):
d = np.zeros((N,))
d[:i+1] = 1.
d[i+1] = -(i+1)
d /= np.sqrt((i+2)*(i+1))
basis_d[i][np.diag_indices(N)] = d
# Compute structure constants
self.structure = np.zeros((self.dimension, self.dimension, self.dimension))+0j
for a,Ta in enumerate(self.basis):
for b,Tb in enumerate(self.basis):
# Compute commutator
Ta, Tb = self.basis[a], self.basis[b]
H = Ta@Tb - Tb@Ta
for c,Tc in enumerate(self.basis):
self.structure[a,b,c] = np.trace(H@Tc)
[docs] def vector(self, H):
"""
Obtain a vector from an element of the Lie algebra.
Args:
H: A traceless Hermitian N-by-N matrix.
Returns:
A vector.
"""
idx = jnp.arange(self.dimension)
def vi(i):
return jnp.trace(H@jnp.array(self.basis)[i])
return jax.vmap(vi)(idx)
def __call__(self, v):
"""
Obtain an element of the Lie algebra from a flat vector.
Args:
v: A vector of size (N**2-1)
Returns:
A traceless Hermitian N-by-N matrix.
"""
return jnp.einsum('i,iab->ab', v, self.basis)
[docs]class MLP(eqx.Module):
""" Multi-layer perceptron.
"""
activation: Callable = staticmethod(jax.nn.celu)
layers: list
[docs] def __init__(self, key, widths, zero=False):
Nlayers = len(widths)-1
keys = jax.random.split(key, Nlayers)
layers = [nn.Linear(widths[n], widths[n+1], key=keys[n]) for n in range(Nlayers)]
if zero:
layers[-1] = eqx.tree_at(lambda l: l.weight, layers[-1], replace_fn=jnp.zeros_like)
layers[-1] = eqx.tree_at(lambda l: l.bias, layers[-1], replace_fn=jnp.zeros_like)
self.layers = layers
def __call__(self, x):
act = self.activation
for layer in self.layers[:-1]:
x = act(layer(x))
return self.layers[-1](x)
[docs]def random_hermitian(key, N, sigma, traceless=False):
""" Generates a random Hermitian matrix.
Args:
key: A PRNG key to consume.
N: dimension of the matrix
sigma: standard deviation
traceless: If `True` (the default), the traceless part is returned
Returns:
An N-by-N Hermitian matrix.
Raises:
ValueError: If `sigma` is not a non-negative real number.
"""
HR,HI = jax.random.normal(key, shape=(2,N,N))*sigma
H = HR + 1j*HI
H += H.conj().transpose()
H /= 2.
if traceless:
H -= jnp.eye(N)*H.trace()
return H
[docs]def random_unitary(key, N, sigma, special=True):
""" Generates a random unitary matrix.
Args:
key: A PRNG key to consume.
N: dimension of group
sigma: inverse concentration near the identity
special: If `True` (the default), the determinant is constrained to be
1.
Returns:
A matrix in the group U(N) or SU(N).
Raises:
ValueError: If `sigma` is not a non-negative real number.
"""
H = random_hermitian(key, N, sigma, traceless=special)
Hvals, Hvecs = jnp.linalg.eigh(H)
return Hvecs @ jnp.diag(jnp.exp(-1j*Hvals)) @ Hvecs.conj().transpose()
[docs]def haar_unitary(key, N, special=True):
""" Generates a random unitary matrix, sampled from the Haar measure.
Args:
key: A PRNG key to consume.
N: dimension of group
special: If ``True`` (the default), the determinant is constrained to
be 1.
Returns:
A matrix in the group U(N) or SU(N)
"""
return random_unitary(key, N, 60., special)
[docs]def bootstrap(xs, ws=None, N=100, Bs=50):
""" Compute bootstrapped error bars.
If `xs` is complex, then the real part of the returned error represents the
error on the real part of the mean, and ditto for the imaginary part.
Args:
xs: Samples.
ws: Weights.
N: Number of resamplings to perform.
Bs: Number of blocks to use.
Returns:
An ordered pair consisting of the reweighted mean and its estimated
standard deviation.
"""
if Bs > len(xs):
Bs = len(xs)
B = len(xs)//Bs
if ws is None:
ws = xs*0 + 1
# Block
x, w = [], []
for i in range(Bs):
x.append(sum(xs[i*B:i*B+B]*ws[i*B:i*B+B])/sum(ws[i*B:i*B+B]))
w.append(sum(ws[i*B:i*B+B]))
x = np.array(x)
w = np.array(w)
# Regular bootstrap
y = x * w
m = (sum(y) / sum(w))
ms = []
for n in range(N):
s = np.random.choice(range(len(x)), len(x))
ms.append((sum(y[s]) / sum(w[s])))
ms = np.array(ms)
return m, np.std(ms.real) + 1j*np.std(ms.imag)
[docs]@dataclass
class Lattice:
""" Lattice geometry.
Args:
L: length of one side of the lattice
D: number of spacetime dimensions
"""
L: int
D: int = 4
[docs] def volume(self):
""" Returns the total number of sites in the lattice.
"""
return self.L**self.D
[docs] def coord(self, site, mu):
""" Compute the coordinates of a site.
Args:
site: The index of the site in question.
mu: The coordinate to compute.
Returns:
The `mu` coordinate of the specified site.
Raises:
ValueError: if `mu` isn't a valid direciton
"""
if not mu < self.D:
raise ValueError(f'invalid direction {mu}')
return (site%(self.L**(mu+1))) // (self.L**mu)
[docs] def step(self, site, mu, dist):
""" Compute the index of one site, from a starting site and an offset.
Args:
site: The site to start from.
mu: The direction to walk in.
dist: Number of steps (in the positive direction) to walk.
Returns:
The index of the specified site.
Raises:
ValueError: if `mu` isn't a valid direction.
"""
if not mu < self.D:
raise ValueError(f'invalid direction {mu}')
# First identify the component of `site` in the `mu` direction.
imu = self.coord(site, mu)
# The index of the site that is `0` in the `mu` direction.
base = site - (imu * self.L**mu)
# Shift imu
imu += dist
imu %= self.L
return base + (imu * self.L**mu)
[docs] def space(self, t):
"""
Obtain a single spatial slice of the lattice.
Args:
t: The time coordinate.
Returns:
A list of all sites at the given time coordinate.
"""
return jnp.arange(t + self.L**(self.D-1) * self.L)
[docs] def flat(self, N):
""" Produce a flat gauge configuration.
Args:
N: number of colors
Returns:
A rank-4 array containing a gauge configuration consisting only of
the identity.
"""
return jnp.tile(jnp.eye(N),(self.volume(),self.D,1,1))
[docs]def plaquette(lattice, U, site, mu, nu):
""" Compute the trace of a 1x1 plaquette.
Args:
lattice: The lattice geometry.
U: A configuration.
site: The site from which the plaquette originates.
mu: First axis.
nu: Second axis.
Returns:
The trace of the plaquette in the fundamental representation.
Raises:
ValueError: If any of `U`, `site`, `mu`, `nu` are incompatible with the
lattice geometry, or if `mu == nu`.
"""
if mu == nu:
raise ValueError('plaquette directions must differ')
# Shifted sites
site_mu = lattice.step(site,mu,1)
site_nu = lattice.step(site,nu,1)
# Involved links
U1 = U[site,mu,:,:]
U2 = U[site_mu,nu,:,:]
U3 = U[site_nu,mu,:,:].conj().transpose()
U4 = U[site,nu,:,:].conj().transpose()
# Product and trace
return (U4 @ U3 @ U2 @ U1).trace()
[docs]def action_gauge(lattice, U, g, N):
""" Compute the gauge part of the action.
Args:
lattice: The lattice geometry.
U: A gauge configuration.
g: The gauge coupling.
N: Number of colors.
Returns:
The gauge part of the Wilson action.
Raises:
ValueError: If the geometry and configuration do not match.
"""
beta = 2*N / g**2
def action_gauge_at(site):
S = 0.
for mu in range(lattice.D):
for nu in range(mu):
S += -beta*(N-plaquette(lattice, U, site, mu, nu).real)
return S
# This is a sum over sites, so list sites.
sites = jnp.arange(lattice.volume())
return jnp.sum(jax.vmap(action_gauge_at)(sites))
[docs]def pauli_matrices():
""" Obtain Pauli (sigma) matrices.
Returns:
A list with four elements---the identity and the three Pauli matrices.
"""
ident = jnp.eye(2)
sigmax = jnp.array([[0,1],[1,0]])
sigmay = jnp.array([[0,-1j],[1j,0]])
sigmaz = jnp.array([[1,0],[0,-1]])
return [ident, sigmax, sigmay, sigmaz]
[docs]def dirac_matrices():
""" Obtain Dirac (gamma) matrices.
Returns:
A list of 4 gamma matrices.
"""
ident, sigmax, sigmay, sigmaz = pauli_matrices()
gamma0 = jnp.kron(sigmax, ident)
gamma1 = jnp.kron(-sigmay, sigmax)
gamma2 = jnp.kron(-sigmay, sigmay)
gamma3 = jnp.kron(-sigmay, sigmaz)
return [gamma0, gamma1, gamma2, gamma3]
[docs]class LatticeFermions(abc.ABC):
""" Utility methods for lattice fermions.
Do not instantiate this class directly, but rather use one of the
subclasses.
"""
[docs] def __init__(self, lattice):
"""
Args:
lattice: The underlying geometry.
"""
self.lattice = lattice
[docs] @abc.abstractmethod
def dirac(self, U):
""" Obtain the Dirac operator.
Args:
U: Background gauge configuration.
Returns:
The Dirac matrix (inverse propagator) in the given gauge background.
"""
pass
[docs] @abc.abstractmethod
def density(self, Dinv):
""" Estimate the density.
Args:
Dinv: inverse Dirac operator
Returns:
The average density.
"""
pass
[docs] def det(self, U):
""" Evaluate the determinant of the Dirac matrix.
On reasonably sized lattices this method is likely to yield an
overflow, and so :meth:`slogdet` should be used instead.
Args:
U: Background gauge configuration.
Returns:
A complex number with the determinant of the Dirac matrix.
"""
D = self.dirac(U)
return jnp.linalg.det(D)
[docs] def slogdet(self, U):
""" Evaluate the determinant of the Dirac matrix, returning the phase
and logarithm.
Args:
U: Background gauge configuration.
Returns:
A pair `(z,f)`, with `z` a unit-magnitude complex number giving the
phase of the determinant, and `f` the real part of the logarithm of
the determinant.
"""
D = self.dirac(U)
return jnp.linalg.slogdet(D)
[docs] def propagator(self, U):
""" Evaluate the inverse of the Dirac matrix.
Args:
U: Background gauge configuration.
Returns:
The inverse of the Dirac matrix.
"""
D = self.dirac(U)
return jnp.linalg.inv(D)
[docs]class NaiveFermions(LatticeFermions):
""" Naive lattice fermions, with doublers not treated.
"""
[docs] def __init__(self, lattice, N:int, mass:float, mu:float=0.):
"""
Args:
lattice: The underlying geometry.
N: Number of colors.
mass: Bare fermion mass.
mu: Chemical potential.
"""
super().__init__(lattice)
self.N = N
self.mass = mass
self.mu = mu
# Initialize Euclidean gamma matrices.
self.gamma = dirac_matrices()
self.gamma5 = self.gamma[0] @ self.gamma[1] @ self.gamma[2] @ self.gamma[3]
[docs] def dirac(self, U):
D = jnp.zeros((self.lattice.volume(),self.lattice.volume(),self.N,self.N,4,4))
sites = jnp.arange(self.lattice.volume())
# Mass along the diagonal.
for c in range(self.N):
diag = jax.vmap(lambda x: (x,x,c,c))(sites)
D = D.at[diag].set(self.mass*jnp.eye(4))
# Time-like hopping (includes chemical potential).
# TODO
# Antiperiodic boundary conditions.
# TODO
# Spatial hoppings.
# TODO
D = D.transpose((0,2,4,1,3,5))
D = D.reshape((4*self.N*self.lattice.volume(), 4*self.N*self.lattice.volume()))
return D
[docs] def density(self, Dinv):
raise NotImplementedError()
[docs]class StaggeredFermions(LatticeFermions):
""" Kogut-Susskind staggered fermions.
"""
[docs] def __init__(self, lattice, N:int, mass:float, mu:float=0.):
"""
Args:
lattice: The underlying geometry.
N: Number of colors.
mass: Bare fermion mass.
mu: Chemical potential.
"""
super().__init__(lattice)
self.N = N
self.mass = mass
self.mu = mu
[docs] def dirac(self, U):
D = jnp.zeros((self.lattice.volume(),self.lattice.volume(),self.N,self.N)) +0j
sites = jnp.arange(self.lattice.volume())
# Mass along the diagonal.
diag = jax.vmap(lambda x: (x,x))(sites)
D = D.at[diag].set(self.mass*jnp.eye(self.N))
# Time-like hopping (includes chemical potential).
hopt = jax.vmap(lambda x: (x,self.lattice.step(x,0,1)))(sites)
hoptp = jax.vmap(lambda x: (self.lattice.step(x,0,1),x))(sites)
D = D.at[hopt].add(0.5 * jnp.exp(self.mu) * U[sites,0,:,:])
D = D.at[hoptp].add(-0.5 * jnp.exp(-self.mu) * U[sites,0,:,:].conj().transpose((0,2,1)))
# Antiperiodic boundary conditions.
space = self.lattice.space(0)
hopap = jax.vmap(lambda x: (x,self.lattice.step(x,0,-1)))(space)
D = D.at[hopap].multiply(-1)
# Spatial hoppings.
for i in [1,2,3]:
hopi = jax.vmap(lambda x: (x,self.lattice.step(x,i,1)))(sites)
hopip = jax.vmap(lambda x: (self.lattice.step(x,i,1),x))(sites)
eta = jax.vmap(lambda s: 1.)(sites)
for j in range(i):
eta = eta * jax.vmap(lambda s: (-1)**self.lattice.coord(s,j))(sites)
D = D.at[hopi].add(0.5 * U[sites,i,:,:])
D = D.at[hopip].add(-0.5 * U[sites,i,:,:].conj().transpose((0,2,1)))
D = D.transpose((0,2,1,3))
D = D.reshape((self.N*self.lattice.volume(), self.N*self.lattice.volume()))
return D
[docs] def density(self, Dinv):
return Dinv.trace()/self.lattice.volume()
[docs]class WilsonFermions(LatticeFermions):
""" Wilson fermions.
"""
[docs] def __init__(self, lattice, N:int, mass:float, mu:float=0.):
"""
Args:
lattice: The underlying geometry.
N: Number of colors.
mass: Bare fermion mass.
mu: Chemical potential.
"""
super().__init__(lattice)
self.N = N
self.mass = mass
self.mu = mu
# Initialize Euclidean gamma matrices.
self.gamma = dirac_matrices()
self.gamma5 = self.gamma[0] @ self.gamma[1] @ self.gamma[2] @ self.gamma[3]
[docs] def dirac(self, U):
raise NotImplementedError()
D = jnp.zeros((self.lattice.volume(),self.lattice.volume()))
return D
[docs] def density(self, Dinv):
raise NotImplementedError()
[docs]class RootedFermions(LatticeFermions):
""" Rooted staggered fermions.
"""
[docs] def __init__(self, flavors, *args):
"""
Args:
flavors: Number of flavors
args: Additional arguments used to defined the staggered fermions.
"""
raise NotImplementedError()
[docs] def dirac(self, U):
raise NotImplementedError()
[docs] def density(self, Dinv):
raise NotImplementedError()
[docs] def det(self, U):
raise NotImplementedError()
[docs] def slogdet(self, U):
raise NotImplementedError()
[docs]class Metropolis:
""" Markov chain using the Metropolis algorithm.
"""
[docs] def __init__(self, x0, action, propose, key):
"""
Args:
x0: Initial field configuration.
action: The action
propose: A function for generating proposals.
key: PRNG key for generating proposals.
"""
self.x = x0
self.action = action
self.propose = propose
self.delta = 1.
self._key = key
self._recent = [False]
[docs] def step(self, N=1):
"""
Args:
N: The number of steps to take before returning a configuration.
"""
logging.info(f'taking {N} steps')
self.S = self.action(self.x).real
for _ in range(N):
kprop, kacc, self._key = jax.random.split(self._key, 3)
xp = self.propose(kprop, self.x, self.delta)
Sp = self.action(xp).real
Sdiff = Sp - self.S
accepted = False
if jax.random.uniform(kacc) < jnp.exp(-Sdiff):
self.x = xp
self.S = Sp
accepted = True
self._recent.append(accepted)
self._recent = self._recent[-100:]
[docs] def calibrate(self):
""" Calibrate the chain, tweaking `delta` until the acceptance rate
lies between 0.3 and 0.55.
"""
logging.info(f'beginning calibration')
self.step(N=100)
while self.acceptance_rate() < 0.3 or self.acceptance_rate() > 0.55:
logging.info(f'calibrating (acceptance rate is {self.acceptance_rate()})')
if self.acceptance_rate() < 0.3:
self.delta *= 0.98
if self.acceptance_rate() > 0.55:
self.delta *= 1.02
self.step(N=100)
[docs] def acceptance_rate(self):
""" Estimates the acceptance rate.
Returns:
The acceptance rate over the last 100 steps.
"""
return sum(self._recent) / len(self._recent)
[docs] def iter(self, skip=1):
""" An infinite iterator yielding configurations.
"""
while True:
self.step(N=skip)
yield self.x
[docs]class HMC:
""" Hamiltonian Monte Carlo.
"""
[docs] def __init__(self):
"""
Args:
"""
[docs] def step(self, N=1):
"""
"""
[docs] def iter(self):
"""
"""
while True:
pass
[docs]def main(args):
if args.verbose:
logging.basicConfig(level=logging.INFO)
seed = args.seed
if args.seed_time:
seed = time.time_ns()
mckey, subkey = jax.random.split(jax.random.PRNGKey(seed))
N = args.colors
lattice = Lattice(args.L)
fermions = [StaggeredFermions(lattice, N, args.mass, args.mu)]
@jax.jit
def action(U):
Sgauge = action_gauge(lattice, U, args.g, N)
Sfermi = 0.
for fermi in fermions:
s, lndet = fermi.slogdet(U)
Sfermi -= lndet+jnp.log(s)
return Sgauge + Sfermi
@jax.jit
def observe(U):
# Mean plaquette.
def mean_plaquette_at(site):
P = 0.
for mu in range(lattice.D):
for nu in range(mu):
P += plaquette(lattice, U, site, mu, nu)
return P
sites = jnp.arange(lattice.volume())
mean_plaquette = jnp.mean(jax.vmap(mean_plaquette_at)(sites))
# Fermion density.
density = 0.
for fermi in fermions:
density += fermi.density(fermi.propagator(U))
# Fermion correlators.
# TODO
return mean_plaquette, density
@jax.jit
def propose(key, U, delta):
delta /= jnp.sqrt(lattice.volume())
# We need lattice.D * lattice.volume() keys.
keys = jax.random.split(key, lattice.D * lattice.volume())
V = jax.vmap(lambda k: random_unitary(k, N, delta))(keys)
V = V.reshape(U.shape)
return jnp.einsum('xiab,xibc->xiac', U, V)
# Initialize/load the subtraction.
subtraction = ExactForm(subkey, 1, 1, lattice.D*lattice.volume())
if not args.init:
try:
subtraction = eqx.tree_deserialise_leaves(args.subtraction, subtraction)
except FileNotFoundError:
pass
# Prepare the Monte Carlo.
chain = Metropolis(lattice.flat(N), action, propose, mckey)
chain.calibrate()
if args.train:
# Parameters will be saved after every gradient step.
def save():
eqx.tree_serialise_leaves(args.subtraction, subtraction)
# Prepare the optimizer.
opt = optax.yogi(1e-3)
opt_state = opt.init(eqx.filter(subtraction, eqx.is_array))
# Stochastic loss functions
def Seff_real():
raise NotImplementedError()
# Train the subtraction.
try:
while True:
save()
raise NotImplementedError()
except KeyboardInterrupt:
pass
else:
# Don't train---just measure.
try:
for U in chain.iter():
S = action(U)
obs = np.array(observe(U))
print(np.exp(-1j*S.imag), obs)
except KeyboardInterrupt:
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Lattice QCD at finite density',
fromfile_prefix_chars='@')
parser.add_argument('-N', '--colors', default=3, type=int, help='Number of colors (3)')
parser.add_argument('-i', '--init', action='store_true', help='re-initialize the subtraction')
parser.add_argument('-t', '--train', action='store_true', help='train the subtraction')
parser.add_argument('subtraction', type=str, help='subtraction filename')
parser.add_argument('L', type=int, help='Lattice size')
parser.add_argument('g', type=float, help='Gauge coupling')
parser.add_argument('mass', type=float, help='Bare fermion mass')
parser.add_argument('mu', type=float, help='Chemical potential')
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose mode')
seed_group = parser.add_mutually_exclusive_group()
seed_group.add_argument('--seed', type=int, default=0, help="random seed")
seed_group.add_argument('--seed-time', action='store_true', help="seed PRNG with current time")
args = parser.parse_args()
main(args)