Source code for torch_ecg.preprocessors.bandpass
""" """
from typing import Any, Optional, Union
import numpy as np
import torch
from .._preprocessors.base import preprocess_multi_lead_signal
from ..utils.utils_signal_t import bandpass_filter
from .registry import PREPROCESSORS
__all__ = [
"BandPass",
]
[docs]
@PREPROCESSORS.register(name="bandpass")
@PREPROCESSORS.register()
class BandPass(torch.nn.Module):
"""Bandpass filtering preprocessor.
Parameters
----------
fs : int
Sampling frequency of the ECG signal to be filtered.
lowcut : float, optional
Low cutoff frequency.
highcut : float, optional
High cutoff frequency.
inplace : bool, default True
Whether to perform the filtering in-place.
kwargs : dict, optional
Other keyword arguments for :class:`torch.nn.Module`.
"""
__name__ = "BandPass"
def __init__(
self,
fs: int,
lowcut: Optional[float] = 0.5,
highcut: Optional[float] = 45,
inplace: bool = True,
**kwargs: Any,
) -> None:
super().__init__()
self.fs = fs
self.lowcut = lowcut
self.highcut = highcut
assert any([self.lowcut is not None, self.highcut is not None]), "At least one of lowcut and highcut should be set"
self.inplace = inplace
[docs]
def forward(self, sig: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""Apply the preprocessor to the signal.
Parameters
----------
sig : numpy.ndarray or torch.Tensor
The ECG signal,
of shape ``(batch, lead, siglen)`` or ``(lead, siglen)``.
Returns
-------
numpy.ndarray or torch.Tensor
The bandpass filtered ECG signal,
of same shape and type as `sig`.
"""
if isinstance(sig, torch.Tensor):
return self._forward_torch(sig)
else:
return self._forward_numpy(sig)
def _forward_torch(self, sig: torch.Tensor) -> torch.Tensor:
if not self.inplace:
sig = sig.clone()
return bandpass_filter(
sig=sig,
fs=self.fs,
lowcut=self.lowcut,
highcut=self.highcut,
)
def _forward_numpy(self, sig: np.ndarray) -> np.ndarray:
# original implementation for numpy arrays
return preprocess_multi_lead_signal(
raw_sig=sig,
fs=self.fs,
band_fs=[self.lowcut, self.highcut],
)