Skip to content

PyTorch Models - Hugging Face & SuperGradients

Prerequisite - Client Setup

Make sure to have configured your environment as per Installation & Setup.

Prerequisite - deci-platform-client>=5.0.0

pip install deci-platform-client>=5.0.0 is necessary for the following.

Prerequisite - Libraries
pip install transformers onnx
pip install super-gradients

Import the model

Please note that it is important to use return_dict=False when calling the from_pretrained method.

from transformers import BertForSequenceClassification

model_name = "bert-base-uncased"
model = BertForSequenceClassification.from_pretrained(model_name, return_dict=False)
from super_gradients.training import models
from super_gradients.common.object_names import Models

model_name = Models.YOLO_NAS_M
model = models.get(model_name, pretrained_weights="coco")

Create inputs metadata for the model

from transformers.models import bert
import numpy as np

batch_size = 1
sequence_length = 128
bert_onnx_config = bert.BertOnnxConfig(config=bert.BertConfig())
inputs_metadata = {
    input_name: {"dtype": np.int64, "shape": (batch_size, sequence_length)}
    for input_name in bert_onnx_config.inputs.keys()
    if input_name != "token_type_ids"
}
import numpy as np

inputs_metadata = {
    "input0": {
        "dtype": np.float32,
        "shape": (1, 3, 224, 224),
    }
} 

Upload the model to the platform

from deci_platform_client import DeciPlatformClient
from deci_platform_client.models import FrameworkType

client = DeciPlatformClient()

client.register_model(
    model=model,
    name=model_name,
    framework=FrameworkType.PYTORCH,
    inputs_metadata=inputs_metadata,
    dynamic_axes=bert_onnx_config.inputs,
)
from deci_platform_client import DeciPlatformClient
from deci_platform_client.models import FrameworkType

client = DeciPlatformClient()

client.register_model(
    model=model,
    name=model_name,
    framework=FrameworkType.PYTORCH,
    inputs_metadata=inputs_metadata,
)