Source code for torch_ecg._preprocessors.resample

"""Resample the signal into fixed sampling frequency or length."""

from typing import Any, List, Optional, Tuple, Union

import scipy.signal as SS
from numpy.typing import NDArray

from ..cfg import DEFAULTS
from .base import PreProcessor

__all__ = [
    "Resample",
]


[docs] class Resample(PreProcessor): """Resample the signal into fixed sampling frequency or length. Parameters ---------- fs : int, optional Sampling frequency of the resampled ECG. siglen : int, optional Number of samples in the resampled ECG. NOTE ---- One and only one of `fs` and `siglen` should be set. Examples -------- .. code-block:: python from torch_ecg.cfg import DEFAULTS sig = DEFAULTS.RNG.randn(1000) pp = Resample(fs=500) sig, _ = pp(sig, 250) """ __name__ = "Resample" def __init__(self, fs: Optional[int] = None, siglen: Optional[int] = None, **kwargs: Any) -> None: self.fs = fs self.siglen = siglen assert sum([bool(self.fs), bool(self.siglen)]) == 1, "one and only one of `fs` and `siglen` should be set"
[docs] def apply(self, sig: NDArray, fs: Union[float, int]) -> Tuple[NDArray, Union[float, int]]: """Apply the preprocessor to `sig`. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - 1d array, which is a single-lead ECG; - 2d array, which is a multi-lead ECG of "lead_first" format; - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. fs : float or int Sampling frequency of the ECG signal. Returns ------- rsmp_sig : :class:`numpy.ndarray` The resampled ECG signal. new_fs : :class:`int`, Sampling frequency of the resampled ECG signal. """ self._check_sig(sig) if self.fs is not None: rsmp_sig = SS.resample_poly(sig.astype(DEFAULTS.np_dtype), up=self.fs, down=fs, axis=-1) new_fs = self.fs else: # self.siglen is not None rsmp_sig = SS.resample(sig.astype(DEFAULTS.np_dtype), num=self.siglen, axis=-1) new_fs = int(round(self.siglen / sig.shape[-1] * fs)) return rsmp_sig, new_fs # type: ignore
[docs] def extra_repr_keys(self) -> List[str]: return [ "fs", "siglen", ] + super().extra_repr_keys()