PyTorch Models - Hugging Face & SuperGradients
Prerequisite - Client Setup
Make sure to have configured your environment as per Installation & Setup.
Prerequisite - deci-platform-client>=5.0.0
pip install deci-platform-client>=5.0.0
is necessary for the following.
Prerequisite - Libraries
pip install transformers onnx
pip install super-gradients
Import the model
Please note that it is important to use return_dict=False
when calling the from_pretrained
method.
from transformers import BertForSequenceClassification
model_name = "bert-base-uncased"
model = BertForSequenceClassification.from_pretrained(model_name, return_dict=False)
from super_gradients.training import models
from super_gradients.common.object_names import Models
model_name = Models.YOLO_NAS_M
model = models.get(model_name, pretrained_weights="coco")
Create inputs metadata for the model
from transformers.models import bert
import numpy as np
batch_size = 1
sequence_length = 128
bert_onnx_config = bert.BertOnnxConfig(config=bert.BertConfig())
inputs_metadata = {
input_name: {"dtype": np.int64, "shape": (batch_size, sequence_length)}
for input_name in bert_onnx_config.inputs.keys()
if input_name != "token_type_ids"
}
import numpy as np
inputs_metadata = {
"input0": {
"dtype": np.float32,
"shape": (1, 3, 224, 224),
}
}
Upload the model to the platform
from deci_platform_client import DeciPlatformClient
from deci_platform_client.models import FrameworkType
client = DeciPlatformClient()
client.register_model(
model=model,
name=model_name,
framework=FrameworkType.PYTORCH,
inputs_metadata=inputs_metadata,
dynamic_axes=bert_onnx_config.inputs,
)
from deci_platform_client import DeciPlatformClient
from deci_platform_client.models import FrameworkType
client = DeciPlatformClient()
client.register_model(
model=model,
name=model_name,
framework=FrameworkType.PYTORCH,
inputs_metadata=inputs_metadata,
)