Source code for hera_sim.utils

"""Utility module."""

from __future__ import annotations

import warnings
from collections.abc import Sequence

import numpy as np
import pyuvdata.utils as uvutils
from astropy import constants, units
from astropy.coordinates import Longitude
from pyuvdata import UVData
from scipy.interpolate import RectBivariateSpline

from .interpolators import Beam

try:
    import numba

    HAVE_NUMBA = True
except ImportError:
    HAVE_NUMBA = False

[docs]def get_antpos_dict( uvd: UVData, *, data_ants: bool = False, frame: Literal[ecef, enu] = "enu", ) -> dict[int, np.ndarray]: """ Get a dictionary of antenna positions from a UVData object. Parameters ---------- uvd UVData object to get antenna positions from. data_ants If True, return only the antennas with data. Otherwise, return all antennas. Returns ------- antpos_dict Dictionary of antenna positions in ENU coordinates. """ if frame not in ["ecef", "enu"]: raise ValueError("frame must be 'ecef' or 'enu'") if data_ants: ants = uvd.get_ants() else: ants = uvd.telescope.antenna_numbers if frame == "ecef": antpos = uvd.telescope.antenna_positions else: antpos = uvd.telescope.get_enu_antpos() antnums = list(uvd.telescope.antenna_numbers) antpos = [ antpos[antnums.index(i)] for i in ants ] return dict(zip(ants, antpos))
def _get_bl_len_vec(bl_len_ns: float | np.ndarray) -> np.ndarray: """ Convert a baseline length in a variety of formats to a standard length-3 vector. Parameters ---------- bl_len_ns The baseline length in nanosec (i.e. 1e9 * metres / c). If scalar, interpreted as E-W length, if len(2), interpreted as EW and NS length, otherwise the full [EW, NS, Z] length. Unspecified dimensions are assumed to be zero. Returns ------- bl_vec A length-3 array. The full [EW, NS, Z] baseline vector. """ if np.isscalar(bl_len_ns): return np.array([bl_len_ns, 0, 0]) elif len(bl_len_ns) <= 3: # make a length-3 array return np.pad(bl_len_ns, pad_width=3 - len(bl_len_ns), mode="constant")[-3:] return bl_len_ns
[docs]def get_bl_len_magnitude(bl_len_ns: float | np.ndarray | Sequence) -> float: """ Get the magnitude of the length of the given baseline. Parameters ---------- bl_len_ns The baseline length in nanosec (i.e. 1e9 * metres / c). If scalar, interpreted as E-W length, if len(2), interpreted as EW and NS length, otherwise the full [EW, NS, Z] length. Unspecified dimensions are assumed to be zero. Returns ------- mag The magnitude of the baseline length. """ bl_len_ns = _get_bl_len_vec(bl_len_ns) return np.sqrt(np.sum(bl_len_ns**2))
[docs]def gen_delay_filter( freqs: np.ndarray, bl_len_ns: float | np.ndarray | Sequence, standoff: float = 0.0, delay_filter_type: str | None = "gauss", min_delay: float | None = None, max_delay: float | None = None, normalize: float | None = None, ) -> np.ndarray: """ Generate a delay filter in delay space. Parameters ---------- freqs Frequency array [GHz] bl_len_ns The baseline length in nanosec (i.e. 1e9 * metres / c). If scalar, interpreted as E-W length, if len(2), interpreted as EW and NS length, otherwise the full [EW, NS, Z] length. Unspecified dimensions are assumed to be zero. standoff Supra-horizon buffer [nanosec] delay_filter_type Options are ``['gauss', 'trunc_gauss', 'tophat', 'none']``. This sets the filter profile. ``gauss`` has a 1-sigma as horizon (+ standoff) divided by four, ``trunc_gauss`` is same but truncated above 1-sigma. ``'none'`` means filter is identically one. min_delay Minimum absolute delay of filter max_delay Maximum absolute delay of filter normalize If set, will normalize the filter such that the power of the output matches the power of the input times the normalization factor. If not set, the filter merely has a maximum of unity. Returns ------- delay_filter Delay filter in delay space (1D) """ # setup delays = np.fft.fftfreq(freqs.size, freqs[1] - freqs[0]) if isinstance(bl_len_ns, np.ndarray): bl_len_ns = np.linalg.norm(bl_len_ns) # add standoff: four sigma is horizon one_sigma = (bl_len_ns + standoff) / 4.0 # create filter if delay_filter_type in [None, "none", "None"]: delay_filter = np.ones_like(delays) elif delay_filter_type in ["gauss", "trunc_gauss"]: delay_filter = np.exp(-0.5 * (delays / one_sigma) ** 2) if delay_filter_type == "trunc_gauss": delay_filter[np.abs(delays) > (one_sigma * 4)] = 0.0 elif delay_filter_type == "tophat": delay_filter = np.ones_like(delays) delay_filter[np.abs(delays) > (one_sigma * 4)] = 0.0 else: raise ValueError(f"Didn't recognize filter_type {delay_filter_type}") # set bounds if min_delay is not None: delay_filter[np.abs(delays) < min_delay] = 0.0 if max_delay is not None: delay_filter[np.abs(delays) > max_delay] = 0.0 # normalize if normalize is not None and np.any(delay_filter): norm = normalize / np.sqrt(np.sum(delay_filter**2)) delay_filter *= norm * np.sqrt(len(delay_filter)) return delay_filter
[docs]def rough_delay_filter( data: np.ndarray, freqs: np.ndarray | None = None, bl_len_ns: np.ndarray | None = None, *, delay_filter: np.ndarray | None = None, **kwargs, ) -> np.ndarray: """ A rough low-pass delay filter of data array along last axis. Parameters ---------- data Data to be filtered along last axis freqs Frequencies of the filter [GHz] bl_len_ns The baseline length (see :func:`gen_delay_filter`). delay_filter The pre-computed filter to use. A filter can be created on-the-fly by passing kwargs. **kwargs Passed to :func:`gen_delay_filter`. Returns ------- filt_data Filtered data array (same shape as ``data``). """ # fft data across last axis dfft = np.fft.fft(data, axis=-1) # get delay filter if delay_filter is None: if freqs is None: raise ValueError( "If you don't provide a pre-computed delay filter, you must " "provide freqs" ) if bl_len_ns is None: raise ValueError( "If you don't provide a pre-computed delay filter, you must provide " "bl_len_ns" ) delay_filter = gen_delay_filter(freqs=freqs, bl_len_ns=bl_len_ns, **kwargs) # apply filtering and fft back filt_data = np.fft.ifft(dfft * delay_filter, axis=-1) return filt_data
[docs]def gen_fringe_filter( lsts: np.ndarray, freqs: np.ndarray, ew_bl_len_ns: float, fringe_filter_type: str | None = "tophat", **filter_kwargs, ) -> np.ndarray: """ Generate a fringe rate filter in fringe-rate & freq space. Parameters ---------- lsts lst array [radians] freqs Frequency array [GHz] ew_bl_len_ns Projected East-West baseline length [nanosec] fringe_filter_type Options ``['tophat', 'gauss', 'custom', 'none']`` **filter_kwargs These are specific to each ``fringe_filter_type``. For ``filter_type == 'gauss'``: * **fr_width** (float or array): Sets gaussian width in fringe-rate [Hz] For ``filter_type == 'custom'``: * **FR_filter** (ndarray): shape (Nfrates, Nfreqs) with custom filter (must be fftshifted, see below) * **FR_frates** (ndarray): array of FR_filter fringe rates [Hz] (must be monotonically increasing) * **FR_freqs** (ndarray): array of FR_filter freqs [GHz] Returns ------- fringe_filter 2D array in fringe-rate & freq space Notes ----- If ``filter_type == 'tophat'`` filter is a tophat out to max fringe-rate set by ew_bl_len_ns. If ``filter_type == 'gauss'`` filter is a Gaussian centered on max fringe-rate with width set by kwarg fr_width in Hz If ``filter_type == 'custom'`` filter is a custom 2D (Nfrates, Nfreqs) filter fed as 'FR_filter' its fringe-rate array is fed as "FR_frates" in Hz, its freq array is fed as "FR_freqs" in GHz. Note that input ``FR_filter`` must be fft-shifted along axis 0, but output filter is ``ifftshift``-ed back along axis 0. If ``filter_type == 'none'`` fringe filter is identically one. """ # setup times = lsts / (2 * np.pi) * units.sday.to("s") fringe_rates = np.fft.fftfreq(times.size, times[1] - times[0]) if fringe_filter_type in [None, "none", "None"]: fringe_filter = np.ones((len(times), len(freqs)), dtype=float) elif fringe_filter_type == "tophat": fr_max = np.repeat( calc_max_fringe_rate(freqs, ew_bl_len_ns)[None, :], len(lsts), axis=0 ) fringe_rates = np.repeat(fringe_rates[:, None], len(freqs), axis=1) fringe_filter = np.where(np.abs(fringe_rates) <= np.abs(fr_max), 1.0, 0) elif fringe_filter_type == "gauss": assert ( "fr_width" in filter_kwargs ), "If filter_type=='gauss' must feed fr_width kwarg" fr_max = np.repeat( calc_max_fringe_rate(freqs, ew_bl_len_ns)[None, :], len(lsts), axis=0 ) fringe_rates = np.repeat(fringe_rates[:, None], len(freqs), axis=1) fringe_filter = np.exp( -0.5 * ((fringe_rates - fr_max) / filter_kwargs["fr_width"]) ** 2 ) elif fringe_filter_type == "custom": assert ( "FR_filter" in filter_kwargs ), "If filter_type=='custom', must feed 2D FR_filter array" assert ( "FR_frates" in filter_kwargs ), "If filter_type=='custom', must feed 1D FR_frates array" assert ( "FR_freqs" in filter_kwargs ), "If filter_type=='custom', must feed 1D FR_freqs array" # interpolate FR_filter at fringe_rates and fqs mdl = RectBivariateSpline( filter_kwargs["FR_frates"], filter_kwargs["FR_freqs"], filter_kwargs["FR_filter"], kx=3, ky=3, ) fringe_filter = np.fft.ifftshift( mdl(np.fft.fftshift(fringe_rates), freqs), axes=0 ) # set things close to zero to zero fringe_filter[np.isclose(fringe_filter, 0.0)] = 0.0 else: raise ValueError(f"filter_type {fringe_filter_type} not recognized") return fringe_filter
[docs]def rough_fringe_filter( data: np.ndarray, lsts: np.ndarray | None = None, freqs: np.ndarray | None = None, ew_bl_len_ns: float | None = None, *, fringe_filter: np.ndarray | None = None, **kwargs, ) -> np.ndarray: """ A rough fringe rate filter of data along zeroth axis. Parameters ---------- data data to filter along zeroth axis fringe_filter A pre-computed fringe-filter to use. Computed on the fly if not given. **kwargs Passed to :func:`gen_fringe_filter` to compute the fringe filter on the fly (if necessary). If so, at least ``lsts``, ``freqs``, and ``ew_bl_len_ns`` are required. Returns ------- filt_data Filtered data (same shape as ``data``). """ # fft data along zeroth axis dfft = np.fft.fft(data, axis=0) # get filter if fringe_filter is None: if any(k is None for k in [lsts, freqs, ew_bl_len_ns]): raise ValueError( "Must provide 'lsts', 'freqs' and 'ew_bl_len_ns' if fringe_filter not " "given." ) fringe_filter = gen_fringe_filter( freqs=freqs, lsts=lsts, ew_bl_len_ns=ew_bl_len_ns, **kwargs ) # apply filter filt_data = np.fft.ifft(dfft * fringe_filter, axis=0) return filt_data
[docs]def calc_max_fringe_rate(fqs: np.ndarray, ew_bl_len_ns: float) -> np.ndarray: """ Calculate the max fringe-rate seen by an East-West baseline. Parameters ---------- fqs Frequency array [GHz] ew_bl_len_ns (float): projected East-West baseline length [ns] ew_bl_len_ns The EW baseline length, in nanosec. Returns ------- fr_max Maximum fringe rate [Hz] """ bl_wavelen = fqs * ew_bl_len_ns return 2 * np.pi / units.sday.to("s") * bl_wavelen
[docs]def compute_ha(lsts: np.ndarray, ra: float) -> np.ndarray: """ Compute hour angle from local sidereal time and right ascension. Parameters ---------- lsts Local sidereal times of the observation to be generated [radians]. Shape=(NTIMES,) ra The right ascension of a point source [radians]. Returns ------- ha Hour angle corresponding to the provide ra and times. Shape=(NTIMES,) """ ha = lsts - ra ha = np.where(ha > np.pi, ha - 2 * np.pi, ha) ha = np.where(ha < -np.pi, ha + 2 * np.pi, ha) return ha
[docs]def wrap2pipi(a): """ Wrap values of an array to [-π; +π] modulo 2π. Parameters ---------- a: array_like Array of values to be wrapped to [-π; +π]. Returns ------- res: array_like Array of 'a' values wrapped to [-π; +π]. """ # np.fmod(~, 2π) outputs values in [0; 2π] or [-2π; 0] res = np.fmod(a, 2 * np.pi) # wrap [π; 2π] to [-π; 0]... res[np.where(res > np.pi)] -= 2 * np.pi # ... and [-2π; -π] to [0; π] res[np.where(res < -np.pi)] += 2 * np.pi return res
[docs]def gen_white_noise( size: int | tuple[int] = 1, rng: np.random.Generator | None = None ) -> np.ndarray: """Produce complex Gaussian noise with unity variance. Parameters ---------- size Shape of output array. Can be an integer if a single dimension is required, otherwise a tuple of ints. rng Random number generator. Returns ------- noise White noise realization with specified shape. """ # Split power evenly between real and imaginary components. std = 1 / np.sqrt(2) args = dict(scale=std, size=size) # Create a random number generator if needed, then generate noise. rng = rng or np.random.default_rng() return rng.normal(**args) + 1j * rng.normal(**args)
[docs]def jansky_to_kelvin(freqs: np.ndarray, omega_p: Beam | np.ndarray) -> np.ndarray: """Return Kelvin -> Jy conversion as a function of frequency. Parameters ---------- freqs Frequencies for which to calculate the conversion. Units of GHz. omega_p Beam area as a function of frequency. Must have the same shape as ``freqs`` if an ndarray. Otherwise, must be an interpolation object which converts frequencies (in GHz) to beam size. Returns ------- Jy_to_K Array for converting Jy to K, same shape as ``freqs``. """ # get actual values of omega_p if it's an interpolation object if callable(omega_p): omega_p = omega_p(freqs) wavelengths = constants.c.value / (freqs * 1e9) # meters # The factor of 1e-26 converts from Jy to W/m^2/Hz. return 1e-26 * wavelengths**2 / (2 * constants.k_B.value * omega_p)
[docs]def Jy2T(freqs, omega_p): """Convert Janskys to Kelvin. Deprecated in v1.0.0. Will be removed in v1.1.0 """ warnings.warn( "This function has been deprecated. Please use `jansky_to_kelvin` instead.", stacklevel=1, ) return jansky_to_kelvin(freqs, omega_p)
def _listify(x): """Ensure a scalar/list is returned as a list. Taken from https://stackoverflow.com/a/1416677/1467820 Copied from the pre-v1 hera_sim.rfi module. """ try: basestring except NameError: basestring = (str, bytes) if isinstance(x, basestring): return [x] else: try: iter(x) except TypeError: return [x] else: return list(x)
[docs]def reshape_vis( vis: np.ndarray, ant_1_array: np.ndarray, ant_2_array: np.ndarray, pol_array: np.ndarray, antenna_numbers: np.ndarray, n_times: int, n_freqs: int, n_ants: int, n_pols: int, invert: bool = False, use_numba: bool = True, ) -> np.ndarray: """Reshaping helper for mutual coupling sims. The mutual coupling simulations take as input, and return, a data array with shape ``(Nblts, Nfreqs, Npols)``, but perform matrix multiplications on the data array reshaped to ``(Ntimes, Nfreqs, 2*Nants, 2*Nants)``. This function performs the reshaping between the matrix multiply shape and the input/output array shapes. Parameters ---------- vis Input data array. ant_1_array Array specifying the first antenna in each baseline. ant_2_array Array specifying the second antenna in each baseline. pol_array Array specifying the observed polarizations via polarization numbers. antenna_numbers Array specifying all of the antennas to include in the reshaped data. n_times Number of integrations in the data. n_freqs Number of frequency channels in the data. n_ants Number of antennas. n_pols Number of polarizations in the data. invert Whether to reshape to :class:`pyuvdata.UVData`'s data array shape. use_numba Whether to use ``numba`` to speed up the reshaping. Returns ------- reshaped_vis Input data reshaped to desired shape. """ if invert: out = np.zeros((ant_1_array.size, n_freqs, n_pols), dtype=complex) else: out = np.zeros((n_times, n_freqs, 2 * n_ants, 2 * n_ants), dtype=complex) # If we have numba, then this is a bit faster. if HAVE_NUMBA and use_numba: # pragma: no cover if invert: fnc = jit_reshape_vis_invert else: fnc = jit_reshape_vis fnc( vis=vis, out=out, ant_1_array=ant_1_array, ant_2_array=ant_2_array, pol_array=pol_array, antenna_numbers=antenna_numbers, ) return out # We don't have numba, so we need to do this a bit more slowly. pol_slices = {"x": slice(None, None, 2), "y": slice(1, None, 2)} polnum2str = {pol: uvutils.polnum2str(pol) for pol in pol_array} for i, ai in enumerate(antenna_numbers): for j, aj in enumerate(antenna_numbers[i:]): j += i uvd_inds = np.argwhere((ant_1_array == ai) & (ant_2_array == aj)).flatten() flipped = uvd_inds.size == 0 ii, jj = i, j if flipped: uvd_inds = np.argwhere( (ant_2_array == ai) & (ant_1_array == aj) ).flatten() ii, jj = j, i if uvd_inds.size == 0: continue for k, pol in enumerate(pol_array): p1, p2 = polnum2str[pol] if flipped: p1, p2 = p2, p1 sl1, sl2 = (pol_slices[p.lower()] for p in (p1, p2)) # NOTE: this is hard-coded to use the new-style UVData shapes! if invert: # Going back to UVData shape out[uvd_inds, :, k] = vis[:, :, sl1, sl2][:, :, ii, jj] else: # Changing from UVData shape out[:, :, sl1, sl2][:, :, ii, jj] = vis[uvd_inds, :, k] out[:, :, sl2, sl1][:, :, jj, ii] = np.conj(vis[uvd_inds, :, k]) return out
[docs]def matmul(left: np.ndarray, right: np.ndarray, use_numba: bool = False) -> np.ndarray: """Helper function for matrix multiplies used in mutual coupling sims. The :class:`~sigchain.MutualCoupling` class performs two matrix multiplications of arrays with shapes ``(1, Nfreqs, 2*Nant, 2*Nant)`` and ``(Ntimes, Nfreqs, 2*Nant, 2*Nant)``. Typically the number of antennas is much less than the number of frequency channels, so the parallelization used by ``numpy``'s matrix multiplication routine tends to be sub-optimal. This routine--when used with ``numba``--produces a substantial speedup in matrix multiplication for typical HERA-sized problems. Parameters ---------- left, right Input arrays to perform matrix multiplication left @ right. use_numba Whether to use ``numba`` to speed up the matrix multiplication. Returns ------- prod Product of the matrix multiplication left @ right. Notes ----- """ if HAVE_NUMBA and use_numba: if left.shape[0] == 1: return _left_matmul(left, right) elif right.shape[0] == 1: return _right_matmul(left, right) elif left.shape == right.shape: return _matmul(left, right) else: raise ValueError("Inputs cannot be broadcast to a common shape.") else: return left @ right
[docs]def find_baseline_orientations( antenna_numbers: np.ndarray, enu_antpos: np.ndarray ) -> dict[tuple[int, int], float]: """Find the orientation of each redundant baseline group. Parameters ---------- antenna_numbers Array containing antenna numbers corresponding to the provided antenna positions. enu_antpos ``(Nants,3)`` array containing the antenna positions in a local topocentric frame with basis (east, north, up). Returns ------- antpair2angle Dictionary mapping antenna pairs ``(ai,aj)`` to baseline orientations. Orientations are defined on [0,2pi). """ groups, baselines = uvutils.redundancy.get_antenna_redundancies( antenna_numbers, enu_antpos, include_autos=False )[:2] antpair2angle = {} for group, (e, n, _u) in zip(groups, baselines): angle = Longitude(np.arctan2(n, e) * units.rad).value conj_angle = Longitude((angle + np.pi) * units.rad).value for blnum in group: ai, aj = uvutils.baseline_to_antnums( blnum, Nants_telescope=antenna_numbers.size ) antpair2angle[(ai, aj)] = angle antpair2angle[(aj, ai)] = conj_angle return antpair2angle
[docs]def tanh_window(x, x_min=None, x_max=None, scale_low=1, scale_high=1): if x_min is None and x_max is None: warnings.warn( "Insufficient information provided; you must provide either x_min or " "x_max. Returning uniform window.", stacklevel=1, ) return np.ones(x.size) window = np.ones(x.size) if x_min is not None: window *= 0.5 * (1 + np.tanh((x - x_min) / scale_low)) if x_max is not None: window *= 0.5 * (1 + np.tanh((x_max - x) / scale_high)) return window
# Just some numba-fied helpful functions. # Note that coverage can't see that these are run without disabling JIT, # which kind of defeats the purpose of testing it. if HAVE_NUMBA: # pragma: no cover
[docs] @numba.njit def jit_reshape_vis(vis, out, ant_1_array, ant_2_array, pol_array, antenna_numbers): """JIT-accelerated reshaping function. See :func:`~reshape_vis` for parameter information. """ # This is basically the same as the non-numba reshape function, # but it's not as pretty. x_sl = slice(None, None, 2) y_sl = slice(1, None, 2) for i, ai in enumerate(antenna_numbers): for j, aj in enumerate(antenna_numbers[i:]): j += i uvd_inds = (ant_1_array == ai) & (ant_2_array == aj) flipped = False ii, jj = i, j if np.all(~uvd_inds): uvd_inds = (ant_2_array == ai) & (ant_1_array == aj) flipped = True ii, jj = j, i # Don't do anything if this baseline isn't present. if np.all(~uvd_inds): continue uvd_inds = np.argwhere(uvd_inds).flatten() for k, pol in enumerate(pol_array): if pol == -5: p1, p2 = x_sl, x_sl elif pol == -6: p1, p2 = y_sl, y_sl elif pol == -7: p1, p2 = x_sl, y_sl else: p1, p2 = y_sl, x_sl if flipped: p1, p2 = p2, p1 _p = out[:, :, p1, p2] for tidx, uvd_ind in enumerate(uvd_inds): _p[tidx, :, ii, jj] = vis[uvd_ind, :, k] _p[tidx, :, jj, ii] = np.conj(vis[uvd_ind, :, k]) return out
[docs] @numba.njit def jit_reshape_vis_invert( vis, out, ant_1_array, ant_2_array, pol_array, antenna_numbers ): """JIT-accelerated reshaping function. See :func:`~reshape_vis` for parameter information. """ # This is basically the same as the non-numba reshape function, # but it's not as pretty. x_sl = slice(None, None, 2) y_sl = slice(1, None, 2) for i, ai in enumerate(antenna_numbers): for j, aj in enumerate(antenna_numbers[i:]): j += i uvd_inds = (ant_1_array == ai) & (ant_2_array == aj) flipped = False ii, jj = i, j if np.all(~uvd_inds): uvd_inds = (ant_2_array == ai) & (ant_1_array == aj) flipped = True ii, jj = j, i # Don't do anything if this baseline isn't present. if np.all(~uvd_inds): continue uvd_inds = np.argwhere(uvd_inds).flatten() for k, pol in enumerate(pol_array): if pol == -5: p1, p2 = x_sl, x_sl elif pol == -6: p1, p2 = y_sl, y_sl elif pol == -7: p1, p2 = x_sl, y_sl else: p1, p2 = y_sl, x_sl if flipped: p1, p2 = p2, p1 # NOTE: This is hard-coded to use new-style UVData arrays! # Go back to UVData shape _p = vis[:, :, p1, p2] for tidx, uvd_ind in enumerate(uvd_inds): out[uvd_ind, :, k] = _p[tidx, :, ii, jj] tidx += 1 return out
@numba.njit def _left_matmul(left, right): """JIT-accelerated matrix multiplication. This multiply assumes the zeroth axis of the ``left`` array is length 1. """ out = np.zeros_like(right) for i in range(out.shape[0]): for j in range(out.shape[1]): out[i, j] = left[0, j] @ right[i, j] return out @numba.njit def _right_matmul(left, right): """JIT-accelerated matrix multiplication. This multiply assumes the zeroth axis of the ``right`` array is length 1. """ out = np.zeros_like(left) for i in range(out.shape[0]): for j in range(out.shape[1]): out[i, j] = left[i, j] @ right[0, j] return out @numba.njit def _matmul(left, right): """JIT-accelerated matrix multiplication. This multiply assumes both arrays have the same shape. It should only provide a speedup over ``numpy``'s matrix multiplication for cases where the first two axes of the input arrays are much larger than the last two axes. """ out = np.zeros_like(left) for i in range(out.shape[0]): for j in range(out.shape[1]): out[i, j] = left[i, j] @ right[i, j] return out