Skip to content

Registry

create_register_decorator(registry)

Create a decorator that registers object of specified type (model, metric, ...)

Parameters:

Name Type Description Default
registry Dict[str, Callable]

Dict including registered objects (maps name to object that you register)

required

Returns:

Type Description
Callable

Register function

Source code in src/super_gradients/common/registry/registry.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
    """
    Create a decorator that registers object of specified type (model, metric, ...)

    :param registry:    Dict including registered objects (maps name to object that you register)
    :return:            Register function
    """

    def register(name: Optional[str] = None, deprecated_name: Optional[str] = None) -> Callable:
        """
        Set up a register decorator.

        :param name:            If specified, the decorated object will be registered with this name. Otherwise, the class name will be used to register.
        :param deprecated_name: If specified, the decorated object will be registered with this name.
                                This is done on top of the `official` registration which is done by setting the `name` argument.
        :return:                Decorator that registers the callable.
        """

        def decorator(cls: Callable) -> Callable:
            """Register the decorated callable"""

            def _registered_cls(registration_name: str):
                if registration_name in registry:
                    registered_cls = registry[registration_name]
                    if registered_cls != cls:
                        raise Exception(
                            f"`{registration_name}` is already registered and points to `{inspect.getmodule(registered_cls).__name__}.{registered_cls.__name__}"
                        )
                registry[registration_name] = cls

            registration_name = name or cls.__name__
            _registered_cls(registration_name=registration_name)

            if deprecated_name:
                # Deprecated objects like other objects - This is meant to avoid any breaking change.
                _registered_cls(registration_name=deprecated_name)

                # But deprecated objects are also listed in the _deprecated_objects key.
                # This can later be used in the factories to know if a name is deprecated and how it should be named instead.
                deprecated_registered_objects = registry.get(_DEPRECATED_KEY, {})
                deprecated_registered_objects[deprecated_name] = registration_name  # Keep the information about how it should be named.
                registry[_DEPRECATED_KEY] = deprecated_registered_objects

            return cls

        return decorator

    return register

warn_if_deprecated(name, registry)

If the name is deprecated, warn the user about it.

Parameters:

Name Type Description Default
name str

The name of the object that we want to check if it is deprecated.

required
registry dict

The registry that may or may not include deprecated objects.

required
Source code in src/super_gradients/common/registry/registry.py
64
65
66
67
68
69
70
71
72
def warn_if_deprecated(name: str, registry: dict):
    """If the name is deprecated, warn the user about it.
    :param name:        The name of the object that we want to check if it is deprecated.
    :param registry:    The registry that may or may not include deprecated objects.
    """
    deprecated_names = registry.get(_DEPRECATED_KEY, {})
    if name in deprecated_names:
        warnings.simplefilter("once", DeprecationWarning)  # Required, otherwise the warning may never be displayed.
        warnings.warn(f"Object name `{name}` is now deprecated. Please replace it with `{deprecated_names[name]}`.", DeprecationWarning)