Source code for torch_ecg.preprocessors.resample

""" """

from typing import Any, Optional, Union

import numpy as np
import torch

from ..utils.utils_signal_t import resample as resample_t
from .registry import PREPROCESSORS

__all__ = [
    "Resample",
]


[docs] @PREPROCESSORS.register(name="resample") @PREPROCESSORS.register() class Resample(torch.nn.Module): """Resample the signal into fixed sampling frequency or length. Parameters ---------- fs : int, optional Sampling frequency of the source signal to be resampled. dst_fs : int, optional Sampling frequency of the resampled ECG. siglen : int, optional Number of samples in the resampled ECG. inplace : bool, default False Whether to perform the resampling in-place. NOTE ---- One and only one of `fs` and `siglen` should be set. If `fs` is set, `src_fs` should also be set. TODO ---- Consider vectorized :func:`scipy.signal.resample`? """ __name__ = "Resample" def __init__( self, fs: Optional[int] = None, dst_fs: Optional[int] = None, siglen: Optional[int] = None, inplace: bool = False, **kwargs: Any, ) -> None: super().__init__() self.dst_fs = dst_fs self.fs = fs self.siglen = siglen self.inplace = inplace assert sum([bool(self.fs), bool(self.siglen)]) == 1, "one and only one of `fs` and `siglen` should be set" if self.dst_fs is not None: assert self.fs is not None, "if `dst_fs` is set, `fs` should also be set" self.scale_factor = self.dst_fs / self.fs
[docs] def forward(self, sig: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """Apply the resampling to the signal. Parameters ---------- sig : numpy.ndarray or torch.Tensor The signal to be resampled, of shape ``(..., n_leads, siglen)``. Returns ------- numpy.ndarray or torch.Tensor The resampled signal. * If `sig` is a :class:`torch.Tensor`, the output is a tensor. * If `sig` is a :class:`numpy.ndarray`, the output is a NumPy array with floating dtype. When the input dtype is floating, the same floating dtype is preserved; otherwise, ``float32`` is used. """ if isinstance(sig, torch.Tensor): return self._forward_torch(sig) else: return self._forward_numpy(sig)
def _forward_torch(self, sig: torch.Tensor) -> torch.Tensor: return resample_t( sig=sig, fs=self.fs, dst_fs=self.dst_fs, siglen=self.siglen, inplace=self.inplace, ) def _forward_numpy(self, sig: np.ndarray) -> np.ndarray: _sig = torch.as_tensor(sig, dtype=torch.float32) _sig = self._forward_torch(_sig) _out = _sig.cpu().numpy() if np.issubdtype(sig.dtype, np.floating): return _out.astype(sig.dtype) return _out