Source code for torch_ecg.preprocessors.baseline_remove

""" """

import warnings
from typing import Any, Union

import numpy as np
import torch

from .._preprocessors.base import preprocess_multi_lead_signal
from ..utils.utils_signal_t import baseline_removal
from .registry import PREPROCESSORS

__all__ = [
    "BaselineRemove",
]


[docs] @PREPROCESSORS.register(name="baseline_remove") @PREPROCESSORS.register() class BaselineRemove(torch.nn.Module): """Baseline removal using sliding average (median filter alternative). Parameters ---------- fs : int Sampling frequency of the ECG signal to be filtered. window1 : float, default 0.2 The smaller window size, with units in seconds. window2 : float, default 0.6 The larger window size, with units in seconds. inplace : bool, default True Whether to perform the filtering in-place. kwargs : dict, optional Other keyword arguments for :class:`torch.nn.Module`. """ __name__ = "BaselineRemove" def __init__(self, fs: int, window1: float = 0.2, window2: float = 0.6, inplace: bool = True, **kwargs: Any) -> None: super().__init__() self.fs = fs self.window1 = window1 self.window2 = window2 if self.window2 < self.window1: self.window1, self.window2 = self.window2, self.window1 warnings.warn("values of `window1` and `window2` are switched", RuntimeWarning) self.inplace = inplace
[docs] def forward(self, sig: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """Apply the preprocessor to the signal. Parameters ---------- sig : numpy.ndarray or torch.Tensor The ECG signal, of shape ``(batch, lead, siglen)`` or ``(lead, siglen)``. Returns ------- numpy.ndarray or torch.Tensor The baseline removed ECG signals, of same shape and type as `sig`. """ if isinstance(sig, torch.Tensor): return self._forward_torch(sig) else: return self._forward_numpy(sig)
def _forward_torch(self, sig: torch.Tensor) -> torch.Tensor: if not self.inplace: sig = sig.clone() return baseline_removal( sig=sig, fs=self.fs, window1=self.window1, window2=self.window2, ) def _forward_numpy(self, sig: np.ndarray) -> np.ndarray: # original implementation for numpy arrays return preprocess_multi_lead_signal( raw_sig=sig, fs=self.fs, bl_win=[self.window1, self.window2], )