Source code for qcd

#!/usr/bin/env python

"""
Tools for studying cold and dense lattice QCD. This module provides:

* Subtractions (a class of functions guaranteed to integrate to 0)
* Yang-Mills action
* Various Dirac matrices
* Metropolis and HMC implementations
"""

import abc
import argparse
from dataclasses import dataclass
import logging
import sys
from typing import Callable

import equinox as eqx
from equinox import nn
import jax
import jax.numpy as jnp
import numpy as np
import optax

jax.config.update('jax_platform_name', 'cpu')

[docs]class SpecialUnitaryAlgebra: """ The su(N) Lie algebra. The basis is orthonormalized (so the normalization differs from that of the Pauli and Gell-Mann bases). """
[docs] def __init__(self, N): """ Args: N: number of colors """ self.N = N self.dimension = N**2-1 self.basis = np.zeros((N**2-1,N,N))+0j # The first N*(N-1)/2 components specify the real part of the triangle. The # second N*(N-1)/2 components specify the imaginary part of the same # triangle. That leaves (N-1) components, which yield the diagonal. basis_r = self.basis[:(N*(N-1))//2] basis_i = self.basis[(N*(N-1))//2:N*(N-1)] basis_d = self.basis[N*(N-1):] # Triangular parts. triu = np.triu_indices(N, 1) for i in range((N*(N-1))//2): # Real (symmetric) basis_r[i,triu[0][i],triu[1][i]] = 1./np.sqrt(2) basis_r[i,triu[1][i],triu[0][i]] = 1./np.sqrt(2) # Imaginary (antisymmetric) basis_i[i,triu[0][i],triu[1][i]] = 1.j/np.sqrt(2) basis_i[i,triu[1][i],triu[0][i]] = -1.j/np.sqrt(2) # Diagonal parts. for i in range(N-1): d = np.zeros((N,)) d[:i+1] = 1. d[i+1] = -(i+1) d /= np.sqrt((i+2)*(i+1)) basis_d[i][np.diag_indices(N)] = d # Compute structure constants self.structure = np.zeros((self.dimension, self.dimension, self.dimension))+0j for a,Ta in enumerate(self.basis): for b,Tb in enumerate(self.basis): # Compute commutator Ta, Tb = self.basis[a], self.basis[b] H = Ta@Tb - Tb@Ta for c,Tc in enumerate(self.basis): self.structure[a,b,c] = np.trace(H@Tc)
[docs] def vector(self, H): """ Obtain a vector from an element of the Lie algebra. Args: H: A traceless Hermitian N-by-N matrix. Returns: A vector. """ idx = jnp.arange(self.dimension) def vi(i): return jnp.trace(H@jnp.array(self.basis)[i]) return jax.vmap(vi)(idx)
def __call__(self, v): """ Obtain an element of the Lie algebra from a flat vector. Args: v: A vector of size (N**2-1) Returns: A traceless Hermitian N-by-N matrix. """ return jnp.einsum('i,iab->ab', v, self.basis)
[docs]class MLP(eqx.Module): """ Multi-layer perceptron. """ activation: Callable = staticmethod(jax.nn.celu) layers: list
[docs] def __init__(self, key, widths, zero=False): Nlayers = len(widths)-1 keys = jax.random.split(key, Nlayers) layers = [nn.Linear(widths[n], widths[n+1], key=keys[n]) for n in range(Nlayers)] if zero: layers[-1] = eqx.tree_at(lambda l: l.weight, layers[-1], replace_fn=jnp.zeros_like) layers[-1] = eqx.tree_at(lambda l: l.bias, layers[-1], replace_fn=jnp.zeros_like) self.layers = layers
def __call__(self, x): act = self.activation for layer in self.layers[:-1]: x = act(layer(x)) return self.layers[-1](x)
[docs]class ExactForm(eqx.Module): """ An exact form on the surface of SU(N)^V. """ K: int N: int algebra: SpecialUnitaryAlgebra mlp: MLP
[docs] def __init__(self, key, wscale, hidden, K, N=3, zero=True): """ Args: key: Initial PRNG key. wscale: Scaling of the width. hidden: Number of hidden layers. K: Number of SU(N) degrees of freedom. N: Number of colors. zero: If ``True``, initialize the network so that the output is always 0. """ self.K = K self.N = N self.algebra = SpecialUnitaryAlgebra(N) w = wscale * K * N**2 widths = [2*K*N**2] + [w]*hidden + [N**2-1] self.mlp = MLP(key, widths, zero=zero)
[docs] def vector(self, U): """ The vector field underlying the exact form. Args: U: Point at which to evaluate the vector field. """ x = jnp.concatenate([U.flatten().real, U.flatten().imag]) return self.mlp(x)
def __call__(self, U): """ Compute an exact form. Args: U: Point at which to evaluate the form. Returns: The divergence of a vector field. """ # The dimension of the vector field. D = self.algebra.dimension*self.K # Allow self.vector to be differentiated. def vec(x): # This vector x is to be interpreted as K elements of the Lie algebra. # Convert these into SU(N) generators. H = jax.vmap(self.algebra)(x.reshape((self.K,self.algebra.dimension))) V = U + 1j*U@H@U return self.vector(V) # The commented-out line below is a slower version. We should keep it # around as documentation of what the faster code is really doing. #return jnp.sum(jax.jacfwd(vec)(jnp.zeros(D)).diagonal()) # Evaluate the ith component of the divergence. def div_part(i): def vec_component(x_i): x = jnp.zeros(D) x = x.at[i].set(x_i) return vec(x)[i] return jax.grad(vec_component)(0.) # Compute and return the divergence. return jnp.sum(jax.vmap(div_part)(jnp.arange(D)))
[docs]class ScaledExactForm(eqx.Module): """ An exact form modified by a scalar. The primary purpose of this class is to have a version of ``ExactForm`` that performs well when the Boltzmann factor is exponentially small (or large). """ action: Callable
[docs] def __init__(self, action, key, wscale, hidden, K, N=3, zero=True): """ Args: action: Action defining the scaling. key: Initial PRNG key. wscale: Scaling of the width. hidden: Number of hidden layers. K: Number of SU(N) degrees of freedom. N: Number of colors. zero: If ``True``, initialize the network so that the output is always 0. """ self.action = action self.K = K self.N = N self.algebra = SpecialUnitaryAlgebra(N) w = wscale * V * N**2 widths = [2*N**2] + [w]*hidden + [N**2-1] self.mlp = MLP(key, widths, zero=zero)
[docs] def vector(self, U): """ The (unscaled) vector field underlying the exact form. Args: U: Point at which to evaluate the vector field. """ x = jnp.concatenate([U.flatten().real, U.flatten().imag]) return self.mlp(x)
def __call__(self, U): """ Compute an exact form. Args: U: Point at which to evaluate the form. Returns: A function f, that when multiplied by the exponential of the action, yields the divergence of a vector field. """ # The dimension of the vector field. D = self.algebra.dimension*self.K # Get the vector at U. v = self.vector(U) # Allow self.vector to be differentiated. def vec(x): # This vector x is to be interpreted as K elements of the Lie algebra. # Convert these into SU(N) generators. H = jax.vmap(self.algebra)(x.reshape((self.K,self.algebra.dimension))) V = U + 1j*U@H@U return self.vector(V) # There are two terms. The first is just the divergence of the vector. # The second (with a sign flipped) is the vector dotted with the # gradient of the action. # Evaluate the ith component of the divergence. def div_part(i): def vec_component(x_i): x = jnp.zeros(D) x = x.at[i].set(x_i) return vec(x)[i] return jax.grad(vec_component)(0.) # Compute the divergence. div = jnp.sum(jax.vmap(div_part)(jnp.arange(D))) # Compute the gradient of the action. grad = jax.grad(self.action)(U) return div - jnp.sum(v*grad)
[docs]def random_hermitian(key, N, sigma, traceless=False): """ Generates a random Hermitian matrix. Args: key: A PRNG key to consume. N: dimension of the matrix sigma: standard deviation traceless: If `True` (the default), the traceless part is returned Returns: An N-by-N Hermitian matrix. Raises: ValueError: If `sigma` is not a non-negative real number. """ HR,HI = jax.random.normal(key, shape=(2,N,N))*sigma H = HR + 1j*HI H += H.conj().transpose() H /= 2. if traceless: H -= jnp.eye(N)*H.trace() return H
[docs]def random_unitary(key, N, sigma, special=True): """ Generates a random unitary matrix. Args: key: A PRNG key to consume. N: dimension of group sigma: inverse concentration near the identity special: If `True` (the default), the determinant is constrained to be 1. Returns: A matrix in the group U(N) or SU(N). Raises: ValueError: If `sigma` is not a non-negative real number. """ H = random_hermitian(key, N, sigma, traceless=special) Hvals, Hvecs = jnp.linalg.eigh(H) return Hvecs @ jnp.diag(jnp.exp(-1j*Hvals)) @ Hvecs.conj().transpose()
[docs]def haar_unitary(key, N, special=True): """ Generates a random unitary matrix, sampled from the Haar measure. Args: key: A PRNG key to consume. N: dimension of group special: If ``True`` (the default), the determinant is constrained to be 1. Returns: A matrix in the group U(N) or SU(N) """ return random_unitary(key, N, 60., special)
[docs]def bootstrap(xs, ws=None, N=100, Bs=50): """ Compute bootstrapped error bars. If `xs` is complex, then the real part of the returned error represents the error on the real part of the mean, and ditto for the imaginary part. Args: xs: Samples. ws: Weights. N: Number of resamplings to perform. Bs: Number of blocks to use. Returns: An ordered pair consisting of the reweighted mean and its estimated standard deviation. """ if Bs > len(xs): Bs = len(xs) B = len(xs)//Bs if ws is None: ws = xs*0 + 1 # Block x, w = [], [] for i in range(Bs): x.append(sum(xs[i*B:i*B+B]*ws[i*B:i*B+B])/sum(ws[i*B:i*B+B])) w.append(sum(ws[i*B:i*B+B])) x = np.array(x) w = np.array(w) # Regular bootstrap y = x * w m = (sum(y) / sum(w)) ms = [] for n in range(N): s = np.random.choice(range(len(x)), len(x)) ms.append((sum(y[s]) / sum(w[s]))) ms = np.array(ms) return m, np.std(ms.real) + 1j*np.std(ms.imag)
[docs]@dataclass class Lattice: """ Lattice geometry. Args: L: length of one side of the lattice D: number of spacetime dimensions """ L: int D: int = 4
[docs] def volume(self): """ Returns the total number of sites in the lattice. """ return self.L**self.D
[docs] def coord(self, site, mu): """ Compute the coordinates of a site. Args: site: The index of the site in question. mu: The coordinate to compute. Returns: The `mu` coordinate of the specified site. Raises: ValueError: if `mu` isn't a valid direciton """ if not mu < self.D: raise ValueError(f'invalid direction {mu}') return (site%(self.L**(mu+1))) // (self.L**mu)
[docs] def step(self, site, mu, dist): """ Compute the index of one site, from a starting site and an offset. Args: site: The site to start from. mu: The direction to walk in. dist: Number of steps (in the positive direction) to walk. Returns: The index of the specified site. Raises: ValueError: if `mu` isn't a valid direction. """ if not mu < self.D: raise ValueError(f'invalid direction {mu}') # First identify the component of `site` in the `mu` direction. imu = self.coord(site, mu) # The index of the site that is `0` in the `mu` direction. base = site - (imu * self.L**mu) # Shift imu imu += dist imu %= self.L return base + (imu * self.L**mu)
[docs] def space(self, t): """ Obtain a single spatial slice of the lattice. Args: t: The time coordinate. Returns: A list of all sites at the given time coordinate. """ return jnp.arange(t + self.L**(self.D-1) * self.L)
[docs] def flat(self, N): """ Produce a flat gauge configuration. Args: N: number of colors Returns: A rank-4 array containing a gauge configuration consisting only of the identity. """ return jnp.tile(jnp.eye(N),(self.volume(),self.D,1,1))
[docs]def plaquette(lattice, U, site, mu, nu): """ Compute the trace of a 1x1 plaquette. Args: lattice: The lattice geometry. U: A configuration. site: The site from which the plaquette originates. mu: First axis. nu: Second axis. Returns: The trace of the plaquette in the fundamental representation. Raises: ValueError: If any of `U`, `site`, `mu`, `nu` are incompatible with the lattice geometry, or if `mu == nu`. """ if mu == nu: raise ValueError('plaquette directions must differ') # Shifted sites site_mu = lattice.step(site,mu,1) site_nu = lattice.step(site,nu,1) # Involved links U1 = U[site,mu,:,:] U2 = U[site_mu,nu,:,:] U3 = U[site_nu,mu,:,:].conj().transpose() U4 = U[site,nu,:,:].conj().transpose() # Product and trace return (U4 @ U3 @ U2 @ U1).trace()
[docs]def action_gauge(lattice, U, g, N): """ Compute the gauge part of the action. Args: lattice: The lattice geometry. U: A gauge configuration. g: The gauge coupling. N: Number of colors. Returns: The gauge part of the Wilson action. Raises: ValueError: If the geometry and configuration do not match. """ beta = 2*N / g**2 def action_gauge_at(site): S = 0. for mu in range(lattice.D): for nu in range(mu): S += -beta*(N-plaquette(lattice, U, site, mu, nu).real) return S # This is a sum over sites, so list sites. sites = jnp.arange(lattice.volume()) return jnp.sum(jax.vmap(action_gauge_at)(sites))
[docs]def pauli_matrices(): """ Obtain Pauli (sigma) matrices. Returns: A list with four elements---the identity and the three Pauli matrices. """ ident = jnp.eye(2) sigmax = jnp.array([[0,1],[1,0]]) sigmay = jnp.array([[0,-1j],[1j,0]]) sigmaz = jnp.array([[1,0],[0,-1]]) return [ident, sigmax, sigmay, sigmaz]
[docs]def dirac_matrices(): """ Obtain Dirac (gamma) matrices. Returns: A list of 4 gamma matrices. """ ident, sigmax, sigmay, sigmaz = pauli_matrices() gamma0 = jnp.kron(sigmax, ident) gamma1 = jnp.kron(-sigmay, sigmax) gamma2 = jnp.kron(-sigmay, sigmay) gamma3 = jnp.kron(-sigmay, sigmaz) return [gamma0, gamma1, gamma2, gamma3]
[docs]class LatticeFermions(abc.ABC): """ Utility methods for lattice fermions. Do not instantiate this class directly, but rather use one of the subclasses. """
[docs] def __init__(self, lattice): """ Args: lattice: The underlying geometry. """ self.lattice = lattice
[docs] @abc.abstractmethod def dirac(self, U): """ Obtain the Dirac operator. Args: U: Background gauge configuration. Returns: The Dirac matrix (inverse propagator) in the given gauge background. """ pass
[docs] @abc.abstractmethod def density(self, Dinv): """ Estimate the density. Args: Dinv: inverse Dirac operator Returns: The average density. """ pass
[docs] def det(self, U): """ Evaluate the determinant of the Dirac matrix. On reasonably sized lattices this method is likely to yield an overflow, and so :meth:`slogdet` should be used instead. Args: U: Background gauge configuration. Returns: A complex number with the determinant of the Dirac matrix. """ D = self.dirac(U) return jnp.linalg.det(D)
[docs] def slogdet(self, U): """ Evaluate the determinant of the Dirac matrix, returning the phase and logarithm. Args: U: Background gauge configuration. Returns: A pair `(z,f)`, with `z` a unit-magnitude complex number giving the phase of the determinant, and `f` the real part of the logarithm of the determinant. """ D = self.dirac(U) return jnp.linalg.slogdet(D)
[docs] def propagator(self, U): """ Evaluate the inverse of the Dirac matrix. Args: U: Background gauge configuration. Returns: The inverse of the Dirac matrix. """ D = self.dirac(U) return jnp.linalg.inv(D)
[docs]class NaiveFermions(LatticeFermions): """ Naive lattice fermions, with doublers not treated. """
[docs] def __init__(self, lattice, N:int, mass:float, mu:float=0.): """ Args: lattice: The underlying geometry. N: Number of colors. mass: Bare fermion mass. mu: Chemical potential. """ super().__init__(lattice) self.N = N self.mass = mass self.mu = mu # Initialize Euclidean gamma matrices. self.gamma = dirac_matrices() self.gamma5 = self.gamma[0] @ self.gamma[1] @ self.gamma[2] @ self.gamma[3]
[docs] def dirac(self, U): D = jnp.zeros((self.lattice.volume(),self.lattice.volume(),self.N,self.N,4,4)) sites = jnp.arange(self.lattice.volume()) # Mass along the diagonal. for c in range(self.N): diag = jax.vmap(lambda x: (x,x,c,c))(sites) D = D.at[diag].set(self.mass*jnp.eye(4)) # Time-like hopping (includes chemical potential). # TODO # Antiperiodic boundary conditions. # TODO # Spatial hoppings. # TODO D = D.transpose((0,2,4,1,3,5)) D = D.reshape((4*self.N*self.lattice.volume(), 4*self.N*self.lattice.volume())) return D
[docs] def density(self, Dinv): raise NotImplementedError()
[docs]class StaggeredFermions(LatticeFermions): """ Kogut-Susskind staggered fermions. """
[docs] def __init__(self, lattice, N:int, mass:float, mu:float=0.): """ Args: lattice: The underlying geometry. N: Number of colors. mass: Bare fermion mass. mu: Chemical potential. """ super().__init__(lattice) self.N = N self.mass = mass self.mu = mu
[docs] def dirac(self, U): D = jnp.zeros((self.lattice.volume(),self.lattice.volume(),self.N,self.N)) +0j sites = jnp.arange(self.lattice.volume()) # Mass along the diagonal. diag = jax.vmap(lambda x: (x,x))(sites) D = D.at[diag].set(self.mass*jnp.eye(self.N)) # Time-like hopping (includes chemical potential). hopt = jax.vmap(lambda x: (x,self.lattice.step(x,0,1)))(sites) hoptp = jax.vmap(lambda x: (self.lattice.step(x,0,1),x))(sites) D = D.at[hopt].add(0.5 * jnp.exp(self.mu) * U[sites,0,:,:]) D = D.at[hoptp].add(-0.5 * jnp.exp(-self.mu) * U[sites,0,:,:].conj().transpose((0,2,1))) # Antiperiodic boundary conditions. space = self.lattice.space(0) hopap = jax.vmap(lambda x: (x,self.lattice.step(x,0,-1)))(space) D = D.at[hopap].multiply(-1) # Spatial hoppings. for i in [1,2,3]: hopi = jax.vmap(lambda x: (x,self.lattice.step(x,i,1)))(sites) hopip = jax.vmap(lambda x: (self.lattice.step(x,i,1),x))(sites) eta = jax.vmap(lambda s: 1.)(sites) for j in range(i): eta = eta * jax.vmap(lambda s: (-1)**self.lattice.coord(s,j))(sites) D = D.at[hopi].add(0.5 * U[sites,i,:,:]) D = D.at[hopip].add(-0.5 * U[sites,i,:,:].conj().transpose((0,2,1))) D = D.transpose((0,2,1,3)) D = D.reshape((self.N*self.lattice.volume(), self.N*self.lattice.volume())) return D
[docs] def density(self, Dinv): return Dinv.trace()/self.lattice.volume()
[docs]class WilsonFermions(LatticeFermions): """ Wilson fermions. """
[docs] def __init__(self, lattice, N:int, mass:float, mu:float=0.): """ Args: lattice: The underlying geometry. N: Number of colors. mass: Bare fermion mass. mu: Chemical potential. """ super().__init__(lattice) self.N = N self.mass = mass self.mu = mu # Initialize Euclidean gamma matrices. self.gamma = dirac_matrices() self.gamma5 = self.gamma[0] @ self.gamma[1] @ self.gamma[2] @ self.gamma[3]
[docs] def dirac(self, U): raise NotImplementedError() D = jnp.zeros((self.lattice.volume(),self.lattice.volume())) return D
[docs] def density(self, Dinv): raise NotImplementedError()
[docs]class RootedFermions(LatticeFermions): """ Rooted staggered fermions. """
[docs] def __init__(self, flavors, *args): """ Args: flavors: Number of flavors args: Additional arguments used to defined the staggered fermions. """ raise NotImplementedError()
[docs] def dirac(self, U): raise NotImplementedError()
[docs] def density(self, Dinv): raise NotImplementedError()
[docs] def det(self, U): raise NotImplementedError()
[docs] def slogdet(self, U): raise NotImplementedError()
[docs]class Metropolis: """ Markov chain using the Metropolis algorithm. """
[docs] def __init__(self, x0, action, propose, key): """ Args: x0: Initial field configuration. action: The action propose: A function for generating proposals. key: PRNG key for generating proposals. """ self.x = x0 self.action = action self.propose = propose self.delta = 1. self._key = key self._recent = [False]
[docs] def step(self, N=1): """ Args: N: The number of steps to take before returning a configuration. """ logging.info(f'taking {N} steps') self.S = self.action(self.x).real for _ in range(N): kprop, kacc, self._key = jax.random.split(self._key, 3) xp = self.propose(kprop, self.x, self.delta) Sp = self.action(xp).real Sdiff = Sp - self.S accepted = False if jax.random.uniform(kacc) < jnp.exp(-Sdiff): self.x = xp self.S = Sp accepted = True self._recent.append(accepted) self._recent = self._recent[-100:]
[docs] def calibrate(self): """ Calibrate the chain, tweaking `delta` until the acceptance rate lies between 0.3 and 0.55. """ logging.info(f'beginning calibration') self.step(N=100) while self.acceptance_rate() < 0.3 or self.acceptance_rate() > 0.55: logging.info(f'calibrating (acceptance rate is {self.acceptance_rate()})') if self.acceptance_rate() < 0.3: self.delta *= 0.98 if self.acceptance_rate() > 0.55: self.delta *= 1.02 self.step(N=100)
[docs] def acceptance_rate(self): """ Estimates the acceptance rate. Returns: The acceptance rate over the last 100 steps. """ return sum(self._recent) / len(self._recent)
[docs] def iter(self, skip=1): """ An infinite iterator yielding configurations. """ while True: self.step(N=skip) yield self.x
[docs]class HMC: """ Hamiltonian Monte Carlo. """
[docs] def __init__(self): """ Args: """
[docs] def step(self, N=1): """ """
[docs] def iter(self): """ """ while True: pass
[docs]def main(args): if args.verbose: logging.basicConfig(level=logging.INFO) seed = args.seed if args.seed_time: seed = time.time_ns() mckey, subkey = jax.random.split(jax.random.PRNGKey(seed)) N = args.colors lattice = Lattice(args.L) fermions = [StaggeredFermions(lattice, N, args.mass, args.mu)] @jax.jit def action(U): Sgauge = action_gauge(lattice, U, args.g, N) Sfermi = 0. for fermi in fermions: s, lndet = fermi.slogdet(U) Sfermi -= lndet+jnp.log(s) return Sgauge + Sfermi @jax.jit def observe(U): # Mean plaquette. def mean_plaquette_at(site): P = 0. for mu in range(lattice.D): for nu in range(mu): P += plaquette(lattice, U, site, mu, nu) return P sites = jnp.arange(lattice.volume()) mean_plaquette = jnp.mean(jax.vmap(mean_plaquette_at)(sites)) # Fermion density. density = 0. for fermi in fermions: density += fermi.density(fermi.propagator(U)) # Fermion correlators. # TODO return mean_plaquette, density @jax.jit def propose(key, U, delta): delta /= jnp.sqrt(lattice.volume()) # We need lattice.D * lattice.volume() keys. keys = jax.random.split(key, lattice.D * lattice.volume()) V = jax.vmap(lambda k: random_unitary(k, N, delta))(keys) V = V.reshape(U.shape) return jnp.einsum('xiab,xibc->xiac', U, V) # Initialize/load the subtraction. subtraction = ExactForm(subkey, 1, 1, lattice.D*lattice.volume()) if not args.init: try: subtraction = eqx.tree_deserialise_leaves(args.subtraction, subtraction) except FileNotFoundError: pass # Prepare the Monte Carlo. chain = Metropolis(lattice.flat(N), action, propose, mckey) chain.calibrate() if args.train: # Parameters will be saved after every gradient step. def save(): eqx.tree_serialise_leaves(args.subtraction, subtraction) # Prepare the optimizer. opt = optax.yogi(1e-3) opt_state = opt.init(eqx.filter(subtraction, eqx.is_array)) # Stochastic loss functions def Seff_real(): raise NotImplementedError() # Train the subtraction. try: while True: save() raise NotImplementedError() except KeyboardInterrupt: pass else: # Don't train---just measure. try: for U in chain.iter(): S = action(U) obs = np.array(observe(U)) print(np.exp(-1j*S.imag), obs) except KeyboardInterrupt: pass
if __name__ == '__main__': parser = argparse.ArgumentParser(description='Lattice QCD at finite density', fromfile_prefix_chars='@') parser.add_argument('-N', '--colors', default=3, type=int, help='Number of colors (3)') parser.add_argument('-i', '--init', action='store_true', help='re-initialize the subtraction') parser.add_argument('-t', '--train', action='store_true', help='train the subtraction') parser.add_argument('subtraction', type=str, help='subtraction filename') parser.add_argument('L', type=int, help='Lattice size') parser.add_argument('g', type=float, help='Gauge coupling') parser.add_argument('mass', type=float, help='Bare fermion mass') parser.add_argument('mu', type=float, help='Chemical potential') parser.add_argument('-v', '--verbose', action='store_true', help='Verbose mode') seed_group = parser.add_mutually_exclusive_group() seed_group.add_argument('--seed', type=int, default=0, help="random seed") seed_group.add_argument('--seed-time', action='store_true', help="seed PRNG with current time") args = parser.parse_args() main(args)