Skip to content

Model to Endpoint

API Classes

For most users the API Classes will provide all the general functionality to create a full AWS ML Pipeline

ModelToEndpoint: Deploy an Endpoint for a Model

ModelToEndpoint

Bases: Transform

ModelToEndpoint: Deploy an Endpoint for a Model

Common Usage
to_endpoint = ModelToEndpoint(model_name, endpoint_name)
to_endpoint.set_output_tags(["aqsol", "public", "whatever"])
to_endpoint.transform()
Source code in src/workbench/core/transforms/model_to_endpoint/model_to_endpoint.py
class ModelToEndpoint(Transform):
    """ModelToEndpoint: Deploy an Endpoint for a Model

    Common Usage:
        ```python
        to_endpoint = ModelToEndpoint(model_name, endpoint_name)
        to_endpoint.set_output_tags(["aqsol", "public", "whatever"])
        to_endpoint.transform()
        ```
    """

    def __init__(self, model_name: str, endpoint_name: str, serverless: bool = True, instance: str = "ml.t2.medium"):
        """ModelToEndpoint Initialization
        Args:
            model_name(str): The Name of the input Model
            endpoint_name(str): The Name of the output Endpoint
            serverless(bool): Deploy the Endpoint in serverless mode (default: True)
            instance(str): The instance type to use for the Endpoint (default: "ml.t2.medium")
        """
        # Make sure the endpoint_name is a valid name
        Artifact.is_name_valid(endpoint_name, delimiter="-", lower_case=False)

        # Call superclass init
        super().__init__(model_name, endpoint_name)

        # Set up all my instance attributes
        self.serverless = serverless
        self.instance_type = "serverless" if serverless else instance
        self.input_type = TransformInput.MODEL
        self.output_type = TransformOutput.ENDPOINT

    def transform_impl(self, **kwargs):
        """Deploy an Endpoint for a Model"""

        # Delete endpoint (if it already exists)
        EndpointCore.managed_delete(self.output_name)

        # Get the Model Package ARN for our input model
        workbench_model = ModelCore(self.input_name)

        # Deploy the model
        self._deploy_model(workbench_model, **kwargs)

        # Add this endpoint to the set of registered endpoints for the model
        workbench_model.register_endpoint(self.output_name)

        # This ensures that the endpoint is ready for use
        time.sleep(5)  # We wait for AWS Lag
        end = EndpointCore(self.output_name)
        self.log.important(f"Endpoint {end.name} is ready for use")

    def _deploy_model(
        self,
        workbench_model: ModelCore,
        mem_size: int = 2048,
        max_concurrency: int = 5,
        data_capture: bool = False,
        capture_percentage: int = 100,
    ):
        """Internal Method: Deploy the Model

        Args:
            workbench_model(ModelCore): The Workbench ModelCore object to deploy
            mem_size(int): Memory size for serverless deployment
            max_concurrency(int): Max concurrency for serverless deployment
            data_capture(bool): Enable data capture during deployment
            capture_percentage(int): Percentage of data to capture. Defaults to 100.
        """
        # Grab the specified Model Package
        model_package_arn = workbench_model.model_package_arn()
        model_package = ModelPackage(
            role=self.workbench_role_arn,
            model_package_arn=model_package_arn,
            sagemaker_session=self.sm_session,
        )

        # Log the image that will be used for deployment
        inference_image = self.sm_client.describe_model_package(ModelPackageName=model_package_arn)[
            "InferenceSpecification"
        ]["Containers"][0]["Image"]
        self.log.important(f"Deploying Model Package: {self.input_name} with Inference Image: {inference_image}")

        # Get the metadata/tags to push into AWS
        aws_tags = self.get_aws_tags()

        # Is this a serverless deployment?
        serverless_config = None
        if self.serverless:
            serverless_config = ServerlessInferenceConfig(
                memory_size_in_mb=mem_size,
                max_concurrency=max_concurrency,
            )

        # Configure data capture if requested (and not serverless)
        data_capture_config = None
        if data_capture and not self.serverless:
            # Set up the S3 path for data capture
            base_endpoint_path = f"{workbench_model.endpoints_s3_path}/{self.output_name}"
            data_capture_path = f"{base_endpoint_path}/data_capture"
            self.log.important(f"Configuring Data Capture --> {data_capture_path}")
            data_capture_config = DataCaptureConfig(
                enable_capture=True,
                sampling_percentage=capture_percentage,
                destination_s3_uri=data_capture_path,
            )
        elif data_capture and self.serverless:
            self.log.warning(
                "Data capture is not supported for serverless endpoints. Skipping data capture configuration."
            )

        # Deploy the Endpoint
        self.log.important(f"Deploying the Endpoint {self.output_name}...")
        model_package.deploy(
            initial_instance_count=1,
            instance_type=self.instance_type,
            serverless_inference_config=serverless_config,
            endpoint_name=self.output_name,
            serializer=CSVSerializer(),
            deserializer=CSVDeserializer(),
            data_capture_config=data_capture_config,
            tags=aws_tags,
        )

    def post_transform(self, **kwargs):
        """Post-Transform: Calling onboard() for the Endpoint"""
        self.log.info("Post-Transform: Calling onboard() for the Endpoint...")

        # Onboard the Endpoint
        output_endpoint = EndpointCore(self.output_name)
        output_endpoint.onboard_with_args(input_model=self.input_name)

__init__(model_name, endpoint_name, serverless=True, instance='ml.t2.medium')

ModelToEndpoint Initialization Args: model_name(str): The Name of the input Model endpoint_name(str): The Name of the output Endpoint serverless(bool): Deploy the Endpoint in serverless mode (default: True) instance(str): The instance type to use for the Endpoint (default: "ml.t2.medium")

Source code in src/workbench/core/transforms/model_to_endpoint/model_to_endpoint.py
def __init__(self, model_name: str, endpoint_name: str, serverless: bool = True, instance: str = "ml.t2.medium"):
    """ModelToEndpoint Initialization
    Args:
        model_name(str): The Name of the input Model
        endpoint_name(str): The Name of the output Endpoint
        serverless(bool): Deploy the Endpoint in serverless mode (default: True)
        instance(str): The instance type to use for the Endpoint (default: "ml.t2.medium")
    """
    # Make sure the endpoint_name is a valid name
    Artifact.is_name_valid(endpoint_name, delimiter="-", lower_case=False)

    # Call superclass init
    super().__init__(model_name, endpoint_name)

    # Set up all my instance attributes
    self.serverless = serverless
    self.instance_type = "serverless" if serverless else instance
    self.input_type = TransformInput.MODEL
    self.output_type = TransformOutput.ENDPOINT

post_transform(**kwargs)

Post-Transform: Calling onboard() for the Endpoint

Source code in src/workbench/core/transforms/model_to_endpoint/model_to_endpoint.py
def post_transform(self, **kwargs):
    """Post-Transform: Calling onboard() for the Endpoint"""
    self.log.info("Post-Transform: Calling onboard() for the Endpoint...")

    # Onboard the Endpoint
    output_endpoint = EndpointCore(self.output_name)
    output_endpoint.onboard_with_args(input_model=self.input_name)

transform_impl(**kwargs)

Deploy an Endpoint for a Model

Source code in src/workbench/core/transforms/model_to_endpoint/model_to_endpoint.py
def transform_impl(self, **kwargs):
    """Deploy an Endpoint for a Model"""

    # Delete endpoint (if it already exists)
    EndpointCore.managed_delete(self.output_name)

    # Get the Model Package ARN for our input model
    workbench_model = ModelCore(self.input_name)

    # Deploy the model
    self._deploy_model(workbench_model, **kwargs)

    # Add this endpoint to the set of registered endpoints for the model
    workbench_model.register_endpoint(self.output_name)

    # This ensures that the endpoint is ready for use
    time.sleep(5)  # We wait for AWS Lag
    end = EndpointCore(self.output_name)
    self.log.important(f"Endpoint {end.name} is ready for use")