diff --git a/baseband/base/payload.py b/baseband/base/payload.py index 8556a039..4d8d5cd9 100644 --- a/baseband/base/payload.py +++ b/baseband/base/payload.py @@ -143,7 +143,7 @@ def tofile(self, fh): return fh.write(self.words.tobytes()) @classmethod - def fromdata(cls, data, header=None, bps=2): + def fromdata(cls, data, header=None, bps=2, **kwargs): """Encode data as a payload. Parameters @@ -156,6 +156,8 @@ def fromdata(cls, data, header=None, bps=2): bps : int, optional Bits per elementary sample, i.e., per channel and per real or imaginary component, used if header is not given. Default: 2. + **kwargs + Any other arguments to pass on to the class initializer. """ sample_shape = data.shape[1:] complex_data = data.dtype.kind == 'c' @@ -170,16 +172,20 @@ def fromdata(cls, data, header=None, bps=2): .format(*(('complex' if c else 'real') for c in (header.complex_data, complex_data)))) - try: - encoder = cls._encoders[bps] - except KeyError: - raise ValueError("{0} cannot encode data with {1} bits" - .format(cls.__name__, bps)) - if complex_data: - data = data.view((data.real.dtype, (2,))) - words = encoder(data).ravel().view(cls._dtype_word) - return cls(words, sample_shape=sample_shape, bps=bps, - complex_data=complex_data) + payload_nbytes = header.payload_nbytes + base_kwargs = {"header": header} + else: + base_kwargs = { + "bps": bps, + "sample_shape": sample_shape, + "complex_data": complex_data, + } + payload_nbytes = data.size * (2 if complex_data else 1) * bps // 8 + + n_words = payload_nbytes // cls._dtype_word.itemsize + self = cls(np.empty(n_words, cls._dtype_word), **base_kwargs, **kwargs) + self[:] = data + return self def __array__(self, dtype=None, copy=None): """Interface to arrays.""" @@ -233,9 +239,8 @@ def _item_to_slices(self, item): Slice such that if one decodes ``ds = self.words[words_slice]``, ``ds`` is the smallest possible array that includes all of the requested ``item``. - data_slice : int or slice - Int or slice such that ``decode(ds)[data_slice]`` is the requested - ``item``. + data_slice : tuple of slice or int + Such that ``decode(ds)[data_slice]`` is the requested ``item``. Notes ----- @@ -247,7 +252,7 @@ def _item_to_slices(self, item): """ if isinstance(item, tuple): sample_index = item[1:] - item = item[0] + item = item[0] if item else slice(None) else: sample_index = () @@ -306,45 +311,41 @@ def _item_to_slices(self, item): return words_slice, (data_slice,) + sample_index - def __getitem__(self, item=()): - decoder = self._decoders[self._coder] - if item == () or item == slice(None): - data = decoder(self.words) - if self.complex_data: - data = data.view(self.dtype) - return data.reshape(self.shape) + def _decode(self, words): + return self._decoders[self._coder](words).view(self.dtype) - words_slice, data_slice = self._item_to_slices(item) + def _encode(self, data): + try: + encoder = self._encoders[self._coder] + except KeyError: + raise ValueError(f"{self.__class__.__name__} cannot encode data " + f"with {self._coder} bits") from None + if data.dtype.kind == 'c': + data = data.view((data.real.dtype, (2,))) + return encoder(data) - return (decoder(self.words[words_slice]).view(self.dtype) + def __getitem__(self, item=()): + words_slice, data_slice = self._item_to_slices(item) + return (self._decode(self.words[words_slice]) .reshape(-1, *self.sample_shape)[data_slice]) def __setitem__(self, item, data): - if item == () or item == slice(None): - words_slice = data_slice = slice(None) - else: - words_slice, data_slice = self._item_to_slices(item) + words_slice, data_slice = self._item_to_slices(item) data = np.asanyarray(data) # Check if the new data spans an entire word and is correctly shaped. # If so, skip decoding. If not, decode appropriate words and insert # new data. - if not (data_slice == slice(None) + if not (data_slice == (slice(None),) and data.shape[-len(self.sample_shape):] == self.sample_shape and data.dtype.kind == self.dtype.kind): - decoder = self._decoders[self._coder] - current_data = decoder(self.words[words_slice]) - if self.complex_data: - current_data = current_data.view(self.dtype) + current_data = self._decode(self.words[words_slice]) current_data.shape = (-1,) + self.sample_shape current_data[data_slice] = data data = current_data - if data.dtype.kind == 'c': - data = data.view((data.real.dtype, (2,))) - - encoder = self._encoders[self._coder] - self.words[words_slice] = encoder(data).ravel().view(self._dtype_word) + encoded_words = self._encode(data).ravel().view(self._dtype_word) + self.words[words_slice] = encoded_words data = property(__getitem__, doc="Full decoded payload.") diff --git a/baseband/base/tests/test_base.py b/baseband/base/tests/test_base.py index 638da143..b4107efd 100644 --- a/baseband/base/tests/test_base.py +++ b/baseband/base/tests/test_base.py @@ -282,8 +282,10 @@ def test_payload_getitem_setitem(self, item): payload = self.Payload.fromdata(data + 1j * data, bps=8) sel_data = payload.data[item] assert np.all(payload[item] == sel_data) - payload[item] = 1 - sel_data + # Check __setitem__ check = payload.data + payload[item] = 1 - sel_data + assert not np.all(payload.data == check) check[item] = 1 - sel_data assert np.all(payload.data == check) diff --git a/baseband/gsb/payload.py b/baseband/gsb/payload.py index 55b90bf1..efb23e9e 100644 --- a/baseband/gsb/payload.py +++ b/baseband/gsb/payload.py @@ -73,7 +73,7 @@ class GSBPayload(PayloadBase): 8: encode_8bit} _decoders = {4: decode_4bit, 8: decode_8bit} - _dtype_word = np.int8 + _dtype_word = np.dtype("i1") _sample_shape_maker_1thread = namedtuple('SampleShape', 'nchan') _sample_shape_maker_nthread = namedtuple('SampleShape', 'nthread, nchan') diff --git a/baseband/guppi/payload.py b/baseband/guppi/payload.py index a49bc73f..f15f054f 100644 --- a/baseband/guppi/payload.py +++ b/baseband/guppi/payload.py @@ -68,131 +68,71 @@ def fromdata(cls, data, header=None, bps=8, channels_first=True): Parameters ---------- data : `~numpy.ndarray` - Data to be encoded. The last dimension is taken as the number of - channels. + Data to be encoded. The trailing dimensions are taken as the + sample shape, normally (npol, nchan). header : `~baseband.guppi.GUPPIHeader`, optional If given, used to infer the ``bps`` and ``channels_first``. bps : int, optional - Bits per elementary sample, used if ``header`` is `None`. + Bits per elementary sample, used only if ``header`` is `None`. Default: 8. channels_first : bool, optional - Whether encoded data should be ordered as (nchan, nsample, npol), - used if ``header`` is `None`. Default: `True`. + If `True`, encode data (nchan, nsample, npol). otherwise + as (nsample, nchan, npol). Used only if ``header`` is `None`. + Default: `True`. """ - if header is not None: - bps = header.bps - channels_first = header.channels_first - sample_shape = data.shape[1:] - complex_data = data.dtype.kind == 'c' - try: - encoder = cls._encoders[bps] - except KeyError: - raise ValueError("{0} cannot encode data with {1} bits" - .format(cls.__name__, bps)) - # If channels-first, switch to (nchan, nsample, npol); otherwise use - # (nsample, nchan, npol). - if channels_first: - data = data.transpose(2, 0, 1) - else: - data = data.transpose(0, 2, 1) - if complex_data: - data = data.view((data.real.dtype, (2,))) - words = encoder(data).ravel().view(cls._dtype_word) - return cls(words, sample_shape=sample_shape, bps=bps, - complex_data=complex_data, channels_first=channels_first) + return super().fromdata(data, header=header, bps=bps, + channels_first=channels_first) def __len__(self): """Number of samples in the payload.""" return self.nbytes * 8 // self._true_bpfs + def _decode(self, words, words_slice): + if self.channels_first: + # Before decoding, reshape so channels fall along first axis. + decoded_words = super()._decode( + words.reshape(self.sample_shape.nchan, -1)[:, words_slice]) + # Reshape to (nsample, nchan, npol), as expected by data_slice. + return decoded_words.T.reshape(-1, *self.sample_shape) + else: + # Transpose result to allow data_slice to assume (npol, nchan). + return (super()._decode(words[words_slice]) + .reshape(-1, self.sample_shape.nchan, + self.sample_shape.npol) + .transpose(0, 2, 1)) + def __getitem__(self, item=()): # GUPPI data may be stored as (nsample, nchan, npol) or, if # channels-first, (nchan, nsample, npol), both of which require # reshaping to get the usual order of (nsample, npol, nchan). - decoder = self._decoders[self._coder] - - # If we want to decode the entire dataset. - if item == () or item == slice(None): - data = decoder(self.words) - if self.complex_data: - data = data.view(self.dtype) - if self.channels_first: - # Reshape to (nchan, nsample, npol); transpose to usual order. - return (data.reshape(self.sample_shape.nchan, -1, - self.sample_shape.npol) - .transpose(1, 2, 0)) - else: - # Reshape to (nsample, nchan, npol); transpose to usual order. - return (data.reshape(-1, self.sample_shape.nchan, - self.sample_shape.npol) - .transpose(0, 2, 1)) words_slice, data_slice = self._item_to_slices(item) - - if self.channels_first: - # Reshape words so channels fall along first axis, then decode. - decoded_words = decoder(self.words.reshape(self.sample_shape.nchan, - -1)[:, words_slice]) - # Reshape to (nsample, nchan, npol), then use data_slice. - return (decoded_words.view(self.dtype).T - .reshape(-1, *self.sample_shape)[data_slice]) - else: - # data_slice assumes (npol, nchan), so transpose before using it. - return (decoder(self.words[words_slice]).view(self.dtype) - .reshape(-1, self.sample_shape.nchan, - self.sample_shape.npol) - .transpose(0, 2, 1)[data_slice]) + return self._decode(self.words, words_slice)[data_slice] def __setitem__(self, item, data): - if item == () or item == slice(None): - words_slice = data_slice = slice(None) - else: - words_slice, data_slice = self._item_to_slices(item) + words_slice, data_slice = self._item_to_slices(item) data = np.asanyarray(data) # Check if the new data spans an entire word and is correctly shaped. # If so, skip decoding. If not, decode appropriate words and insert # new data. - if not (data_slice == slice(None) + if not (data_slice == (slice(None),) and data.shape[-2:] == self.sample_shape and data.dtype.kind == self.dtype.kind): - decoder = self._decoders[self._coder] - if self.channels_first: - decoded_words = decoder(np.ascontiguousarray( - self.words.reshape( - self.sample_shape.nchan, -1)[:, words_slice])) - current_data = (decoded_words.view(self.dtype) - .T.reshape(-1, *self.sample_shape)) - else: - current_data = (decoder(self.words[words_slice]) - .view(self.dtype) - .reshape(-1, self.sample_shape.nchan, - self.sample_shape.npol) - .transpose(0, 2, 1)) - + current_data = self._decode(self.words, words_slice) current_data[data_slice] = data data = current_data # Reshape before separating real and complex components. if self.channels_first: - data = data.reshape(-1, self.sample_shape.nchan).T - else: - data = data.transpose(0, 2, 1) - - # Separate real and complex components. - if data.dtype.kind == 'c': - data = data.view((data.real.dtype, (2,))) - - # Select encoder. - encoder = self._encoders[self._coder] - - # Reshape and encode words. - if self.channels_first: + encoded_words = (self._encode(data.transpose(2, 0, 1)) + .reshape(self.sample_shape.nchan, -1) + .view(self._dtype_word)) self.words.reshape(self.sample_shape.nchan, -1)[:, words_slice] = ( - encoder(data).reshape(self.sample_shape.nchan, -1) - .view(self._dtype_word)) + encoded_words) else: - self.words[words_slice] = (encoder(data.ravel()) - .view(self._dtype_word)) + encoded_words = (self._encode(data.transpose(0, 2, 1)) + .ravel().view(self._dtype_word)) + self.words[words_slice] = encoded_words data = property(__getitem__, doc="Full decoded payload.")