Skip to content

Dataloaders

BaseDataloaderAdapterFactory

Bases: ABC

Factory class, responsible for adapting datasets/dataloaders to seamlessly work with SG format.

Source code in src/super_gradients/training/dataloaders/adapters.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class BaseDataloaderAdapterFactory(ABC):
    """Factory class, responsible for adapting datasets/dataloaders to seamlessly work with SG format."""

    @classmethod
    def from_dataset(
        cls,
        dataset: torch.utils.data.Dataset,
        config: Optional[DataConfig] = None,
        config_path: Optional[str] = None,
        collate_fn: Optional[callable] = None,
        **dataloader_kwargs,
    ) -> torch.utils.data.DataLoader:
        """Wrap a DataLoader to adapt its output to fit SuperGradients format for the specific task.

        :param dataset:         Dataset to adapt.
        :param config:          Adapter configuration. Use this if you want to explicitly set some specific params of your dataset.
                                Mutually exclusive with `config_path`.
        :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                                Mutually exclusive with `config`.
        :param collate_fn:      Collate function to use. Use this if you .If None, the pytorch default collate function will be used.

        :return:                Adapted DataLoader.
        """

        dataloader = torch.utils.data.DataLoader(dataset=dataset, **dataloader_kwargs)

        # `AdapterCollateFNClass` depends on the tasks, but just represents the collate function adapter for that specific task.
        AdapterCollateFNClass = cls._get_collate_fn_class()
        adapter_collate = AdapterCollateFNClass(base_collate_fn=collate_fn, config=config, config_path=config_path)

        _maybe_setup_adapter(adapter=adapter_collate.adapter, data=dataset)
        dataloader.collate_fn = adapter_collate
        return dataloader

    @classmethod
    def from_dataloader(
        cls,
        dataloader: torch.utils.data.DataLoader,
        config: Optional[DataConfig] = None,
        config_path: Optional[str] = None,
    ) -> torch.utils.data.DataLoader:
        """Wrap a DataLoader to adapt its output to fit SuperGradients format for the specific task.

        :param dataloader:      DataLoader to adapt.
        :param config:          Adapter configuration. Use this if you want to explicitly set some specific params of your dataset.
                                Mutually exclusive with `config_path`.
        :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                                Mutually exclusive with `config`.

        :return:                Adapted DataLoader.
        """

        # `AdapterCollateFNClass` depends on the tasks, but just represents the collate function adapter for that specific task.
        AdapterCollateFNClass = cls._get_collate_fn_class()
        adapter_collate = AdapterCollateFNClass(base_collate_fn=dataloader.collate_fn, config=config, config_path=config_path)

        _maybe_setup_adapter(adapter=adapter_collate.adapter, data=dataloader)
        dataloader.collate_fn = adapter_collate
        return dataloader

    @classmethod
    @abstractmethod
    def _get_collate_fn_class(cls) -> type:
        """
        Returns the specific Collate Function class for this type of task.

        :return: Collate Function class specific to the task.
        """
        pass

from_dataloader(dataloader, config=None, config_path=None) classmethod

Wrap a DataLoader to adapt its output to fit SuperGradients format for the specific task.

Parameters:

Name Type Description Default
dataloader torch.utils.data.DataLoader

DataLoader to adapt.

required
config Optional[DataConfig]

Adapter configuration. Use this if you want to explicitly set some specific params of your dataset. Mutually exclusive with config_path.

None
config_path Optional[str]

Adapter cache path. Use this if you want to load and/or save the adapter config from a local path. Mutually exclusive with config.

None

Returns:

Type Description
torch.utils.data.DataLoader

Adapted DataLoader.

Source code in src/super_gradients/training/dataloaders/adapters.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@classmethod
def from_dataloader(
    cls,
    dataloader: torch.utils.data.DataLoader,
    config: Optional[DataConfig] = None,
    config_path: Optional[str] = None,
) -> torch.utils.data.DataLoader:
    """Wrap a DataLoader to adapt its output to fit SuperGradients format for the specific task.

    :param dataloader:      DataLoader to adapt.
    :param config:          Adapter configuration. Use this if you want to explicitly set some specific params of your dataset.
                            Mutually exclusive with `config_path`.
    :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                            Mutually exclusive with `config`.

    :return:                Adapted DataLoader.
    """

    # `AdapterCollateFNClass` depends on the tasks, but just represents the collate function adapter for that specific task.
    AdapterCollateFNClass = cls._get_collate_fn_class()
    adapter_collate = AdapterCollateFNClass(base_collate_fn=dataloader.collate_fn, config=config, config_path=config_path)

    _maybe_setup_adapter(adapter=adapter_collate.adapter, data=dataloader)
    dataloader.collate_fn = adapter_collate
    return dataloader

from_dataset(dataset, config=None, config_path=None, collate_fn=None, **dataloader_kwargs) classmethod

Wrap a DataLoader to adapt its output to fit SuperGradients format for the specific task.

Parameters:

Name Type Description Default
dataset torch.utils.data.Dataset

Dataset to adapt.

required
config Optional[DataConfig]

Adapter configuration. Use this if you want to explicitly set some specific params of your dataset. Mutually exclusive with config_path.

None
config_path Optional[str]

Adapter cache path. Use this if you want to load and/or save the adapter config from a local path. Mutually exclusive with config.

None
collate_fn Optional[callable]

Collate function to use. Use this if you .If None, the pytorch default collate function will be used.

None

Returns:

Type Description
torch.utils.data.DataLoader

Adapted DataLoader.

Source code in src/super_gradients/training/dataloaders/adapters.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@classmethod
def from_dataset(
    cls,
    dataset: torch.utils.data.Dataset,
    config: Optional[DataConfig] = None,
    config_path: Optional[str] = None,
    collate_fn: Optional[callable] = None,
    **dataloader_kwargs,
) -> torch.utils.data.DataLoader:
    """Wrap a DataLoader to adapt its output to fit SuperGradients format for the specific task.

    :param dataset:         Dataset to adapt.
    :param config:          Adapter configuration. Use this if you want to explicitly set some specific params of your dataset.
                            Mutually exclusive with `config_path`.
    :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                            Mutually exclusive with `config`.
    :param collate_fn:      Collate function to use. Use this if you .If None, the pytorch default collate function will be used.

    :return:                Adapted DataLoader.
    """

    dataloader = torch.utils.data.DataLoader(dataset=dataset, **dataloader_kwargs)

    # `AdapterCollateFNClass` depends on the tasks, but just represents the collate function adapter for that specific task.
    AdapterCollateFNClass = cls._get_collate_fn_class()
    adapter_collate = AdapterCollateFNClass(base_collate_fn=collate_fn, config=config, config_path=config_path)

    _maybe_setup_adapter(adapter=adapter_collate.adapter, data=dataset)
    dataloader.collate_fn = adapter_collate
    return dataloader

DetectionDataloaderAdapterFactory

Bases: BaseDataloaderAdapterFactory

Factory class, responsible for adapting datasets/dataloaders to seamlessly work with SG YOLOX, YOLONAS and PPYOLOE

Source code in src/super_gradients/training/dataloaders/adapters.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class DetectionDataloaderAdapterFactory(BaseDataloaderAdapterFactory):
    """Factory class, responsible for adapting datasets/dataloaders to seamlessly work with SG YOLOX, YOLONAS and PPYOLOE"""

    @classmethod
    def from_dataset(
        cls,
        dataset: torch.utils.data.Dataset,
        config: Optional[DataConfig] = None,
        config_path: Optional[str] = None,
        **dataloader_kwargs,
    ) -> torch.utils.data.DataLoader:
        return super().from_dataset(
            dataset=dataset,
            config=config,
            config_path=config_path,
            collate_fn=DetectionCollateFN(),  #
            **dataloader_kwargs,
        )

    @classmethod
    def _get_collate_fn_class(cls) -> type:
        return DetectionDatasetAdapterCollateFN

maybe_setup_dataloader_adapter(dataloader)

If the dataloader collate function is an adapter, and requires to be set up, do it. Otherwise skip.

Source code in src/super_gradients/training/dataloaders/adapters.py
126
127
128
129
130
131
132
133
134
135
136
137
138
def maybe_setup_dataloader_adapter(dataloader: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader:
    """If the dataloader collate function is an adapter, and requires to be set up, do it. Otherwise skip."""
    collate_fn = dataloader.collate_fn
    if isinstance(collate_fn, BaseDatasetAdapterCollateFN):
        if collate_fn.adapter.data_config.is_batch:
            # Enforce a first execution with 0 worker. This is required because python `input` is no compatible multiprocessing (i.e. num_workers > 0)
            # Therefore we want to make sure to ask the questions on 0 workers.
            dataloader.num_workers, _num_workers = 0, dataloader.num_workers
            _maybe_setup_adapter(adapter=collate_fn.adapter, data=dataloader)
            dataloader.num_workers = _num_workers
        else:
            _maybe_setup_adapter(adapter=collate_fn.adapter, data=dataloader.dataset)
    return dataloader

get(name=None, dataset_params=None, dataloader_params=None, dataset=None)

Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.

Parameters:

Name Type Description Default
name str

dataset name in ALL_DATALOADERS.

None
dataset_params Dict

dataset params that override the yaml configured defaults, then passed to the dataset_cls.init.

None
dataloader_params Dict

DataLoader params that override the yaml configured defaults, then passed to the DataLoader.init

None
dataset torch.utils.data.Dataset

torch.utils.data.Dataset to be used instead of passing "name" (i.e for external dataset objects).

None

Returns:

Type Description
DataLoader

initialized DataLoader.

Source code in src/super_gradients/training/dataloaders/dataloaders.py
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
def get(name: str = None, dataset_params: Dict = None, dataloader_params: Dict = None, dataset: torch.utils.data.Dataset = None) -> DataLoader:
    """
    Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.

    :param name: dataset name in ALL_DATALOADERS.
    :param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.
    :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__
    :param dataset: torch.utils.data.Dataset to be used instead of passing "name" (i.e for external dataset objects).
    :return: initialized DataLoader.
    """
    if dataset is not None:
        if name or dataset_params:
            raise ValueError("'name' and 'dataset_params' cannot be passed with initialized dataset.")

    dataset_str = get_param(dataloader_params, "dataset")

    if dataset_str:
        if name or dataset:
            raise ValueError("'name' and 'datasets' cannot be passed when 'dataset' arg dataloader_params is used as well.")
        if dataset_params is not None:
            dataset = DatasetsFactory().get(conf={dataset_str: dataset_params})
        else:
            dataset = DatasetsFactory().get(conf=dataset_str)
        _ = dataloader_params.pop("dataset")

    if dataset is not None:
        dataloader_params = _process_sampler_params(dataloader_params, dataset, {})
        dataloader_params = _process_collate_fn_params(dataloader_params)

        dataloader = DataLoader(dataset=dataset, **dataloader_params)

        dataloader.dataloader_params = dataloader_params
        if not hasattr(dataset, "dataset_params"):
            dataset.dataset_params = dataset_params

    elif name not in ALL_DATALOADERS.keys():
        raise ValueError("Unsupported dataloader: " + str(name))
    else:
        dataloader_cls = ALL_DATALOADERS[name]
        dataloader = dataloader_cls(dataset_params=dataset_params, dataloader_params=dataloader_params)

    maybe_setup_dataloader_adapter(dataloader=dataloader)
    return dataloader

get_data_loader(config_name, dataset_cls, train, dataset_params=None, dataloader_params=None)

Class for creating dataloaders for taking defaults from yaml files in src/super_gradients/recipes.

Parameters:

Name Type Description Default
config_name str

yaml config filename of dataset_params in recipes (for example coco_detection_dataset_params).

required
dataset_cls object

torch dataset uninitialized class.

required
train bool

controls whether to take cfg.train_dataloader_params or cfg.valid_dataloader_params as defaults for the dataset constructor and cfg.train_dataset_params or cfg.valid_dataset_params as defaults for DataLoader contructor.

required
dataset_params Mapping

dataset params that override the yaml configured defaults, then passed to the dataset_cls.init.

None
dataloader_params Mapping

DataLoader params that override the yaml configured defaults, then passed to the DataLoader.init

None

Returns:

Type Description
DataLoader

DataLoader

Source code in src/super_gradients/training/dataloaders/dataloaders.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def get_data_loader(config_name: str, dataset_cls: object, train: bool, dataset_params: Mapping = None, dataloader_params: Mapping = None) -> DataLoader:
    """
    Class for creating dataloaders for taking defaults from yaml files in src/super_gradients/recipes.

    :param config_name: yaml config filename of dataset_params in recipes (for example coco_detection_dataset_params).
    :param dataset_cls: torch dataset uninitialized class.
    :param train: controls whether to take
        cfg.train_dataloader_params or cfg.valid_dataloader_params as defaults for the dataset constructor
     and
        cfg.train_dataset_params or cfg.valid_dataset_params as defaults for DataLoader contructor.

    :param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.
    :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__
    :return: DataLoader
    """
    if dataloader_params is None:
        dataloader_params = dict()
    if dataset_params is None:
        dataset_params = dict()

    cfg = load_dataset_params(config_name=config_name)

    dataset_params = _process_dataset_params(cfg, dataset_params, train)

    local_rank = get_local_rank()
    with wait_for_the_master(local_rank):
        dataset = dataset_cls(**dataset_params)
        if not hasattr(dataset, "dataset_params"):
            dataset.dataset_params = dataset_params

    dataloader_params = _process_dataloader_params(cfg, dataloader_params, dataset, train)

    # Ensure there is no dataset in dataloader_params (Could be there if the user provided dataset class name)
    if "dataset" in dataloader_params:
        _ = dataloader_params.pop("dataset")

    dataloader = DataLoader(dataset=dataset, **dataloader_params)
    dataloader.dataloader_params = dataloader_params

    maybe_setup_dataloader_adapter(dataloader=dataloader)
    return dataloader