Source code for torch_ecg.augmenters.stretch_compress

""" """

from random import choice, randint
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np
import scipy.signal as SS
import torch
import torch.nn.functional as F
from numpy.typing import NDArray
from torch import Tensor

from ..cfg import DEFAULTS
from ..utils.misc import ReprMixin, add_docstring
from .base import Augmenter
from .registry import AUGMENTERS

__all__ = [
    "StretchCompress",
    "StretchCompressOffline",
]


[docs] @AUGMENTERS.register(name="stretch_compress") @AUGMENTERS.register() class StretchCompress(Augmenter): """Stretch-or-compress augmenter on ECG tensors. Rescaling the ECGs by a factor sampled from a normal distribution along the time axis. Parameters ---------- ratio : int or float, default 6 Mean ratio of the stretch or compress. If it is in the interval[1, 100], then it will be transformed to [0, 1]. The ratio of one batch element is sampled from a normal distribution. prob : float, default 0.5 Probability of the augmenter to be applied. inplace : bool, default True If True, the input ECGs will be modified inplace. kwargs : dict, optional Additional keyword arguments. Example ------- .. code-block:: python sc = StretchCompress() sig = torch.randn((32, 12, 5000)) labels = torch.randint(0, 2, (32, 5000, 26)) label = torch.randint(0, 2, (32, 26), dtype=torch.float32) mask = torch.randint(0, 2, (32, 5000, 3), dtype=torch.float32) sig, label, mask = sc(sig, label, mask) """ __name__ = "StretchCompress" def __init__(self, ratio: Union[int, float] = 6, prob: float = 0.5, inplace: bool = True, **kwargs: Any) -> None: super().__init__() self.prob = prob assert 0 <= self.prob <= 1, "Probability must be between 0 and 1" self.inplace = inplace self.ratio = ratio if self.ratio > 1: self.ratio = self.ratio / 100 assert 0 <= self.ratio <= 1, "Ratio must be between 0 and 1, or between 0 and 100"
[docs] def forward(self, sig: Tensor, *labels: Optional[Sequence[Tensor]], **kwargs: Any) -> Tuple[Tensor, ...]: """Forward method of the augmenter. Parameters ---------- sig : torch.Tensor Batched ECGs to be stretched or compressed, of shape ``(batch, lead, siglen)``. labels : Sequence[torch.Tensor], optional Label tensors of the ECGs, If set, labels of ``ndim = 3``, of shape ``(batch, label_len, channels)`` will be stretched or compressed. `siglen` should be divisible by `label_len`. kwargs : dict, optional Not used, but kept for consistency with other augmenters. Returns ------- sig : torch.Tensor The stretched or compressed ECG tensors. labels : Sequence[torch.Tensor], optional The stretched or compressed label tensors. """ batch, lead, siglen = sig.shape if not self.inplace: sig = sig.clone() labels = [label.clone() for label in labels] if self.prob == 0: return (sig, *labels) label_len = [] n_labels = len(labels) for idx in range(n_labels): if labels[idx].ndim < 3: label_len.append(None) continue labels[idx] = labels[idx].permute(0, 2, 1) # (batch, label_len, n_classes) -> (batch, n_classes, label_len) ll = labels[idx].shape[-1] if ll != siglen: labels[idx] = F.interpolate(labels[idx], size=(siglen,), mode="linear", align_corners=True) label_len.append(ll) for batch_idx in self.get_indices(prob=self.prob, pop_size=batch): sign = choice([-1, 1]) ratio = self._sample_ratio() # print(f"batch_idx = {batch_idx}, sign = {sign}, ratio = {ratio}") new_len = int(round((1 + sign * ratio) * siglen)) diff_len = abs(new_len - siglen) half_diff_len = diff_len // 2 if sign > 0: # stretch and cut sig[batch_idx, ...] = F.interpolate( sig[batch_idx, ...].unsqueeze(0), size=new_len, mode="linear", align_corners=True, )[..., half_diff_len : siglen + half_diff_len].squeeze(0) for idx in range(n_labels): if labels[idx].ndim < 3: continue labels[idx][batch_idx, ...] = F.interpolate( labels[idx][batch_idx, ...].unsqueeze(0), size=new_len, mode="linear", align_corners=True, )[..., half_diff_len : siglen + half_diff_len].squeeze(0) else: # compress and pad sig[batch_idx, ...] = F.pad( F.interpolate( sig[batch_idx, ...].unsqueeze(0), size=new_len, mode="linear", align_corners=True, ), pad=(half_diff_len, diff_len - half_diff_len), mode="constant", value=0.0, ).squeeze(0) for idx in range(n_labels): if labels[idx].ndim < 3: continue labels[idx][batch_idx, ...] = F.pad( F.interpolate( labels[idx][batch_idx, ...].unsqueeze(0), size=new_len, mode="linear", align_corners=True, ), pad=(half_diff_len, diff_len - half_diff_len), mode="constant", value=0.0, ).squeeze(0) for idx, (label, ll) in enumerate(zip(labels, label_len)): if labels[idx].ndim < 3: continue if ll != siglen: labels[idx] = F.interpolate(label, size=(ll,), mode="linear", align_corners=True) labels[idx] = labels[idx].permute(0, 2, 1) # (batch, n_classes, label_len) -> (batch, label_len, n_classes) return (sig, *labels)
def _sample_ratio(self) -> float: """Sample the ratio of stretching or compressing.""" return np.clip(DEFAULTS.RNG.normal(self.ratio, 0.382 * self.ratio), 0, 2 * self.ratio) def _generate(self, sig: Tensor, *labels: Optional[Sequence[Tensor]]) -> Union[Tuple[Tensor, ...], Tensor]: """NOT finished, NOT checked, parallel version of `self.generate`, NOT tested yet! Parameters ---------- sig : torch.Tensor Batched ECGs to be stretched or compressed, of shape ``(batch, lead, siglen)``. labels : Sequence[torch.Tensor], optional Label tensors of the ECGs. If set, should be of ``ndim = 3``, and of shapes ``(batch, label_len, n_classes)``. `siglen` should be divisible by `label_len`. Returns ------- sig : torch.Tensor The stretched or compressed ECG tensors. labels : Sequence[torch.Tensor], optional The stretched or compressed label tensors. """ batch, lead, siglen = sig.shape if not self.inplace: sig = sig.clone() if self.prob == 0: if len(labels) == 0: return sig return (sig, *labels) indices = self.get_indices(prob=self.prob, pop_size=batch) for batch_idx in indices: data = _stretch_compress_one_batch_element( self.ratio, sig[batch_idx, ...].unsqueeze(0), *(label[batch_idx, ...].unsqueeze(0) for label in labels), ) if len(labels) == 0: sig[batch_idx, ...] = data else: sig[batch_idx, ...] = data[0] for idx, label in enumerate(data[1:]): labels[idx][batch_idx, ...] = label if len(labels) == 0: return sig return (sig, *labels)
[docs] def extra_repr_keys(self) -> List[str]: return [ "ratio", "prob", "inplace", ] + super().extra_repr_keys()
def _stretch_compress_one_batch_element( ratio: Union[int, float], sig: Tensor, *labels: Sequence[Tensor] ) -> Union[Tensor, Tuple[Tensor, ...]]: """Stretch or compress one batch element of the ECGs. Parameters ---------- ratio : int or float Ratio of the stretch/compress. sig : torch.Tensor The ECGs to be stretched or compressed, of shape ``(1, lead, siglen)``. labels : Sequence[torch.Tensor], optional Label tensors of the ECGs. If set, each should be of ``ndim = 3``, and of shape ``(1, label_len, channels)``. ``siglen`` should be divisible by ``label_len``. Returns ------- sig : torch.Tensor The stretched or compressed ECG tensor, of shape ``(lead, siglen)``. labels : Sequence[torch.Tensor], optional The stretched or compressed label tensors, of shapes ``(label_len, channels)``. """ labels = list(labels) label_len = [] n_labels = len(labels) siglen = sig.shape[-1] for idx in range(n_labels): if labels[idx].ndim < 3: label_len.append(0) continue labels[idx] = labels[idx].permute(0, 2, 1) # (1, label_len, n_classes) -> (1, n_classes, label_len) ll = labels[idx].shape[-1] if ll != siglen: labels[idx] = F.interpolate(labels[idx], size=(siglen,), mode="linear", align_corners=True) label_len.append(ll) sign = choice([-1, 1]) ratio = np.clip(DEFAULTS.RNG.normal(ratio, 0.382 * ratio), 0, 2 * ratio) # print(f"batch_idx = {batch_idx}, sign = {sign}, ratio = {ratio}") new_len = int(round((1 + sign * ratio) * siglen)) diff_len = abs(new_len - siglen) half_diff_len = diff_len // 2 if sign > 0: # stretch and cut sig = F.interpolate( sig, size=new_len, mode="linear", align_corners=True, )[ ..., half_diff_len : siglen + half_diff_len ].squeeze(0) for idx in range(n_labels): if label_len[idx] == 0: continue labels[idx] = F.interpolate( labels[idx], size=new_len, mode="linear", align_corners=True, )[..., half_diff_len : siglen + half_diff_len] else: # compress and pad sig = F.pad( F.interpolate( sig, size=new_len, mode="linear", align_corners=True, ), pad=(half_diff_len, diff_len - half_diff_len), mode="constant", value=0.0, ).squeeze(0) for idx in range(n_labels): if label_len[idx] == 0: continue labels[idx] = F.pad( F.interpolate( labels[idx], size=new_len, mode="linear", align_corners=True, ), pad=(half_diff_len, diff_len - half_diff_len), mode="constant", value=0.0, ) for idx, (label, ll) in enumerate(zip(labels, label_len)): if ll == 0: labels[idx] = labels[idx].squeeze(0) continue if ll != siglen: labels[idx] = F.interpolate(label, size=(ll,), mode="linear", align_corners=True) labels[idx] = labels[idx].squeeze(0).permute(1, 0) # (n_classes, label_len) -> (label_len, n_classes) if len(labels) > 0: return (sig, *labels) return sig
[docs] class StretchCompressOffline(ReprMixin): """Offline stretch-or-compress augmenter. Stretch-or-compress augmenter on orginal length-varying ECG signals (in the form of numpy arrays), for the purpose of offline data generation. Parameters ---------- ratio : int or float, default 6 Mean ratio of the stretch or compress. If it is in the interval [1, 100], then it will be transformed to [0, 1]. The ratio of one batch element is sampled from a normal distribution. prob : float, default 0.5 Probability of the augmenter to be applied. overlap : float, default 0.5 Overlap of offline generated data. critical_overlap : float, default 0.85 Overlap of the critical region of the ECG. Example ------- .. code-block:: python sco = StretchCompressOffline() seglen = 600 sig = torch.randn((12, 60000)).numpy() labels = torch.ones((60000, 3)).numpy().astype(int) masks = torch.ones((60000, 1)).numpy().astype(int) segments = sco(600, sig, labels, masks, critical_points=[10000,30000]) """ __name__ = "StretchCompressOffline" def __init__( self, ratio: Union[int, float] = 6, prob: float = 0.5, overlap: float = 0.5, critical_overlap: float = 0.85, ) -> None: self.prob = prob assert 0 <= self.prob <= 1, "Probability must be between 0 and 1" self.ratio = ratio if self.ratio > 1: self.ratio = self.ratio / 100 assert 0 <= self.ratio <= 1, "Ratio must be between 0 and 1, or between 0 and 100" self.overlap = overlap assert 0 <= self.overlap < 1, "Overlap ratio must be between 0 and 1 (1 not included)" self.critical_overlap = critical_overlap assert 0 <= self.critical_overlap < 1, "Critical overlap ratio must be between 0 and 1 (1 not included)"
[docs] def generate( self, seglen: int, sig: NDArray, *labels: Sequence[NDArray], critical_points: Optional[Sequence[int]] = None, ) -> List[Tuple[Union[NDArray, int], ...]]: """Generate stretched or compressed segments from the ECGs. Parameters ---------- seglen : int Length of the ECG segments to be generated. sig : numpy.ndarray, THe ECGs to generate stretched or compressed segments, of shape ``(lead, siglen)``. labels : numpy.ndarray, optional Labels of the ECGs, of shape ``(label_len, channels)``. For example, when doing segmentation, `label_len` should be divisible by `siglen`, `channels` should be the same as the number of classes. critical_points : Sequence[int], optional Indices of the critical points of the ECG, usually have larger overlap by :attr:`self.critical_overlap`. Returns ------- list list of generated segments, consisting segments of the form ``(seg, label1, label2, ..., start_idx, end_idx)``. """ siglen = sig.shape[1] forward_len = int(round(seglen - seglen * self.overlap)) critical_forward_len = int(round(seglen - seglen * self.critical_overlap)) critical_forward_len = [critical_forward_len // 4, critical_forward_len] # print(forward_len, critical_forward_len) # skip those records that are too short if siglen < seglen: return [] segments = [] # ordinary segments with constant forward_len for idx in range((siglen - seglen) // forward_len + 1): start_idx = idx * forward_len new_seg = self.__generate_segment( seglen, sig, *labels, start_idx=start_idx, ) segments.append(new_seg) # the tail segment if (siglen - seglen) % forward_len != 0: new_seg = self.__generate_segment( seglen, sig, *labels, end_idx=siglen, ) segments.append(new_seg) # special segments around critical_points with random forward_len in critical_forward_len for cp in critical_points or []: start_idx = max( 0, cp - seglen + randint(critical_forward_len[0], critical_forward_len[1]), ) while start_idx <= min(cp - critical_forward_len[1], siglen - seglen): new_seg = self.__generate_segment( seglen, sig, *labels, start_idx=start_idx, ) segments.append(new_seg) start_idx += randint(critical_forward_len[0], critical_forward_len[1]) return segments
def __generate_segment( self, seglen: int, sig: NDArray, *labels: Sequence[NDArray], start_idx: Optional[int] = None, end_idx: Optional[int] = None, ) -> Tuple[Union[NDArray, int], ...]: """Internal function to generate a stretched or compressed segment. Parameters ---------- seglen : int Length of the ECG segments to be generated. sig : numpy.ndarray ECGs to generate stretched or compressed segments, of shape ``(lead, siglen)``. labels : numpy.ndarray, optional lLbels of the ECGs, of shape ``(label_len, channels)``. For example, when doing segmentation, `label_len` should be divisible by `siglen`, `channels` should be the same as the number of classes. start_idx : int, optional Start index of the segment in `sig`. end_idx : int, optional End index of the segment in `sig`. If `start_idx` is set, then `end_idx` is ignored. At least one of `start_idx` and `end_idx` should be set. Returns ------- tuple Tuple of generated segment, consisting of segments of the form ``(seg, label1, label2, ..., start_idx, end_idx)``. """ assert not all([start_idx is None, end_idx is None]), "at least one of `start_idx` and `end_idx` should be set" siglen = sig.shape[1] ratio = self._sample_ratio() aug_labels = [] if ratio != 0: sign = choice([-1, 1]) new_len = int(round((1 + sign * ratio) * seglen)) if start_idx is not None: start_idx = min(siglen, max(0, start_idx)) end_idx = start_idx + new_len else: # end_idx is not None start_idx = max(0, end_idx - new_len) end_idx = start_idx + new_len if end_idx > siglen: end_idx = siglen start_idx = max(0, end_idx - new_len) ratio = (end_idx - start_idx) / seglen - 1 aug_seg = sig[..., start_idx:end_idx] aug_seg = SS.resample(x=aug_seg, num=seglen, axis=1) for lb in labels: dtype = lb.dtype aug_labels.append( F.interpolate( torch.from_numpy(lb[start_idx:end_idx, ...].T.astype(np.float32)).unsqueeze(0), size=seglen, mode="nearest", ) .squeeze(0) .numpy() .T.astype(dtype) ) else: if start_idx is not None: start_idx = min(siglen, max(0, start_idx)) end_idx = start_idx + seglen if end_idx > siglen: end_idx = siglen start_idx = end_idx - seglen else: # end_idx is not None end_idx = min(siglen, max(0, end_idx)) start_idx = end_idx - seglen if start_idx < 0: start_idx = 0 end_idx = seglen aug_seg = sig[..., start_idx:end_idx] for lb in labels: aug_labels.append(lb[start_idx:end_idx, ...]) return (aug_seg,) + tuple(aug_labels) + (start_idx, end_idx) def _sample_ratio(self) -> float: """Sample the ratio of stretching or compressing.""" if DEFAULTS.RNG.uniform() >= self.prob: return 0 else: return np.clip( DEFAULTS.RNG.normal(self.ratio, 0.382 * self.ratio), 0.01 * self.ratio, 2 * self.ratio, ) @add_docstring(generate.__doc__) def __call__( self, seglen: int, sig: NDArray, *labels: Sequence[NDArray], critical_points: Optional[Sequence[int]] = None, ) -> List[Tuple[NDArray, ...]]: return self.generate(seglen, sig, *labels, critical_points=critical_points)
[docs] def extra_repr_keys(self) -> List[str]: return super().extra_repr_keys() + [ "ratio", "prob", "overlap", "critical_overlap", ]