Skip to content

Data interface

ADNNModelRepositoryDataInterfaces

Bases: ILogger

ResearchModelRepositoryDataInterface

Source code in src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class ADNNModelRepositoryDataInterfaces(ILogger):
    """
    ResearchModelRepositoryDataInterface
    """

    def __init__(self, data_connection_location: str = "local", data_connection_credentials: str = None):
        """
        ModelCheckpointsDataInterface
            :param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name'
            :param data_connection_credentials: credentials string
                    - name of aws profile in case data_connection_source is s3. will be taken form environment variable
                    AWS_PROFILE if left empty
        """
        super().__init__()
        self.tb_events_file_prefix = "events.out.tfevents"
        self.log_file_prefix = "experiment_logs_"
        self.latest_checkpoint_filename = "ckpt_latest.pth"
        self.best_checkpoint_filename = "ckpt_best.pth"

        if data_connection_location.startswith("s3"):
            assert data_connection_location.index("s3://") >= 0, "S3 path must be formatted s3://bucket-name"
            self.model_repo_bucket_name = data_connection_location.split("://")[1]
            self.data_connection_source = "s3"

            if data_connection_credentials is None:

                data_connection_credentials = env_variables.AWS_PROFILE

            self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name)

    @explicit_params_validation(validation_type="None")
    def load_all_remote_log_files(self, model_name: str, model_checkpoint_local_dir: str):
        """
        load_all_remote_checkpoint_files
            :param model_name:
            :param model_checkpoint_local_dir:
            :return:
        """
        self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="tensorboard")
        self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="text")

    @explicit_params_validation(validation_type="None")
    def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_local_dir: str, log_file_name: str):
        """
        save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo
            :param model_name:                  The Model Name to store in Remote Repo
            :param model_checkpoint_local_dir:  Local directory with the relevant data to upload
            :param log_file_name:               The log_file name (Created independently)
        """
        for checkpoint_file_name in [self.latest_checkpoint_filename, self.best_checkpoint_filename]:
            self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, checkpoint_file_name)

        self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, log_file_name)
        self.save_remote_tensorboard_event_files(model_name, model_checkpoint_local_dir)

    @explicit_params_validation(validation_type="None")
    def load_remote_checkpoints_file(
        self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str, ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False
    ) -> str:
        """
        load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file
            :param ckpt_source_remote_dir:               The source folder to download from
            :param ckpt_destination_local_dir:           The destination folder to save the checkpoint at
            :param ckpt_file_name:                       Filename to load from Remote Repo
            :param overwrite_local_checkpoints_file:     Use Only for Cloud-Stored Model Checkpoints if required behavior
                                                            is to overwrite a previous version of the same files
            :return: Model Checkpoint File Path -> Depends on model architecture
        """
        ckpt_file_local_full_path = ckpt_destination_local_dir + "/" + ckpt_file_name

        if self.data_connection_source == "s3":
            if overwrite_local_checkpoints_file:
                # DELETE THE LOCAL VERSION ON THE MACHINE
                if os.path.exists(ckpt_file_local_full_path):
                    os.remove(ckpt_file_local_full_path)

            key_to_download = ckpt_source_remote_dir + "/" + ckpt_file_name
            download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path, key_to_download=key_to_download)

            if not download_success:
                failed_download_path = "s3://" + self.model_repo_bucket_name + "/" + key_to_download
                error_msg = "Failed to Download Model Checkpoint from " + failed_download_path
                self._logger.error(error_msg)
                raise ModelCheckpointNotFoundException(error_msg)

        return ckpt_file_local_full_path

    @explicit_params_validation(validation_type="NoneOrEmpty")
    def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name: str, logging_type: str):
        """
        load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository
            :param model_name:
            :param model_checkpoint_dir_name:
            :param logging_type:
            :return:
        """
        if not os.path.isdir(model_checkpoint_dir_name):
            raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")

        # LOADS THE DATA FROM THE REMOTE REPOSITORY
        s3_bucket_path_prefix = model_name
        if logging_type == "tensorboard":
            if self.data_connection_source == "s3":
                self.s3_connector.download_keys_by_prefix(
                    s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.tb_events_file_prefix
                )
        elif logging_type == "text":
            if self.data_connection_source == "s3":
                self.s3_connector.download_keys_by_prefix(
                    s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.log_file_prefix
                )

    @explicit_params_validation(validation_type="NoneOrEmpty")
    def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str, checkpoints_file_name: str) -> bool:
        """
        save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo
            :param model_name:                      The Model Name for S3 Prefix
            :param model_checkpoint_local_dir:      Model Directory - Based on Model name
            :param checkpoints_file_name:           Filename to upload to Remote Repo
            :return: True/False for Operation Success/Failure
        """
        # LOAD THE LOCAL VERSION
        model_checkpoint_file_full_path = model_checkpoint_local_dir + "/" + checkpoints_file_name

        # SAVE ON THE REMOTE S3 REPOSITORY
        if self.data_connection_source == "s3":
            model_checkpoint_s3_in_bucket_path = model_name + "/" + checkpoints_file_name
            return self.__update_or_upload_s3_key(model_checkpoint_file_full_path, model_checkpoint_s3_in_bucket_path)

    @explicit_params_validation(validation_type="NoneOrEmpty")
    def save_remote_tensorboard_event_files(self, model_name: str, model_checkpoint_dir_name: str):
        """
        save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely
            :param model_name:                Prefix for Cloud Storage
            :param model_checkpoint_dir_name: The directory where the files are stored in
        """
        if not os.path.isdir(model_checkpoint_dir_name):
            raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")

        for tb_events_file_name in os.listdir(model_checkpoint_dir_name):
            if tb_events_file_name.startswith(self.tb_events_file_prefix):
                upload_success = self.save_remote_checkpoints_file(
                    model_name=model_name, model_checkpoint_local_dir=model_checkpoint_dir_name, checkpoints_file_name=tb_events_file_name
                )

                if not upload_success:
                    self._logger.error("Failed to upload tb_events_file: " + tb_events_file_name)

    @explicit_params_validation(validation_type="NoneOrEmpty")
    def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
        """
        __update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path
            :param local_file_path: The Local file path to upload to S3
            :param s3_key_path:     The S3 path to create/update the S3 Key
        """
        if self.s3_connector.check_key_exists(s3_key_path):
            # DELETE KEY TO UPDATE THE FILE IN S3
            delete_response = self.s3_connector.delete_key(s3_key_path)
            if delete_response:
                self._logger.info("Removed previous checkpoint from S3")

        upload_success = self.s3_connector.upload_file(local_file_path, s3_key_path)
        if not upload_success:
            self._logger.error("Failed to upload model checkpoint")

        return upload_success

__init__(data_connection_location='local', data_connection_credentials=None)

ModelCheckpointsDataInterface :param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name' :param data_connection_credentials: credentials string - name of aws profile in case data_connection_source is s3. will be taken form environment variable AWS_PROFILE if left empty

Source code in src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(self, data_connection_location: str = "local", data_connection_credentials: str = None):
    """
    ModelCheckpointsDataInterface
        :param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name'
        :param data_connection_credentials: credentials string
                - name of aws profile in case data_connection_source is s3. will be taken form environment variable
                AWS_PROFILE if left empty
    """
    super().__init__()
    self.tb_events_file_prefix = "events.out.tfevents"
    self.log_file_prefix = "experiment_logs_"
    self.latest_checkpoint_filename = "ckpt_latest.pth"
    self.best_checkpoint_filename = "ckpt_best.pth"

    if data_connection_location.startswith("s3"):
        assert data_connection_location.index("s3://") >= 0, "S3 path must be formatted s3://bucket-name"
        self.model_repo_bucket_name = data_connection_location.split("://")[1]
        self.data_connection_source = "s3"

        if data_connection_credentials is None:

            data_connection_credentials = env_variables.AWS_PROFILE

        self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name)

__update_or_upload_s3_key(local_file_path, s3_key_path)

__update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path :param local_file_path: The Local file path to upload to S3 :param s3_key_path: The S3 path to create/update the S3 Key

Source code in src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
@explicit_params_validation(validation_type="NoneOrEmpty")
def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
    """
    __update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path
        :param local_file_path: The Local file path to upload to S3
        :param s3_key_path:     The S3 path to create/update the S3 Key
    """
    if self.s3_connector.check_key_exists(s3_key_path):
        # DELETE KEY TO UPDATE THE FILE IN S3
        delete_response = self.s3_connector.delete_key(s3_key_path)
        if delete_response:
            self._logger.info("Removed previous checkpoint from S3")

    upload_success = self.s3_connector.upload_file(local_file_path, s3_key_path)
    if not upload_success:
        self._logger.error("Failed to upload model checkpoint")

    return upload_success

load_all_remote_log_files(model_name, model_checkpoint_local_dir)

load_all_remote_checkpoint_files :param model_name: :param model_checkpoint_local_dir: :return:

Source code in src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py
44
45
46
47
48
49
50
51
52
53
@explicit_params_validation(validation_type="None")
def load_all_remote_log_files(self, model_name: str, model_checkpoint_local_dir: str):
    """
    load_all_remote_checkpoint_files
        :param model_name:
        :param model_checkpoint_local_dir:
        :return:
    """
    self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="tensorboard")
    self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="text")

load_remote_checkpoints_file(ckpt_source_remote_dir, ckpt_destination_local_dir, ckpt_file_name, overwrite_local_checkpoints_file=False)

load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file :param ckpt_source_remote_dir: The source folder to download from :param ckpt_destination_local_dir: The destination folder to save the checkpoint at :param ckpt_file_name: Filename to load from Remote Repo :param overwrite_local_checkpoints_file: Use Only for Cloud-Stored Model Checkpoints if required behavior is to overwrite a previous version of the same files :return: Model Checkpoint File Path -> Depends on model architecture

Source code in src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@explicit_params_validation(validation_type="None")
def load_remote_checkpoints_file(
    self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str, ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False
) -> str:
    """
    load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file
        :param ckpt_source_remote_dir:               The source folder to download from
        :param ckpt_destination_local_dir:           The destination folder to save the checkpoint at
        :param ckpt_file_name:                       Filename to load from Remote Repo
        :param overwrite_local_checkpoints_file:     Use Only for Cloud-Stored Model Checkpoints if required behavior
                                                        is to overwrite a previous version of the same files
        :return: Model Checkpoint File Path -> Depends on model architecture
    """
    ckpt_file_local_full_path = ckpt_destination_local_dir + "/" + ckpt_file_name

    if self.data_connection_source == "s3":
        if overwrite_local_checkpoints_file:
            # DELETE THE LOCAL VERSION ON THE MACHINE
            if os.path.exists(ckpt_file_local_full_path):
                os.remove(ckpt_file_local_full_path)

        key_to_download = ckpt_source_remote_dir + "/" + ckpt_file_name
        download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path, key_to_download=key_to_download)

        if not download_success:
            failed_download_path = "s3://" + self.model_repo_bucket_name + "/" + key_to_download
            error_msg = "Failed to Download Model Checkpoint from " + failed_download_path
            self._logger.error(error_msg)
            raise ModelCheckpointNotFoundException(error_msg)

    return ckpt_file_local_full_path

load_remote_logging_files(model_name, model_checkpoint_dir_name, logging_type)

load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository :param model_name: :param model_checkpoint_dir_name: :param logging_type: :return:

Source code in src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@explicit_params_validation(validation_type="NoneOrEmpty")
def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name: str, logging_type: str):
    """
    load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository
        :param model_name:
        :param model_checkpoint_dir_name:
        :param logging_type:
        :return:
    """
    if not os.path.isdir(model_checkpoint_dir_name):
        raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")

    # LOADS THE DATA FROM THE REMOTE REPOSITORY
    s3_bucket_path_prefix = model_name
    if logging_type == "tensorboard":
        if self.data_connection_source == "s3":
            self.s3_connector.download_keys_by_prefix(
                s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.tb_events_file_prefix
            )
    elif logging_type == "text":
        if self.data_connection_source == "s3":
            self.s3_connector.download_keys_by_prefix(
                s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.log_file_prefix
            )

save_all_remote_checkpoint_files(model_name, model_checkpoint_local_dir, log_file_name)

save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo :param model_name: The Model Name to store in Remote Repo :param model_checkpoint_local_dir: Local directory with the relevant data to upload :param log_file_name: The log_file name (Created independently)

Source code in src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py
55
56
57
58
59
60
61
62
63
64
65
66
67
@explicit_params_validation(validation_type="None")
def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_local_dir: str, log_file_name: str):
    """
    save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo
        :param model_name:                  The Model Name to store in Remote Repo
        :param model_checkpoint_local_dir:  Local directory with the relevant data to upload
        :param log_file_name:               The log_file name (Created independently)
    """
    for checkpoint_file_name in [self.latest_checkpoint_filename, self.best_checkpoint_filename]:
        self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, checkpoint_file_name)

    self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, log_file_name)
    self.save_remote_tensorboard_event_files(model_name, model_checkpoint_local_dir)

save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, checkpoints_file_name)

save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo :param model_name: The Model Name for S3 Prefix :param model_checkpoint_local_dir: Model Directory - Based on Model name :param checkpoints_file_name: Filename to upload to Remote Repo :return: True/False for Operation Success/Failure

Source code in src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@explicit_params_validation(validation_type="NoneOrEmpty")
def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str, checkpoints_file_name: str) -> bool:
    """
    save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo
        :param model_name:                      The Model Name for S3 Prefix
        :param model_checkpoint_local_dir:      Model Directory - Based on Model name
        :param checkpoints_file_name:           Filename to upload to Remote Repo
        :return: True/False for Operation Success/Failure
    """
    # LOAD THE LOCAL VERSION
    model_checkpoint_file_full_path = model_checkpoint_local_dir + "/" + checkpoints_file_name

    # SAVE ON THE REMOTE S3 REPOSITORY
    if self.data_connection_source == "s3":
        model_checkpoint_s3_in_bucket_path = model_name + "/" + checkpoints_file_name
        return self.__update_or_upload_s3_key(model_checkpoint_file_full_path, model_checkpoint_s3_in_bucket_path)

save_remote_tensorboard_event_files(model_name, model_checkpoint_dir_name)

save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely :param model_name: Prefix for Cloud Storage :param model_checkpoint_dir_name: The directory where the files are stored in

Source code in src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
@explicit_params_validation(validation_type="NoneOrEmpty")
def save_remote_tensorboard_event_files(self, model_name: str, model_checkpoint_dir_name: str):
    """
    save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely
        :param model_name:                Prefix for Cloud Storage
        :param model_checkpoint_dir_name: The directory where the files are stored in
    """
    if not os.path.isdir(model_checkpoint_dir_name):
        raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")

    for tb_events_file_name in os.listdir(model_checkpoint_dir_name):
        if tb_events_file_name.startswith(self.tb_events_file_prefix):
            upload_success = self.save_remote_checkpoints_file(
                model_name=model_name, model_checkpoint_local_dir=model_checkpoint_dir_name, checkpoints_file_name=tb_events_file_name
            )

            if not upload_success:
                self._logger.error("Failed to upload tb_events_file: " + tb_events_file_name)

DatasetDataInterface

Source code in src/super_gradients/common/data_interface/dataset_data_interface.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class DatasetDataInterface:
    def __init__(self, env: str, data_connection_source: str = "s3"):
        """

        :param env: str "development"/"production"
        :param data_connection_source: str "s3" for aws bny default
        """
        self.env = env
        self.s3_connector = None
        self.data_connection_source = data_connection_source

    @explicit_params_validation(validation_type="None")
    def load_remote_dataset_file(self, remote_file: str, local_dir: str, overwrite_local_dataset: bool = False) -> str:
        """

        :param remote_file: str - the name of s3 file
        :param local_dir: str - the directory to put the dataset in
        :param overwrite_local_dataset: Whether too  delete the dataset dir before downloading
        :return:
        """

        dataset_full_path = local_dir
        bucket = remote_file.split("/")[2]
        file_path = "/".join(remote_file.split("/")[3:])
        if self.data_connection_source == "s3":
            self.s3_connector = S3Connector(self.env, bucket)

            # DELETE THE LOCAL VERSION ON THE MACHINE
            if os.path.exists(dataset_full_path):
                if overwrite_local_dataset:

                    filelist = os.listdir(local_dir)
                    for f in filelist:
                        os.remove(os.path.join(local_dir, f))
                else:
                    Warning("Overwrite local dataset set to False but dataset exists in the dir")
            if not os.path.exists(local_dir):
                os.mkdir(local_dir)

            local_file = self.s3_connector.download_file_by_path(file_path, local_dir)
            with zipfile.ZipFile(local_dir + "/" + local_file, "r") as zip_ref:
                zip_ref.extractall(local_dir + "/")
            os.remove(local_dir + "/" + local_file)

        return local_dir

__init__(env, data_connection_source='s3')

Parameters:

Name Type Description Default
env str

str "development"/"production"

required
data_connection_source str

str "s3" for aws bny default

's3'
Source code in src/super_gradients/common/data_interface/dataset_data_interface.py
 8
 9
10
11
12
13
14
15
16
def __init__(self, env: str, data_connection_source: str = "s3"):
    """

    :param env: str "development"/"production"
    :param data_connection_source: str "s3" for aws bny default
    """
    self.env = env
    self.s3_connector = None
    self.data_connection_source = data_connection_source

load_remote_dataset_file(remote_file, local_dir, overwrite_local_dataset=False)

Parameters:

Name Type Description Default
remote_file str

str - the name of s3 file

required
local_dir str

str - the directory to put the dataset in

required
overwrite_local_dataset bool

Whether too delete the dataset dir before downloading

False

Returns:

Type Description
str
Source code in src/super_gradients/common/data_interface/dataset_data_interface.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@explicit_params_validation(validation_type="None")
def load_remote_dataset_file(self, remote_file: str, local_dir: str, overwrite_local_dataset: bool = False) -> str:
    """

    :param remote_file: str - the name of s3 file
    :param local_dir: str - the directory to put the dataset in
    :param overwrite_local_dataset: Whether too  delete the dataset dir before downloading
    :return:
    """

    dataset_full_path = local_dir
    bucket = remote_file.split("/")[2]
    file_path = "/".join(remote_file.split("/")[3:])
    if self.data_connection_source == "s3":
        self.s3_connector = S3Connector(self.env, bucket)

        # DELETE THE LOCAL VERSION ON THE MACHINE
        if os.path.exists(dataset_full_path):
            if overwrite_local_dataset:

                filelist = os.listdir(local_dir)
                for f in filelist:
                    os.remove(os.path.join(local_dir, f))
            else:
                Warning("Overwrite local dataset set to False but dataset exists in the dir")
        if not os.path.exists(local_dir):
            os.mkdir(local_dir)

        local_file = self.s3_connector.download_file_by_path(file_path, local_dir)
        with zipfile.ZipFile(local_dir + "/" + local_file, "r") as zip_ref:
            zip_ref.extractall(local_dir + "/")
        os.remove(local_dir + "/" + local_file)

    return local_dir