Source code for hera_sim.visibilities.matvis

"""Wrapper for matvis visibility simulator."""

from __future__ import annotations

import itertools
import logging

import numpy as np
from astropy.time import Time

try:
    from matvis import HAVE_GPU, __version__, cpu
    HAVE_MATVIS = True
    if HAVE_GPU:
        from matvis import gpu
except ImportError:  # pragma: no cover
    HAVE_GPU = False
    HAVE_MATVIS = False
    __version__ = None
    cpu = None
    gpu = None

from pyuvdata import BeamInterface, UVData
from pyuvdata import utils as uvutils

from .simulators import ModelData, VisibilitySimulator

logger = logging.getLogger(__name__)


[docs] class MatVis(VisibilitySimulator): """ matvis visibility simulator. This is a fast, matrix-based visibility simulator. Parameters ---------- polarized : bool, optional Whether to calculate polarized visibilities or not. By default does polarization iff multiple polarizations exist in the UVData object. The behaviour of the simulator is that if requesting polarized output and only a subset of the simulated pols are available in the UVdata object, the code will issue a warning but otherwise continue happily, throwing away the simulated pols it can't store in the UVdata object. Conversely, if polarization is not requested and multiple polarizations are present on the UVData object, it will error unless ``allow_empty_pols`` is set to True (in which case it will warn but continue). The "unpolarized" output of ``matvis`` is expected to be XX polarization, which corresponds to whatever the UVData object considers to be the x-direction (default East). precision : int, optional Which precision level to use for floats and complex numbers. Allowed values: - 1: float32, complex64 - 2: float64, complex128 use_gpu : bool, optional Whether to use the GPU version of matvis or not. Default: False. mpi_comm : MPI communicator MPI communicator, for parallelization. ref_time A reference time for computing adjustments to the co-ordinate transforms using astropy. For best fidelity, set this to a mid-point of your observation times. If specified as a string, this must either use the 'isot' format and 'utc' scale, or be one of "mean", "min" or "max". If any of the latter, the value ll be calculated from the input data directly. check_antenna_conjugation Whether to check the antenna conjugation. Default is True. This is a fairly heavy operation if there are many antennas and/or many times, and can be safely ignored if the data_model was created from a config file. **kwargs Passed through to :class:`~.simulators.VisibilitySimulator`. """ conjugation_convention = "ant1<ant2" time_ordering = "time" diffuse_ability = False __version__ = __version__ def __init__( self, precision: int = 2, use_gpu: bool = False, mpi_comm=None, check_antenna_conjugation: bool = True, **kwargs, ): if not HAVE_MATVIS: raise ImportError( "matvis is not installed. Please install matvis to use MatVis." ) assert precision in {1, 2} self._precision = precision if precision == 1: self._real_dtype = np.float32 self._complex_dtype = np.complex64 else: self._real_dtype = float self._complex_dtype = complex if use_gpu and mpi_comm is not None and mpi_comm.Get_size() > 1: raise RuntimeError("MPI is not yet supported in GPU mode") if use_gpu and not HAVE_GPU: raise ImportError( "GPU acceleration requires installing with `pip install hera_sim[gpu]`." ) self._matvis = gpu.simulate if use_gpu else cpu.simulate self.use_gpu = use_gpu self.mpi_comm = mpi_comm self.check_antenna_conjugation = check_antenna_conjugation self._functions_to_profile = (self._matvis,) self.kwargs = kwargs
[docs] def validate(self, data_model: ModelData): """Checks for correct input format.""" # N(N-1)/2 unique cross-correlations + N autocorrelations. if data_model.uvdata.Nbls != data_model.n_ant * (data_model.n_ant + 1) / 2: raise ValueError( "MatVis requires using every pair of antennas, " "but the UVData object does not comply." ) logger.info("Checking baseline-time axis shape") if not data_model.uvdata.blts_are_rectangular: raise ValueError("MatVis requires that every baseline uses the same LSTS.") if self.check_antenna_conjugation: logger.info("Checking antenna conjugation") # TODO: the following is extremely slow. If possible, it would be good to # find a better way to do it. if any( data_model.uvdata.antpair2ind(ai, aj) is not None and data_model.uvdata.antpair2ind(aj, ai) is not None for ai, aj in data_model.uvdata.get_antpairs() if ai != aj ): raise ValueError( "MatVis requires that baselines be in a conjugation in which " "antenna order doesn't change with time!" ) beam_interface = data_model.beams[0] # Representative beam uvdata = data_model.uvdata # Now check that we only have linear polarizations (don't allow pseudo-stokes) if any(pol not in [-5, -6, -7, -8] for pol in uvdata.polarization_array): raise ValueError( """ While UVData allows non-linear polarizations, they are not suitable for generating simulations. Please convert your UVData object to use linear polarizations before simulating (and convert back to other polarizations afterwards if necessary). """ ) do_pol = self._check_if_polarized(data_model) if do_pol: assert beam_interface.beam.Nfeeds == 2
[docs] def estimate_memory(self, data_model: ModelData) -> float: """ Estimates the memory usage of the model. Parameters ---------- data_model : ModelData The model data. Returns ------- float Estimated memory usage in GB. """ bm = data_model.beams[0] beam_obj = getattr(bm, "beam", bm) nt = len(data_model.lsts) nax = getattr(beam_obj, "Naxes_vec", 1) nfd = getattr(beam_obj, "Nfeeds", 1) nant = len(data_model.uvdata.get_ants()) nsrc = len(data_model.sky_model.ra) nbeam = len(data_model.beams) nf = len(data_model.freqs) try: nbmpix = beam_obj.data_array[..., 0, :].size except AttributeError: nbmpix = 0 all_floats = ( nf * nt * nfd**2 * nant**2 + nant * nsrc * nax * nfd / 2 # visibilities + nf * nbeam * nbmpix # per-antenna vis + nax * nfd * nbeam * nsrc / 2 # raw beam + 3 * nant # interpolated beam + nsrc * nf # antenna positions + nt * 9 # source fluxes + 3 * nsrc # rotation matrices + 3 * nsrc + nant * nsrc / 2 # source positions (topo and eq) # tau. ) return all_floats * self._precision * 4 / 1024**3
def _check_if_polarized(self, data_model: ModelData) -> bool: p = data_model.uvdata.polarization_array # We only do a non-polarized simulation if UVData has only XX or YY polarization return len(p) != 1 or uvutils.polnum2str(p[0]) not in ["xx", "yy"]
[docs] @staticmethod def get_feed(uvdata) -> str: """Get the feed to use from the beam, given the UVData object. Only applies for an *unpolarized* simulation (for a polarized sim, all feeds are used). """ return uvutils.polnum2str(uvdata.polarization_array[0])[0]
[docs] def simulate(self, data_model): """ Calls :func:matvis to perform the visibility calculation. Returns ------- array_like of self._complex_dtype Visibilities. Shape=self.uvdata.data_array.shape. """ polarized = self._check_if_polarized(data_model) # Setup MPI info if enabled if self.mpi_comm is not None: myid = self.mpi_comm.Get_rank() nproc = self.mpi_comm.Get_size() # The following are antenna positions in the order that they are # in the uvdata.data_array active_antpos, ant_list = data_model.uvdata.get_enu_data_ants() num2name = { i: nm for i, nm in zip( data_model.uvdata.telescope.antenna_numbers, data_model.uvdata.telescope.antenna_names ) } beam_ids = np.array( [data_model.beam_ids[num2name[i]] for i in ant_list] ) # Get all the polarizations required to be simulated. req_pols = self._get_req_pols( data_model.uvdata, data_model.beams[0], polarized=polarized ) # Empty visibility array if np.all(data_model.uvdata.data_array == 0): # Here, we don't make new memory, because that is just a whole extra copy # of the largest array in the calculation. Instead we fill the data_array # directly. visfull = data_model.uvdata.data_array else: visfull = np.zeros_like( data_model.uvdata.data_array, dtype=self._complex_dtype ) antpairs = data_model.uvdata.get_antpairs() antlist = ant_list.tolist() antpairs = np.array([[antlist.index(a), antlist.index(b)] for a,b in antpairs]) for i, freq in enumerate(data_model.freqs): # Divide tasks between MPI workers if needed if self.mpi_comm is not None and i % nproc != myid: continue logger.info(f"Simulating Frequency {i + 1}/{len(data_model.freqs)}") # Call matvis function to simulate visibilities vis = self._matvis( antpos=active_antpos, freq=freq, times=Time(data_model.times, format="jd"), skycoords=data_model.sky_model.skycoord, telescope_loc=data_model.uvdata.telescope.location, I_sky=data_model.sky_model.stokes[0, i].to("Jy").value, beam_list=data_model.beams, beam_idx=beam_ids, beam_spline_opts=data_model.beams.spline_interp_opts, precision=self._precision, polarized=polarized, antpairs=antpairs, **self.kwargs, ) logger.info("... re-ordering visibilities...") self._reorder_vis( req_pols, data_model.uvdata, visfull[:, i], vis, polarized ) # Reduce visfull array if in MPI mode if self.mpi_comm is not None: visfull = self._reduce_mpi(visfull, myid) if visfull is data_model.uvdata.data_array: # In the case that we were just fulling up the data array the whole time, # we return zero, because this will be added to the data_array in the # wrapper simulate() function. return 0 else: return visfull
def _reorder_vis( self, req_pols: list[tuple[int, int]], uvdata: UVData, visfull: np.ndarray, vis: np.ndarray, polarized: bool ): if ( uvdata.blts_are_rectangular and not uvdata.time_axis_faster_than_bls and sorted(req_pols) == req_pols ): logger.info("Using direct setting of data without reordering") # This is the best case scenario -- no need to reorder anything. # It is also MUCH MUCH faster! vis = vis.reshape((uvdata.Nblts, uvdata.Npols)) visfull[:] = vis return logger.info( f"Reordering baselines. Pols sorted: {sorted(req_pols) == req_pols}. " f"Pols = {req_pols}. blt_order = {uvdata.blt_order}" ) for i, (ant1, ant2) in enumerate(uvdata.get_antpairs()): # get all blt indices corresponding to this antpair indx = uvdata.antpair2ind(ant1, ant2) vis_here = vis[:, i] if polarized: for p, (p1, p2) in enumerate(req_pols): visfull[indx, p] = vis_here[:, p1, p2] else: visfull[indx, 0] = vis_here @staticmethod def _get_req_pols( uvdata: UVData, uvbeam: BeamInterface, polarized: bool ) -> list[tuple[int, int]]: beam_obj = uvbeam.beam feeds = uvbeam.feed_array if isinstance(feeds, np.ndarray): feeds = feeds.tolist() # convert to list if necessary feed_set = {str(feed).lower() for feed in feeds} # In order to get all 4 visibility polarizations for a dual feed system vispols = set() for p1, p2 in itertools.combinations_with_replacement(feeds, 2): vispols.add(p1 + p2) vispols.add(p2 + p1) avail_pols = { vispol: (feeds.index(vispol[0]), feeds.index(vispol[1])) for vispol in vispols } # Get the mapping from uvdata pols to uvbeam pols if feed_set.issubset({"x", "y"}): x_orientation = None else: x_orientation = beam_obj.get_x_orientation_from_feeds() uvdata_pols = [ uvutils.polnum2str(polnum, x_orientation) for polnum in uvdata.polarization_array ] if not polarized: feed = MatVis.get_feed(uvdata) return [(feeds.index(feed),feeds.index(feed))] if any(pol not in avail_pols for pol in uvdata_pols): raise ValueError( "Not all polarizations in UVData object are in your beam. " f"UVData polarizations = {uvdata_pols}. " f"UVBeam polarizations = {list(avail_pols.keys())}" ) return [avail_pols[pol] for pol in uvdata_pols] def _reduce_mpi(self, visfull, myid): # pragma: no cover from mpi4py.MPI import SUM _visfull = np.zeros(visfull.shape, dtype=visfull.dtype) self.mpi_comm.Reduce(visfull, _visfull, op=SUM, root=0) if myid == 0: return _visfull else: return 0 # workers return 0
[docs] def compress_data_model(self, data_model: ModelData): data_model.uvdata.uvw_array = 0 # data_model.uvdata.baseline_array = 0 data_model.uvdata.integration_time = data_model.uvdata.integration_time.item(0)
[docs] def restore_data_model(self, data_model: ModelData): uv_obj = data_model.uvdata uv_obj.integration_time = np.repeat( uv_obj.integration_time, uv_obj.Nbls * uv_obj.Ntimes ) uv_obj.set_uvws_from_antenna_positions()