Skip to content

ModelCore

API Classes

Found a method here you want to use? The API Classes have method pass-through so just call the method on the Model API Class and voilà it works the same.

ModelCore: Workbench ModelCore Class

InferenceImage

Class for retrieving locked Scikit-Learn inference images

Source code in src/workbench/core/artifacts/model_core.py
class InferenceImage:
    """Class for retrieving locked Scikit-Learn inference images"""

    image_uris = {
        ("us-east-1", "sklearn", "1.2.1"): (
            "683313688378.dkr.ecr.us-east-1.amazonaws.com/"
            "sagemaker-scikit-learn@sha256:ed242e33af079f334972acd2a7ddf74d13310d3c9a0ef3a0e9b0429ccc104dcd"
        ),
        ("us-east-2", "sklearn", "1.2.1"): (
            "257758044811.dkr.ecr.us-east-2.amazonaws.com/"
            "sagemaker-scikit-learn@sha256:ed242e33af079f334972acd2a7ddf74d13310d3c9a0ef3a0e9b0429ccc104dcd"
        ),
        ("us-west-1", "sklearn", "1.2.1"): (
            "746614075791.dkr.ecr.us-west-1.amazonaws.com/"
            "sagemaker-scikit-learn@sha256:ed242e33af079f334972acd2a7ddf74d13310d3c9a0ef3a0e9b0429ccc104dcd"
        ),
        ("us-west-2", "sklearn", "1.2.1"): (
            "246618743249.dkr.ecr.us-west-2.amazonaws.com/"
            "sagemaker-scikit-learn@sha256:ed242e33af079f334972acd2a7ddf74d13310d3c9a0ef3a0e9b0429ccc104dcd"
        ),
    }

    @classmethod
    def get_image_uri(cls, region, framework, version):
        key = (region, framework, version)
        if key in cls.image_uris:
            return cls.image_uris[key]
        else:
            raise ValueError(
                f"No matching image found for region: {region}, framework: {framework}, version: {version}"
            )

ModelCore

Bases: Artifact

ModelCore: Workbench ModelCore Class

Common Usage
my_model = ModelCore(model_uuid)
my_model.summary()
my_model.details()
Source code in src/workbench/core/artifacts/model_core.py
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
class ModelCore(Artifact):
    """ModelCore: Workbench ModelCore Class

    Common Usage:
        ```python
        my_model = ModelCore(model_uuid)
        my_model.summary()
        my_model.details()
        ```
    """

    def __init__(self, model_uuid: str, model_type: ModelType = None, **kwargs):
        """ModelCore Initialization
        Args:
            model_uuid (str): Name of Model in Workbench.
            model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
            **kwargs: Additional keyword arguments
        """

        # Make sure the model name is valid
        self.is_name_valid(model_uuid, delimiter="-", lower_case=False)

        # Call SuperClass Initialization
        super().__init__(model_uuid, **kwargs)

        # Initialize our class attributes
        self.latest_model = None
        self.model_type = ModelType.UNKNOWN
        self.model_training_path = None
        self.endpoint_inference_path = None

        # Grab an Cloud Platform Meta object and pull information for this Model
        self.model_name = model_uuid
        self.model_meta = self.meta.model(self.model_name)
        if self.model_meta is None:
            self.log.warning(f"Could not find model {self.model_name} within current visibility scope")
            return
        else:
            # Is this a model package group without any models?
            if len(self.model_meta["ModelPackageList"]) == 0:
                self.log.warning(f"Model Group {self.model_name} has no Model Packages!")
                self.latest_model = None
                self.add_health_tag("model_not_found")
                return
            try:
                self.latest_model = self.model_meta["ModelPackageList"][0]
                self.description = self.latest_model.get("ModelPackageDescription", "-")
                self.training_job_name = self._extract_training_job_name()
                if model_type:
                    self._set_model_type(model_type)
                else:
                    self.model_type = self._get_model_type()
            except (IndexError, KeyError):
                self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
                return

        # Set the Model Training S3 Path
        self.model_training_path = self.models_s3_path + "/training/" + self.model_name

        # Get our Endpoint Inference Path (might be None)
        self.endpoint_inference_path = self.get_endpoint_inference_path()

        # Call SuperClass Post Initialization
        super().__post_init__()

        # All done
        self.log.info(f"Model Initialized: {self.model_name}")

    def refresh_meta(self):
        """Refresh the Artifact's metadata"""
        self.model_meta = self.meta.model(self.model_name)
        self.latest_model = self.model_meta["ModelPackageList"][0]
        self.description = self.latest_model.get("ModelPackageDescription", "-")
        self.training_job_name = self._extract_training_job_name()

    def exists(self) -> bool:
        """Does the model metadata exist in the AWS Metadata?"""
        if self.model_meta is None:
            self.log.info(f"Model {self.model_name} not found in AWS Metadata!")
            return False
        return True

    def health_check(self) -> list[str]:
        """Perform a health check on this model
        Returns:
            list[str]: List of health issues
        """
        # Call the base class health check
        health_issues = super().health_check()

        # Check if the model exists
        if self.latest_model is None:
            health_issues.append("model_not_found")

        # Model Type
        if self._get_model_type() == ModelType.UNKNOWN:
            health_issues.append("model_type_unknown")
        else:
            self.remove_health_tag("model_type_unknown")

        # Model Performance Metrics
        needs_metrics = self.model_type in {ModelType.REGRESSOR, ModelType.QUANTILE_REGRESSOR, ModelType.CLASSIFIER}
        if needs_metrics and self.get_inference_metrics() is None:
            health_issues.append("metrics_needed")
        else:
            self.remove_health_tag("metrics_needed")

        # Endpoint
        if not self.endpoints():
            health_issues.append("no_endpoint")
        else:
            self.remove_health_tag("no_endpoint")
        return health_issues

    def latest_model_object(self) -> SagemakerModel:
        """Return the latest AWS Sagemaker Model object for this Workbench Model

        Returns:
           sagemaker.model.Model: AWS Sagemaker Model object
        """
        return SagemakerModel(
            model_data=self.model_package_arn(), sagemaker_session=self.sm_session, image_uri=self.container_image()
        )

    def list_inference_runs(self) -> list[str]:
        """List the inference runs for this model

        Returns:
            list[str]: List of inference runs
        """

        # Check if we have a model (if not return empty list)
        if self.latest_model is None:
            return []

        # Check if we have model training metrics in our metadata
        have_model_training = True if self.workbench_meta().get("workbench_training_metrics") else False

        # Now grab the list of directories from our inference path
        inference_runs = []
        if self.endpoint_inference_path:
            directories = wr.s3.list_directories(path=self.endpoint_inference_path + "/")
            inference_runs = [urlparse(directory).path.split("/")[-2] for directory in directories]

        # We're going to add the model training to the end of the list
        if have_model_training:
            inference_runs.append("model_training")
        return inference_runs

    def delete_inference_run(self, inference_run_uuid: str):
        """Delete the inference run for this model

        Args:
            inference_run_uuid (str): UUID of the inference run
        """
        if inference_run_uuid == "model_training":
            self.log.warning("Cannot delete model training data!")
            return

        if self.endpoint_inference_path:
            full_path = f"{self.endpoint_inference_path}/{inference_run_uuid}"
            # Check if there are any objects at the path
            if wr.s3.list_objects(full_path):
                wr.s3.delete_objects(path=full_path)
                self.log.important(f"Deleted inference run {inference_run_uuid} for {self.model_name}")
            else:
                self.log.warning(f"Inference run {inference_run_uuid} not found for {self.model_name}!")
        else:
            self.log.warning(f"No inference data found for {self.model_name}!")

    def get_inference_metrics(self, capture_uuid: str = "latest") -> Union[pd.DataFrame, None]:
        """Retrieve the inference performance metrics for this model

        Args:
            capture_uuid (str, optional): Specific capture_uuid or "training" (default: "latest")
        Returns:
            pd.DataFrame: DataFrame of the Model Metrics

        Note:
            If a capture_uuid isn't specified this will try to return something reasonable
        """
        # Try to get the auto_capture 'training_holdout' or the training
        if capture_uuid == "latest":
            metrics_df = self.get_inference_metrics("auto_inference")
            return metrics_df if metrics_df is not None else self.get_inference_metrics("model_training")

        # Grab the metrics captured during model training (could return None)
        if capture_uuid == "model_training":
            # Sanity check the workbench metadata
            if self.workbench_meta() is None:
                error_msg = f"Model {self.model_name} has no workbench_meta(). Either onboard() or delete this model!"
                self.log.critical(error_msg)
                raise ValueError(error_msg)

            metrics = self.workbench_meta().get("workbench_training_metrics")
            return pd.DataFrame.from_dict(metrics) if metrics else None

        else:  # Specific capture_uuid (could return None)
            s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_metrics.csv"
            metrics = pull_s3_data(s3_path, embedded_index=True)
            if metrics is not None:
                return metrics
            else:
                self.log.warning(f"Performance metrics {capture_uuid} not found for {self.model_name}!")
                return None

    def confusion_matrix(self, capture_uuid: str = "latest") -> Union[pd.DataFrame, None]:
        """Retrieve the confusion_matrix for this model

        Args:
            capture_uuid (str, optional): Specific capture_uuid or "training" (default: "latest")
        Returns:
            pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
        """

        # Sanity check the workbench metadata
        if self.workbench_meta() is None:
            error_msg = f"Model {self.model_name} has no workbench_meta(). Either onboard() or delete this model!"
            self.log.critical(error_msg)
            raise ValueError(error_msg)

        # Grab the metrics from the Workbench Metadata (try inference first, then training)
        if capture_uuid == "latest":
            cm = self.confusion_matrix("auto_inference")
            return cm if cm is not None else self.confusion_matrix("model_training")

        # Grab the confusion matrix captured during model training (could return None)
        if capture_uuid == "model_training":
            cm = self.workbench_meta().get("workbench_training_cm")
            return pd.DataFrame.from_dict(cm) if cm else None

        else:  # Specific capture_uuid
            s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_cm.csv"
            cm = pull_s3_data(s3_path, embedded_index=True)
            if cm is not None:
                return cm
            else:
                self.log.warning(f"Confusion Matrix {capture_uuid} not found for {self.model_name}!")
                return None

    def set_input(self, input: str, force: bool = False):
        """Override: Set the input data for this artifact

        Args:
            input (str): Name of input for this artifact
            force (bool, optional): Force the input to be set (default: False)
        Note:
            We're going to not allow this to be used for Models
        """
        if not force:
            self.log.warning(f"Model {self.uuid}: Does not allow manual override of the input!")
            return

        # Okay we're going to allow this to be set
        self.log.important(f"{self.uuid}: Setting input to {input}...")
        self.log.important("Be careful with this! It breaks automatic provenance of the artifact!")
        self.upsert_workbench_meta({"workbench_input": input})

    def size(self) -> float:
        """Return the size of this data in MegaBytes"""
        return 0.0

    def aws_meta(self) -> dict:
        """Get ALL the AWS metadata for this artifact"""
        return self.model_meta

    def arn(self) -> str:
        """AWS ARN (Amazon Resource Name) for the Model Package Group"""
        return self.group_arn()

    def group_arn(self) -> Union[str, None]:
        """AWS ARN (Amazon Resource Name) for the Model Package Group"""
        return self.model_meta["ModelPackageGroupArn"] if self.model_meta else None

    def model_package_arn(self) -> Union[str, None]:
        """AWS ARN (Amazon Resource Name) for the Latest Model Package (within the Group)"""
        if self.latest_model is None:
            return None
        return self.latest_model["ModelPackageArn"]

    def container_info(self) -> Union[dict, None]:
        """Container Info for the Latest Model Package"""
        return self.latest_model["InferenceSpecification"]["Containers"][0] if self.latest_model else None

    def container_image(self) -> str:
        """Container Image for the Latest Model Package"""
        return self.container_info()["Image"]

    def aws_url(self):
        """The AWS URL for looking at/querying this model"""
        return f"https://{self.aws_region}.console.aws.amazon.com/athena/home"

    def created(self) -> datetime:
        """Return the datetime when this artifact was created"""
        if self.latest_model is None:
            return "-"
        return self.latest_model["CreationTime"]

    def modified(self) -> datetime:
        """Return the datetime when this artifact was last modified"""
        if self.latest_model is None:
            return "-"
        return self.latest_model["CreationTime"]

    def hash(self) -> Optional[str]:
        """Return the hash for this artifact

        Returns:
            Optional[str]: The hash for this artifact
        """
        model_url = self.get_model_data_url()
        return compute_s3_object_hash(model_url, self.boto3_session)

    def register_endpoint(self, endpoint_name: str):
        """Add this endpoint to the set of registered endpoints for the model

        Args:
            endpoint_name (str): Name of the endpoint
        """
        self.log.important(f"Registering Endpoint {endpoint_name} with Model {self.uuid}...")
        registered_endpoints = set(self.workbench_meta().get("workbench_registered_endpoints", []))
        registered_endpoints.add(endpoint_name)
        self.upsert_workbench_meta({"workbench_registered_endpoints": list(registered_endpoints)})

        # Remove any health tags
        self.remove_health_tag("no_endpoint")

        # A new endpoint means we need to refresh our inference path
        time.sleep(2)  # Give the AWS Metadata a chance to update
        self.endpoint_inference_path = self.get_endpoint_inference_path()

    def remove_endpoint(self, endpoint_name: str):
        """Remove this endpoint from the set of registered endpoints for the model

        Args:
            endpoint_name (str): Name of the endpoint
        """
        self.log.important(f"Removing Endpoint {endpoint_name} from Model {self.uuid}...")
        registered_endpoints = set(self.workbench_meta().get("workbench_registered_endpoints", []))
        registered_endpoints.discard(endpoint_name)
        self.upsert_workbench_meta({"workbench_registered_endpoints": list(registered_endpoints)})

        # If we have NO endpionts, then set a health tags
        if not registered_endpoints:
            self.add_health_tag("no_endpoint")
            self.details(recompute=True)

        # A new endpoint means we need to refresh our inference path
        time.sleep(2)

    def endpoints(self) -> list[str]:
        """Get the list of registered endpoints for this Model

        Returns:
            list[str]: List of registered endpoints
        """
        return self.workbench_meta().get("workbench_registered_endpoints", [])

    def get_endpoint_inference_path(self) -> Union[str, None]:
        """Get the S3 Path for the Inference Data

        Returns:
            str: S3 Path for the Inference Data (or None if not found)
        """

        # Look for any Registered Endpoints
        registered_endpoints = self.workbench_meta().get("workbench_registered_endpoints")

        # Note: We may have 0 to N endpoints, so we find the one with the most recent artifacts
        if registered_endpoints:
            endpoint_inference_base = self.endpoints_s3_path + "/inference/"
            endpoint_inference_paths = [endpoint_inference_base + e for e in registered_endpoints]
            inference_path = newest_path(endpoint_inference_paths, self.sm_session)
            if inference_path is None:
                self.log.important(f"No inference data found for {self.model_name}!")
                self.log.important(f"Returning default inference path for {registered_endpoints[0]}...")
                self.log.important(f"{endpoint_inference_paths[0]}")
                return endpoint_inference_paths[0]
            else:
                return inference_path
        else:
            self.log.warning(f"No registered endpoints found for {self.model_name}!")
            return None

    def set_target(self, target_column: str):
        """Set the target for this Model

        Args:
            target_column (str): Target column for this Model
        """
        self.upsert_workbench_meta({"workbench_model_target": target_column})

    def set_features(self, feature_columns: list[str]):
        """Set the features for this Model

        Args:
            feature_columns (list[str]): List of feature columns
        """
        self.upsert_workbench_meta({"workbench_model_features": feature_columns})

    def target(self) -> Union[str, None]:
        """Return the target for this Model (if supervised, else None)

        Returns:
            str: Target column for this Model (if supervised, else None)
        """
        return self.workbench_meta().get("workbench_model_target")  # Returns None if not found

    def features(self) -> Union[list[str], None]:
        """Return a list of features used for this Model

        Returns:
            list[str]: List of features used for this Model
        """
        return self.workbench_meta().get("workbench_model_features")  # Returns None if not found

    def class_labels(self) -> Union[list[str], None]:
        """Return the class labels for this Model (if it's a classifier)

        Returns:
            list[str]: List of class labels
        """
        if self.model_type == ModelType.CLASSIFIER:
            return self.workbench_meta().get("class_labels")  # Returns None if not found
        else:
            return None

    def set_class_labels(self, labels: list[str]):
        """Return the class labels for this Model (if it's a classifier)

        Args:
            labels (list[str]): List of class labels
        """
        if self.model_type == ModelType.CLASSIFIER:
            self.upsert_workbench_meta({"class_labels": labels})
        else:
            self.log.error(f"Model {self.model_name} is not a classifier!")

    def details(self, recompute=False) -> dict:
        """Additional Details about this Model
        Args:
            recompute (bool, optional): Recompute the details (default: False)
        Returns:
            dict: Dictionary of details about this Model
        """
        self.log.info("Computing Model Details...")
        details = self.summary()
        details["pipeline"] = self.get_pipeline()
        details["model_type"] = self.model_type.value
        details["model_package_group_arn"] = self.group_arn()
        details["model_package_arn"] = self.model_package_arn()

        # Sanity check is we have models in the group
        if self.latest_model is None:
            self.log.warning(f"Model Package Group {self.model_name} has no models!")
            return details

        # Grab the Model Details
        details["description"] = self.latest_model.get("ModelPackageDescription", "-")
        details["version"] = self.latest_model["ModelPackageVersion"]
        details["status"] = self.latest_model["ModelPackageStatus"]
        details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
        details["image"] = self.container_image().split("/")[-1]  # Shorten the image uri

        # Grab the inference and container info
        inference_spec = self.latest_model["InferenceSpecification"]
        container_info = self.container_info()
        details["framework"] = container_info.get("Framework", "unknown")
        details["framework_version"] = container_info.get("FrameworkVersion", "unknown")
        details["inference_types"] = inference_spec["SupportedRealtimeInferenceInstanceTypes"]
        details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
        details["content_types"] = inference_spec["SupportedContentTypes"]
        details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
        details["model_metrics"] = self.get_inference_metrics()
        if self.model_type == ModelType.CLASSIFIER:
            details["confusion_matrix"] = self.confusion_matrix()
            details["predictions"] = None
        elif self.model_type in [ModelType.REGRESSOR, ModelType.QUANTILE_REGRESSOR]:
            details["confusion_matrix"] = None
            details["predictions"] = self.get_inference_predictions()
        else:
            details["confusion_matrix"] = None
            details["predictions"] = None

        # Grab the inference metadata
        details["inference_meta"] = self.get_inference_metadata()

        # Return the details
        return details

    # Pipeline for this model
    def get_pipeline(self) -> str:
        """Get the pipeline for this model"""
        return self.workbench_meta().get("workbench_pipeline")

    def set_pipeline(self, pipeline: str):
        """Set the pipeline for this model

        Args:
            pipeline (str): Pipeline that was used to create this model
        """
        self.upsert_workbench_meta({"workbench_pipeline": pipeline})

    def expected_meta(self) -> list[str]:
        """Metadata we expect to see for this Model when it's ready
        Returns:
            list[str]: List of expected metadata keys
        """
        # Our current list of expected metadata, we can add to this as needed
        return ["workbench_status", "workbench_training_metrics", "workbench_training_cm"]

    def is_model_unknown(self) -> bool:
        """Is the Model Type unknown?"""
        return self.model_type == ModelType.UNKNOWN

    def _determine_model_type(self):
        """Internal: Determine the Model Type"""
        model_type = input("Model Type? (classifier, regressor, quantile_regressor, unsupervised, transformer): ")
        if model_type == "classifier":
            self._set_model_type(ModelType.CLASSIFIER)
        elif model_type == "regressor":
            self._set_model_type(ModelType.REGRESSOR)
        elif model_type == "quantile_regressor":
            self._set_model_type(ModelType.QUANTILE_REGRESSOR)
        elif model_type == "unsupervised":
            self._set_model_type(ModelType.UNSUPERVISED)
        elif model_type == "transformer":
            self._set_model_type(ModelType.TRANSFORMER)
        else:
            self.log.warning(f"Unknown Model Type {model_type}!")
            self._set_model_type(ModelType.UNKNOWN)

    def onboard(self, ask_everything=False) -> bool:
        """This is an interactive method that will onboard the Model (make it ready)

        Args:
            ask_everything (bool, optional): Ask for all the details. Defaults to False.

        Returns:
            bool: True if the Model is successfully onboarded, False otherwise
        """
        # Set the status to onboarding
        self.set_status("onboarding")

        # Determine the Model Type
        while self.is_model_unknown():
            self._determine_model_type()

        # Is our input data set?
        if self.get_input() in ["", "unknown"] or ask_everything:
            input_data = input("Input Data?: ")
            if input_data not in ["None", "none", "", "unknown"]:
                self.set_input(input_data)

        # Determine the Target Column (can be None)
        target_column = self.target()
        if target_column is None or ask_everything:
            target_column = input("Target Column? (for unsupervised/transformer just type None): ")
            if target_column in ["None", "none", ""]:
                target_column = None

        # Determine the Feature Columns
        feature_columns = self.features()
        if feature_columns is None or ask_everything:
            feature_columns = input("Feature Columns? (use commas): ")
            feature_columns = [e.strip() for e in feature_columns.split(",")]
            if feature_columns in [["None"], ["none"], [""]]:
                feature_columns = None

        # Registered Endpoints?
        endpoints = self.endpoints()
        if not endpoints or ask_everything:
            endpoints = input("Register Endpoints? (use commas for multiple): ")
            endpoints = [e.strip() for e in endpoints.split(",")]
            if endpoints in [["None"], ["none"], [""]]:
                endpoints = None

        # Model Owner?
        owner = self.get_owner()
        if owner in [None, "unknown"] or ask_everything:
            owner = input("Model Owner: ")
            if owner in ["None", "none", ""]:
                owner = "unknown"

        # Model Class Labels (if it's a classifier)
        if self.model_type == ModelType.CLASSIFIER:
            class_labels = self.class_labels()
            if class_labels is None or ask_everything:
                class_labels = input("Class Labels? (use commas): ")
                class_labels = [e.strip() for e in class_labels.split(",")]
                if class_labels in [["None"], ["none"], [""]]:
                    class_labels = None
            self.set_class_labels(class_labels)

        # Now that we have all the details, let's onboard the Model with all the args
        return self.onboard_with_args(self.model_type, target_column, feature_columns, endpoints, owner)

    def onboard_with_args(
        self,
        model_type: ModelType,
        target_column: str = None,
        feature_list: list = None,
        endpoints: list = None,
        owner: str = None,
    ) -> bool:
        """Onboard the Model with the given arguments

        Args:
            model_type (ModelType): Model Type
            target_column (str): Target Column
            feature_list (list): List of Feature Columns
            endpoints (list, optional): List of Endpoints. Defaults to None.
            owner (str, optional): Model Owner. Defaults to None.
        Returns:
            bool: True if the Model is successfully onboarded, False otherwise
        """
        # Set the status to onboarding
        self.set_status("onboarding")

        # Set All the Details
        self._set_model_type(model_type)
        if target_column:
            self.set_target(target_column)
        if feature_list:
            self.set_features(feature_list)
        if endpoints:
            for endpoint in endpoints:
                self.register_endpoint(endpoint)
        if owner:
            self.set_owner(owner)

        # Load the training metrics and inference metrics
        self._load_training_metrics()
        self._load_inference_metrics()

        # Remove the needs_onboard tag
        self.remove_health_tag("needs_onboard")
        self.set_status("ready")

        # Run a health check and refresh the meta
        time.sleep(2)  # Give the AWS Metadata a chance to update
        self.health_check()
        self.refresh_meta()
        self.details(recompute=True)
        return True

    def get_model_data_url(self) -> Optional[str]:
        """Retrieve the ModelDataUrl from the model's AWS metadata.

        Returns:
            Optional[str]: The ModelDataUrl if available, otherwise None.
        """
        meta = self.aws_meta()
        try:
            return meta["ModelPackageList"][0]["InferenceSpecification"]["Containers"][0]["ModelDataUrl"]
        except (KeyError, IndexError, TypeError):
            return None

    def delete(self):
        """Delete the Model Packages and the Model Group"""
        if not self.exists():
            self.log.warning(f"Trying to delete an Model that doesn't exist: {self.uuid}")

        # Call the Class Method to delete the Model Group
        ModelCore.managed_delete(model_group_name=self.uuid)

    @classmethod
    def managed_delete(cls, model_group_name: str):
        """Delete the Model Packages, Model Group, and S3 Storage Objects

        Args:
            model_group_name (str): The name of the Model Group to delete
        """
        # Check if the model group exists in SageMaker
        try:
            cls.sm_client.describe_model_package_group(ModelPackageGroupName=model_group_name)
        except ClientError as e:
            if e.response["Error"]["Code"] in ["ValidationException", "ResourceNotFound"]:
                cls.log.info(f"Model Group {model_group_name} not found!")
                return
            else:
                raise  # Re-raise unexpected errors

        # Delete Model Packages within the Model Group
        try:
            paginator = cls.sm_client.get_paginator("list_model_packages")
            for page in paginator.paginate(ModelPackageGroupName=model_group_name):
                for model_package in page["ModelPackageSummaryList"]:
                    package_arn = model_package["ModelPackageArn"]
                    cls.log.info(f"Deleting Model Package {package_arn}...")
                    cls.sm_client.delete_model_package(ModelPackageName=package_arn)
        except ClientError as e:
            cls.log.error(f"Error while deleting model packages: {e}")
            raise

        # Delete the Model Package Group
        cls.log.info(f"Deleting Model Group {model_group_name}...")
        cls.sm_client.delete_model_package_group(ModelPackageGroupName=model_group_name)

        # Delete S3 training artifacts
        s3_delete_path = f"{cls.models_s3_path}/training/{model_group_name}/"
        cls.log.info(f"Deleting S3 Objects at {s3_delete_path}...")
        wr.s3.delete_objects(s3_delete_path, boto3_session=cls.boto3_session)

        # Delete any dataframes that were stored in the Dataframe Cache
        cls.log.info("Deleting Dataframe Cache...")
        cls.df_cache.delete_recursive(model_group_name)

    def _set_model_type(self, model_type: ModelType):
        """Internal: Set the Model Type for this Model"""
        self.model_type = model_type
        self.upsert_workbench_meta({"workbench_model_type": self.model_type.value})
        self.remove_health_tag("model_type_unknown")

    def _get_model_type(self) -> ModelType:
        """Internal: Query the Workbench Metadata to get the model type
        Returns:
            ModelType: The ModelType of this Model
        Notes:
            This is an internal method that should not be called directly
            Use the model_type attribute instead
        """
        model_type = self.workbench_meta().get("workbench_model_type")
        try:
            return ModelType(model_type)
        except ValueError:
            self.log.warning(f"Could not determine model type for {self.model_name}!")
            return ModelType.UNKNOWN

    def _load_training_metrics(self):
        """Internal: Retrieve the training metrics and Confusion Matrix for this model
                     and load the data into the Workbench Metadata

        Notes:
            This may or may not exist based on whether we have access to TrainingJobAnalytics
        """
        try:
            df = TrainingJobAnalytics(training_job_name=self.training_job_name).dataframe()
            if df.empty:
                self.log.important(f"No training job metrics found for {self.training_job_name}")
                self.upsert_workbench_meta({"workbench_training_metrics": None, "workbench_training_cm": None})
                return
            if self.model_type in [ModelType.REGRESSOR, ModelType.QUANTILE_REGRESSOR]:
                if "timestamp" in df.columns:
                    df = df.drop(columns=["timestamp"])

                # We're going to pivot the DataFrame to get the desired structure
                reg_metrics_df = df.set_index("metric_name").T

                # Store and return the metrics in the Workbench Metadata
                self.upsert_workbench_meta(
                    {"workbench_training_metrics": reg_metrics_df.to_dict(), "workbench_training_cm": None}
                )
                return

        except (KeyError, botocore.exceptions.ClientError):
            self.log.important(f"No training job metrics found for {self.training_job_name}")
            # Store and return the metrics in the Workbench Metadata
            self.upsert_workbench_meta({"workbench_training_metrics": None, "workbench_training_cm": None})
            return

        # We need additional processing for classification metrics
        if self.model_type == ModelType.CLASSIFIER:
            metrics_df, cm_df = self._process_classification_metrics(df)

            # Store and return the metrics in the Workbench Metadata
            self.upsert_workbench_meta(
                {"workbench_training_metrics": metrics_df.to_dict(), "workbench_training_cm": cm_df.to_dict()}
            )

    def _load_inference_metrics(self, capture_uuid: str = "auto_inference"):
        """Internal: Retrieve the inference model metrics for this model
                     and load the data into the Workbench Metadata

        Args:
            capture_uuid (str, optional): A specific capture_uuid (default: "auto_inference")
        Notes:
            This may or may not exist based on whether an Endpoint ran Inference
        """
        s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_metrics.csv"
        inference_metrics = pull_s3_data(s3_path)

        # Store data into the Workbench Metadata
        metrics_storage = None if inference_metrics is None else inference_metrics.to_dict("records")
        self.upsert_workbench_meta({"workbench_inference_metrics": metrics_storage})

    def get_inference_metadata(self, capture_uuid: str = "auto_inference") -> Union[pd.DataFrame, None]:
        """Retrieve the inference metadata for this model

        Args:
            capture_uuid (str, optional): A specific capture_uuid (default: "auto_inference")

        Returns:
            dict: Dictionary of the inference metadata (might be None)
        Notes:
            Basically when Endpoint inference was run, name of the dataset, the MD5, etc
        """
        # Sanity check the inference path (which may or may not exist)
        if self.endpoint_inference_path is None:
            return None

        # Check for model_training capture_uuid
        if capture_uuid == "model_training":
            # Create a DataFrame with the training metadata
            meta_df = pd.DataFrame(
                [
                    {
                        "name": "AWS Training Capture",
                        "data_hash": "N/A",
                        "num_rows": "-",
                        "description": "-",
                    }
                ]
            )
            return meta_df

        # Pull the inference metadata
        try:
            s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_meta.json"
            return wr.s3.read_json(s3_path)
        except NoFilesFound:
            self.log.info(f"Could not find model inference meta at {s3_path}...")
            return None

    def get_inference_predictions(self, capture_uuid: str = "auto_inference") -> Union[pd.DataFrame, None]:
        """Retrieve the captured prediction results for this model

        Args:
            capture_uuid (str, optional): Specific capture_uuid (default: training_holdout)

        Returns:
            pd.DataFrame: DataFrame of the Captured Predictions (might be None)
        """
        self.log.important(f"Grabbing {capture_uuid} predictions for {self.model_name}...")

        # Sanity check that the model should have predictions
        has_predictions = self.model_type in [ModelType.CLASSIFIER, ModelType.REGRESSOR, ModelType.QUANTILE_REGRESSOR]
        if not has_predictions:
            self.log.warning(f"No Predictions for {self.model_name}...")
            return None

        # Special case for model_training
        if capture_uuid == "model_training":
            return self._get_validation_predictions()

        # Construct the S3 path for the Inference Predictions
        s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_predictions.csv"
        return pull_s3_data(s3_path)

    def _get_validation_predictions(self) -> Union[pd.DataFrame, None]:
        """Internal: Retrieve the captured prediction results for this model

        Returns:
            pd.DataFrame: DataFrame of the Captured Validation Predictions (might be None)
        """
        # Sanity check the training path (which may or may not exist)
        if self.model_training_path is None:
            self.log.warning(f"No Validation Predictions for {self.model_name}...")
            return None
        self.log.important(f"Grabbing Validation Predictions for {self.model_name}...")
        s3_path = f"{self.model_training_path}/validation_predictions.csv"
        df = pull_s3_data(s3_path)
        return df

    def _extract_training_job_name(self) -> Union[str, None]:
        """Internal: Extract the training job name from the ModelDataUrl"""
        try:
            model_data_url = self.container_info()["ModelDataUrl"]
            parsed_url = urllib.parse.urlparse(model_data_url)
            training_job_name = parsed_url.path.lstrip("/").split("/")[0]
            return training_job_name
        except KeyError:
            self.log.warning(f"Could not extract training job name from {model_data_url}")
            return None

    @staticmethod
    def _process_classification_metrics(df: pd.DataFrame) -> (pd.DataFrame, pd.DataFrame):
        """Internal: Process classification metrics into a more reasonable format
        Args:
            df (pd.DataFrame): DataFrame of training metrics
        Returns:
            (pd.DataFrame, pd.DataFrame): Tuple of DataFrames. Metrics and confusion matrix
        """
        # Split into two DataFrames based on 'metric_name'
        metrics_df = df[df["metric_name"].str.startswith("Metrics:")].copy()
        cm_df = df[df["metric_name"].str.startswith("ConfusionMatrix:")].copy()

        # Split the 'metric_name' into different parts
        metrics_df["class"] = metrics_df["metric_name"].str.split(":").str[1]
        metrics_df["metric_type"] = metrics_df["metric_name"].str.split(":").str[2]

        # Pivot the DataFrame to get the desired structure
        metrics_df = metrics_df.pivot(index="class", columns="metric_type", values="value").reset_index()
        metrics_df = metrics_df.rename_axis(None, axis=1)

        # Now process the confusion matrix
        cm_df["row_class"] = cm_df["metric_name"].str.split(":").str[1]
        cm_df["col_class"] = cm_df["metric_name"].str.split(":").str[2]

        # Pivot the DataFrame to create a form suitable for the heatmap
        cm_df = cm_df.pivot(index="row_class", columns="col_class", values="value")

        # Convert the values in cm_df to integers
        cm_df = cm_df.astype(int)

        return metrics_df, cm_df

    def shapley_values(self, capture_uuid: str = "auto_inference") -> Union[list[pd.DataFrame], pd.DataFrame, None]:
        """Retrieve the Shapely values for this model

        Args:
            capture_uuid (str, optional): Specific capture_uuid (default: training_holdout)

        Returns:
            pd.DataFrame: Dataframe(s) of the shapley values or None if not found

        Notes:
            This may or may not exist based on whether an Endpoint ran Shapley
        """

        # Sanity check the inference path (which may or may not exist)
        if self.endpoint_inference_path is None:
            return None

        # Construct the S3 path for the Shapley values
        shapley_s3_path = f"{self.endpoint_inference_path}/{capture_uuid}"

        # Multiple CSV if classifier
        if self.model_type == ModelType.CLASSIFIER:
            # CSVs for shap values are indexed by prediction class
            # Because we don't know how many classes there are, we need to search through
            # a list of S3 objects in the parent folder
            s3_paths = wr.s3.list_objects(shapley_s3_path)
            return [pull_s3_data(f) for f in s3_paths if "inference_shap_values" in f]

        # One CSV if regressor
        if self.model_type in [ModelType.REGRESSOR, ModelType.QUANTILE_REGRESSOR]:
            s3_path = f"{shapley_s3_path}/inference_shap_values.csv"
            return pull_s3_data(s3_path)

__init__(model_uuid, model_type=None, **kwargs)

ModelCore Initialization Args: model_uuid (str): Name of Model in Workbench. model_type (ModelType, optional): Set this for newly created Models. Defaults to None. **kwargs: Additional keyword arguments

Source code in src/workbench/core/artifacts/model_core.py
def __init__(self, model_uuid: str, model_type: ModelType = None, **kwargs):
    """ModelCore Initialization
    Args:
        model_uuid (str): Name of Model in Workbench.
        model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
        **kwargs: Additional keyword arguments
    """

    # Make sure the model name is valid
    self.is_name_valid(model_uuid, delimiter="-", lower_case=False)

    # Call SuperClass Initialization
    super().__init__(model_uuid, **kwargs)

    # Initialize our class attributes
    self.latest_model = None
    self.model_type = ModelType.UNKNOWN
    self.model_training_path = None
    self.endpoint_inference_path = None

    # Grab an Cloud Platform Meta object and pull information for this Model
    self.model_name = model_uuid
    self.model_meta = self.meta.model(self.model_name)
    if self.model_meta is None:
        self.log.warning(f"Could not find model {self.model_name} within current visibility scope")
        return
    else:
        # Is this a model package group without any models?
        if len(self.model_meta["ModelPackageList"]) == 0:
            self.log.warning(f"Model Group {self.model_name} has no Model Packages!")
            self.latest_model = None
            self.add_health_tag("model_not_found")
            return
        try:
            self.latest_model = self.model_meta["ModelPackageList"][0]
            self.description = self.latest_model.get("ModelPackageDescription", "-")
            self.training_job_name = self._extract_training_job_name()
            if model_type:
                self._set_model_type(model_type)
            else:
                self.model_type = self._get_model_type()
        except (IndexError, KeyError):
            self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
            return

    # Set the Model Training S3 Path
    self.model_training_path = self.models_s3_path + "/training/" + self.model_name

    # Get our Endpoint Inference Path (might be None)
    self.endpoint_inference_path = self.get_endpoint_inference_path()

    # Call SuperClass Post Initialization
    super().__post_init__()

    # All done
    self.log.info(f"Model Initialized: {self.model_name}")

arn()

AWS ARN (Amazon Resource Name) for the Model Package Group

Source code in src/workbench/core/artifacts/model_core.py
def arn(self) -> str:
    """AWS ARN (Amazon Resource Name) for the Model Package Group"""
    return self.group_arn()

aws_meta()

Get ALL the AWS metadata for this artifact

Source code in src/workbench/core/artifacts/model_core.py
def aws_meta(self) -> dict:
    """Get ALL the AWS metadata for this artifact"""
    return self.model_meta

aws_url()

The AWS URL for looking at/querying this model

Source code in src/workbench/core/artifacts/model_core.py
def aws_url(self):
    """The AWS URL for looking at/querying this model"""
    return f"https://{self.aws_region}.console.aws.amazon.com/athena/home"

class_labels()

Return the class labels for this Model (if it's a classifier)

Returns:

Type Description
Union[list[str], None]

list[str]: List of class labels

Source code in src/workbench/core/artifacts/model_core.py
def class_labels(self) -> Union[list[str], None]:
    """Return the class labels for this Model (if it's a classifier)

    Returns:
        list[str]: List of class labels
    """
    if self.model_type == ModelType.CLASSIFIER:
        return self.workbench_meta().get("class_labels")  # Returns None if not found
    else:
        return None

confusion_matrix(capture_uuid='latest')

Retrieve the confusion_matrix for this model

Parameters:

Name Type Description Default
capture_uuid str

Specific capture_uuid or "training" (default: "latest")

'latest'

Returns: pd.DataFrame: DataFrame of the Confusion Matrix (might be None)

Source code in src/workbench/core/artifacts/model_core.py
def confusion_matrix(self, capture_uuid: str = "latest") -> Union[pd.DataFrame, None]:
    """Retrieve the confusion_matrix for this model

    Args:
        capture_uuid (str, optional): Specific capture_uuid or "training" (default: "latest")
    Returns:
        pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
    """

    # Sanity check the workbench metadata
    if self.workbench_meta() is None:
        error_msg = f"Model {self.model_name} has no workbench_meta(). Either onboard() or delete this model!"
        self.log.critical(error_msg)
        raise ValueError(error_msg)

    # Grab the metrics from the Workbench Metadata (try inference first, then training)
    if capture_uuid == "latest":
        cm = self.confusion_matrix("auto_inference")
        return cm if cm is not None else self.confusion_matrix("model_training")

    # Grab the confusion matrix captured during model training (could return None)
    if capture_uuid == "model_training":
        cm = self.workbench_meta().get("workbench_training_cm")
        return pd.DataFrame.from_dict(cm) if cm else None

    else:  # Specific capture_uuid
        s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_cm.csv"
        cm = pull_s3_data(s3_path, embedded_index=True)
        if cm is not None:
            return cm
        else:
            self.log.warning(f"Confusion Matrix {capture_uuid} not found for {self.model_name}!")
            return None

container_image()

Container Image for the Latest Model Package

Source code in src/workbench/core/artifacts/model_core.py
def container_image(self) -> str:
    """Container Image for the Latest Model Package"""
    return self.container_info()["Image"]

container_info()

Container Info for the Latest Model Package

Source code in src/workbench/core/artifacts/model_core.py
def container_info(self) -> Union[dict, None]:
    """Container Info for the Latest Model Package"""
    return self.latest_model["InferenceSpecification"]["Containers"][0] if self.latest_model else None

created()

Return the datetime when this artifact was created

Source code in src/workbench/core/artifacts/model_core.py
def created(self) -> datetime:
    """Return the datetime when this artifact was created"""
    if self.latest_model is None:
        return "-"
    return self.latest_model["CreationTime"]

delete()

Delete the Model Packages and the Model Group

Source code in src/workbench/core/artifacts/model_core.py
def delete(self):
    """Delete the Model Packages and the Model Group"""
    if not self.exists():
        self.log.warning(f"Trying to delete an Model that doesn't exist: {self.uuid}")

    # Call the Class Method to delete the Model Group
    ModelCore.managed_delete(model_group_name=self.uuid)

delete_inference_run(inference_run_uuid)

Delete the inference run for this model

Parameters:

Name Type Description Default
inference_run_uuid str

UUID of the inference run

required
Source code in src/workbench/core/artifacts/model_core.py
def delete_inference_run(self, inference_run_uuid: str):
    """Delete the inference run for this model

    Args:
        inference_run_uuid (str): UUID of the inference run
    """
    if inference_run_uuid == "model_training":
        self.log.warning("Cannot delete model training data!")
        return

    if self.endpoint_inference_path:
        full_path = f"{self.endpoint_inference_path}/{inference_run_uuid}"
        # Check if there are any objects at the path
        if wr.s3.list_objects(full_path):
            wr.s3.delete_objects(path=full_path)
            self.log.important(f"Deleted inference run {inference_run_uuid} for {self.model_name}")
        else:
            self.log.warning(f"Inference run {inference_run_uuid} not found for {self.model_name}!")
    else:
        self.log.warning(f"No inference data found for {self.model_name}!")

details(recompute=False)

Additional Details about this Model Args: recompute (bool, optional): Recompute the details (default: False) Returns: dict: Dictionary of details about this Model

Source code in src/workbench/core/artifacts/model_core.py
def details(self, recompute=False) -> dict:
    """Additional Details about this Model
    Args:
        recompute (bool, optional): Recompute the details (default: False)
    Returns:
        dict: Dictionary of details about this Model
    """
    self.log.info("Computing Model Details...")
    details = self.summary()
    details["pipeline"] = self.get_pipeline()
    details["model_type"] = self.model_type.value
    details["model_package_group_arn"] = self.group_arn()
    details["model_package_arn"] = self.model_package_arn()

    # Sanity check is we have models in the group
    if self.latest_model is None:
        self.log.warning(f"Model Package Group {self.model_name} has no models!")
        return details

    # Grab the Model Details
    details["description"] = self.latest_model.get("ModelPackageDescription", "-")
    details["version"] = self.latest_model["ModelPackageVersion"]
    details["status"] = self.latest_model["ModelPackageStatus"]
    details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
    details["image"] = self.container_image().split("/")[-1]  # Shorten the image uri

    # Grab the inference and container info
    inference_spec = self.latest_model["InferenceSpecification"]
    container_info = self.container_info()
    details["framework"] = container_info.get("Framework", "unknown")
    details["framework_version"] = container_info.get("FrameworkVersion", "unknown")
    details["inference_types"] = inference_spec["SupportedRealtimeInferenceInstanceTypes"]
    details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
    details["content_types"] = inference_spec["SupportedContentTypes"]
    details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
    details["model_metrics"] = self.get_inference_metrics()
    if self.model_type == ModelType.CLASSIFIER:
        details["confusion_matrix"] = self.confusion_matrix()
        details["predictions"] = None
    elif self.model_type in [ModelType.REGRESSOR, ModelType.QUANTILE_REGRESSOR]:
        details["confusion_matrix"] = None
        details["predictions"] = self.get_inference_predictions()
    else:
        details["confusion_matrix"] = None
        details["predictions"] = None

    # Grab the inference metadata
    details["inference_meta"] = self.get_inference_metadata()

    # Return the details
    return details

endpoints()

Get the list of registered endpoints for this Model

Returns:

Type Description
list[str]

list[str]: List of registered endpoints

Source code in src/workbench/core/artifacts/model_core.py
def endpoints(self) -> list[str]:
    """Get the list of registered endpoints for this Model

    Returns:
        list[str]: List of registered endpoints
    """
    return self.workbench_meta().get("workbench_registered_endpoints", [])

exists()

Does the model metadata exist in the AWS Metadata?

Source code in src/workbench/core/artifacts/model_core.py
def exists(self) -> bool:
    """Does the model metadata exist in the AWS Metadata?"""
    if self.model_meta is None:
        self.log.info(f"Model {self.model_name} not found in AWS Metadata!")
        return False
    return True

expected_meta()

Metadata we expect to see for this Model when it's ready Returns: list[str]: List of expected metadata keys

Source code in src/workbench/core/artifacts/model_core.py
def expected_meta(self) -> list[str]:
    """Metadata we expect to see for this Model when it's ready
    Returns:
        list[str]: List of expected metadata keys
    """
    # Our current list of expected metadata, we can add to this as needed
    return ["workbench_status", "workbench_training_metrics", "workbench_training_cm"]

features()

Return a list of features used for this Model

Returns:

Type Description
Union[list[str], None]

list[str]: List of features used for this Model

Source code in src/workbench/core/artifacts/model_core.py
def features(self) -> Union[list[str], None]:
    """Return a list of features used for this Model

    Returns:
        list[str]: List of features used for this Model
    """
    return self.workbench_meta().get("workbench_model_features")  # Returns None if not found

get_endpoint_inference_path()

Get the S3 Path for the Inference Data

Returns:

Name Type Description
str Union[str, None]

S3 Path for the Inference Data (or None if not found)

Source code in src/workbench/core/artifacts/model_core.py
def get_endpoint_inference_path(self) -> Union[str, None]:
    """Get the S3 Path for the Inference Data

    Returns:
        str: S3 Path for the Inference Data (or None if not found)
    """

    # Look for any Registered Endpoints
    registered_endpoints = self.workbench_meta().get("workbench_registered_endpoints")

    # Note: We may have 0 to N endpoints, so we find the one with the most recent artifacts
    if registered_endpoints:
        endpoint_inference_base = self.endpoints_s3_path + "/inference/"
        endpoint_inference_paths = [endpoint_inference_base + e for e in registered_endpoints]
        inference_path = newest_path(endpoint_inference_paths, self.sm_session)
        if inference_path is None:
            self.log.important(f"No inference data found for {self.model_name}!")
            self.log.important(f"Returning default inference path for {registered_endpoints[0]}...")
            self.log.important(f"{endpoint_inference_paths[0]}")
            return endpoint_inference_paths[0]
        else:
            return inference_path
    else:
        self.log.warning(f"No registered endpoints found for {self.model_name}!")
        return None

get_inference_metadata(capture_uuid='auto_inference')

Retrieve the inference metadata for this model

Parameters:

Name Type Description Default
capture_uuid str

A specific capture_uuid (default: "auto_inference")

'auto_inference'

Returns:

Name Type Description
dict Union[DataFrame, None]

Dictionary of the inference metadata (might be None)

Notes: Basically when Endpoint inference was run, name of the dataset, the MD5, etc

Source code in src/workbench/core/artifacts/model_core.py
def get_inference_metadata(self, capture_uuid: str = "auto_inference") -> Union[pd.DataFrame, None]:
    """Retrieve the inference metadata for this model

    Args:
        capture_uuid (str, optional): A specific capture_uuid (default: "auto_inference")

    Returns:
        dict: Dictionary of the inference metadata (might be None)
    Notes:
        Basically when Endpoint inference was run, name of the dataset, the MD5, etc
    """
    # Sanity check the inference path (which may or may not exist)
    if self.endpoint_inference_path is None:
        return None

    # Check for model_training capture_uuid
    if capture_uuid == "model_training":
        # Create a DataFrame with the training metadata
        meta_df = pd.DataFrame(
            [
                {
                    "name": "AWS Training Capture",
                    "data_hash": "N/A",
                    "num_rows": "-",
                    "description": "-",
                }
            ]
        )
        return meta_df

    # Pull the inference metadata
    try:
        s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_meta.json"
        return wr.s3.read_json(s3_path)
    except NoFilesFound:
        self.log.info(f"Could not find model inference meta at {s3_path}...")
        return None

get_inference_metrics(capture_uuid='latest')

Retrieve the inference performance metrics for this model

Parameters:

Name Type Description Default
capture_uuid str

Specific capture_uuid or "training" (default: "latest")

'latest'

Returns: pd.DataFrame: DataFrame of the Model Metrics

Note

If a capture_uuid isn't specified this will try to return something reasonable

Source code in src/workbench/core/artifacts/model_core.py
def get_inference_metrics(self, capture_uuid: str = "latest") -> Union[pd.DataFrame, None]:
    """Retrieve the inference performance metrics for this model

    Args:
        capture_uuid (str, optional): Specific capture_uuid or "training" (default: "latest")
    Returns:
        pd.DataFrame: DataFrame of the Model Metrics

    Note:
        If a capture_uuid isn't specified this will try to return something reasonable
    """
    # Try to get the auto_capture 'training_holdout' or the training
    if capture_uuid == "latest":
        metrics_df = self.get_inference_metrics("auto_inference")
        return metrics_df if metrics_df is not None else self.get_inference_metrics("model_training")

    # Grab the metrics captured during model training (could return None)
    if capture_uuid == "model_training":
        # Sanity check the workbench metadata
        if self.workbench_meta() is None:
            error_msg = f"Model {self.model_name} has no workbench_meta(). Either onboard() or delete this model!"
            self.log.critical(error_msg)
            raise ValueError(error_msg)

        metrics = self.workbench_meta().get("workbench_training_metrics")
        return pd.DataFrame.from_dict(metrics) if metrics else None

    else:  # Specific capture_uuid (could return None)
        s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_metrics.csv"
        metrics = pull_s3_data(s3_path, embedded_index=True)
        if metrics is not None:
            return metrics
        else:
            self.log.warning(f"Performance metrics {capture_uuid} not found for {self.model_name}!")
            return None

get_inference_predictions(capture_uuid='auto_inference')

Retrieve the captured prediction results for this model

Parameters:

Name Type Description Default
capture_uuid str

Specific capture_uuid (default: training_holdout)

'auto_inference'

Returns:

Type Description
Union[DataFrame, None]

pd.DataFrame: DataFrame of the Captured Predictions (might be None)

Source code in src/workbench/core/artifacts/model_core.py
def get_inference_predictions(self, capture_uuid: str = "auto_inference") -> Union[pd.DataFrame, None]:
    """Retrieve the captured prediction results for this model

    Args:
        capture_uuid (str, optional): Specific capture_uuid (default: training_holdout)

    Returns:
        pd.DataFrame: DataFrame of the Captured Predictions (might be None)
    """
    self.log.important(f"Grabbing {capture_uuid} predictions for {self.model_name}...")

    # Sanity check that the model should have predictions
    has_predictions = self.model_type in [ModelType.CLASSIFIER, ModelType.REGRESSOR, ModelType.QUANTILE_REGRESSOR]
    if not has_predictions:
        self.log.warning(f"No Predictions for {self.model_name}...")
        return None

    # Special case for model_training
    if capture_uuid == "model_training":
        return self._get_validation_predictions()

    # Construct the S3 path for the Inference Predictions
    s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_predictions.csv"
    return pull_s3_data(s3_path)

get_model_data_url()

Retrieve the ModelDataUrl from the model's AWS metadata.

Returns:

Type Description
Optional[str]

Optional[str]: The ModelDataUrl if available, otherwise None.

Source code in src/workbench/core/artifacts/model_core.py
def get_model_data_url(self) -> Optional[str]:
    """Retrieve the ModelDataUrl from the model's AWS metadata.

    Returns:
        Optional[str]: The ModelDataUrl if available, otherwise None.
    """
    meta = self.aws_meta()
    try:
        return meta["ModelPackageList"][0]["InferenceSpecification"]["Containers"][0]["ModelDataUrl"]
    except (KeyError, IndexError, TypeError):
        return None

get_pipeline()

Get the pipeline for this model

Source code in src/workbench/core/artifacts/model_core.py
def get_pipeline(self) -> str:
    """Get the pipeline for this model"""
    return self.workbench_meta().get("workbench_pipeline")

group_arn()

AWS ARN (Amazon Resource Name) for the Model Package Group

Source code in src/workbench/core/artifacts/model_core.py
def group_arn(self) -> Union[str, None]:
    """AWS ARN (Amazon Resource Name) for the Model Package Group"""
    return self.model_meta["ModelPackageGroupArn"] if self.model_meta else None

hash()

Return the hash for this artifact

Returns:

Type Description
Optional[str]

Optional[str]: The hash for this artifact

Source code in src/workbench/core/artifacts/model_core.py
def hash(self) -> Optional[str]:
    """Return the hash for this artifact

    Returns:
        Optional[str]: The hash for this artifact
    """
    model_url = self.get_model_data_url()
    return compute_s3_object_hash(model_url, self.boto3_session)

health_check()

Perform a health check on this model Returns: list[str]: List of health issues

Source code in src/workbench/core/artifacts/model_core.py
def health_check(self) -> list[str]:
    """Perform a health check on this model
    Returns:
        list[str]: List of health issues
    """
    # Call the base class health check
    health_issues = super().health_check()

    # Check if the model exists
    if self.latest_model is None:
        health_issues.append("model_not_found")

    # Model Type
    if self._get_model_type() == ModelType.UNKNOWN:
        health_issues.append("model_type_unknown")
    else:
        self.remove_health_tag("model_type_unknown")

    # Model Performance Metrics
    needs_metrics = self.model_type in {ModelType.REGRESSOR, ModelType.QUANTILE_REGRESSOR, ModelType.CLASSIFIER}
    if needs_metrics and self.get_inference_metrics() is None:
        health_issues.append("metrics_needed")
    else:
        self.remove_health_tag("metrics_needed")

    # Endpoint
    if not self.endpoints():
        health_issues.append("no_endpoint")
    else:
        self.remove_health_tag("no_endpoint")
    return health_issues

is_model_unknown()

Is the Model Type unknown?

Source code in src/workbench/core/artifacts/model_core.py
def is_model_unknown(self) -> bool:
    """Is the Model Type unknown?"""
    return self.model_type == ModelType.UNKNOWN

latest_model_object()

Return the latest AWS Sagemaker Model object for this Workbench Model

Returns:

Type Description
Model

sagemaker.model.Model: AWS Sagemaker Model object

Source code in src/workbench/core/artifacts/model_core.py
def latest_model_object(self) -> SagemakerModel:
    """Return the latest AWS Sagemaker Model object for this Workbench Model

    Returns:
       sagemaker.model.Model: AWS Sagemaker Model object
    """
    return SagemakerModel(
        model_data=self.model_package_arn(), sagemaker_session=self.sm_session, image_uri=self.container_image()
    )

list_inference_runs()

List the inference runs for this model

Returns:

Type Description
list[str]

list[str]: List of inference runs

Source code in src/workbench/core/artifacts/model_core.py
def list_inference_runs(self) -> list[str]:
    """List the inference runs for this model

    Returns:
        list[str]: List of inference runs
    """

    # Check if we have a model (if not return empty list)
    if self.latest_model is None:
        return []

    # Check if we have model training metrics in our metadata
    have_model_training = True if self.workbench_meta().get("workbench_training_metrics") else False

    # Now grab the list of directories from our inference path
    inference_runs = []
    if self.endpoint_inference_path:
        directories = wr.s3.list_directories(path=self.endpoint_inference_path + "/")
        inference_runs = [urlparse(directory).path.split("/")[-2] for directory in directories]

    # We're going to add the model training to the end of the list
    if have_model_training:
        inference_runs.append("model_training")
    return inference_runs

managed_delete(model_group_name) classmethod

Delete the Model Packages, Model Group, and S3 Storage Objects

Parameters:

Name Type Description Default
model_group_name str

The name of the Model Group to delete

required
Source code in src/workbench/core/artifacts/model_core.py
@classmethod
def managed_delete(cls, model_group_name: str):
    """Delete the Model Packages, Model Group, and S3 Storage Objects

    Args:
        model_group_name (str): The name of the Model Group to delete
    """
    # Check if the model group exists in SageMaker
    try:
        cls.sm_client.describe_model_package_group(ModelPackageGroupName=model_group_name)
    except ClientError as e:
        if e.response["Error"]["Code"] in ["ValidationException", "ResourceNotFound"]:
            cls.log.info(f"Model Group {model_group_name} not found!")
            return
        else:
            raise  # Re-raise unexpected errors

    # Delete Model Packages within the Model Group
    try:
        paginator = cls.sm_client.get_paginator("list_model_packages")
        for page in paginator.paginate(ModelPackageGroupName=model_group_name):
            for model_package in page["ModelPackageSummaryList"]:
                package_arn = model_package["ModelPackageArn"]
                cls.log.info(f"Deleting Model Package {package_arn}...")
                cls.sm_client.delete_model_package(ModelPackageName=package_arn)
    except ClientError as e:
        cls.log.error(f"Error while deleting model packages: {e}")
        raise

    # Delete the Model Package Group
    cls.log.info(f"Deleting Model Group {model_group_name}...")
    cls.sm_client.delete_model_package_group(ModelPackageGroupName=model_group_name)

    # Delete S3 training artifacts
    s3_delete_path = f"{cls.models_s3_path}/training/{model_group_name}/"
    cls.log.info(f"Deleting S3 Objects at {s3_delete_path}...")
    wr.s3.delete_objects(s3_delete_path, boto3_session=cls.boto3_session)

    # Delete any dataframes that were stored in the Dataframe Cache
    cls.log.info("Deleting Dataframe Cache...")
    cls.df_cache.delete_recursive(model_group_name)

model_package_arn()

AWS ARN (Amazon Resource Name) for the Latest Model Package (within the Group)

Source code in src/workbench/core/artifacts/model_core.py
def model_package_arn(self) -> Union[str, None]:
    """AWS ARN (Amazon Resource Name) for the Latest Model Package (within the Group)"""
    if self.latest_model is None:
        return None
    return self.latest_model["ModelPackageArn"]

modified()

Return the datetime when this artifact was last modified

Source code in src/workbench/core/artifacts/model_core.py
def modified(self) -> datetime:
    """Return the datetime when this artifact was last modified"""
    if self.latest_model is None:
        return "-"
    return self.latest_model["CreationTime"]

onboard(ask_everything=False)

This is an interactive method that will onboard the Model (make it ready)

Parameters:

Name Type Description Default
ask_everything bool

Ask for all the details. Defaults to False.

False

Returns:

Name Type Description
bool bool

True if the Model is successfully onboarded, False otherwise

Source code in src/workbench/core/artifacts/model_core.py
def onboard(self, ask_everything=False) -> bool:
    """This is an interactive method that will onboard the Model (make it ready)

    Args:
        ask_everything (bool, optional): Ask for all the details. Defaults to False.

    Returns:
        bool: True if the Model is successfully onboarded, False otherwise
    """
    # Set the status to onboarding
    self.set_status("onboarding")

    # Determine the Model Type
    while self.is_model_unknown():
        self._determine_model_type()

    # Is our input data set?
    if self.get_input() in ["", "unknown"] or ask_everything:
        input_data = input("Input Data?: ")
        if input_data not in ["None", "none", "", "unknown"]:
            self.set_input(input_data)

    # Determine the Target Column (can be None)
    target_column = self.target()
    if target_column is None or ask_everything:
        target_column = input("Target Column? (for unsupervised/transformer just type None): ")
        if target_column in ["None", "none", ""]:
            target_column = None

    # Determine the Feature Columns
    feature_columns = self.features()
    if feature_columns is None or ask_everything:
        feature_columns = input("Feature Columns? (use commas): ")
        feature_columns = [e.strip() for e in feature_columns.split(",")]
        if feature_columns in [["None"], ["none"], [""]]:
            feature_columns = None

    # Registered Endpoints?
    endpoints = self.endpoints()
    if not endpoints or ask_everything:
        endpoints = input("Register Endpoints? (use commas for multiple): ")
        endpoints = [e.strip() for e in endpoints.split(",")]
        if endpoints in [["None"], ["none"], [""]]:
            endpoints = None

    # Model Owner?
    owner = self.get_owner()
    if owner in [None, "unknown"] or ask_everything:
        owner = input("Model Owner: ")
        if owner in ["None", "none", ""]:
            owner = "unknown"

    # Model Class Labels (if it's a classifier)
    if self.model_type == ModelType.CLASSIFIER:
        class_labels = self.class_labels()
        if class_labels is None or ask_everything:
            class_labels = input("Class Labels? (use commas): ")
            class_labels = [e.strip() for e in class_labels.split(",")]
            if class_labels in [["None"], ["none"], [""]]:
                class_labels = None
        self.set_class_labels(class_labels)

    # Now that we have all the details, let's onboard the Model with all the args
    return self.onboard_with_args(self.model_type, target_column, feature_columns, endpoints, owner)

onboard_with_args(model_type, target_column=None, feature_list=None, endpoints=None, owner=None)

Onboard the Model with the given arguments

Parameters:

Name Type Description Default
model_type ModelType

Model Type

required
target_column str

Target Column

None
feature_list list

List of Feature Columns

None
endpoints list

List of Endpoints. Defaults to None.

None
owner str

Model Owner. Defaults to None.

None

Returns: bool: True if the Model is successfully onboarded, False otherwise

Source code in src/workbench/core/artifacts/model_core.py
def onboard_with_args(
    self,
    model_type: ModelType,
    target_column: str = None,
    feature_list: list = None,
    endpoints: list = None,
    owner: str = None,
) -> bool:
    """Onboard the Model with the given arguments

    Args:
        model_type (ModelType): Model Type
        target_column (str): Target Column
        feature_list (list): List of Feature Columns
        endpoints (list, optional): List of Endpoints. Defaults to None.
        owner (str, optional): Model Owner. Defaults to None.
    Returns:
        bool: True if the Model is successfully onboarded, False otherwise
    """
    # Set the status to onboarding
    self.set_status("onboarding")

    # Set All the Details
    self._set_model_type(model_type)
    if target_column:
        self.set_target(target_column)
    if feature_list:
        self.set_features(feature_list)
    if endpoints:
        for endpoint in endpoints:
            self.register_endpoint(endpoint)
    if owner:
        self.set_owner(owner)

    # Load the training metrics and inference metrics
    self._load_training_metrics()
    self._load_inference_metrics()

    # Remove the needs_onboard tag
    self.remove_health_tag("needs_onboard")
    self.set_status("ready")

    # Run a health check and refresh the meta
    time.sleep(2)  # Give the AWS Metadata a chance to update
    self.health_check()
    self.refresh_meta()
    self.details(recompute=True)
    return True

refresh_meta()

Refresh the Artifact's metadata

Source code in src/workbench/core/artifacts/model_core.py
def refresh_meta(self):
    """Refresh the Artifact's metadata"""
    self.model_meta = self.meta.model(self.model_name)
    self.latest_model = self.model_meta["ModelPackageList"][0]
    self.description = self.latest_model.get("ModelPackageDescription", "-")
    self.training_job_name = self._extract_training_job_name()

register_endpoint(endpoint_name)

Add this endpoint to the set of registered endpoints for the model

Parameters:

Name Type Description Default
endpoint_name str

Name of the endpoint

required
Source code in src/workbench/core/artifacts/model_core.py
def register_endpoint(self, endpoint_name: str):
    """Add this endpoint to the set of registered endpoints for the model

    Args:
        endpoint_name (str): Name of the endpoint
    """
    self.log.important(f"Registering Endpoint {endpoint_name} with Model {self.uuid}...")
    registered_endpoints = set(self.workbench_meta().get("workbench_registered_endpoints", []))
    registered_endpoints.add(endpoint_name)
    self.upsert_workbench_meta({"workbench_registered_endpoints": list(registered_endpoints)})

    # Remove any health tags
    self.remove_health_tag("no_endpoint")

    # A new endpoint means we need to refresh our inference path
    time.sleep(2)  # Give the AWS Metadata a chance to update
    self.endpoint_inference_path = self.get_endpoint_inference_path()

remove_endpoint(endpoint_name)

Remove this endpoint from the set of registered endpoints for the model

Parameters:

Name Type Description Default
endpoint_name str

Name of the endpoint

required
Source code in src/workbench/core/artifacts/model_core.py
def remove_endpoint(self, endpoint_name: str):
    """Remove this endpoint from the set of registered endpoints for the model

    Args:
        endpoint_name (str): Name of the endpoint
    """
    self.log.important(f"Removing Endpoint {endpoint_name} from Model {self.uuid}...")
    registered_endpoints = set(self.workbench_meta().get("workbench_registered_endpoints", []))
    registered_endpoints.discard(endpoint_name)
    self.upsert_workbench_meta({"workbench_registered_endpoints": list(registered_endpoints)})

    # If we have NO endpionts, then set a health tags
    if not registered_endpoints:
        self.add_health_tag("no_endpoint")
        self.details(recompute=True)

    # A new endpoint means we need to refresh our inference path
    time.sleep(2)

set_class_labels(labels)

Return the class labels for this Model (if it's a classifier)

Parameters:

Name Type Description Default
labels list[str]

List of class labels

required
Source code in src/workbench/core/artifacts/model_core.py
def set_class_labels(self, labels: list[str]):
    """Return the class labels for this Model (if it's a classifier)

    Args:
        labels (list[str]): List of class labels
    """
    if self.model_type == ModelType.CLASSIFIER:
        self.upsert_workbench_meta({"class_labels": labels})
    else:
        self.log.error(f"Model {self.model_name} is not a classifier!")

set_features(feature_columns)

Set the features for this Model

Parameters:

Name Type Description Default
feature_columns list[str]

List of feature columns

required
Source code in src/workbench/core/artifacts/model_core.py
def set_features(self, feature_columns: list[str]):
    """Set the features for this Model

    Args:
        feature_columns (list[str]): List of feature columns
    """
    self.upsert_workbench_meta({"workbench_model_features": feature_columns})

set_input(input, force=False)

Override: Set the input data for this artifact

Parameters:

Name Type Description Default
input str

Name of input for this artifact

required
force bool

Force the input to be set (default: False)

False

Note: We're going to not allow this to be used for Models

Source code in src/workbench/core/artifacts/model_core.py
def set_input(self, input: str, force: bool = False):
    """Override: Set the input data for this artifact

    Args:
        input (str): Name of input for this artifact
        force (bool, optional): Force the input to be set (default: False)
    Note:
        We're going to not allow this to be used for Models
    """
    if not force:
        self.log.warning(f"Model {self.uuid}: Does not allow manual override of the input!")
        return

    # Okay we're going to allow this to be set
    self.log.important(f"{self.uuid}: Setting input to {input}...")
    self.log.important("Be careful with this! It breaks automatic provenance of the artifact!")
    self.upsert_workbench_meta({"workbench_input": input})

set_pipeline(pipeline)

Set the pipeline for this model

Parameters:

Name Type Description Default
pipeline str

Pipeline that was used to create this model

required
Source code in src/workbench/core/artifacts/model_core.py
def set_pipeline(self, pipeline: str):
    """Set the pipeline for this model

    Args:
        pipeline (str): Pipeline that was used to create this model
    """
    self.upsert_workbench_meta({"workbench_pipeline": pipeline})

set_target(target_column)

Set the target for this Model

Parameters:

Name Type Description Default
target_column str

Target column for this Model

required
Source code in src/workbench/core/artifacts/model_core.py
def set_target(self, target_column: str):
    """Set the target for this Model

    Args:
        target_column (str): Target column for this Model
    """
    self.upsert_workbench_meta({"workbench_model_target": target_column})

shapley_values(capture_uuid='auto_inference')

Retrieve the Shapely values for this model

Parameters:

Name Type Description Default
capture_uuid str

Specific capture_uuid (default: training_holdout)

'auto_inference'

Returns:

Type Description
Union[list[DataFrame], DataFrame, None]

pd.DataFrame: Dataframe(s) of the shapley values or None if not found

Notes

This may or may not exist based on whether an Endpoint ran Shapley

Source code in src/workbench/core/artifacts/model_core.py
def shapley_values(self, capture_uuid: str = "auto_inference") -> Union[list[pd.DataFrame], pd.DataFrame, None]:
    """Retrieve the Shapely values for this model

    Args:
        capture_uuid (str, optional): Specific capture_uuid (default: training_holdout)

    Returns:
        pd.DataFrame: Dataframe(s) of the shapley values or None if not found

    Notes:
        This may or may not exist based on whether an Endpoint ran Shapley
    """

    # Sanity check the inference path (which may or may not exist)
    if self.endpoint_inference_path is None:
        return None

    # Construct the S3 path for the Shapley values
    shapley_s3_path = f"{self.endpoint_inference_path}/{capture_uuid}"

    # Multiple CSV if classifier
    if self.model_type == ModelType.CLASSIFIER:
        # CSVs for shap values are indexed by prediction class
        # Because we don't know how many classes there are, we need to search through
        # a list of S3 objects in the parent folder
        s3_paths = wr.s3.list_objects(shapley_s3_path)
        return [pull_s3_data(f) for f in s3_paths if "inference_shap_values" in f]

    # One CSV if regressor
    if self.model_type in [ModelType.REGRESSOR, ModelType.QUANTILE_REGRESSOR]:
        s3_path = f"{shapley_s3_path}/inference_shap_values.csv"
        return pull_s3_data(s3_path)

size()

Return the size of this data in MegaBytes

Source code in src/workbench/core/artifacts/model_core.py
def size(self) -> float:
    """Return the size of this data in MegaBytes"""
    return 0.0

target()

Return the target for this Model (if supervised, else None)

Returns:

Name Type Description
str Union[str, None]

Target column for this Model (if supervised, else None)

Source code in src/workbench/core/artifacts/model_core.py
def target(self) -> Union[str, None]:
    """Return the target for this Model (if supervised, else None)

    Returns:
        str: Target column for this Model (if supervised, else None)
    """
    return self.workbench_meta().get("workbench_model_target")  # Returns None if not found

ModelType

Bases: Enum

Enumerated Types for Workbench Model Types

Source code in src/workbench/core/artifacts/model_core.py
class ModelType(Enum):
    """Enumerated Types for Workbench Model Types"""

    CLASSIFIER = "classifier"
    REGRESSOR = "regressor"
    CLUSTERER = "clusterer"
    TRANSFORMER = "transformer"
    PROJECTION = "projection"
    UNSUPERVISED = "unsupervised"
    QUANTILE_REGRESSOR = "quantile_regressor"
    DETECTOR = "detector"
    UNKNOWN = "unknown"