Source code for pyuff_ustb.readers.lazy_arrays.base

from dataclasses import dataclass
from functools import reduce
from typing import Sequence, Tuple

import numpy as np

from pyuff_ustb.readers import Reader
from pyuff_ustb.readers.lazy_arrays.lazy_operations import (
    LazyOperation,
    LazyTranspose,
    apply_lazy_operations_on_data,
    apply_lazy_operations_on_index,
    apply_lazy_operations_on_shape,
)


[docs]@dataclass(frozen=True) class LazyArray: _reader: Reader _lazy_operations: Sequence[LazyOperation] = () def __post_init__(self): if not isinstance(self._reader, Reader): raise TypeError( f"Expected a Reader object, got {type(self._reader)} instead." ) def __repr__(self): return f"<LazyArray shape={self.shape} dtype={self.dtype}>" def __getitem__(self, k) -> np.ndarray: is_complex = np.squeeze(self._reader.attrs["complex"]) if is_complex: with self._reader["real"].read() as obj: real = obj k = apply_lazy_operations_on_index(k, real.shape, self._lazy_operations) real = real.__getitem__(k) with self._reader["imag"].read() as obj: imag = obj.__getitem__(k) value = real + 1j * imag else: with self._reader.read() as obj: k = apply_lazy_operations_on_index(k, obj.shape, self._lazy_operations) value = obj.__getitem__(k) return apply_lazy_operations_on_data(value, self._lazy_operations) @property def T(self): return LazyArray(self._reader, (*self._lazy_operations, LazyTranspose())) @property def shape(self) -> Tuple[int, ...]: is_complex = np.squeeze(self._reader.attrs["complex"]) reader = self._reader["real"] if is_complex else self._reader with reader.read() as obj: shape = obj.shape return apply_lazy_operations_on_shape(shape, self._lazy_operations) @property def size(self) -> int: return reduce(lambda x, y: x * y, self.shape) @property def ndim(self) -> int: return len(self.shape) @property def dtype(self) -> np.dtype: if np.squeeze(self._reader.attrs["complex"]): # NOTE: Complex numbers may have a different number of bits than stated return np.complex128 with self._reader.read() as obj: return obj.dtype def __len__(self) -> int: if self.ndim == 0: raise TypeError("len() of unsized object") return self.shape[0] # Conversion operations def __array__(self): return self[...] def __jax_array__(self): import jax.numpy as jnp return jnp.array(self[...]) def __float__(self): return float(self[...]) def __int__(self): return float(self[...]) # Math operations def __add__(self, other): return self[...] + other def __sub__(self, other): return self[...] - other def __mul__(self, other): return self[...] * other def __truediv__(self, other): return self[...] / other def __radd__(self, other): return other + self[...] def __rsub__(self, other): return other - self[...] def __rmul__(self, other): return other * self[...] def __rtruediv__(self, other): return other / self[...] # Check for equality def __eq__(self, other): return np.array_equal(self[...], other)
[docs]class LazyScalar(LazyArray):
[docs] def __init__( self, reader: Reader, lazy_operations: Sequence[LazyOperation] = (), ): def transform_shape(shape: Tuple[int, ...]): # If you think that this assertion should not fail, then it might be a bug # and we should use LazyArray instead for the given value. assert all(dim == 1 for dim in shape), "Expected a scalar value." return () lazy_squeeze = LazyOperation(np.squeeze, transform_shape=transform_shape) super().__init__(reader, (lazy_squeeze, *lazy_operations))