Source code for torch_ecg.utils.registry

"""
Generic registry implementation for managing and building modules (backbones, models, etc.)
"""

from typing import Any, Dict, List, Optional, Union

__all__ = [
    "Registry",
]


[docs] class Registry: """Registry for managing and building modules. A registry is used to map strings (module names) to classes, and provides a unified interface to instantiate modules from configurations. Parameters ---------- name : str Name of the registry. Examples -------- >>> BACKBONES = Registry("backbones") >>> @BACKBONES.register() ... class ResNet(nn.Module): ... def __init__(self, depth): ... self.depth = depth >>> # Build from string >>> model = BACKBONES.build("ResNet", depth=50) >>> # Build from config dict >>> model = BACKBONES.build({"name": "ResNet", "depth": 101}) """ def __init__(self, name: str) -> None: self._name = name self._module_dict: Dict[str, type] = {} def __len__(self) -> int: return len(self._module_dict) def __contains__(self, name: str) -> bool: return name in self._module_dict def __repr__(self) -> str: return f"Registry(name={self._name}, items={list(self._module_dict.keys())})" @property def name(self) -> str: return self._name
[docs] def register(self, name: Optional[str] = None, override: bool = False) -> Any: """Decorator to register a module. Parameters ---------- name : str, optional Name of the module. If not specified, the class name will be used. override : bool, default False Whether to override the existing module with the same name. """ def _register(cls: type) -> type: _name = name or cls.__name__ if not override and _name in self._module_dict: raise KeyError(f"{_name} is already registered in {self._name} registry") self._module_dict[_name] = cls return cls return _register
[docs] def get(self, name: str) -> Optional[type]: """Get the module class by name. Parameters ---------- name : str Name of the module. Returns ------- type or None The registered module class. """ return self._module_dict.get(name)
[docs] def list_all(self) -> List[str]: """List all registered modules. Returns ------- List[str] A list of all registered module names. """ return list(self._module_dict.keys())
[docs] def build(self, config: Union[str, Dict[str, Any]], **kwargs: Any) -> Any: """Build a module from a configuration. Parameters ---------- config : str or dict Configuration of the module. If it's a string, it should be the name of the registered module. If it's a dict, it must contain a "name" key. **kwargs : Any Additional keyword arguments passed to the module's constructor. Returns ------- Any The instantiated module. """ if isinstance(config, str): obj_type = config obj_config = {} elif isinstance(config, dict): obj_config = config.copy() if "name" not in obj_config: raise KeyError(f"Configuration for {self._name} must contain a 'name' key") obj_type = obj_config.pop("name") else: raise TypeError(f"Config must be a str or dict, but got {type(config)}") obj_cls = self.get(obj_type) if obj_cls is None: raise KeyError(f"{obj_type} is not registered in the {self._name} registry") # Merge config from dict and extra kwargs final_kwargs = {**obj_config, **kwargs} return obj_cls(**final_kwargs)