Skip to content

Model Utilities

Examples

Examples of using the Model Utilities are listed at the bottom of this page Examples.

Model Utilities for Workbench models

instance_architecture(instance_name)

Get the architecture for the given instance name

Source code in src/workbench/utils/model_utils.py
def instance_architecture(instance_name: str) -> str:
    """Get the architecture for the given instance name"""
    info = model_instance_info()
    return info[info["Instance Name"] == instance_name]["Architecture"].values[0]

load_category_mappings_from_s3(model_artifact_uri)

Download and extract category mappings from a model artifact in S3.

Parameters:

Name Type Description Default
model_artifact_uri str

S3 URI of the model artifact.

required

Returns:

Name Type Description
dict Optional[dict]

The loaded category mappings or None if not found.

Source code in src/workbench/utils/model_utils.py
def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
    """
    Download and extract category mappings from a model artifact in S3.

    Args:
        model_artifact_uri (str): S3 URI of the model artifact.

    Returns:
        dict: The loaded category mappings or None if not found.
    """
    category_mappings = None

    with tempfile.TemporaryDirectory() as tmpdir:
        # Download model artifact
        local_tar_path = os.path.join(tmpdir, "model.tar.gz")
        wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)

        # Extract tarball
        with tarfile.open(local_tar_path, "r:gz") as tar:
            tar.extractall(path=tmpdir, filter="data")

        # Look for category mappings in base directory only
        mappings_path = os.path.join(tmpdir, "category_mappings.json")

        if os.path.exists(mappings_path):
            try:
                with open(mappings_path, "r") as f:
                    category_mappings = json.load(f)
                print(f"Loaded category mappings from {mappings_path}")
            except Exception as e:
                print(f"Failed to load category mappings from {mappings_path}: {e}")

    return category_mappings

model_instance_info()

Get the instance information for the Model

Source code in src/workbench/utils/model_utils.py
def model_instance_info() -> pd.DataFrame:
    """Get the instance information for the Model"""
    data = [
        {
            "Instance Name": "ml.t2.medium",
            "vCPUs": 2,
            "Memory": 4,
            "Price per Hour": 0.06,
            "Category": "General",
            "Architecture": "x86_64",
        },
        {
            "Instance Name": "ml.m7i.large",
            "vCPUs": 2,
            "Memory": 8,
            "Price per Hour": 0.12,
            "Category": "General",
            "Architecture": "x86_64",
        },
        {
            "Instance Name": "ml.c7i.large",
            "vCPUs": 2,
            "Memory": 4,
            "Price per Hour": 0.11,
            "Category": "Compute",
            "Architecture": "x86_64",
        },
        {
            "Instance Name": "ml.c7i.xlarge",
            "vCPUs": 4,
            "Memory": 8,
            "Price per Hour": 0.21,
            "Category": "Compute",
            "Architecture": "x86_64",
        },
        {
            "Instance Name": "ml.c7g.large",
            "vCPUs": 2,
            "Memory": 4,
            "Price per Hour": 0.09,
            "Category": "Compute",
            "Architecture": "arm64",
        },
        {
            "Instance Name": "ml.c7g.xlarge",
            "vCPUs": 4,
            "Memory": 8,
            "Price per Hour": 0.17,
            "Category": "Compute",
            "Architecture": "arm64",
        },
    ]
    return pd.DataFrame(data)

proximity_model(model, prox_model_name, track_columns=None)

Create a proximity model based on the given model

Parameters:

Name Type Description Default
model Model

The model to create the proximity model from

required
prox_model_name str

The name of the proximity model to create

required
track_columns list

List of columns to track in the proximity model

None

Returns: Model: The proximity model

Source code in src/workbench/utils/model_utils.py
def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
    """Create a proximity model based on the given model

    Args:
        model (Model): The model to create the proximity model from
        prox_model_name (str): The name of the proximity model to create
        track_columns (list, optional): List of columns to track in the proximity model
    Returns:
        Model: The proximity model
    """
    from workbench.api import Model, ModelType, FeatureSet  # noqa: F401 (avoid circular import)

    # Get the custom script path for the proximity model
    script_path = get_custom_script_path("proximity", "feature_space_proximity.template")

    # Get Feature and Target Columns from the existing given Model
    features = model.features()
    target = model.target()

    # Create the Proximity Model from our FeatureSet
    fs = FeatureSet(model.get_input())
    prox_model = fs.to_model(
        name=prox_model_name,
        model_type=ModelType.PROXIMITY,
        feature_list=features,
        target_column=target,
        description=f"Proximity Model for {model.uuid}",
        tags=["proximity", model.uuid],
        custom_script=script_path,
        custom_args={"track_columns": track_columns},
    )
    return prox_model

supported_instance_types(arch='x86_64')

Get the supported instance types for the Model/Model

Source code in src/workbench/utils/model_utils.py
def supported_instance_types(arch: str = "x86_64") -> list:
    """Get the supported instance types for the Model/Model"""

    # Filter the instance types based on the architecture
    info = model_instance_info()
    return info[info["Architecture"] == arch]["Instance Name"].tolist()

Examples

Feature Importance

"""Example for using some Model Utilities"""
from workbench.utils.model_utils import feature_importance

model = Model("aqsol_classification")
feature_importance(model)

Output

[('mollogp', 469.0),
 ('minabsestateindex', 277.0),
 ('peoe_vsa8', 237.0),
 ('qed', 237.0),
 ('fpdensitymorgan1', 230.0),
 ('fpdensitymorgan3', 221.0),
 ('estate_vsa4', 220.0),
 ('bcut2d_logphi', 218.0),
 ('vsa_estate5', 218.0),
 ('vsa_estate4', 209.0),

Additional Resources