"""Wrapper for matvis visibility simulator."""
from __future__ import annotations
import logging
import numpy as np
from pyuvdata import BeamInterface
from pyuvdata import utils as uvutils
from ..utils import get_antpos_dict
from .matvis import MatVis
from .simulators import ModelData, VisibilitySimulator
logger = logging.getLogger(__name__)
try:
import fftvis
from matvis.core.beams import prepare_beam_unpolarized
HAVE_FFTVIS = True
except ImportError: # pragma: no cover
HAVE_FFTVIS = False
fftvis = None
[docs]
class FFTVis(VisibilitySimulator):
"""
fftvis visibility simulator.
This is a fast visibility simulator based on the Flatiron Non-Uniform Fast Fourier
Transform (https://github.com/flatironinstitute/finufft). This class calls the
fftvis package (https://github.com/tyler-a-cox/fftvis) which utilizes the finufft
algorithm to evaluate the measurement equation by gridding and fourier transforming
an input sky model. The simulated visibilities agree with matvis to high precision,
and are often computed more quickly than matvis. FFTVis is particularly well-suited
for simulations with compact arrays with large numbers of antennas, and sky models
with many sources.
Parameters
----------
precision : int, optional
Which precision level to use for floats and complex numbers.
Allowed values:
- 1: float32, complex64
- 2: float64, complex128
mpi_comm : MPI communicator
MPI communicator, for parallelization.
check_antenna_conjugation
Whether to check the antenna conjugation. Default is True. This is a fairly
heavy operation if there are many antennas and/or many times, and can be
safely ignored if the data_model was created from a config file.
**kwargs
Passed through to `:func:fftvis.SimulationEngine.simulate` function.
"""
conjugation_convention = "ant1<ant2"
time_ordering = "time"
_functions_to_profile = (
fftvis.CPUSimulationEngine.simulate if HAVE_FFTVIS else None,
)
diffuse_ability = False
__version__ = "1.0.0" # Fill in the version number here
def __init__(
self,
*,
precision: int = 2,
mpi_comm=None,
check_antenna_conjugation: bool = True,
**kwargs,
):
if not HAVE_FFTVIS:
raise ImportError(
"fftvis is not installed. Please install fftvis to use FFTVis."
)
assert precision in {1, 2}
self._precision = precision
if precision == 1:
self._real_dtype = np.float32
self._complex_dtype = np.complex64
else:
self._real_dtype = float
self._complex_dtype = complex
self.mpi_comm = mpi_comm
self.check_antenna_conjugation = check_antenna_conjugation
self.kwargs = kwargs
def _check_if_polarized(self, data_model: ModelData) -> bool:
p = data_model.uvdata.polarization_array
# We only do a non-polarized simulation if UVData has only XX or YY polarization
return len(p) != 1 or uvutils.polnum2str(p[0]) not in ["xx", "yy"]
[docs]
def validate(self, data_model: ModelData):
"""Checks for correct input format."""
logger.info("Checking baseline-time axis shape")
if not data_model.uvdata.blts_are_rectangular:
raise ValueError("FFTVis requires that every baseline uses the same LSTS.")
if self.check_antenna_conjugation:
logger.info("Checking antenna conjugation")
# TODO: the following is extremely slow. If possible, it would be good to
# find a better way to do it.
if any(
data_model.uvdata.antpair2ind(ai, aj) is not None
and data_model.uvdata.antpair2ind(aj, ai) is not None
for ai, aj in data_model.uvdata.get_antpairs()
if ai != aj
):
raise ValueError(
"FFTVis requires that baselines be in a conjugation in which "
"antenna order doesn't change with time!"
)
beam_interface = data_model.beams[0] # Representative beam
uvdata = data_model.uvdata
# Now check that we only have linear polarizations (don't allow pseudo-stokes)
if any(pol not in [-5, -6, -7, -8] for pol in uvdata.polarization_array):
raise ValueError(
"""
While UVData allows non-linear polarizations, they are not suitable
for generating simulations. Please convert your UVData object to use
linear polarizations before simulating (and convert back to
other polarizations afterwards if necessary).
"""
)
if self._check_if_polarized(data_model) and beam_interface.Nfeeds != 2:
raise ValueError(
"FFTVis requires that the beams have two feeds if simulating polarized "
"visibilities."
)
[docs]
def estimate_memory(self, data_model: ModelData) -> float:
"""
Estimates the memory usage of the model.
Parameters
----------
data_model : ModelData
The model data.
Returns
-------
float
Estimated memory usage in GB.
"""
bm: BeamInterface = data_model.beams[0]
nt = len(data_model.lsts)
nax = 2 if bm.beam_type=="efield" else 1
nfd = bm.Nfeeds
nant = data_model.uvdata.Nants_data
nsrc = len(data_model.sky_model.ra)
nbeam = len(data_model.beams)
nf = len(data_model.freqs)
# Estimate size of the FFT grid used to compute the visibilities
active_antpos_array, _ = data_model.uvdata.get_enu_data_ants()
# Estimate the size of the grid used to compute the visibilities
max_blx, max_bly, _ = np.abs(
active_antpos_array.max(axis=0) - active_antpos_array.min(axis=0)
)
avg_freq = np.mean(data_model.freqs)
n_gridx = int(8 * avg_freq * max_blx / 3e8) # number of grid points in u/l axis
n_gridy = int(8 * avg_freq * max_bly / 3e8) # number of grid points in v/m axis
try:
nbmpix = bm.beam.data_array[..., 0, :].size
except AttributeError:
nbmpix = 0
all_floats = (
nf * nt * nfd**2 * nant**2 # visibilities
+ n_gridx * n_gridy # FFT grid
+ nf * nbeam * nbmpix # raw beam
+ nax * nfd * nbeam * nsrc / 2 # interpolated beam
+ 3 * nant # antenna positions
+ nsrc * nf # source fluxes
+ nt * 9 # rotation matrices
+ 3 * nsrc
+ 3 * nsrc # source positions (topo and eq)
)
return all_floats * self._precision * 4 / 1024**3
[docs]
def get_feed(self, uvdata) -> str:
"""Get the feed to use from the beam, given the UVData object.
Only applies for an *unpolarized* simulation (for a polarized sim, all feeds
are used).
"""
return uvutils.polnum2str(uvdata.polarization_array[0])[0]
@staticmethod
def _get_req_pols(uvdata, uvbeam, polarized: bool) -> list[tuple[int, int]]:
return MatVis._get_req_pols(uvdata, uvbeam, polarized)
[docs]
def simulate(self, data_model):
"""
Calls :func:`fftvis` to perform the visibility calculation.
Returns
-------
array_like of self._complex_dtype
Visibilities. Shape=self.uvdata.data_array.shape.
"""
polarized = self._check_if_polarized(data_model)
feed = self.get_feed(data_model.uvdata)
# Setup MPI info if enabled
if self.mpi_comm is not None:
myid = self.mpi_comm.Get_rank()
nproc = self.mpi_comm.Get_size()
ra, dec = data_model.sky_model.ra.rad, data_model.sky_model.dec.rad
active_antpos = get_antpos_dict(data_model.uvdata, data_ants=True)
num2name = {
i: nm for i, nm in zip(
data_model.uvdata.telescope.antenna_numbers,
data_model.uvdata.telescope.antenna_names
)
}
# since pyuvdata v3, get_antpairs always returns antpairs in the right order.
antpairs = data_model.uvdata.get_antpairs()
# Get pixelized beams if required
logger.info("Preparing Beams...")
if not polarized:
beams = [
prepare_beam_unpolarized(beam, use_feed=feed)
for beam in data_model.beams
]
else:
beams = data_model.beams
beam_ids = [data_model.beam_ids[num2name[i]] for i in active_antpos.keys()]
# Get all the polarizations required to be simulated.
req_pols = self._get_req_pols(data_model.uvdata, beams[0], polarized=polarized)
# Empty visibility array
if np.all(data_model.uvdata.data_array == 0):
# Here, we don't make new memory, because that is just a whole extra copy
# of the largest array in the calculation. Instead we fill the data_array
# directly.
visfull = data_model.uvdata.data_array
else:
visfull = np.zeros_like(
data_model.uvdata.data_array, dtype=self._complex_dtype
)
for i, freq in enumerate(data_model.freqs):
# Divide tasks between MPI workers if needed
if self.mpi_comm is not None and i % nproc != myid:
continue
logger.info(f"Simulating Frequency {i + 1}/{len(data_model.freqs)}")
# Call fftvis function to simulate visibilities
vis = fftvis.CPUSimulationEngine().simulate(
ants=active_antpos,
freqs=np.array([freq]),
ra=ra,
dec=dec,
times=data_model.times,
telescope_loc=data_model.uvdata.telescope.location,
beam_list=beams,
beam_idx=beam_ids,
fluxes=data_model.sky_model.stokes[0, [i]].to("Jy").value.T,
beam_spline_opts=data_model.beams.spline_interp_opts,
precision=self._precision,
polarized=polarized,
baselines=antpairs,
**self.kwargs,
)[0]
logger.info("... re-ordering visibilities...")
self._reorder_vis(
req_pols, data_model.uvdata, visfull[:, i], vis, antpairs, polarized
)
# Reduce visfull array if in MPI mode
if self.mpi_comm is not None:
visfull = self._reduce_mpi(visfull, myid)
if visfull is data_model.uvdata.data_array:
# In the case that we were just fulling up the data array the whole time,
# we return zero, because this will be added to the data_array in the
# wrapper simulate() function.
return 0
else:
return visfull
def _reorder_vis(self, req_pols, uvdata, visfull, vis, antpairs, polarized):
if polarized:
if uvdata.time_axis_faster_than_bls:
vis = vis.transpose((3, 0, 1, 2))
else:
vis = vis.transpose((0, 3, 1, 2))
if polarized:
for p, (p1, p2) in enumerate(req_pols):
visfull[:, p] = vis[..., p1, p2].reshape(-1)
else:
visfull[:, 0] = vis.reshape(-1)