Skip to content

Model to Endpoint

API Classes

For most users the API Classes will provide all the general functionality to create a full AWS ML Pipeline

ModelToEndpoint: Deploy an Endpoint for a Model

FIXME: Investigate using V3's ModelBuilder for deployment instead of manually creating Model/EndpointConfig/Endpoint resources. ModelBuilder handles inference code bundling, SAGEMAKER_PROGRAM env var, and model artifact repacking automatically. May eliminate the need for our inference-metadata.json workaround. Need to verify it works with our custom inference Docker images (which use their own main.py, not SageMaker's built-in serving stack). See https://sagemaker.readthedocs.io/en/stable/inference/index.html

ModelToEndpoint

Bases: Transform

ModelToEndpoint: Deploy an Endpoint for a Model

Common Usage
to_endpoint = ModelToEndpoint(model_name, endpoint_name)
to_endpoint.set_output_tags(["aqsol", "public", "whatever"])
to_endpoint.transform()
Source code in src/workbench/core/transforms/model_to_endpoint/model_to_endpoint.py
class ModelToEndpoint(Transform):
    """ModelToEndpoint: Deploy an Endpoint for a Model

    Common Usage:
        ```python
        to_endpoint = ModelToEndpoint(model_name, endpoint_name)
        to_endpoint.set_output_tags(["aqsol", "public", "whatever"])
        to_endpoint.transform()
        ```
    """

    def __init__(
        self,
        model_name: str,
        endpoint_name: str,
        serverless: bool = True,
        instance: str = None,
        async_endpoint: bool = False,
    ):
        """ModelToEndpoint Initialization
        Args:
            model_name(str): The Name of the input Model
            endpoint_name(str): The Name of the output Endpoint
            serverless(bool): Deploy the Endpoint in serverless mode (default: True)
            instance(str): The instance type for Realtime Endpoints (default: None = auto-select)
            async_endpoint(bool): Deploy as an async endpoint (default: False). Async
                endpoints support up to 15-minute invocations and use S3 for I/O.
                Incompatible with serverless — if both are True, serverless is forced off.
        """
        # Make sure the endpoint_name is a valid name
        Artifact.is_name_valid(endpoint_name, delimiter="-", lower_case=False)

        # Call superclass init
        super().__init__(model_name, endpoint_name)

        # Async endpoints are always realtime (not serverless)
        if async_endpoint and serverless:
            self.log.warning("Async endpoints are not compatible with serverless. Forcing serverless=False.")
            serverless = False

        # Set up all my instance attributes
        self.serverless = serverless
        self.instance = instance
        self.async_endpoint = async_endpoint
        self.input_type = TransformInput.MODEL
        self.output_type = TransformOutput.ENDPOINT

    def transform_impl(self, **kwargs):
        """Deploy an Endpoint for a Model"""

        # Delete endpoint (if it already exists)
        EndpointCore.managed_delete(self.output_name)

        # Get the Model Package ARN for our input model
        workbench_model = ModelCore(self.input_name)

        # Deploy the model
        self._deploy_model(workbench_model, **kwargs)

        # Add this endpoint to the set of registered endpoints for the model
        workbench_model.register_endpoint(self.output_name)

        # This ensures that the endpoint is ready for use
        time.sleep(5)  # We wait for AWS Lag
        end = EndpointCore(self.output_name)
        self.log.important(f"Endpoint {end.name} is ready for use")

    def _deploy_model(
        self,
        workbench_model: ModelCore,
        mem_size: int = 2048,
        max_concurrency: int = 5,
        data_capture: bool = False,
        capture_percentage: int = 100,
    ):
        """Internal Method: Deploy the Model

        Args:
            workbench_model(ModelCore): The Workbench ModelCore object to deploy
            mem_size(int): Memory size for serverless deployment
            max_concurrency(int): Max concurrency for serverless deployment
            data_capture(bool): Enable data capture during deployment
            capture_percentage(int): Percentage of data to capture. Defaults to 100.
        """
        # Grab the specified Model Package ARN and inference image
        model_package_arn = workbench_model.model_package_arn()
        inference_image = workbench_model.container_image()
        self.log.important(f"Deploying Model Package: {self.input_name} with Inference Image: {inference_image}")

        # Get the metadata/tags to push into AWS
        aws_tags = self.get_aws_tags()
        sagemaker_tags = [Tag(key=t["key"], value=t["value"]) for t in aws_tags]

        # Check the model framework for resource requirements
        from workbench.api import ModelFramework

        self.log.info(f"Model Framework: {workbench_model.model_framework}")
        needs_more_resources = workbench_model.model_framework in [ModelFramework.PYTORCH, ModelFramework.CHEMPROP]

        # Determine serverless config and instance type
        serverless_config = None
        if self.serverless:
            # For PyTorch or ChemProp we need at least 4GB of memory
            if needs_more_resources and mem_size < 4096:
                self.log.important(f"{workbench_model.model_framework} needs at least 4GB of memory (setting to 4GB)")
                mem_size = 4096
            serverless_config = ProductionVariantServerlessConfig(
                memory_size_in_mb=mem_size,
                max_concurrency=max_concurrency,
            )
            instance_type = None  # Not used for serverless
            self.log.important(f"Serverless Config: Memory={mem_size}MB, MaxConcurrency={max_concurrency}")
        else:
            # For realtime endpoints, use explicit instance if provided, otherwise auto-select
            if self.instance:
                instance_type = self.instance
                self.log.important(f"Realtime Endpoint: Using specified instance type: {instance_type}")
            elif needs_more_resources:
                instance_type = "ml.c7i.large"
                self.log.important(f"{workbench_model.model_framework} needs more resources (using {instance_type})")
            else:
                instance_type = "ml.t2.medium"
                self.log.important(f"Realtime Endpoint: Instance Type={instance_type}")

        # Configure data capture if requested (and not serverless)
        data_capture_config = None
        if data_capture and not self.serverless:
            # Set up the S3 path for data capture
            base_endpoint_path = f"{workbench_model.endpoints_s3_path}/{self.output_name}"
            data_capture_path = f"{base_endpoint_path}/data_capture"
            self.log.important(f"Configuring Data Capture --> {data_capture_path}")
            data_capture_config = DataCaptureConfigShape(
                enable_capture=True,
                initial_sampling_percentage=capture_percentage,
                destination_s3_uri=data_capture_path,
                capture_options=[
                    CaptureOption(capture_mode="Input"),
                    CaptureOption(capture_mode="Output"),
                ],
            )
        elif data_capture and self.serverless:
            self.log.warning(
                "Data capture is not supported for serverless endpoints. Skipping data capture configuration."
            )

        # Deploy the Endpoint using V3 Resource Classes
        self.log.important(f"Deploying the Endpoint {self.output_name}...")
        try:
            self._create_endpoint_resources(
                model_package_arn=model_package_arn,
                serverless_config=serverless_config,
                instance_type=instance_type,
                data_capture_config=data_capture_config,
                tags=sagemaker_tags,
            )
        except ClientError as e:
            # Check if this is the "endpoint config already exists" error
            if "Cannot create already existing endpoint configuration" in str(e):
                self.log.warning("Endpoint config already exists, deleting and retrying...")
                EndpointConfig.get(self.output_name, session=self.boto3_session).delete()
                # Retry
                self._create_endpoint_resources(
                    model_package_arn=model_package_arn,
                    serverless_config=serverless_config,
                    instance_type=instance_type,
                    data_capture_config=data_capture_config,
                    tags=sagemaker_tags,
                )
            else:
                raise

    def _create_endpoint_resources(
        self,
        model_package_arn: str,
        serverless_config=None,
        instance_type: str = None,
        data_capture_config=None,
        tags=None,
    ):
        """Internal: Create the SageMaker Model, EndpointConfig, and Endpoint resources.

        Args:
            model_package_arn (str): The model package ARN to deploy
            serverless_config: ServerlessConfig for serverless deployments
            instance_type (str): Instance type for realtime deployments
            data_capture_config: Data capture configuration
            tags: List of Tag objects
        """
        model_name = self.output_name
        config_name = self.output_name

        # Step 1: Create the SageMaker Model from the Model Package
        container = ContainerDefinition(model_package_name=model_package_arn)
        try:
            SagemakerModel.create(
                model_name=model_name,
                primary_container=container,
                execution_role_arn=self.workbench_role_arn,
                tags=tags,
                session=self.boto3_session,
            )
        except ClientError as e:
            if "Cannot create already existing model" in str(e):
                self.log.warning("Model already exists, deleting and recreating...")
                SagemakerModel.get(model_name, session=self.boto3_session).delete()
                SagemakerModel.create(
                    model_name=model_name,
                    primary_container=container,
                    execution_role_arn=self.workbench_role_arn,
                    tags=tags,
                    session=self.boto3_session,
                )
            else:
                raise

        # Step 2: Create the EndpointConfig
        production_variant = ProductionVariant(
            variant_name="AllTraffic",
            model_name=model_name,
            initial_variant_weight=1.0,
        )
        if serverless_config:
            production_variant.serverless_config = serverless_config
        else:
            production_variant.initial_instance_count = 1
            production_variant.instance_type = instance_type
            production_variant.container_startup_health_check_timeout_in_seconds = 300

        # Build async inference config if requested
        async_inference_config = None
        if self.async_endpoint:
            base_path = f"{self.endpoints_s3_path}/{self.output_name}"
            async_inference_config = AsyncInferenceConfig(
                output_config=AsyncInferenceOutputConfig(
                    s3_output_path=f"{base_path}/async-output",
                    s3_failure_path=f"{base_path}/async-failures",
                ),
            )
            self.log.important(f"Async Endpoint Config: output → {base_path}/async-output")

        EndpointConfig.create(
            endpoint_config_name=config_name,
            production_variants=[production_variant],
            async_inference_config=async_inference_config,
            data_capture_config=data_capture_config,
            tags=tags,
            session=self.boto3_session,
        )

        # Step 3: Create the Endpoint and wait for it to be InService
        endpoint = SagemakerEndpoint.create(
            endpoint_name=self.output_name,
            endpoint_config_name=config_name,
            tags=tags,
            session=self.boto3_session,
        )
        endpoint.wait_for_status("InService")

        # For async endpoints, register a scale-to-zero auto-scaling policy.
        # This must be done after the endpoint is InService — AWS doesn't
        # allow managed instance scaling on the ProductionVariant for async configs.
        if self.async_endpoint:
            register_autoscaling(self.boto3_session, self.output_name)

    def post_transform(self, **kwargs):
        """Post-Transform: Calling onboard() for the Endpoint"""
        self.log.info("Post-Transform: Calling onboard() for the Endpoint...")

        # Onboard the Endpoint
        output_endpoint = EndpointCore(self.output_name)
        output_endpoint.onboard_with_args(input_model=self.input_name)

__init__(model_name, endpoint_name, serverless=True, instance=None, async_endpoint=False)

ModelToEndpoint Initialization Args: model_name(str): The Name of the input Model endpoint_name(str): The Name of the output Endpoint serverless(bool): Deploy the Endpoint in serverless mode (default: True) instance(str): The instance type for Realtime Endpoints (default: None = auto-select) async_endpoint(bool): Deploy as an async endpoint (default: False). Async endpoints support up to 15-minute invocations and use S3 for I/O. Incompatible with serverless — if both are True, serverless is forced off.

Source code in src/workbench/core/transforms/model_to_endpoint/model_to_endpoint.py
def __init__(
    self,
    model_name: str,
    endpoint_name: str,
    serverless: bool = True,
    instance: str = None,
    async_endpoint: bool = False,
):
    """ModelToEndpoint Initialization
    Args:
        model_name(str): The Name of the input Model
        endpoint_name(str): The Name of the output Endpoint
        serverless(bool): Deploy the Endpoint in serverless mode (default: True)
        instance(str): The instance type for Realtime Endpoints (default: None = auto-select)
        async_endpoint(bool): Deploy as an async endpoint (default: False). Async
            endpoints support up to 15-minute invocations and use S3 for I/O.
            Incompatible with serverless — if both are True, serverless is forced off.
    """
    # Make sure the endpoint_name is a valid name
    Artifact.is_name_valid(endpoint_name, delimiter="-", lower_case=False)

    # Call superclass init
    super().__init__(model_name, endpoint_name)

    # Async endpoints are always realtime (not serverless)
    if async_endpoint and serverless:
        self.log.warning("Async endpoints are not compatible with serverless. Forcing serverless=False.")
        serverless = False

    # Set up all my instance attributes
    self.serverless = serverless
    self.instance = instance
    self.async_endpoint = async_endpoint
    self.input_type = TransformInput.MODEL
    self.output_type = TransformOutput.ENDPOINT

post_transform(**kwargs)

Post-Transform: Calling onboard() for the Endpoint

Source code in src/workbench/core/transforms/model_to_endpoint/model_to_endpoint.py
def post_transform(self, **kwargs):
    """Post-Transform: Calling onboard() for the Endpoint"""
    self.log.info("Post-Transform: Calling onboard() for the Endpoint...")

    # Onboard the Endpoint
    output_endpoint = EndpointCore(self.output_name)
    output_endpoint.onboard_with_args(input_model=self.input_name)

transform_impl(**kwargs)

Deploy an Endpoint for a Model

Source code in src/workbench/core/transforms/model_to_endpoint/model_to_endpoint.py
def transform_impl(self, **kwargs):
    """Deploy an Endpoint for a Model"""

    # Delete endpoint (if it already exists)
    EndpointCore.managed_delete(self.output_name)

    # Get the Model Package ARN for our input model
    workbench_model = ModelCore(self.input_name)

    # Deploy the model
    self._deploy_model(workbench_model, **kwargs)

    # Add this endpoint to the set of registered endpoints for the model
    workbench_model.register_endpoint(self.output_name)

    # This ensures that the endpoint is ready for use
    time.sleep(5)  # We wait for AWS Lag
    end = EndpointCore(self.output_name)
    self.log.important(f"Endpoint {end.name} is ready for use")