"""Base class for preprocessors."""
from abc import ABC, abstractmethod
from itertools import repeat
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
from biosppy.signals.tools import filter_signal
from numpy.typing import NDArray
from scipy.ndimage import median_filter
from ..cfg import DEFAULTS
from ..utils.misc import ReprMixin, add_docstring
from ..utils.utils_signal import butter_bandpass_filter
# from scipy.signal import medfilt
# https://github.com/scipy/scipy/issues/9680
__all__ = [
"PreProcessor",
"preprocess_multi_lead_signal",
"preprocess_single_lead_signal",
]
[docs]
class PreProcessor(ReprMixin, ABC):
"""Base class for preprocessors."""
__name__ = "PreProcessor"
[docs]
@abstractmethod
def apply(self, sig: NDArray, fs: Union[int, float]) -> Tuple[NDArray, Union[int, float]]:
"""Apply the preprocessor to `sig`.
Parameters
----------
sig : numpy.ndarray
The ECG signal, can be
- 1d array, which is a single-lead ECG;
- 2d array, which is a multi-lead ECG of "lead_first" format;
- 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
fs : int or float
Sampling frequency of the ECG signal.
"""
raise NotImplementedError
@add_docstring(apply.__doc__) # type: ignore
def __call__(self, sig: NDArray, fs: Union[int, float]) -> Tuple[NDArray, Union[int, float]]:
"""alias of :meth:`self.apply`."""
return self.apply(sig, fs)
def _check_sig(self, sig: NDArray) -> None:
"""Check validity of the signal.
Parameters
----------
sig : numpy.ndarray
The ECG signal, can be
- 1d array, which is a single-lead ECG;
- 2d array, which is a multi-lead ECG of "lead_first" format;
- 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
"""
if sig.ndim not in [1, 2, 3]:
raise ValueError(
"Invalid input ECG signal. Should be"
"1d array, which is a single-lead ECG;"
"or 2d array, which is a multi-lead ECG of `lead_first` format;"
"or 3d array, which is a tensor of several ECGs, of shape (batch, lead, siglen)."
)
[docs]
def preprocess_multi_lead_signal(
raw_sig: NDArray,
fs: Union[int, float],
sig_fmt: Literal["channel_first", "lead_first", "channel_last", "lead_last"] = "channel_first",
bl_win: Optional[List[Union[int, float]]] = None,
band_fs: Optional[List[Union[int, float]]] = None,
filter_type: Literal["butter", "fir"] = "butter",
filter_order: Optional[int] = None,
) -> NDArray:
"""Perform preprocessing for multi-lead ECG signal (with units in mV).
preprocessing may include median filter, bandpass filter, and rpeaks detection, etc.
Also works for single-lead ECG signal (setting ``sig_fmt="channel_first"``).
Parameters
----------
raw_sig : numpy.ndarray
The raw ECG signal, with units in mV.
fs : int or float
Sampling frequency of `raw_sig`.
sig_fmt : str, default "channel_first"
Format of the multi-lead ECG signal,
"channel_last" (alias "lead_last"), or
"channel_first" (alias "lead_first").
bl_win : List[Union[int, float]], optional
Window (units in second) of baseline removal
using :meth:`~scipy.ndimage.median_filter`,
the first is the shorter one, the second the longer one,
a typical pair is ``[0.2, 0.6]``.
If is None or empty, baseline removal will not be performed.
band_fs : List[Union[int, float]], optional
Frequency band of the bandpass filter,
a typical pair is ``[0.5, 45]``.
Be careful when detecting paced rhythm.
If is None or empty, bandpass filtering will not be performed.
filter_type : {"butter", "fir"}, default "butter"
Type of the bandpass filter.
filter_order : int, optional
Order of the bandpass filter.
Returns
-------
filtered_ecg : numpy.ndarray
The array of the processed ECG signal.
The format of the signal is kept the same with the original signal,
i.e. `sig_fmt`.
"""
raw_sig = np.asarray(raw_sig)
assert raw_sig.ndim in [2, 3], "multi-lead signal should be 2d or 3d array"
assert sig_fmt.lower() in [
"channel_first",
"lead_first",
"channel_last",
"lead_last",
], f"multi-lead signal format `{sig_fmt}` not supported"
if sig_fmt.lower() in ["channel_last", "lead_last"]:
# might have a batch dimension at the first axis
filtered_ecg = np.moveaxis(raw_sig, -2, -1).astype(DEFAULTS.np_dtype) # type: ignore
else:
filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) # type: ignore
# remove baseline
if bl_win:
window1, window2 = list(repeat(1, filtered_ecg.ndim)), list(repeat(1, filtered_ecg.ndim))
window1[-1] = 2 * (int(bl_win[0] * fs) // 2) + 1 # window size must be odd
window2[-1] = 2 * (int(bl_win[1] * fs) // 2) + 1
baseline = median_filter(filtered_ecg, size=window1, mode="nearest")
baseline = median_filter(baseline, size=window2, mode="nearest")
filtered_ecg = filtered_ecg - baseline
# filter signal
if band_fs:
assert band_fs[0] < band_fs[1], "Invalid frequency band"
nyq = 0.5 * fs
if band_fs[0] <= 0 and band_fs[1] < nyq:
band = "lowpass"
frequency = band_fs[1]
elif band_fs[1] >= nyq and band_fs[0] > 0:
band = "highpass"
frequency = band_fs[0]
elif band_fs[0] > 0 and band_fs[1] < nyq:
band = "bandpass"
frequency = band_fs
else:
raise AssertionError("Invalid frequency band")
if filter_type.lower() == "fir":
filtered_ecg = filter_signal(
signal=filtered_ecg,
ftype="FIR",
# ftype="butter",
band=band,
order=filter_order or int(0.2 * fs),
sampling_rate=fs,
frequency=frequency,
)["signal"]
elif filter_type.lower() == "butter":
filtered_ecg = butter_bandpass_filter(
data=filtered_ecg,
lowcut=band_fs[0],
highcut=band_fs[1],
fs=fs,
order=filter_order or round(0.01 * fs), # better be determined by the `buttord`
)
else:
raise ValueError(f"Unsupported filter type `{filter_type}`")
if sig_fmt.lower() in ["channel_last", "lead_last"]:
filtered_ecg = filtered_ecg.T
return filtered_ecg
[docs]
def preprocess_single_lead_signal(
raw_sig: NDArray,
fs: Union[int, float],
bl_win: Optional[List[Union[int, float]]] = None,
band_fs: Optional[List[Union[int, float]]] = None,
filter_type: Literal["butter", "fir"] = "butter",
filter_order: Optional[int] = None,
) -> NDArray:
"""Perform preprocessing for single lead ECG signal (with units in mV).
Preprocessing may include median filter, bandpass filter, and rpeaks detection, etc.
Parameters
----------
raw_sig : numpy.ndarray
Raw ECG signal, with units in mV.
fs : int or float
Sampling frequency of `raw_sig`.
bl_win : list (of 2 int or float), optional
Window (units in second) of baseline removal
using :meth:`~scipy.ndimage.median_filter`,
the first is the shorter one, the second the longer one,
a typical pair is ``[0.2, 0.6]``.
If is None or empty, baseline removal will not be performed.
band_fs : list of int or float, optional
Frequency band of the bandpass filter,
a typical pair is ``[0.5, 45]``.
Be careful when detecting paced rhythm.
If is None or empty, bandpass filtering will not be performed.
filter_type : {"butter", "fir"}, default "butter"
Type of the bandpass filter.
filter_order : int, optional
Order of the bandpass filter.
Returns
-------
filtered_ecg : numpy.ndarray
The array of the processed ECG signal.
NOTE
----
Bandpass filter uses FIR filters, an alternative can be Butterworth filter,
e.g. :meth:`~torch_ecg.utils.butter_bandpass_filter`.
"""
filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) # type: ignore
assert filtered_ecg.ndim == 1, "single-lead signal should be 1d array"
# remove baseline
if bl_win:
window1 = 2 * (int(bl_win[0] * fs) // 2) + 1 # window size must be odd
window2 = 2 * (int(bl_win[1] * fs) // 2) + 1
baseline = median_filter(filtered_ecg, size=window1, mode="nearest")
baseline = median_filter(baseline, size=window2, mode="nearest")
filtered_ecg = filtered_ecg - baseline
# filter signal
if band_fs:
assert band_fs[0] < band_fs[1], "Invalid frequency band"
nyq = 0.5 * fs
if band_fs[0] <= 0 and band_fs[1] < nyq:
band = "lowpass"
frequency = band_fs[1]
elif band_fs[1] >= nyq and band_fs[0] > 0:
band = "highpass"
frequency = band_fs[0]
elif band_fs[0] > 0 and band_fs[1] < nyq:
band = "bandpass"
frequency = band_fs
else:
raise AssertionError("Invalid frequency band")
if filter_type.lower() == "fir":
filtered_ecg = filter_signal(
signal=filtered_ecg,
ftype="FIR",
# ftype="butter",
band=band,
order=int(0.3 * fs),
sampling_rate=fs,
frequency=frequency,
)["signal"]
elif filter_type.lower() == "butter":
filtered_ecg = butter_bandpass_filter(
data=filtered_ecg,
lowcut=band_fs[0],
highcut=band_fs[1],
fs=fs,
order=filter_order or round(0.01 * fs), # better be determined by the `buttord`
)
else:
raise ValueError(f"Unsupported filter type `{filter_type}`")
return filtered_ecg