""" """
from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from ..cfg import DEFAULTS
from .base import Augmenter
from .registry import AUGMENTERS
__all__ = [
"RandomFlip",
]
[docs]
@AUGMENTERS.register(name="random_flip")
@AUGMENTERS.register()
class RandomFlip(Augmenter):
"""Randomly flip the ECGs along the voltage axis.
Parameters
----------
fs : int, optional
Sampling frequency of the ECGs to be augmented
per_channel : bool, default True
Whether to flip each channel independently.
prob : float or Sequence[float], default ``[0.4, 0.2]``
Probability of performing flip,
the first probality is for the batch dimension,
the second probability is for the lead dimension.
inplace : bool, default True
If True, ECG signal tensors will be modified inplace.
kwargs : dict, optional
Additional keyword arguments.
Examples
--------
.. code-block:: python
rf = RandomFlip()
sig = torch.randn(32, 12, 5000)
sig, _ = rf(sig, None)
"""
__name__ = "RandomFlip"
def __init__(
self,
fs: Optional[int] = None,
per_channel: bool = True,
prob: Union[Sequence[float], float] = [0.4, 0.2],
inplace: bool = True,
**kwargs: Any,
) -> None:
super().__init__()
self.fs = fs
self.per_channel = per_channel
self.inplace = inplace
self.prob = prob
if isinstance(self.prob, (float, int)):
self.prob = np.array([self.prob, self.prob], dtype=DEFAULTS.np_dtype)
else:
self.prob = np.array(self.prob, dtype=DEFAULTS.np_dtype)
assert (self.prob >= 0).all() and (self.prob <= 1).all(), "Probability must be between 0 and 1"
[docs]
def forward(
self, sig: Tensor, label: Optional[Tensor], *extra_tensors: Sequence[Tensor], **kwargs: Any
) -> Tuple[Tensor, ...]:
"""Forward function of the RandomFlip augmenter.
Parameters
----------
sig : torch.Tensor
The ECGs to be augmented, of shape ``(batch, lead, siglen)``.
label : torch.Tensor, optional
Label tensor of the ECGs.
Not used, but kept for consistency with other augmenters.
extra_tensors : Sequence[torch.Tensor], optional
Not used, but kept for consistency with other augmenters.
kwargs : dict, optional
Additional keyword arguments.
Not used, but kept for consistency with other augmenters.
Returns
-------
sig : torch.Tensor
The augmented ECGs.
label : torch.Tensor
The label tensor of the augmented ECGs, unchanged.
extra_tensors : Sequence[torch.Tensor], optional
Unchanged extra tensors.
"""
batch, lead, siglen = sig.shape
if not self.inplace:
sig = sig.clone()
if self.prob[0] == 0:
return (sig, label, *extra_tensors)
if self.per_channel:
flip = torch.ones((batch, lead, 1), dtype=sig.dtype, device=sig.device)
for i in self.get_indices(prob=self.prob[0], pop_size=batch):
flip[i, self.get_indices(prob=self.prob[1], pop_size=lead), ...] = -1
sig = sig.mul_(flip)
else:
flip = torch.ones((batch, 1, 1), dtype=sig.dtype, device=sig.device)
flip[self.get_indices(prob=self.prob[0], pop_size=batch), ...] = -1
sig = sig.mul_(flip)
return (sig, label, *extra_tensors)
[docs]
def extra_repr_keys(self) -> List[str]:
return [
"per_channel",
"prob",
"inplace",
] + super().extra_repr_keys()