Skip to content

AthenaSource

API Classes

Found a method here you want to use? The API Classes have method pass-through so just call the method on the DataSource API Class and voilĂ  it works the same.

AthenaSource: SageWorks Data Source accessible through Athena

AthenaSource

Bases: DataSourceAbstract

AthenaSource: SageWorks Data Source accessible through Athena

Common Usage
my_data = AthenaSource(data_uuid, database="sageworks")
my_data.summary()
my_data.details()
df = my_data.query(f"select * from {data_uuid} limit 5")
Source code in src/sageworks/core/artifacts/athena_source.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
class AthenaSource(DataSourceAbstract):
    """AthenaSource: SageWorks Data Source accessible through Athena

    Common Usage:
        ```
        my_data = AthenaSource(data_uuid, database="sageworks")
        my_data.summary()
        my_data.details()
        df = my_data.query(f"select * from {data_uuid} limit 5")
        ```
    """

    def __init__(self, data_uuid, database="sageworks", force_refresh: bool = False):
        """AthenaSource Initialization

        Args:
            data_uuid (str): Name of Athena Table
            database (str): Athena Database Name (default: sageworks)
            force_refresh (bool): Force refresh of AWS Metadata (default: False)
        """
        # Ensure the data_uuid is a valid name/id
        self.ensure_valid_name(data_uuid)

        # Call superclass init
        super().__init__(data_uuid, database)

        # Flag for metadata cache refresh logic
        self.metadata_refresh_needed = False

        # Setup our AWS Metadata Broker
        self.catalog_table_meta = self.meta_broker.data_source_details(
            data_uuid, self.get_database(), refresh=force_refresh
        )
        if self.catalog_table_meta is None:
            self.log.error(f"Unable to find {self.get_database()}:{self.table} in Glue Catalogs...")

        # Call superclass post init
        super().__post_init__()

        # All done
        self.log.debug(f"AthenaSource Initialized: {self.get_database()}.{self.table}")

    def refresh_meta(self):
        """Refresh our internal AWS Broker catalog metadata"""
        _catalog_meta = self.aws_broker.get_metadata(ServiceCategory.DATA_CATALOG, force_refresh=True)
        self.catalog_table_meta = _catalog_meta[self.get_database()].get(self.table)
        self.metadata_refresh_needed = False

    def exists(self) -> bool:
        """Validation Checks for this Data Source"""

        # Are we able to pull AWS Metadata for this table_name?"""
        # Do we have a valid catalog_table_meta?
        if getattr(self, "catalog_table_meta", None) is None:
            self.log.debug(f"AthenaSource {self.table} not found in SageWorks Metadata...")
            return False
        return True

    def arn(self) -> str:
        """AWS ARN (Amazon Resource Name) for this artifact"""
        # Grab our SageWorks Role Manager, get our AWS account id, and region for ARN creation
        account_id = self.aws_account_clamp.account_id
        region = self.aws_account_clamp.region
        arn = f"arn:aws:glue:{region}:{account_id}:table/{self.get_database()}/{self.table}"
        return arn

    def sageworks_meta(self) -> dict:
        """Get the SageWorks specific metadata for this Artifact"""

        # Sanity Check if we have invalid AWS Metadata
        self.log.info(f"Retrieving SageWorks Metadata for Artifact: {self.uuid}...")
        if self.catalog_table_meta is None:
            if not self.exists():
                self.log.error(f"DataSource {self.uuid} doesn't appear to exist...")
            else:
                self.log.critical(f"Unable to get AWS Metadata for {self.table}")
                self.log.critical("Malformed Artifact! Delete this Artifact and recreate it!")
            return {}

        # Check if we need to refresh our metadata
        if self.metadata_refresh_needed:
            self.refresh_meta()

        # Get the SageWorks Metadata from the Catalog Table Metadata
        return sageworks_meta_from_catalog_table_meta(self.catalog_table_meta)

    def upsert_sageworks_meta(self, new_meta: dict):
        """Add SageWorks specific metadata to this Artifact

        Args:
            new_meta (dict): Dictionary of new metadata to add
        """

        # Give a warning message for keys that don't start with sageworks_
        for key in new_meta.keys():
            if not key.startswith("sageworks_"):
                self.log.warning("Append 'sageworks_' to key names to avoid overwriting AWS meta data")

        # Now convert any non-string values to JSON strings
        for key, value in new_meta.items():
            if not isinstance(value, str):
                new_meta[key] = json.dumps(value, cls=CustomEncoder)

        # Store our updated metadata
        try:
            wr.catalog.upsert_table_parameters(
                parameters=new_meta,
                database=self.get_database(),
                table=self.table,
                boto3_session=self.boto3_session,
            )
            self.metadata_refresh_needed = True
        except botocore.exceptions.ClientError as e:
            error_code = e.response["Error"]["Code"]
            if error_code == "InvalidInputException":
                self.log.error(f"Unable to upsert metadata for {self.table}")
                self.log.error("Probably because the metadata is too large")
                self.log.error(new_meta)
            elif error_code == "ConcurrentModificationException":
                self.log.warning("ConcurrentModificationException... trying again...")
                time.sleep(5)
                wr.catalog.upsert_table_parameters(
                    parameters=new_meta,
                    database=self.get_database(),
                    table=self.table,
                    boto3_session=self.boto3_session,
                )
            else:
                self.log.critical(f"Failed to upsert metadata: {e}")
                self.log.critical(f"{self.uuid} is Malformed! Delete this Artifact and recreate it!")
        except Exception as e:
            self.log.critical(f"Failed to upsert metadata: {e}")
            self.log.critical(f"{self.uuid} is Malformed! Delete this Artifact and recreate it!")

    def size(self) -> float:
        """Return the size of this data in MegaBytes"""
        size_in_bytes = sum(wr.s3.size_objects(self.s3_storage_location(), boto3_session=self.boto3_session).values())
        size_in_mb = size_in_bytes / 1_000_000
        return size_in_mb

    def aws_meta(self) -> dict:
        """Get the FULL AWS metadata for this artifact"""
        return self.catalog_table_meta

    def aws_url(self):
        """The AWS URL for looking at/querying this data source"""
        sageworks_details = self.sageworks_meta().get("sageworks_details", {})
        return sageworks_details.get("aws_url", "unknown")

    def created(self) -> datetime:
        """Return the datetime when this artifact was created"""
        return self.catalog_table_meta["CreateTime"]

    def modified(self) -> datetime:
        """Return the datetime when this artifact was last modified"""
        return self.catalog_table_meta["UpdateTime"]

    def num_rows(self) -> int:
        """Return the number of rows for this Data Source"""
        count_df = self.query(f'select count(*) AS sageworks_count from "{self.get_database()}"."{self.table}"')
        return count_df["sageworks_count"][0] if count_df is not None else 0

    def num_columns(self) -> int:
        """Return the number of columns for this Data Source"""
        return len(self.columns)

    @property
    def columns(self) -> list[str]:
        """Return the column names for this Athena Table"""
        return [item["Name"] for item in self.catalog_table_meta["StorageDescriptor"]["Columns"]]

    @property
    def column_types(self) -> list[str]:
        """Return the column types of the internal AthenaSource"""
        return [item["Type"] for item in self.catalog_table_meta["StorageDescriptor"]["Columns"]]

    def query(self, query: str) -> Union[pd.DataFrame, None]:
        """Query the AthenaSource

        Args:
            query (str): The query to run against the AthenaSource

        Returns:
            pd.DataFrame: The results of the query
        """
        self.log.debug(f"Executing Query: {query}...")
        try:
            df = wr.athena.read_sql_query(
                sql=query,
                database=self.get_database(),
                ctas_approach=False,
                boto3_session=self.boto3_session,
            )
            scanned_bytes = df.query_metadata["Statistics"]["DataScannedInBytes"]
            if scanned_bytes > 0:
                self.log.debug(f"Athena Query successful (scanned bytes: {scanned_bytes})")
            return df
        except wr.exceptions.QueryFailed as e:
            self.log.critical(f"Failed to execute query: {e}")
            return None

    def execute_statement(self, query: str, silence_errors: bool = False):
        """Execute a non-returning SQL statement in Athena

        Args:
            query (str): The query to run against the AthenaSource
            silence_errors (bool): Silence errors (default: False)
        """
        try:
            # Start the query execution
            query_execution_id = wr.athena.start_query_execution(
                sql=query,
                database=self.get_database(),
                boto3_session=self.boto3_session,
            )
            self.log.debug(f"QueryExecutionId: {query_execution_id}")

            # Wait for the query to complete
            wr.athena.wait_query(query_execution_id=query_execution_id, boto3_session=self.boto3_session)
            self.log.debug(f"Statement executed successfully: {query_execution_id}")
        except wr.exceptions.QueryFailed as e:
            if "AlreadyExistsException" in str(e):
                self.log.warning(f"Table already exists. Ignoring: {e}")
            else:
                if not silence_errors:
                    self.log.error(f"Failed to execute statement: {e}")
                raise
        except botocore.exceptions.ClientError as e:
            error_code = e.response["Error"]["Code"]
            if error_code == "InvalidRequestException":
                self.log.error(f"Invalid Query: {query}")
            else:
                self.log.error(f"Failed to execute statement: {e}")
            raise

    def s3_storage_location(self) -> str:
        """Get the S3 Storage Location for this Data Source"""
        return self.catalog_table_meta["StorageDescriptor"]["Location"]

    def athena_test_query(self):
        """Validate that Athena Queries are working"""
        query = f"select count(*) as sageworks_count from {self.table}"
        df = wr.athena.read_sql_query(
            sql=query,
            database=self.get_database(),
            ctas_approach=False,
            boto3_session=self.boto3_session,
        )
        scanned_bytes = df.query_metadata["Statistics"]["DataScannedInBytes"]
        self.log.info(f"Athena TEST Query successful (scanned bytes: {scanned_bytes})")

    def sample_impl(self) -> pd.DataFrame:
        """Pull a sample of rows from the DataSource

        Returns:
            pd.DataFrame: A sample DataFrame for an Athena DataSource
        """

        # Call the SQL function to pull a sample of the rows
        return sample_rows.sample_rows(self)

    def descriptive_stats(self, recompute: bool = False) -> dict[dict]:
        """Compute Descriptive Stats for all the numeric columns in a DataSource

        Args:
            recompute (bool): Recompute the descriptive stats (default: False)

        Returns:
            dict(dict): A dictionary of descriptive stats for each column in the form
                 {'col1': {'min': 0, 'q1': 1, 'median': 2, 'q3': 3, 'max': 4},
                  'col2': ...}
        """

        # First check if we have already computed the descriptive stats
        stat_dict = self.sageworks_meta().get("sageworks_descriptive_stats")
        if stat_dict and not recompute:
            return stat_dict

        # Call the SQL function to compute descriptive stats
        stat_dict = descriptive_stats.descriptive_stats(self)

        # Push the descriptive stat data into our DataSource Metadata
        self.upsert_sageworks_meta({"sageworks_descriptive_stats": stat_dict})

        # Return the descriptive stats
        return stat_dict

    def outliers_impl(self, scale: float = 1.5, use_stddev=False) -> pd.DataFrame:
        """Compute outliers for all the numeric columns in a DataSource

        Args:
            scale (float): The scale to use for the IQR (default: 1.5)
            use_stddev (bool): Use Standard Deviation instead of IQR (default: False)

        Returns:
            pd.DataFrame: A DataFrame of outliers from this DataSource

        Notes:
            Uses the IQR * 1.5 (~= 2.5 Sigma) (use 1.7 for ~= 3 Sigma)
            The scale parameter can be adjusted to change the IQR multiplier
        """

        # Compute outliers using the SQL Outliers class
        sql_outliers = outliers.Outliers()
        return sql_outliers.compute_outliers(self, scale=scale, use_stddev=use_stddev)

    def smart_sample(self, recompute: bool = False) -> pd.DataFrame:
        """Get a smart sample dataframe for this DataSource

        Args:
            recompute (bool): Recompute the smart sample (default: False)

        Returns:
            pd.DataFrame: A combined DataFrame of sample data + outliers
        """

        # Check if we have cached smart_sample data
        storage_key = f"data_source:{self.uuid}:smart_sample"
        if not recompute and self.data_storage.get(storage_key):
            return pd.read_json(StringIO(self.data_storage.get(storage_key)))

        # Compute/recompute the smart sample
        self.log.important(f"Computing Smart Sample {self.uuid}...")

        # Outliers DataFrame
        outlier_rows = self.outliers(recompute=recompute)

        # Sample DataFrame
        sample_rows = self.sample(recompute=recompute)
        sample_rows["outlier_group"] = "sample"

        # Combine the sample rows with the outlier rows
        all_rows = pd.concat([outlier_rows, sample_rows]).reset_index(drop=True)

        # Drop duplicates
        all_except_outlier_group = [col for col in all_rows.columns if col != "outlier_group"]
        all_rows = all_rows.drop_duplicates(subset=all_except_outlier_group, ignore_index=True)

        # Cache the smart_sample data
        self.data_storage.set(storage_key, all_rows.to_json())

        # Return the smart_sample data
        return all_rows

    def correlations(self, recompute: bool = False) -> dict[dict]:
        """Compute Correlations for all the numeric columns in a DataSource

        Args:
            recompute (bool): Recompute the column stats (default: False)

        Returns:
            dict(dict): A dictionary of correlations for each column in this format
                 {'col1': {'col2': 0.5, 'col3': 0.9, 'col4': 0.4, ...},
                  'col2': {'col1': 0.5, 'col3': 0.8, 'col4': 0.3, ...}}
        """

        # First check if we have already computed the correlations
        correlations_dict = self.sageworks_meta().get("sageworks_correlations")
        if correlations_dict and not recompute:
            return correlations_dict

        # Call the SQL function to compute correlations
        correlations_dict = correlations.correlations(self)

        # Push the correlation data into our DataSource Metadata
        self.upsert_sageworks_meta({"sageworks_correlations": correlations_dict})

        # Return the correlation data
        return correlations_dict

    def column_stats(self, recompute: bool = False) -> dict[dict]:
        """Compute Column Stats for all the columns in a DataSource

        Args:
            recompute (bool): Recompute the column stats (default: False)

        Returns:
            dict(dict): A dictionary of stats for each column this format
            NB: String columns will NOT have num_zeros, descriptive_stats or correlation data
                {'col1': {'dtype': 'string', 'unique': 4321, 'nulls': 12},
                 'col2': {'dtype': 'int', 'unique': 4321, 'nulls': 12, 'num_zeros': 100,
                          'descriptive_stats': {...}, 'correlations': {...}},
                 ...}
        """

        # First check if we have already computed the column stats
        columns_stats_dict = self.sageworks_meta().get("sageworks_column_stats")
        if columns_stats_dict and not recompute:
            return columns_stats_dict

        # Call the SQL function to compute column stats
        column_stats_dict = column_stats.column_stats(self, recompute=recompute)

        # Push the column stats data into our DataSource Metadata
        self.upsert_sageworks_meta({"sageworks_column_stats": column_stats_dict})

        # Return the column stats data
        return column_stats_dict

    def value_counts(self, recompute: bool = False) -> dict[dict]:
        """Compute 'value_counts' for all the string columns in a DataSource

        Args:
            recompute (bool): Recompute the value counts (default: False)

        Returns:
            dict(dict): A dictionary of value counts for each column in the form
                 {'col1': {'value_1': 42, 'value_2': 16, 'value_3': 9,...},
                  'col2': ...}
        """

        # First check if we have already computed the value counts
        value_counts_dict = self.sageworks_meta().get("sageworks_value_counts")
        if value_counts_dict and not recompute:
            return value_counts_dict

        # Call the SQL function to compute value_counts
        value_count_dict = value_counts.value_counts(self)

        # Push the value_count data into our DataSource Metadata
        self.upsert_sageworks_meta({"sageworks_value_counts": value_count_dict})

        # Return the value_count data
        return value_count_dict

    def details(self, recompute: bool = False) -> dict[dict]:
        """Additional Details about this AthenaSource Artifact

        Args:
            recompute (bool): Recompute the details (default: False)

        Returns:
            dict(dict): A dictionary of details about this AthenaSource
        """

        # Check if we have cached version of the DataSource Details
        storage_key = f"data_source:{self.uuid}:details"
        cached_details = self.data_storage.get(storage_key)
        if cached_details and not recompute:
            return cached_details

        self.log.info(f"Recomputing DataSource Details ({self.uuid})...")

        # Get the details from the base class
        details = super().details()

        # Compute additional details
        details["s3_storage_location"] = self.s3_storage_location()
        details["storage_type"] = "athena"

        # Compute our AWS URL
        query = f"select * from {self.get_database()}.{self.table} limit 10"
        query_exec_id = wr.athena.start_query_execution(
            sql=query, database=self.get_database(), boto3_session=self.boto3_session
        )
        base_url = "https://console.aws.amazon.com/athena/home"
        details["aws_url"] = f"{base_url}?region={self.aws_region}#query/history/{query_exec_id}"

        # Push the aws_url data into our DataSource Metadata
        self.upsert_sageworks_meta({"sageworks_details": {"aws_url": details["aws_url"]}})

        # Convert any datetime fields to ISO-8601 strings
        details = convert_all_to_iso8601(details)

        # Add the column stats
        details["column_stats"] = self.column_stats()

        # Cache the details
        self.data_storage.set(storage_key, details)

        # Return the details data
        return details

    def delete(self):
        """Delete the AWS Data Catalog Table and S3 Storage Objects"""

        # Make sure the AthenaSource exists
        if not self.exists():
            self.log.warning(f"Trying to delete a AthenaSource that doesn't exist: {self.table}")

        # Delete any views associated with this AthenaSource
        self.delete_views()

        # Delete Data Catalog Table
        self.log.info(f"Deleting DataCatalog Table: {self.get_database()}.{self.table}...")
        wr.catalog.delete_table_if_exists(self.get_database(), self.table, boto3_session=self.boto3_session)

        # Delete S3 Storage Objects (if they exist)
        try:
            # Make sure we add the trailing slash
            s3_path = self.s3_storage_location()
            s3_path = s3_path if s3_path.endswith("/") else f"{s3_path}/"

            self.log.info(f"Deleting S3 Storage Objects: {s3_path}...")
            wr.s3.delete_objects(s3_path, boto3_session=self.boto3_session)
        except Exception as e:
            self.log.error(f"Failed to delete S3 Storage Objects: {e}")
            self.log.warning("Malformed Artifact... good thing it's being deleted...")

        # Delete any data in the Cache
        for key in self.data_storage.list_subkeys(f"data_source:{self.uuid}:"):
            self.log.info(f"Deleting Cache Key {key}...")
            self.data_storage.delete(key)

    def delete_views(self):
        """Delete any views associated with this FeatureSet"""
        from sageworks.core.views.view_utils import delete_views_and_supplemental_data

        delete_views_and_supplemental_data(self)

column_types: list[str] property

Return the column types of the internal AthenaSource

columns: list[str] property

Return the column names for this Athena Table

__init__(data_uuid, database='sageworks', force_refresh=False)

AthenaSource Initialization

Parameters:

Name Type Description Default
data_uuid str

Name of Athena Table

required
database str

Athena Database Name (default: sageworks)

'sageworks'
force_refresh bool

Force refresh of AWS Metadata (default: False)

False
Source code in src/sageworks/core/artifacts/athena_source.py
def __init__(self, data_uuid, database="sageworks", force_refresh: bool = False):
    """AthenaSource Initialization

    Args:
        data_uuid (str): Name of Athena Table
        database (str): Athena Database Name (default: sageworks)
        force_refresh (bool): Force refresh of AWS Metadata (default: False)
    """
    # Ensure the data_uuid is a valid name/id
    self.ensure_valid_name(data_uuid)

    # Call superclass init
    super().__init__(data_uuid, database)

    # Flag for metadata cache refresh logic
    self.metadata_refresh_needed = False

    # Setup our AWS Metadata Broker
    self.catalog_table_meta = self.meta_broker.data_source_details(
        data_uuid, self.get_database(), refresh=force_refresh
    )
    if self.catalog_table_meta is None:
        self.log.error(f"Unable to find {self.get_database()}:{self.table} in Glue Catalogs...")

    # Call superclass post init
    super().__post_init__()

    # All done
    self.log.debug(f"AthenaSource Initialized: {self.get_database()}.{self.table}")

arn()

AWS ARN (Amazon Resource Name) for this artifact

Source code in src/sageworks/core/artifacts/athena_source.py
def arn(self) -> str:
    """AWS ARN (Amazon Resource Name) for this artifact"""
    # Grab our SageWorks Role Manager, get our AWS account id, and region for ARN creation
    account_id = self.aws_account_clamp.account_id
    region = self.aws_account_clamp.region
    arn = f"arn:aws:glue:{region}:{account_id}:table/{self.get_database()}/{self.table}"
    return arn

athena_test_query()

Validate that Athena Queries are working

Source code in src/sageworks/core/artifacts/athena_source.py
def athena_test_query(self):
    """Validate that Athena Queries are working"""
    query = f"select count(*) as sageworks_count from {self.table}"
    df = wr.athena.read_sql_query(
        sql=query,
        database=self.get_database(),
        ctas_approach=False,
        boto3_session=self.boto3_session,
    )
    scanned_bytes = df.query_metadata["Statistics"]["DataScannedInBytes"]
    self.log.info(f"Athena TEST Query successful (scanned bytes: {scanned_bytes})")

aws_meta()

Get the FULL AWS metadata for this artifact

Source code in src/sageworks/core/artifacts/athena_source.py
def aws_meta(self) -> dict:
    """Get the FULL AWS metadata for this artifact"""
    return self.catalog_table_meta

aws_url()

The AWS URL for looking at/querying this data source

Source code in src/sageworks/core/artifacts/athena_source.py
def aws_url(self):
    """The AWS URL for looking at/querying this data source"""
    sageworks_details = self.sageworks_meta().get("sageworks_details", {})
    return sageworks_details.get("aws_url", "unknown")

column_stats(recompute=False)

Compute Column Stats for all the columns in a DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the column stats (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of stats for each column this format

NB dict[dict]

String columns will NOT have num_zeros, descriptive_stats or correlation data {'col1': {'dtype': 'string', 'unique': 4321, 'nulls': 12}, 'col2': {'dtype': 'int', 'unique': 4321, 'nulls': 12, 'num_zeros': 100, 'descriptive_stats': {...}, 'correlations': {...}}, ...}

Source code in src/sageworks/core/artifacts/athena_source.py
def column_stats(self, recompute: bool = False) -> dict[dict]:
    """Compute Column Stats for all the columns in a DataSource

    Args:
        recompute (bool): Recompute the column stats (default: False)

    Returns:
        dict(dict): A dictionary of stats for each column this format
        NB: String columns will NOT have num_zeros, descriptive_stats or correlation data
            {'col1': {'dtype': 'string', 'unique': 4321, 'nulls': 12},
             'col2': {'dtype': 'int', 'unique': 4321, 'nulls': 12, 'num_zeros': 100,
                      'descriptive_stats': {...}, 'correlations': {...}},
             ...}
    """

    # First check if we have already computed the column stats
    columns_stats_dict = self.sageworks_meta().get("sageworks_column_stats")
    if columns_stats_dict and not recompute:
        return columns_stats_dict

    # Call the SQL function to compute column stats
    column_stats_dict = column_stats.column_stats(self, recompute=recompute)

    # Push the column stats data into our DataSource Metadata
    self.upsert_sageworks_meta({"sageworks_column_stats": column_stats_dict})

    # Return the column stats data
    return column_stats_dict

correlations(recompute=False)

Compute Correlations for all the numeric columns in a DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the column stats (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of correlations for each column in this format {'col1': {'col2': 0.5, 'col3': 0.9, 'col4': 0.4, ...}, 'col2': {'col1': 0.5, 'col3': 0.8, 'col4': 0.3, ...}}

Source code in src/sageworks/core/artifacts/athena_source.py
def correlations(self, recompute: bool = False) -> dict[dict]:
    """Compute Correlations for all the numeric columns in a DataSource

    Args:
        recompute (bool): Recompute the column stats (default: False)

    Returns:
        dict(dict): A dictionary of correlations for each column in this format
             {'col1': {'col2': 0.5, 'col3': 0.9, 'col4': 0.4, ...},
              'col2': {'col1': 0.5, 'col3': 0.8, 'col4': 0.3, ...}}
    """

    # First check if we have already computed the correlations
    correlations_dict = self.sageworks_meta().get("sageworks_correlations")
    if correlations_dict and not recompute:
        return correlations_dict

    # Call the SQL function to compute correlations
    correlations_dict = correlations.correlations(self)

    # Push the correlation data into our DataSource Metadata
    self.upsert_sageworks_meta({"sageworks_correlations": correlations_dict})

    # Return the correlation data
    return correlations_dict

created()

Return the datetime when this artifact was created

Source code in src/sageworks/core/artifacts/athena_source.py
def created(self) -> datetime:
    """Return the datetime when this artifact was created"""
    return self.catalog_table_meta["CreateTime"]

delete()

Delete the AWS Data Catalog Table and S3 Storage Objects

Source code in src/sageworks/core/artifacts/athena_source.py
def delete(self):
    """Delete the AWS Data Catalog Table and S3 Storage Objects"""

    # Make sure the AthenaSource exists
    if not self.exists():
        self.log.warning(f"Trying to delete a AthenaSource that doesn't exist: {self.table}")

    # Delete any views associated with this AthenaSource
    self.delete_views()

    # Delete Data Catalog Table
    self.log.info(f"Deleting DataCatalog Table: {self.get_database()}.{self.table}...")
    wr.catalog.delete_table_if_exists(self.get_database(), self.table, boto3_session=self.boto3_session)

    # Delete S3 Storage Objects (if they exist)
    try:
        # Make sure we add the trailing slash
        s3_path = self.s3_storage_location()
        s3_path = s3_path if s3_path.endswith("/") else f"{s3_path}/"

        self.log.info(f"Deleting S3 Storage Objects: {s3_path}...")
        wr.s3.delete_objects(s3_path, boto3_session=self.boto3_session)
    except Exception as e:
        self.log.error(f"Failed to delete S3 Storage Objects: {e}")
        self.log.warning("Malformed Artifact... good thing it's being deleted...")

    # Delete any data in the Cache
    for key in self.data_storage.list_subkeys(f"data_source:{self.uuid}:"):
        self.log.info(f"Deleting Cache Key {key}...")
        self.data_storage.delete(key)

delete_views()

Delete any views associated with this FeatureSet

Source code in src/sageworks/core/artifacts/athena_source.py
def delete_views(self):
    """Delete any views associated with this FeatureSet"""
    from sageworks.core.views.view_utils import delete_views_and_supplemental_data

    delete_views_and_supplemental_data(self)

descriptive_stats(recompute=False)

Compute Descriptive Stats for all the numeric columns in a DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the descriptive stats (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of descriptive stats for each column in the form {'col1': {'min': 0, 'q1': 1, 'median': 2, 'q3': 3, 'max': 4}, 'col2': ...}

Source code in src/sageworks/core/artifacts/athena_source.py
def descriptive_stats(self, recompute: bool = False) -> dict[dict]:
    """Compute Descriptive Stats for all the numeric columns in a DataSource

    Args:
        recompute (bool): Recompute the descriptive stats (default: False)

    Returns:
        dict(dict): A dictionary of descriptive stats for each column in the form
             {'col1': {'min': 0, 'q1': 1, 'median': 2, 'q3': 3, 'max': 4},
              'col2': ...}
    """

    # First check if we have already computed the descriptive stats
    stat_dict = self.sageworks_meta().get("sageworks_descriptive_stats")
    if stat_dict and not recompute:
        return stat_dict

    # Call the SQL function to compute descriptive stats
    stat_dict = descriptive_stats.descriptive_stats(self)

    # Push the descriptive stat data into our DataSource Metadata
    self.upsert_sageworks_meta({"sageworks_descriptive_stats": stat_dict})

    # Return the descriptive stats
    return stat_dict

details(recompute=False)

Additional Details about this AthenaSource Artifact

Parameters:

Name Type Description Default
recompute bool

Recompute the details (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of details about this AthenaSource

Source code in src/sageworks/core/artifacts/athena_source.py
def details(self, recompute: bool = False) -> dict[dict]:
    """Additional Details about this AthenaSource Artifact

    Args:
        recompute (bool): Recompute the details (default: False)

    Returns:
        dict(dict): A dictionary of details about this AthenaSource
    """

    # Check if we have cached version of the DataSource Details
    storage_key = f"data_source:{self.uuid}:details"
    cached_details = self.data_storage.get(storage_key)
    if cached_details and not recompute:
        return cached_details

    self.log.info(f"Recomputing DataSource Details ({self.uuid})...")

    # Get the details from the base class
    details = super().details()

    # Compute additional details
    details["s3_storage_location"] = self.s3_storage_location()
    details["storage_type"] = "athena"

    # Compute our AWS URL
    query = f"select * from {self.get_database()}.{self.table} limit 10"
    query_exec_id = wr.athena.start_query_execution(
        sql=query, database=self.get_database(), boto3_session=self.boto3_session
    )
    base_url = "https://console.aws.amazon.com/athena/home"
    details["aws_url"] = f"{base_url}?region={self.aws_region}#query/history/{query_exec_id}"

    # Push the aws_url data into our DataSource Metadata
    self.upsert_sageworks_meta({"sageworks_details": {"aws_url": details["aws_url"]}})

    # Convert any datetime fields to ISO-8601 strings
    details = convert_all_to_iso8601(details)

    # Add the column stats
    details["column_stats"] = self.column_stats()

    # Cache the details
    self.data_storage.set(storage_key, details)

    # Return the details data
    return details

execute_statement(query, silence_errors=False)

Execute a non-returning SQL statement in Athena

Parameters:

Name Type Description Default
query str

The query to run against the AthenaSource

required
silence_errors bool

Silence errors (default: False)

False
Source code in src/sageworks/core/artifacts/athena_source.py
def execute_statement(self, query: str, silence_errors: bool = False):
    """Execute a non-returning SQL statement in Athena

    Args:
        query (str): The query to run against the AthenaSource
        silence_errors (bool): Silence errors (default: False)
    """
    try:
        # Start the query execution
        query_execution_id = wr.athena.start_query_execution(
            sql=query,
            database=self.get_database(),
            boto3_session=self.boto3_session,
        )
        self.log.debug(f"QueryExecutionId: {query_execution_id}")

        # Wait for the query to complete
        wr.athena.wait_query(query_execution_id=query_execution_id, boto3_session=self.boto3_session)
        self.log.debug(f"Statement executed successfully: {query_execution_id}")
    except wr.exceptions.QueryFailed as e:
        if "AlreadyExistsException" in str(e):
            self.log.warning(f"Table already exists. Ignoring: {e}")
        else:
            if not silence_errors:
                self.log.error(f"Failed to execute statement: {e}")
            raise
    except botocore.exceptions.ClientError as e:
        error_code = e.response["Error"]["Code"]
        if error_code == "InvalidRequestException":
            self.log.error(f"Invalid Query: {query}")
        else:
            self.log.error(f"Failed to execute statement: {e}")
        raise

exists()

Validation Checks for this Data Source

Source code in src/sageworks/core/artifacts/athena_source.py
def exists(self) -> bool:
    """Validation Checks for this Data Source"""

    # Are we able to pull AWS Metadata for this table_name?"""
    # Do we have a valid catalog_table_meta?
    if getattr(self, "catalog_table_meta", None) is None:
        self.log.debug(f"AthenaSource {self.table} not found in SageWorks Metadata...")
        return False
    return True

modified()

Return the datetime when this artifact was last modified

Source code in src/sageworks/core/artifacts/athena_source.py
def modified(self) -> datetime:
    """Return the datetime when this artifact was last modified"""
    return self.catalog_table_meta["UpdateTime"]

num_columns()

Return the number of columns for this Data Source

Source code in src/sageworks/core/artifacts/athena_source.py
def num_columns(self) -> int:
    """Return the number of columns for this Data Source"""
    return len(self.columns)

num_rows()

Return the number of rows for this Data Source

Source code in src/sageworks/core/artifacts/athena_source.py
def num_rows(self) -> int:
    """Return the number of rows for this Data Source"""
    count_df = self.query(f'select count(*) AS sageworks_count from "{self.get_database()}"."{self.table}"')
    return count_df["sageworks_count"][0] if count_df is not None else 0

outliers_impl(scale=1.5, use_stddev=False)

Compute outliers for all the numeric columns in a DataSource

Parameters:

Name Type Description Default
scale float

The scale to use for the IQR (default: 1.5)

1.5
use_stddev bool

Use Standard Deviation instead of IQR (default: False)

False

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame of outliers from this DataSource

Notes

Uses the IQR * 1.5 (~= 2.5 Sigma) (use 1.7 for ~= 3 Sigma) The scale parameter can be adjusted to change the IQR multiplier

Source code in src/sageworks/core/artifacts/athena_source.py
def outliers_impl(self, scale: float = 1.5, use_stddev=False) -> pd.DataFrame:
    """Compute outliers for all the numeric columns in a DataSource

    Args:
        scale (float): The scale to use for the IQR (default: 1.5)
        use_stddev (bool): Use Standard Deviation instead of IQR (default: False)

    Returns:
        pd.DataFrame: A DataFrame of outliers from this DataSource

    Notes:
        Uses the IQR * 1.5 (~= 2.5 Sigma) (use 1.7 for ~= 3 Sigma)
        The scale parameter can be adjusted to change the IQR multiplier
    """

    # Compute outliers using the SQL Outliers class
    sql_outliers = outliers.Outliers()
    return sql_outliers.compute_outliers(self, scale=scale, use_stddev=use_stddev)

query(query)

Query the AthenaSource

Parameters:

Name Type Description Default
query str

The query to run against the AthenaSource

required

Returns:

Type Description
Union[DataFrame, None]

pd.DataFrame: The results of the query

Source code in src/sageworks/core/artifacts/athena_source.py
def query(self, query: str) -> Union[pd.DataFrame, None]:
    """Query the AthenaSource

    Args:
        query (str): The query to run against the AthenaSource

    Returns:
        pd.DataFrame: The results of the query
    """
    self.log.debug(f"Executing Query: {query}...")
    try:
        df = wr.athena.read_sql_query(
            sql=query,
            database=self.get_database(),
            ctas_approach=False,
            boto3_session=self.boto3_session,
        )
        scanned_bytes = df.query_metadata["Statistics"]["DataScannedInBytes"]
        if scanned_bytes > 0:
            self.log.debug(f"Athena Query successful (scanned bytes: {scanned_bytes})")
        return df
    except wr.exceptions.QueryFailed as e:
        self.log.critical(f"Failed to execute query: {e}")
        return None

refresh_meta()

Refresh our internal AWS Broker catalog metadata

Source code in src/sageworks/core/artifacts/athena_source.py
def refresh_meta(self):
    """Refresh our internal AWS Broker catalog metadata"""
    _catalog_meta = self.aws_broker.get_metadata(ServiceCategory.DATA_CATALOG, force_refresh=True)
    self.catalog_table_meta = _catalog_meta[self.get_database()].get(self.table)
    self.metadata_refresh_needed = False

s3_storage_location()

Get the S3 Storage Location for this Data Source

Source code in src/sageworks/core/artifacts/athena_source.py
def s3_storage_location(self) -> str:
    """Get the S3 Storage Location for this Data Source"""
    return self.catalog_table_meta["StorageDescriptor"]["Location"]

sageworks_meta()

Get the SageWorks specific metadata for this Artifact

Source code in src/sageworks/core/artifacts/athena_source.py
def sageworks_meta(self) -> dict:
    """Get the SageWorks specific metadata for this Artifact"""

    # Sanity Check if we have invalid AWS Metadata
    self.log.info(f"Retrieving SageWorks Metadata for Artifact: {self.uuid}...")
    if self.catalog_table_meta is None:
        if not self.exists():
            self.log.error(f"DataSource {self.uuid} doesn't appear to exist...")
        else:
            self.log.critical(f"Unable to get AWS Metadata for {self.table}")
            self.log.critical("Malformed Artifact! Delete this Artifact and recreate it!")
        return {}

    # Check if we need to refresh our metadata
    if self.metadata_refresh_needed:
        self.refresh_meta()

    # Get the SageWorks Metadata from the Catalog Table Metadata
    return sageworks_meta_from_catalog_table_meta(self.catalog_table_meta)

sample_impl()

Pull a sample of rows from the DataSource

Returns:

Type Description
DataFrame

pd.DataFrame: A sample DataFrame for an Athena DataSource

Source code in src/sageworks/core/artifacts/athena_source.py
def sample_impl(self) -> pd.DataFrame:
    """Pull a sample of rows from the DataSource

    Returns:
        pd.DataFrame: A sample DataFrame for an Athena DataSource
    """

    # Call the SQL function to pull a sample of the rows
    return sample_rows.sample_rows(self)

size()

Return the size of this data in MegaBytes

Source code in src/sageworks/core/artifacts/athena_source.py
def size(self) -> float:
    """Return the size of this data in MegaBytes"""
    size_in_bytes = sum(wr.s3.size_objects(self.s3_storage_location(), boto3_session=self.boto3_session).values())
    size_in_mb = size_in_bytes / 1_000_000
    return size_in_mb

smart_sample(recompute=False)

Get a smart sample dataframe for this DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the smart sample (default: False)

False

Returns:

Type Description
DataFrame

pd.DataFrame: A combined DataFrame of sample data + outliers

Source code in src/sageworks/core/artifacts/athena_source.py
def smart_sample(self, recompute: bool = False) -> pd.DataFrame:
    """Get a smart sample dataframe for this DataSource

    Args:
        recompute (bool): Recompute the smart sample (default: False)

    Returns:
        pd.DataFrame: A combined DataFrame of sample data + outliers
    """

    # Check if we have cached smart_sample data
    storage_key = f"data_source:{self.uuid}:smart_sample"
    if not recompute and self.data_storage.get(storage_key):
        return pd.read_json(StringIO(self.data_storage.get(storage_key)))

    # Compute/recompute the smart sample
    self.log.important(f"Computing Smart Sample {self.uuid}...")

    # Outliers DataFrame
    outlier_rows = self.outliers(recompute=recompute)

    # Sample DataFrame
    sample_rows = self.sample(recompute=recompute)
    sample_rows["outlier_group"] = "sample"

    # Combine the sample rows with the outlier rows
    all_rows = pd.concat([outlier_rows, sample_rows]).reset_index(drop=True)

    # Drop duplicates
    all_except_outlier_group = [col for col in all_rows.columns if col != "outlier_group"]
    all_rows = all_rows.drop_duplicates(subset=all_except_outlier_group, ignore_index=True)

    # Cache the smart_sample data
    self.data_storage.set(storage_key, all_rows.to_json())

    # Return the smart_sample data
    return all_rows

upsert_sageworks_meta(new_meta)

Add SageWorks specific metadata to this Artifact

Parameters:

Name Type Description Default
new_meta dict

Dictionary of new metadata to add

required
Source code in src/sageworks/core/artifacts/athena_source.py
def upsert_sageworks_meta(self, new_meta: dict):
    """Add SageWorks specific metadata to this Artifact

    Args:
        new_meta (dict): Dictionary of new metadata to add
    """

    # Give a warning message for keys that don't start with sageworks_
    for key in new_meta.keys():
        if not key.startswith("sageworks_"):
            self.log.warning("Append 'sageworks_' to key names to avoid overwriting AWS meta data")

    # Now convert any non-string values to JSON strings
    for key, value in new_meta.items():
        if not isinstance(value, str):
            new_meta[key] = json.dumps(value, cls=CustomEncoder)

    # Store our updated metadata
    try:
        wr.catalog.upsert_table_parameters(
            parameters=new_meta,
            database=self.get_database(),
            table=self.table,
            boto3_session=self.boto3_session,
        )
        self.metadata_refresh_needed = True
    except botocore.exceptions.ClientError as e:
        error_code = e.response["Error"]["Code"]
        if error_code == "InvalidInputException":
            self.log.error(f"Unable to upsert metadata for {self.table}")
            self.log.error("Probably because the metadata is too large")
            self.log.error(new_meta)
        elif error_code == "ConcurrentModificationException":
            self.log.warning("ConcurrentModificationException... trying again...")
            time.sleep(5)
            wr.catalog.upsert_table_parameters(
                parameters=new_meta,
                database=self.get_database(),
                table=self.table,
                boto3_session=self.boto3_session,
            )
        else:
            self.log.critical(f"Failed to upsert metadata: {e}")
            self.log.critical(f"{self.uuid} is Malformed! Delete this Artifact and recreate it!")
    except Exception as e:
        self.log.critical(f"Failed to upsert metadata: {e}")
        self.log.critical(f"{self.uuid} is Malformed! Delete this Artifact and recreate it!")

value_counts(recompute=False)

Compute 'value_counts' for all the string columns in a DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the value counts (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of value counts for each column in the form {'col1': {'value_1': 42, 'value_2': 16, 'value_3': 9,...}, 'col2': ...}

Source code in src/sageworks/core/artifacts/athena_source.py
def value_counts(self, recompute: bool = False) -> dict[dict]:
    """Compute 'value_counts' for all the string columns in a DataSource

    Args:
        recompute (bool): Recompute the value counts (default: False)

    Returns:
        dict(dict): A dictionary of value counts for each column in the form
             {'col1': {'value_1': 42, 'value_2': 16, 'value_3': 9,...},
              'col2': ...}
    """

    # First check if we have already computed the value counts
    value_counts_dict = self.sageworks_meta().get("sageworks_value_counts")
    if value_counts_dict and not recompute:
        return value_counts_dict

    # Call the SQL function to compute value_counts
    value_count_dict = value_counts.value_counts(self)

    # Push the value_count data into our DataSource Metadata
    self.upsert_sageworks_meta({"sageworks_value_counts": value_count_dict})

    # Return the value_count data
    return value_count_dict