Skip to content

AthenaSource

API Classes

Found a method here you want to use? The API Classes have method pass-through so just call the method on the DataSource API Class and voilà it works the same.

AthenaSource: Workbench Data Source accessible through Athena

AthenaSource

Bases: DataSourceAbstract

AthenaSource: Workbench Data Source accessible through Athena

Common Usage
my_data = AthenaSource(data_uuid, database="workbench")
my_data.summary()
my_data.details()
df = my_data.query(f"select * from {data_uuid} limit 5")
Source code in src/workbench/core/artifacts/athena_source.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
class AthenaSource(DataSourceAbstract):
    """AthenaSource: Workbench Data Source accessible through Athena

    Common Usage:
        ```python
        my_data = AthenaSource(data_uuid, database="workbench")
        my_data.summary()
        my_data.details()
        df = my_data.query(f"select * from {data_uuid} limit 5")
        ```
    """

    def __init__(self, data_uuid, database="workbench", **kwargs):
        """AthenaSource Initialization

        Args:
            data_uuid (str): Name of Athena Table
            database (str): Athena Database Name (default: workbench)
        """
        # Ensure the data_uuid is a valid name/id
        self.is_name_valid(data_uuid)

        # Call superclass init
        super().__init__(data_uuid, database, **kwargs)

        # Grab our metadata from the Meta class
        self.log.info(f"Retrieving metadata for: {self.uuid}...")
        self.data_source_meta = self.meta.data_source(data_uuid, database=database)
        if self.data_source_meta is None:
            self.log.error(f"Unable to find {database}:{self.table} in Glue Catalogs...")
            return

        # Call superclass post init
        super().__post_init__()

        # All done
        self.log.debug(f"AthenaSource Initialized: {database}.{self.table}")

    def refresh_meta(self):
        """Refresh our internal AWS Broker catalog metadata"""
        self.data_source_meta = self.meta.data_source(self.uuid, database=self.database)

    def exists(self) -> bool:
        """Validation Checks for this Data Source"""

        # Are we able to pull AWS Metadata for this table_name?"""
        # Do we have a valid data_source_meta?
        if getattr(self, "data_source_meta", None) is None:
            self.log.debug(f"AthenaSource {self.table} not found in Workbench Metadata...")
            return False
        return True

    def arn(self) -> str:
        """AWS ARN (Amazon Resource Name) for this artifact"""
        # Grab our Workbench Role Manager, get our AWS account id, and region for ARN creation
        account_id = self.aws_account_clamp.account_id
        region = self.aws_account_clamp.region
        arn = f"arn:aws:glue:{region}:{account_id}:table/{self.database}/{self.table}"
        return arn

    def workbench_meta(self) -> dict:
        """Get the Workbench specific metadata for this Artifact"""

        # Sanity Check if we have invalid AWS Metadata
        if self.data_source_meta is None:
            if not self.exists():
                self.log.error(f"DataSource {self.uuid} doesn't appear to exist...")
            else:
                self.log.critical(f"Unable to get AWS Metadata for {self.table}")
                self.log.critical("Malformed Artifact! Delete this Artifact and recreate it!")
            return {}

        # Get the Workbench Metadata from the 'Parameters' section of the DataSource Metadata
        params = self.data_source_meta.get("Parameters", {})
        return {key: decode_value(value) for key, value in params.items() if "workbench" in key}

    def upsert_workbench_meta(self, new_meta: dict):
        """Add Workbench specific metadata to this Artifact

        Args:
            new_meta (dict): Dictionary of new metadata to add
        """
        self.log.important(f"Upserting Workbench Metadata {self.uuid}:{str(new_meta)[:50]}...")

        # Give a warning message for keys that don't start with workbench_
        for key in new_meta.keys():
            if not key.startswith("workbench_"):
                self.log.warning("Append 'workbench_' to key names to avoid overwriting AWS meta data")

        # Now convert any non-string values to JSON strings
        for key, value in new_meta.items():
            if not isinstance(value, str):
                new_meta[key] = json.dumps(value, cls=CustomEncoder)

        # Store our updated metadata
        try:
            wr.catalog.upsert_table_parameters(
                parameters=new_meta,
                database=self.database,
                table=self.table,
                boto3_session=self.boto3_session,
            )
        except botocore.exceptions.ClientError as e:
            error_code = e.response["Error"]["Code"]
            if error_code == "InvalidInputException":
                self.log.error(f"Unable to upsert metadata for {self.table}")
                self.log.error("Probably because the metadata is too large")
                self.log.error(new_meta)
            elif error_code == "ConcurrentModificationException":
                self.log.warning("ConcurrentModificationException... trying again...")
                time.sleep(5)
                wr.catalog.upsert_table_parameters(
                    parameters=new_meta,
                    database=self.database,
                    table=self.table,
                    boto3_session=self.boto3_session,
                )
            else:
                self.log.critical(f"Failed to upsert metadata: {e}")
                self.log.critical(f"{self.uuid} is Malformed! Delete this Artifact and recreate it!")
        except Exception as e:
            self.log.critical(f"Failed to upsert metadata: {e}")
            self.log.critical(f"{self.uuid} is Malformed! Delete this Artifact and recreate it!")

    def size(self) -> float:
        """Return the size of this data in MegaBytes"""
        size_in_bytes = sum(wr.s3.size_objects(self.s3_storage_location(), boto3_session=self.boto3_session).values())
        size_in_mb = size_in_bytes / 1_000_000
        return size_in_mb

    def aws_meta(self) -> dict:
        """Get the FULL AWS metadata for this artifact"""
        return self.data_source_meta

    def aws_url(self):
        """The AWS URL for looking at/querying this data source"""
        workbench_details = self.workbench_meta().get("workbench_details", {})
        return workbench_details.get("aws_url", "unknown")

    def created(self) -> datetime:
        """Return the datetime when this artifact was created"""
        return self.data_source_meta["CreateTime"]

    def modified(self) -> datetime:
        """Return the datetime when this artifact was last modified"""
        return self.data_source_meta["UpdateTime"]

    def hash(self) -> str:
        """Get the hash for the set of Parquet files used for this Artifact"""
        s3_uri = self.s3_storage_location()
        return compute_parquet_hash(s3_uri, self.boto3_session)

    def table_hash(self) -> str:
        """Get the table hash for this AthenaSource"""
        s3_scratch = f"s3://{self.workbench_bucket}/temp/athena_output"
        return compute_athena_table_hash(self.database, self.table, self.boto3_session, s3_scratch)

    def num_rows(self) -> int:
        """Return the number of rows for this Data Source"""
        count_df = self.query(f'select count(*) AS workbench_count from "{self.database}"."{self.table}"')
        return count_df["workbench_count"][0] if count_df is not None else 0

    def num_columns(self) -> int:
        """Return the number of columns for this Data Source"""
        return len(self.columns)

    @property
    def columns(self) -> list[str]:
        """Return the column names for this Athena Table"""
        return [item["Name"] for item in self.data_source_meta["StorageDescriptor"]["Columns"]]

    @property
    def column_types(self) -> list[str]:
        """Return the column types of the internal AthenaSource"""
        return [item["Type"] for item in self.data_source_meta["StorageDescriptor"]["Columns"]]

    def query(self, query: str) -> Union[pd.DataFrame, None]:
        """Query the AthenaSource

        Args:
            query (str): The query to run against the AthenaSource

        Returns:
            pd.DataFrame: The results of the query
        """

        # Call internal class _query method
        return self.database_query(self.database, query)

    @classmethod
    def database_query(cls, database: str, query: str) -> Union[pd.DataFrame, None]:
        """Specify the Database and Query the Athena Service

        Args:
            database (str): The Athena Database to query
            query (str): The query to run against the AthenaSource

        Returns:
            pd.DataFrame: The results of the query
        """
        cls.log.debug(f"Executing Query: {query}...")
        try:
            df = wr.athena.read_sql_query(
                sql=query,
                database=database,
                ctas_approach=False,
                boto3_session=cls.boto3_session,
            )
            scanned_bytes = df.query_metadata["Statistics"]["DataScannedInBytes"]
            if scanned_bytes > 0:
                cls.log.debug(f"Athena Query successful (scanned bytes: {scanned_bytes})")
            return df
        except wr.exceptions.QueryFailed as e:
            cls.log.critical(f"Failed to execute query: {e}")
            return None

    def execute_statement(self, query: str, silence_errors: bool = False):
        """Execute a non-returning SQL statement in Athena with retries.

        Args:
            query (str): The query to run against the AthenaSource
            silence_errors (bool): Silence errors (default: False)
        """
        attempt = 0
        max_retries = 3
        retry_delay = 10
        while attempt < max_retries:
            try:
                # Start the query execution
                query_execution_id = wr.athena.start_query_execution(
                    sql=query,
                    database=self.database,
                    boto3_session=self.boto3_session,
                )
                self.log.debug(f"QueryExecutionId: {query_execution_id}")

                # Wait for the query to complete
                wr.athena.wait_query(query_execution_id=query_execution_id, boto3_session=self.boto3_session)
                self.log.debug(f"Statement executed successfully: {query_execution_id}")
                break  # If successful, exit the retry loop
            except wr.exceptions.QueryFailed as e:
                if "AlreadyExistsException" in str(e):
                    self.log.warning(f"Table already exists: {e} \nIgnoring...")
                    break  # No need to retry for this error
                elif "ConcurrentModificationException" in str(e):
                    self.log.warning(f"Concurrent modification detected: {e}\nRetrying...")
                    attempt += 1
                    if attempt < max_retries:
                        time.sleep(retry_delay)
                    else:
                        if not silence_errors:
                            self.log.critical(f"Failed to execute statement after {max_retries} attempts: {e}")
                        raise
                else:
                    if not silence_errors:
                        self.log.critical(f"Failed to execute statement: {e}")
                    raise

    def s3_storage_location(self) -> str:
        """Get the S3 Storage Location for this Data Source"""
        return self.data_source_meta["StorageDescriptor"]["Location"]

    def athena_test_query(self):
        """Validate that Athena Queries are working"""
        query = f'select count(*) as workbench_count from "{self.table}"'
        df = wr.athena.read_sql_query(
            sql=query,
            database=self.database,
            ctas_approach=False,
            boto3_session=self.boto3_session,
        )
        scanned_bytes = df.query_metadata["Statistics"]["DataScannedInBytes"]
        self.log.info(f"Athena TEST Query successful (scanned bytes: {scanned_bytes})")

    def descriptive_stats(self, recompute: bool = False) -> dict[dict]:
        """Compute Descriptive Stats for all the numeric columns in a DataSource

        Args:
            recompute (bool): Recompute the descriptive stats (default: False)

        Returns:
            dict(dict): A dictionary of descriptive stats for each column in the form
                 {'col1': {'min': 0, 'q1': 1, 'median': 2, 'q3': 3, 'max': 4},
                  'col2': ...}
        """

        # First check if we have already computed the descriptive stats
        stat_dict = self.workbench_meta().get("workbench_descriptive_stats")
        if stat_dict and not recompute:
            return stat_dict

        # Call the SQL function to compute descriptive stats
        stat_dict = sql.descriptive_stats(self)

        # Push the descriptive stat data into our DataSource Metadata
        self.upsert_workbench_meta({"workbench_descriptive_stats": stat_dict})

        # Return the descriptive stats
        return stat_dict

    @cache_dataframe("sample")
    def sample(self) -> pd.DataFrame:
        """Pull a sample of rows from the DataSource

        Returns:
            pd.DataFrame: A sample DataFrame for an Athena DataSource
        """

        # Call the SQL function to pull a sample of the rows
        return sql.sample_rows(self)

    @cache_dataframe("outliers")
    def outliers(self, scale: float = 1.5, use_stddev=False) -> pd.DataFrame:
        """Compute outliers for all the numeric columns in a DataSource

        Args:
            scale (float): The scale to use for the IQR (default: 1.5)
            use_stddev (bool): Use Standard Deviation instead of IQR (default: False)

        Returns:
            pd.DataFrame: A DataFrame of outliers from this DataSource

        Notes:
            Uses the IQR * 1.5 (~= 2.5 Sigma) (use 1.7 for ~= 3 Sigma)
            The scale parameter can be adjusted to change the IQR multiplier
        """

        # Compute outliers using the SQL Outliers class
        sql_outliers = sql.outliers.Outliers()
        return sql_outliers.compute_outliers(self, scale=scale, use_stddev=use_stddev)

    @cache_dataframe("smart_sample")
    def smart_sample(self, recompute: bool = False) -> pd.DataFrame:
        """Get a smart sample dataframe for this DataSource

        Args:
            recompute (bool): Recompute the smart sample (default: False)

        Returns:
            pd.DataFrame: A combined DataFrame of sample data + outliers
        """

        # Compute/recompute the smart sample
        self.log.important(f"Computing Smart Sample {self.uuid}...")

        # Outliers DataFrame
        outlier_rows = self.outliers()

        # Sample DataFrame
        sample_rows = self.sample()
        sample_rows["outlier_group"] = "sample"

        # Combine the sample rows with the outlier rows
        all_rows = pd.concat([outlier_rows, sample_rows]).reset_index(drop=True)

        # Drop duplicates
        all_except_outlier_group = [col for col in all_rows.columns if col != "outlier_group"]
        all_rows = all_rows.drop_duplicates(subset=all_except_outlier_group, ignore_index=True)

        # Return the smart_sample data
        return all_rows

    def correlations(self, recompute: bool = False) -> dict[dict]:
        """Compute Correlations for all the numeric columns in a DataSource

        Args:
            recompute (bool): Recompute the column stats (default: False)

        Returns:
            dict(dict): A dictionary of correlations for each column in this format
                 {'col1': {'col2': 0.5, 'col3': 0.9, 'col4': 0.4, ...},
                  'col2': {'col1': 0.5, 'col3': 0.8, 'col4': 0.3, ...}}
        """

        # First check if we have already computed the correlations
        correlations_dict = self.workbench_meta().get("workbench_correlations")
        if correlations_dict and not recompute:
            return correlations_dict

        # Call the SQL function to compute correlations
        correlations_dict = sql.correlations(self)

        # Push the correlation data into our DataSource Metadata
        self.upsert_workbench_meta({"workbench_correlations": correlations_dict})

        # Return the correlation data
        return correlations_dict

    def column_stats(self, recompute: bool = False) -> dict[dict]:
        """Compute Column Stats for all the columns in a DataSource

        Args:
            recompute (bool): Recompute the column stats (default: False)

        Returns:
            dict(dict): A dictionary of stats for each column this format
            NB: String columns will NOT have num_zeros, descriptive_stats or correlation data
                {'col1': {'dtype': 'string', 'unique': 4321, 'nulls': 12},
                 'col2': {'dtype': 'int', 'unique': 4321, 'nulls': 12, 'num_zeros': 100,
                          'descriptive_stats': {...}, 'correlations': {...}},
                 ...}
        """

        # First check if we have already computed the column stats
        columns_stats_dict = self.workbench_meta().get("workbench_column_stats")
        if columns_stats_dict and not recompute:
            return columns_stats_dict

        # Call the SQL function to compute column stats
        column_stats_dict = sql.column_stats(self, recompute=recompute)

        # Push the column stats data into our DataSource Metadata
        self.upsert_workbench_meta({"workbench_column_stats": column_stats_dict})

        # Return the column stats data
        return column_stats_dict

    def value_counts(self, recompute: bool = False) -> dict[dict]:
        """Compute 'value_counts' for all the string columns in a DataSource

        Args:
            recompute (bool): Recompute the value counts (default: False)

        Returns:
            dict(dict): A dictionary of value counts for each column in the form
                 {'col1': {'value_1': 42, 'value_2': 16, 'value_3': 9,...},
                  'col2': ...}
        """

        # First check if we have already computed the value counts
        value_counts_dict = self.workbench_meta().get("workbench_value_counts")
        if value_counts_dict and not recompute:
            return value_counts_dict

        # Call the SQL function to compute value_counts
        value_count_dict = sql.value_counts(self)

        # Push the value_count data into our DataSource Metadata
        self.upsert_workbench_meta({"workbench_value_counts": value_count_dict})

        # Return the value_count data
        return value_count_dict

    def details(self, recompute: bool = False) -> dict[dict]:
        """Additional Details about this AthenaSource Artifact

        Args:
            recompute (bool): Recompute the details (default: False)

        Returns:
            dict(dict): A dictionary of details about this AthenaSource
        """
        self.log.info(f"Computing DataSource Details ({self.uuid})...")

        # Get the details from the base class
        details = super().details()

        # Compute additional details
        details["s3_storage_location"] = self.s3_storage_location()
        details["storage_type"] = "athena"

        # Compute our AWS URL
        query = f'select * from "{self.database}.{self.table}" limit 10'
        query_exec_id = wr.athena.start_query_execution(
            sql=query, database=self.database, boto3_session=self.boto3_session
        )
        base_url = "https://console.aws.amazon.com/athena/home"
        details["aws_url"] = f"{base_url}?region={self.aws_region}#query/history/{query_exec_id}"

        # Push the aws_url data into our DataSource Metadata
        # FIXME: We need to revisit this but doing an upsert just for aws_url is silly
        # self.upsert_workbench_meta({"workbench_details": {"aws_url": details["aws_url"]}})

        # Convert any datetime fields to ISO-8601 strings
        details = convert_all_to_iso8601(details)

        # Add the column stats
        details["column_stats"] = self.column_stats()

        # Return the details data
        return details

    def delete(self):
        """Instance Method: Delete the AWS Data Catalog Table and S3 Storage Objects"""

        # Make sure the AthenaSource exists
        if not self.exists():
            self.log.warning(f"Trying to delete an AthenaSource that doesn't exist: {self.uuid}")

        # Call the Class Method to delete the AthenaSource
        AthenaSource.managed_delete(self.uuid, database=self.database)

    @classmethod
    def managed_delete(cls, data_source_name: str, database: str = "workbench"):
        """Class Method: Delete the AWS Data Catalog Table and S3 Storage Objects

        Args:
            data_source_name (str): Name of DataSource (AthenaSource)
            database (str): Athena Database Name (default: workbench)
        """
        table = data_source_name  # The table name is the same as the data_source_name

        # Check if the Glue Catalog Table exists
        if not wr.catalog.does_table_exist(database, table, boto3_session=cls.boto3_session):
            cls.log.info(f"DataSource {table} not found in database {database}.")
            return

        # Delete any views associated with this AthenaSource
        cls.delete_views(table, database)

        # Delete S3 Storage Objects (if they exist)
        try:
            # Make an AWS Query to get the S3 storage location
            s3_path = wr.catalog.get_table_location(database, table, boto3_session=cls.boto3_session)

            # Delete Data Catalog Table
            cls.log.info(f"Deleting DataCatalog Table: {database}.{table}...")
            wr.catalog.delete_table_if_exists(database, table, boto3_session=cls.boto3_session)

            # Make sure we add the trailing slash
            s3_path = s3_path if s3_path.endswith("/") else f"{s3_path}/"
            cls.log.info(f"Deleting S3 Storage Objects: {s3_path}...")
            wr.s3.delete_objects(s3_path, boto3_session=cls.boto3_session)
        except Exception as e:
            cls.log.error(f"Failure when trying to delete {data_source_name}: {e}")

        # Delete any dataframes that were stored in the Dataframe Cache
        cls.log.info("Deleting Dataframe Cache...")
        cls.df_cache.delete_recursive(data_source_name)

    @classmethod
    def delete_views(cls, table: str, database: str):
        """Delete any views associated with this FeatureSet

        Args:
            table (str): Name of Athena Table
            database (str): Athena Database Name
        """
        from workbench.core.views.view_utils import delete_views_and_supplemental_data

        delete_views_and_supplemental_data(table, database, cls.boto3_session)

column_types: list[str] property

Return the column types of the internal AthenaSource

columns: list[str] property

Return the column names for this Athena Table

__init__(data_uuid, database='workbench', **kwargs)

AthenaSource Initialization

Parameters:

Name Type Description Default
data_uuid str

Name of Athena Table

required
database str

Athena Database Name (default: workbench)

'workbench'
Source code in src/workbench/core/artifacts/athena_source.py
def __init__(self, data_uuid, database="workbench", **kwargs):
    """AthenaSource Initialization

    Args:
        data_uuid (str): Name of Athena Table
        database (str): Athena Database Name (default: workbench)
    """
    # Ensure the data_uuid is a valid name/id
    self.is_name_valid(data_uuid)

    # Call superclass init
    super().__init__(data_uuid, database, **kwargs)

    # Grab our metadata from the Meta class
    self.log.info(f"Retrieving metadata for: {self.uuid}...")
    self.data_source_meta = self.meta.data_source(data_uuid, database=database)
    if self.data_source_meta is None:
        self.log.error(f"Unable to find {database}:{self.table} in Glue Catalogs...")
        return

    # Call superclass post init
    super().__post_init__()

    # All done
    self.log.debug(f"AthenaSource Initialized: {database}.{self.table}")

arn()

AWS ARN (Amazon Resource Name) for this artifact

Source code in src/workbench/core/artifacts/athena_source.py
def arn(self) -> str:
    """AWS ARN (Amazon Resource Name) for this artifact"""
    # Grab our Workbench Role Manager, get our AWS account id, and region for ARN creation
    account_id = self.aws_account_clamp.account_id
    region = self.aws_account_clamp.region
    arn = f"arn:aws:glue:{region}:{account_id}:table/{self.database}/{self.table}"
    return arn

athena_test_query()

Validate that Athena Queries are working

Source code in src/workbench/core/artifacts/athena_source.py
def athena_test_query(self):
    """Validate that Athena Queries are working"""
    query = f'select count(*) as workbench_count from "{self.table}"'
    df = wr.athena.read_sql_query(
        sql=query,
        database=self.database,
        ctas_approach=False,
        boto3_session=self.boto3_session,
    )
    scanned_bytes = df.query_metadata["Statistics"]["DataScannedInBytes"]
    self.log.info(f"Athena TEST Query successful (scanned bytes: {scanned_bytes})")

aws_meta()

Get the FULL AWS metadata for this artifact

Source code in src/workbench/core/artifacts/athena_source.py
def aws_meta(self) -> dict:
    """Get the FULL AWS metadata for this artifact"""
    return self.data_source_meta

aws_url()

The AWS URL for looking at/querying this data source

Source code in src/workbench/core/artifacts/athena_source.py
def aws_url(self):
    """The AWS URL for looking at/querying this data source"""
    workbench_details = self.workbench_meta().get("workbench_details", {})
    return workbench_details.get("aws_url", "unknown")

column_stats(recompute=False)

Compute Column Stats for all the columns in a DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the column stats (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of stats for each column this format

NB dict[dict]

String columns will NOT have num_zeros, descriptive_stats or correlation data {'col1': {'dtype': 'string', 'unique': 4321, 'nulls': 12}, 'col2': {'dtype': 'int', 'unique': 4321, 'nulls': 12, 'num_zeros': 100, 'descriptive_stats': {...}, 'correlations': {...}}, ...}

Source code in src/workbench/core/artifacts/athena_source.py
def column_stats(self, recompute: bool = False) -> dict[dict]:
    """Compute Column Stats for all the columns in a DataSource

    Args:
        recompute (bool): Recompute the column stats (default: False)

    Returns:
        dict(dict): A dictionary of stats for each column this format
        NB: String columns will NOT have num_zeros, descriptive_stats or correlation data
            {'col1': {'dtype': 'string', 'unique': 4321, 'nulls': 12},
             'col2': {'dtype': 'int', 'unique': 4321, 'nulls': 12, 'num_zeros': 100,
                      'descriptive_stats': {...}, 'correlations': {...}},
             ...}
    """

    # First check if we have already computed the column stats
    columns_stats_dict = self.workbench_meta().get("workbench_column_stats")
    if columns_stats_dict and not recompute:
        return columns_stats_dict

    # Call the SQL function to compute column stats
    column_stats_dict = sql.column_stats(self, recompute=recompute)

    # Push the column stats data into our DataSource Metadata
    self.upsert_workbench_meta({"workbench_column_stats": column_stats_dict})

    # Return the column stats data
    return column_stats_dict

correlations(recompute=False)

Compute Correlations for all the numeric columns in a DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the column stats (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of correlations for each column in this format {'col1': {'col2': 0.5, 'col3': 0.9, 'col4': 0.4, ...}, 'col2': {'col1': 0.5, 'col3': 0.8, 'col4': 0.3, ...}}

Source code in src/workbench/core/artifacts/athena_source.py
def correlations(self, recompute: bool = False) -> dict[dict]:
    """Compute Correlations for all the numeric columns in a DataSource

    Args:
        recompute (bool): Recompute the column stats (default: False)

    Returns:
        dict(dict): A dictionary of correlations for each column in this format
             {'col1': {'col2': 0.5, 'col3': 0.9, 'col4': 0.4, ...},
              'col2': {'col1': 0.5, 'col3': 0.8, 'col4': 0.3, ...}}
    """

    # First check if we have already computed the correlations
    correlations_dict = self.workbench_meta().get("workbench_correlations")
    if correlations_dict and not recompute:
        return correlations_dict

    # Call the SQL function to compute correlations
    correlations_dict = sql.correlations(self)

    # Push the correlation data into our DataSource Metadata
    self.upsert_workbench_meta({"workbench_correlations": correlations_dict})

    # Return the correlation data
    return correlations_dict

created()

Return the datetime when this artifact was created

Source code in src/workbench/core/artifacts/athena_source.py
def created(self) -> datetime:
    """Return the datetime when this artifact was created"""
    return self.data_source_meta["CreateTime"]

database_query(database, query) classmethod

Specify the Database and Query the Athena Service

Parameters:

Name Type Description Default
database str

The Athena Database to query

required
query str

The query to run against the AthenaSource

required

Returns:

Type Description
Union[DataFrame, None]

pd.DataFrame: The results of the query

Source code in src/workbench/core/artifacts/athena_source.py
@classmethod
def database_query(cls, database: str, query: str) -> Union[pd.DataFrame, None]:
    """Specify the Database and Query the Athena Service

    Args:
        database (str): The Athena Database to query
        query (str): The query to run against the AthenaSource

    Returns:
        pd.DataFrame: The results of the query
    """
    cls.log.debug(f"Executing Query: {query}...")
    try:
        df = wr.athena.read_sql_query(
            sql=query,
            database=database,
            ctas_approach=False,
            boto3_session=cls.boto3_session,
        )
        scanned_bytes = df.query_metadata["Statistics"]["DataScannedInBytes"]
        if scanned_bytes > 0:
            cls.log.debug(f"Athena Query successful (scanned bytes: {scanned_bytes})")
        return df
    except wr.exceptions.QueryFailed as e:
        cls.log.critical(f"Failed to execute query: {e}")
        return None

delete()

Instance Method: Delete the AWS Data Catalog Table and S3 Storage Objects

Source code in src/workbench/core/artifacts/athena_source.py
def delete(self):
    """Instance Method: Delete the AWS Data Catalog Table and S3 Storage Objects"""

    # Make sure the AthenaSource exists
    if not self.exists():
        self.log.warning(f"Trying to delete an AthenaSource that doesn't exist: {self.uuid}")

    # Call the Class Method to delete the AthenaSource
    AthenaSource.managed_delete(self.uuid, database=self.database)

delete_views(table, database) classmethod

Delete any views associated with this FeatureSet

Parameters:

Name Type Description Default
table str

Name of Athena Table

required
database str

Athena Database Name

required
Source code in src/workbench/core/artifacts/athena_source.py
@classmethod
def delete_views(cls, table: str, database: str):
    """Delete any views associated with this FeatureSet

    Args:
        table (str): Name of Athena Table
        database (str): Athena Database Name
    """
    from workbench.core.views.view_utils import delete_views_and_supplemental_data

    delete_views_and_supplemental_data(table, database, cls.boto3_session)

descriptive_stats(recompute=False)

Compute Descriptive Stats for all the numeric columns in a DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the descriptive stats (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of descriptive stats for each column in the form {'col1': {'min': 0, 'q1': 1, 'median': 2, 'q3': 3, 'max': 4}, 'col2': ...}

Source code in src/workbench/core/artifacts/athena_source.py
def descriptive_stats(self, recompute: bool = False) -> dict[dict]:
    """Compute Descriptive Stats for all the numeric columns in a DataSource

    Args:
        recompute (bool): Recompute the descriptive stats (default: False)

    Returns:
        dict(dict): A dictionary of descriptive stats for each column in the form
             {'col1': {'min': 0, 'q1': 1, 'median': 2, 'q3': 3, 'max': 4},
              'col2': ...}
    """

    # First check if we have already computed the descriptive stats
    stat_dict = self.workbench_meta().get("workbench_descriptive_stats")
    if stat_dict and not recompute:
        return stat_dict

    # Call the SQL function to compute descriptive stats
    stat_dict = sql.descriptive_stats(self)

    # Push the descriptive stat data into our DataSource Metadata
    self.upsert_workbench_meta({"workbench_descriptive_stats": stat_dict})

    # Return the descriptive stats
    return stat_dict

details(recompute=False)

Additional Details about this AthenaSource Artifact

Parameters:

Name Type Description Default
recompute bool

Recompute the details (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of details about this AthenaSource

Source code in src/workbench/core/artifacts/athena_source.py
def details(self, recompute: bool = False) -> dict[dict]:
    """Additional Details about this AthenaSource Artifact

    Args:
        recompute (bool): Recompute the details (default: False)

    Returns:
        dict(dict): A dictionary of details about this AthenaSource
    """
    self.log.info(f"Computing DataSource Details ({self.uuid})...")

    # Get the details from the base class
    details = super().details()

    # Compute additional details
    details["s3_storage_location"] = self.s3_storage_location()
    details["storage_type"] = "athena"

    # Compute our AWS URL
    query = f'select * from "{self.database}.{self.table}" limit 10'
    query_exec_id = wr.athena.start_query_execution(
        sql=query, database=self.database, boto3_session=self.boto3_session
    )
    base_url = "https://console.aws.amazon.com/athena/home"
    details["aws_url"] = f"{base_url}?region={self.aws_region}#query/history/{query_exec_id}"

    # Push the aws_url data into our DataSource Metadata
    # FIXME: We need to revisit this but doing an upsert just for aws_url is silly
    # self.upsert_workbench_meta({"workbench_details": {"aws_url": details["aws_url"]}})

    # Convert any datetime fields to ISO-8601 strings
    details = convert_all_to_iso8601(details)

    # Add the column stats
    details["column_stats"] = self.column_stats()

    # Return the details data
    return details

execute_statement(query, silence_errors=False)

Execute a non-returning SQL statement in Athena with retries.

Parameters:

Name Type Description Default
query str

The query to run against the AthenaSource

required
silence_errors bool

Silence errors (default: False)

False
Source code in src/workbench/core/artifacts/athena_source.py
def execute_statement(self, query: str, silence_errors: bool = False):
    """Execute a non-returning SQL statement in Athena with retries.

    Args:
        query (str): The query to run against the AthenaSource
        silence_errors (bool): Silence errors (default: False)
    """
    attempt = 0
    max_retries = 3
    retry_delay = 10
    while attempt < max_retries:
        try:
            # Start the query execution
            query_execution_id = wr.athena.start_query_execution(
                sql=query,
                database=self.database,
                boto3_session=self.boto3_session,
            )
            self.log.debug(f"QueryExecutionId: {query_execution_id}")

            # Wait for the query to complete
            wr.athena.wait_query(query_execution_id=query_execution_id, boto3_session=self.boto3_session)
            self.log.debug(f"Statement executed successfully: {query_execution_id}")
            break  # If successful, exit the retry loop
        except wr.exceptions.QueryFailed as e:
            if "AlreadyExistsException" in str(e):
                self.log.warning(f"Table already exists: {e} \nIgnoring...")
                break  # No need to retry for this error
            elif "ConcurrentModificationException" in str(e):
                self.log.warning(f"Concurrent modification detected: {e}\nRetrying...")
                attempt += 1
                if attempt < max_retries:
                    time.sleep(retry_delay)
                else:
                    if not silence_errors:
                        self.log.critical(f"Failed to execute statement after {max_retries} attempts: {e}")
                    raise
            else:
                if not silence_errors:
                    self.log.critical(f"Failed to execute statement: {e}")
                raise

exists()

Validation Checks for this Data Source

Source code in src/workbench/core/artifacts/athena_source.py
def exists(self) -> bool:
    """Validation Checks for this Data Source"""

    # Are we able to pull AWS Metadata for this table_name?"""
    # Do we have a valid data_source_meta?
    if getattr(self, "data_source_meta", None) is None:
        self.log.debug(f"AthenaSource {self.table} not found in Workbench Metadata...")
        return False
    return True

hash()

Get the hash for the set of Parquet files used for this Artifact

Source code in src/workbench/core/artifacts/athena_source.py
def hash(self) -> str:
    """Get the hash for the set of Parquet files used for this Artifact"""
    s3_uri = self.s3_storage_location()
    return compute_parquet_hash(s3_uri, self.boto3_session)

managed_delete(data_source_name, database='workbench') classmethod

Class Method: Delete the AWS Data Catalog Table and S3 Storage Objects

Parameters:

Name Type Description Default
data_source_name str

Name of DataSource (AthenaSource)

required
database str

Athena Database Name (default: workbench)

'workbench'
Source code in src/workbench/core/artifacts/athena_source.py
@classmethod
def managed_delete(cls, data_source_name: str, database: str = "workbench"):
    """Class Method: Delete the AWS Data Catalog Table and S3 Storage Objects

    Args:
        data_source_name (str): Name of DataSource (AthenaSource)
        database (str): Athena Database Name (default: workbench)
    """
    table = data_source_name  # The table name is the same as the data_source_name

    # Check if the Glue Catalog Table exists
    if not wr.catalog.does_table_exist(database, table, boto3_session=cls.boto3_session):
        cls.log.info(f"DataSource {table} not found in database {database}.")
        return

    # Delete any views associated with this AthenaSource
    cls.delete_views(table, database)

    # Delete S3 Storage Objects (if they exist)
    try:
        # Make an AWS Query to get the S3 storage location
        s3_path = wr.catalog.get_table_location(database, table, boto3_session=cls.boto3_session)

        # Delete Data Catalog Table
        cls.log.info(f"Deleting DataCatalog Table: {database}.{table}...")
        wr.catalog.delete_table_if_exists(database, table, boto3_session=cls.boto3_session)

        # Make sure we add the trailing slash
        s3_path = s3_path if s3_path.endswith("/") else f"{s3_path}/"
        cls.log.info(f"Deleting S3 Storage Objects: {s3_path}...")
        wr.s3.delete_objects(s3_path, boto3_session=cls.boto3_session)
    except Exception as e:
        cls.log.error(f"Failure when trying to delete {data_source_name}: {e}")

    # Delete any dataframes that were stored in the Dataframe Cache
    cls.log.info("Deleting Dataframe Cache...")
    cls.df_cache.delete_recursive(data_source_name)

modified()

Return the datetime when this artifact was last modified

Source code in src/workbench/core/artifacts/athena_source.py
def modified(self) -> datetime:
    """Return the datetime when this artifact was last modified"""
    return self.data_source_meta["UpdateTime"]

num_columns()

Return the number of columns for this Data Source

Source code in src/workbench/core/artifacts/athena_source.py
def num_columns(self) -> int:
    """Return the number of columns for this Data Source"""
    return len(self.columns)

num_rows()

Return the number of rows for this Data Source

Source code in src/workbench/core/artifacts/athena_source.py
def num_rows(self) -> int:
    """Return the number of rows for this Data Source"""
    count_df = self.query(f'select count(*) AS workbench_count from "{self.database}"."{self.table}"')
    return count_df["workbench_count"][0] if count_df is not None else 0

outliers(scale=1.5, use_stddev=False)

Compute outliers for all the numeric columns in a DataSource

Parameters:

Name Type Description Default
scale float

The scale to use for the IQR (default: 1.5)

1.5
use_stddev bool

Use Standard Deviation instead of IQR (default: False)

False

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame of outliers from this DataSource

Notes

Uses the IQR * 1.5 (~= 2.5 Sigma) (use 1.7 for ~= 3 Sigma) The scale parameter can be adjusted to change the IQR multiplier

Source code in src/workbench/core/artifacts/athena_source.py
@cache_dataframe("outliers")
def outliers(self, scale: float = 1.5, use_stddev=False) -> pd.DataFrame:
    """Compute outliers for all the numeric columns in a DataSource

    Args:
        scale (float): The scale to use for the IQR (default: 1.5)
        use_stddev (bool): Use Standard Deviation instead of IQR (default: False)

    Returns:
        pd.DataFrame: A DataFrame of outliers from this DataSource

    Notes:
        Uses the IQR * 1.5 (~= 2.5 Sigma) (use 1.7 for ~= 3 Sigma)
        The scale parameter can be adjusted to change the IQR multiplier
    """

    # Compute outliers using the SQL Outliers class
    sql_outliers = sql.outliers.Outliers()
    return sql_outliers.compute_outliers(self, scale=scale, use_stddev=use_stddev)

query(query)

Query the AthenaSource

Parameters:

Name Type Description Default
query str

The query to run against the AthenaSource

required

Returns:

Type Description
Union[DataFrame, None]

pd.DataFrame: The results of the query

Source code in src/workbench/core/artifacts/athena_source.py
def query(self, query: str) -> Union[pd.DataFrame, None]:
    """Query the AthenaSource

    Args:
        query (str): The query to run against the AthenaSource

    Returns:
        pd.DataFrame: The results of the query
    """

    # Call internal class _query method
    return self.database_query(self.database, query)

refresh_meta()

Refresh our internal AWS Broker catalog metadata

Source code in src/workbench/core/artifacts/athena_source.py
def refresh_meta(self):
    """Refresh our internal AWS Broker catalog metadata"""
    self.data_source_meta = self.meta.data_source(self.uuid, database=self.database)

s3_storage_location()

Get the S3 Storage Location for this Data Source

Source code in src/workbench/core/artifacts/athena_source.py
def s3_storage_location(self) -> str:
    """Get the S3 Storage Location for this Data Source"""
    return self.data_source_meta["StorageDescriptor"]["Location"]

sample()

Pull a sample of rows from the DataSource

Returns:

Type Description
DataFrame

pd.DataFrame: A sample DataFrame for an Athena DataSource

Source code in src/workbench/core/artifacts/athena_source.py
@cache_dataframe("sample")
def sample(self) -> pd.DataFrame:
    """Pull a sample of rows from the DataSource

    Returns:
        pd.DataFrame: A sample DataFrame for an Athena DataSource
    """

    # Call the SQL function to pull a sample of the rows
    return sql.sample_rows(self)

size()

Return the size of this data in MegaBytes

Source code in src/workbench/core/artifacts/athena_source.py
def size(self) -> float:
    """Return the size of this data in MegaBytes"""
    size_in_bytes = sum(wr.s3.size_objects(self.s3_storage_location(), boto3_session=self.boto3_session).values())
    size_in_mb = size_in_bytes / 1_000_000
    return size_in_mb

smart_sample(recompute=False)

Get a smart sample dataframe for this DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the smart sample (default: False)

False

Returns:

Type Description
DataFrame

pd.DataFrame: A combined DataFrame of sample data + outliers

Source code in src/workbench/core/artifacts/athena_source.py
@cache_dataframe("smart_sample")
def smart_sample(self, recompute: bool = False) -> pd.DataFrame:
    """Get a smart sample dataframe for this DataSource

    Args:
        recompute (bool): Recompute the smart sample (default: False)

    Returns:
        pd.DataFrame: A combined DataFrame of sample data + outliers
    """

    # Compute/recompute the smart sample
    self.log.important(f"Computing Smart Sample {self.uuid}...")

    # Outliers DataFrame
    outlier_rows = self.outliers()

    # Sample DataFrame
    sample_rows = self.sample()
    sample_rows["outlier_group"] = "sample"

    # Combine the sample rows with the outlier rows
    all_rows = pd.concat([outlier_rows, sample_rows]).reset_index(drop=True)

    # Drop duplicates
    all_except_outlier_group = [col for col in all_rows.columns if col != "outlier_group"]
    all_rows = all_rows.drop_duplicates(subset=all_except_outlier_group, ignore_index=True)

    # Return the smart_sample data
    return all_rows

table_hash()

Get the table hash for this AthenaSource

Source code in src/workbench/core/artifacts/athena_source.py
def table_hash(self) -> str:
    """Get the table hash for this AthenaSource"""
    s3_scratch = f"s3://{self.workbench_bucket}/temp/athena_output"
    return compute_athena_table_hash(self.database, self.table, self.boto3_session, s3_scratch)

upsert_workbench_meta(new_meta)

Add Workbench specific metadata to this Artifact

Parameters:

Name Type Description Default
new_meta dict

Dictionary of new metadata to add

required
Source code in src/workbench/core/artifacts/athena_source.py
def upsert_workbench_meta(self, new_meta: dict):
    """Add Workbench specific metadata to this Artifact

    Args:
        new_meta (dict): Dictionary of new metadata to add
    """
    self.log.important(f"Upserting Workbench Metadata {self.uuid}:{str(new_meta)[:50]}...")

    # Give a warning message for keys that don't start with workbench_
    for key in new_meta.keys():
        if not key.startswith("workbench_"):
            self.log.warning("Append 'workbench_' to key names to avoid overwriting AWS meta data")

    # Now convert any non-string values to JSON strings
    for key, value in new_meta.items():
        if not isinstance(value, str):
            new_meta[key] = json.dumps(value, cls=CustomEncoder)

    # Store our updated metadata
    try:
        wr.catalog.upsert_table_parameters(
            parameters=new_meta,
            database=self.database,
            table=self.table,
            boto3_session=self.boto3_session,
        )
    except botocore.exceptions.ClientError as e:
        error_code = e.response["Error"]["Code"]
        if error_code == "InvalidInputException":
            self.log.error(f"Unable to upsert metadata for {self.table}")
            self.log.error("Probably because the metadata is too large")
            self.log.error(new_meta)
        elif error_code == "ConcurrentModificationException":
            self.log.warning("ConcurrentModificationException... trying again...")
            time.sleep(5)
            wr.catalog.upsert_table_parameters(
                parameters=new_meta,
                database=self.database,
                table=self.table,
                boto3_session=self.boto3_session,
            )
        else:
            self.log.critical(f"Failed to upsert metadata: {e}")
            self.log.critical(f"{self.uuid} is Malformed! Delete this Artifact and recreate it!")
    except Exception as e:
        self.log.critical(f"Failed to upsert metadata: {e}")
        self.log.critical(f"{self.uuid} is Malformed! Delete this Artifact and recreate it!")

value_counts(recompute=False)

Compute 'value_counts' for all the string columns in a DataSource

Parameters:

Name Type Description Default
recompute bool

Recompute the value counts (default: False)

False

Returns:

Name Type Description
dict dict

A dictionary of value counts for each column in the form {'col1': {'value_1': 42, 'value_2': 16, 'value_3': 9,...}, 'col2': ...}

Source code in src/workbench/core/artifacts/athena_source.py
def value_counts(self, recompute: bool = False) -> dict[dict]:
    """Compute 'value_counts' for all the string columns in a DataSource

    Args:
        recompute (bool): Recompute the value counts (default: False)

    Returns:
        dict(dict): A dictionary of value counts for each column in the form
             {'col1': {'value_1': 42, 'value_2': 16, 'value_3': 9,...},
              'col2': ...}
    """

    # First check if we have already computed the value counts
    value_counts_dict = self.workbench_meta().get("workbench_value_counts")
    if value_counts_dict and not recompute:
        return value_counts_dict

    # Call the SQL function to compute value_counts
    value_count_dict = sql.value_counts(self)

    # Push the value_count data into our DataSource Metadata
    self.upsert_workbench_meta({"workbench_value_counts": value_count_dict})

    # Return the value_count data
    return value_count_dict

workbench_meta()

Get the Workbench specific metadata for this Artifact

Source code in src/workbench/core/artifacts/athena_source.py
def workbench_meta(self) -> dict:
    """Get the Workbench specific metadata for this Artifact"""

    # Sanity Check if we have invalid AWS Metadata
    if self.data_source_meta is None:
        if not self.exists():
            self.log.error(f"DataSource {self.uuid} doesn't appear to exist...")
        else:
            self.log.critical(f"Unable to get AWS Metadata for {self.table}")
            self.log.critical("Malformed Artifact! Delete this Artifact and recreate it!")
        return {}

    # Get the Workbench Metadata from the 'Parameters' section of the DataSource Metadata
    params = self.data_source_meta.get("Parameters", {})
    return {key: decode_value(value) for key, value in params.items() if "workbench" in key}