Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 39 additions & 38 deletions baseband/base/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def tofile(self, fh):
return fh.write(self.words.tobytes())

@classmethod
def fromdata(cls, data, header=None, bps=2):
def fromdata(cls, data, header=None, bps=2, **kwargs):
"""Encode data as a payload.

Parameters
Expand All @@ -156,6 +156,8 @@ def fromdata(cls, data, header=None, bps=2):
bps : int, optional
Bits per elementary sample, i.e., per channel and per real or
imaginary component, used if header is not given. Default: 2.
**kwargs
Any other arguments to pass on to the class initializer.
"""
sample_shape = data.shape[1:]
complex_data = data.dtype.kind == 'c'
Expand All @@ -170,16 +172,20 @@ def fromdata(cls, data, header=None, bps=2):
.format(*(('complex' if c else 'real') for c
in (header.complex_data,
complex_data))))
try:
encoder = cls._encoders[bps]
except KeyError:
raise ValueError("{0} cannot encode data with {1} bits"
.format(cls.__name__, bps))
if complex_data:
data = data.view((data.real.dtype, (2,)))
words = encoder(data).ravel().view(cls._dtype_word)
return cls(words, sample_shape=sample_shape, bps=bps,
complex_data=complex_data)
payload_nbytes = header.payload_nbytes
base_kwargs = {"header": header}
else:
base_kwargs = {
"bps": bps,
"sample_shape": sample_shape,
"complex_data": complex_data,
}
payload_nbytes = data.size * (2 if complex_data else 1) * bps // 8

n_words = payload_nbytes // cls._dtype_word.itemsize
self = cls(np.empty(n_words, cls._dtype_word), **base_kwargs, **kwargs)
self[:] = data
return self

def __array__(self, dtype=None, copy=None):
"""Interface to arrays."""
Expand Down Expand Up @@ -233,9 +239,8 @@ def _item_to_slices(self, item):
Slice such that if one decodes ``ds = self.words[words_slice]``,
``ds`` is the smallest possible array that includes all
of the requested ``item``.
data_slice : int or slice
Int or slice such that ``decode(ds)[data_slice]`` is the requested
``item``.
data_slice : tuple of slice or int
Such that ``decode(ds)[data_slice]`` is the requested ``item``.

Notes
-----
Expand All @@ -247,7 +252,7 @@ def _item_to_slices(self, item):
"""
if isinstance(item, tuple):
sample_index = item[1:]
item = item[0]
item = item[0] if item else slice(None)
else:
sample_index = ()

Expand Down Expand Up @@ -306,45 +311,41 @@ def _item_to_slices(self, item):

return words_slice, (data_slice,) + sample_index

def __getitem__(self, item=()):
decoder = self._decoders[self._coder]
if item == () or item == slice(None):
data = decoder(self.words)
if self.complex_data:
data = data.view(self.dtype)
return data.reshape(self.shape)
def _decode(self, words):
return self._decoders[self._coder](words).view(self.dtype)

words_slice, data_slice = self._item_to_slices(item)
def _encode(self, data):
try:
encoder = self._encoders[self._coder]
except KeyError:
raise ValueError(f"{self.__class__.__name__} cannot encode data "
f"with {self._coder} bits") from None
if data.dtype.kind == 'c':
data = data.view((data.real.dtype, (2,)))
return encoder(data)

return (decoder(self.words[words_slice]).view(self.dtype)
def __getitem__(self, item=()):
words_slice, data_slice = self._item_to_slices(item)
return (self._decode(self.words[words_slice])
.reshape(-1, *self.sample_shape)[data_slice])

def __setitem__(self, item, data):
if item == () or item == slice(None):
words_slice = data_slice = slice(None)
else:
words_slice, data_slice = self._item_to_slices(item)
words_slice, data_slice = self._item_to_slices(item)

data = np.asanyarray(data)
# Check if the new data spans an entire word and is correctly shaped.
# If so, skip decoding. If not, decode appropriate words and insert
# new data.
if not (data_slice == slice(None)
if not (data_slice == (slice(None),)
and data.shape[-len(self.sample_shape):] == self.sample_shape
and data.dtype.kind == self.dtype.kind):
decoder = self._decoders[self._coder]
current_data = decoder(self.words[words_slice])
if self.complex_data:
current_data = current_data.view(self.dtype)
current_data = self._decode(self.words[words_slice])
current_data.shape = (-1,) + self.sample_shape
current_data[data_slice] = data
data = current_data

if data.dtype.kind == 'c':
data = data.view((data.real.dtype, (2,)))

encoder = self._encoders[self._coder]
self.words[words_slice] = encoder(data).ravel().view(self._dtype_word)
encoded_words = self._encode(data).ravel().view(self._dtype_word)
self.words[words_slice] = encoded_words

data = property(__getitem__, doc="Full decoded payload.")

Expand Down
4 changes: 3 additions & 1 deletion baseband/base/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,10 @@ def test_payload_getitem_setitem(self, item):
payload = self.Payload.fromdata(data + 1j * data, bps=8)
sel_data = payload.data[item]
assert np.all(payload[item] == sel_data)
payload[item] = 1 - sel_data
# Check __setitem__
check = payload.data
payload[item] = 1 - sel_data
assert not np.all(payload.data == check)
check[item] = 1 - sel_data
assert np.all(payload.data == check)

Expand Down
2 changes: 1 addition & 1 deletion baseband/gsb/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class GSBPayload(PayloadBase):
8: encode_8bit}
_decoders = {4: decode_4bit,
8: decode_8bit}
_dtype_word = np.int8
_dtype_word = np.dtype("i1")

_sample_shape_maker_1thread = namedtuple('SampleShape', 'nchan')
_sample_shape_maker_nthread = namedtuple('SampleShape', 'nthread, nchan')
Expand Down
126 changes: 33 additions & 93 deletions baseband/guppi/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,131 +68,71 @@ def fromdata(cls, data, header=None, bps=8, channels_first=True):
Parameters
----------
data : `~numpy.ndarray`
Data to be encoded. The last dimension is taken as the number of
channels.
Data to be encoded. The trailing dimensions are taken as the
sample shape, normally (npol, nchan).
header : `~baseband.guppi.GUPPIHeader`, optional
If given, used to infer the ``bps`` and ``channels_first``.
bps : int, optional
Bits per elementary sample, used if ``header`` is `None`.
Bits per elementary sample, used only if ``header`` is `None`.
Default: 8.
channels_first : bool, optional
Whether encoded data should be ordered as (nchan, nsample, npol),
used if ``header`` is `None`. Default: `True`.
If `True`, encode data (nchan, nsample, npol). otherwise
as (nsample, nchan, npol). Used only if ``header`` is `None`.
Default: `True`.
"""
if header is not None:
bps = header.bps
channels_first = header.channels_first
sample_shape = data.shape[1:]
complex_data = data.dtype.kind == 'c'
try:
encoder = cls._encoders[bps]
except KeyError:
raise ValueError("{0} cannot encode data with {1} bits"
.format(cls.__name__, bps))
# If channels-first, switch to (nchan, nsample, npol); otherwise use
# (nsample, nchan, npol).
if channels_first:
data = data.transpose(2, 0, 1)
else:
data = data.transpose(0, 2, 1)
if complex_data:
data = data.view((data.real.dtype, (2,)))
words = encoder(data).ravel().view(cls._dtype_word)
return cls(words, sample_shape=sample_shape, bps=bps,
complex_data=complex_data, channels_first=channels_first)
return super().fromdata(data, header=header, bps=bps,
channels_first=channels_first)

def __len__(self):
"""Number of samples in the payload."""
return self.nbytes * 8 // self._true_bpfs

def _decode(self, words, words_slice):
if self.channels_first:
# Before decoding, reshape so channels fall along first axis.
decoded_words = super()._decode(
words.reshape(self.sample_shape.nchan, -1)[:, words_slice])
# Reshape to (nsample, nchan, npol), as expected by data_slice.
return decoded_words.T.reshape(-1, *self.sample_shape)
else:
# Transpose result to allow data_slice to assume (npol, nchan).
return (super()._decode(words[words_slice])
.reshape(-1, self.sample_shape.nchan,
self.sample_shape.npol)
.transpose(0, 2, 1))

def __getitem__(self, item=()):
# GUPPI data may be stored as (nsample, nchan, npol) or, if
# channels-first, (nchan, nsample, npol), both of which require
# reshaping to get the usual order of (nsample, npol, nchan).
decoder = self._decoders[self._coder]

# If we want to decode the entire dataset.
if item == () or item == slice(None):
data = decoder(self.words)
if self.complex_data:
data = data.view(self.dtype)
if self.channels_first:
# Reshape to (nchan, nsample, npol); transpose to usual order.
return (data.reshape(self.sample_shape.nchan, -1,
self.sample_shape.npol)
.transpose(1, 2, 0))
else:
# Reshape to (nsample, nchan, npol); transpose to usual order.
return (data.reshape(-1, self.sample_shape.nchan,
self.sample_shape.npol)
.transpose(0, 2, 1))

words_slice, data_slice = self._item_to_slices(item)

if self.channels_first:
# Reshape words so channels fall along first axis, then decode.
decoded_words = decoder(self.words.reshape(self.sample_shape.nchan,
-1)[:, words_slice])
# Reshape to (nsample, nchan, npol), then use data_slice.
return (decoded_words.view(self.dtype).T
.reshape(-1, *self.sample_shape)[data_slice])
else:
# data_slice assumes (npol, nchan), so transpose before using it.
return (decoder(self.words[words_slice]).view(self.dtype)
.reshape(-1, self.sample_shape.nchan,
self.sample_shape.npol)
.transpose(0, 2, 1)[data_slice])
return self._decode(self.words, words_slice)[data_slice]

def __setitem__(self, item, data):
if item == () or item == slice(None):
words_slice = data_slice = slice(None)
else:
words_slice, data_slice = self._item_to_slices(item)
words_slice, data_slice = self._item_to_slices(item)

data = np.asanyarray(data)
# Check if the new data spans an entire word and is correctly shaped.
# If so, skip decoding. If not, decode appropriate words and insert
# new data.
if not (data_slice == slice(None)
if not (data_slice == (slice(None),)
and data.shape[-2:] == self.sample_shape
and data.dtype.kind == self.dtype.kind):
decoder = self._decoders[self._coder]
if self.channels_first:
decoded_words = decoder(np.ascontiguousarray(
self.words.reshape(
self.sample_shape.nchan, -1)[:, words_slice]))
current_data = (decoded_words.view(self.dtype)
.T.reshape(-1, *self.sample_shape))
else:
current_data = (decoder(self.words[words_slice])
.view(self.dtype)
.reshape(-1, self.sample_shape.nchan,
self.sample_shape.npol)
.transpose(0, 2, 1))

current_data = self._decode(self.words, words_slice)
current_data[data_slice] = data
data = current_data

# Reshape before separating real and complex components.
if self.channels_first:
data = data.reshape(-1, self.sample_shape.nchan).T
else:
data = data.transpose(0, 2, 1)

# Separate real and complex components.
if data.dtype.kind == 'c':
data = data.view((data.real.dtype, (2,)))

# Select encoder.
encoder = self._encoders[self._coder]

# Reshape and encode words.
if self.channels_first:
encoded_words = (self._encode(data.transpose(2, 0, 1))
.reshape(self.sample_shape.nchan, -1)
.view(self._dtype_word))
self.words.reshape(self.sample_shape.nchan, -1)[:, words_slice] = (
encoder(data).reshape(self.sample_shape.nchan, -1)
.view(self._dtype_word))
encoded_words)
else:
self.words[words_slice] = (encoder(data.ravel())
.view(self._dtype_word))
encoded_words = (self._encode(data.transpose(0, 2, 1))
.ravel().view(self._dtype_word))
self.words[words_slice] = encoded_words

data = property(__getitem__, doc="Full decoded payload.")
Loading