Skip to content

InferenceCache

InferenceCache is a caching wrapper around a Workbench Endpoint. It's handy when an endpoint is slow to invoke and the same inputs show up across calls — the motivating example is the 3D molecular feature endpoint smiles-to-3d-fast-v1, which takes real time to generate conformers and force-field optimize each molecule.

On each inference(df) call, rows whose cache-key value is already in the cache are served from S3, and only the new rows go to the underlying endpoint. Newly-computed rows are written back to the cache. The cache lives in a shared S3-backed DFStore, so once one person has computed a row, everyone gets it for free.

Not the same as workbench.cached.CachedEndpoint

CachedEndpoint caches metadata methods like summary(), details(), and health_check(). InferenceCache caches inference results. Different classes, different concerns.

Example

inference_cache_example.py
from workbench.api import Endpoint, FeatureSet, InferenceCache

# Wrap a slow endpoint in an InferenceCache
endpoint = Endpoint("smiles-to-3d-fast-v1")
cached_endpoint = InferenceCache(endpoint, cache_key_column="smiles")

# Pull a DataFrame of molecules and run inference
df = FeatureSet("feature_endpoint_fs").pull_dataframe()[:50]

# First call: slow (cache is empty, rows go to the endpoint)
results = cached_endpoint.inference(df)

# Second call with the same SMILES: near-instant (all hits)
results_again = cached_endpoint.inference(df)

# Drop a bad row so it recomputes on the next call
cached_endpoint.delete_entries("c1ccc(cc1)C(=O)O")

# Or drop many at once
cached_endpoint.delete_entries(["CCO", "CCN", "CCOCC"])

# Inspect the cache
print(cached_endpoint.cache_size())
print(cached_endpoint.cache_info())

Output (log lines)

InferenceCache[smiles-to-3d-fast-v1]: 0/50 cache hits
InferenceCache[smiles-to-3d-fast-v1]: computing 50 new rows via endpoint
InferenceCache[smiles-to-3d-fast-v1]: 50/50 cache hits
InferenceCache[smiles-to-3d-fast-v1]: removed 1 entries
InferenceCache[smiles-to-3d-fast-v1]: removed 3 entries

Endpoint change detection

By default, InferenceCache keeps the existing cache regardless of endpoint changes. If you want it to automatically clear the cache when the endpoint has been modified since the cache was last written, pass auto_invalidate_cache=True:

cached_endpoint = InferenceCache(endpoint, cache_key_column="smiles", auto_invalidate_cache=True)

A tiny sidecar manifest stores the endpoint's modified() timestamp; when auto-invalidation is enabled, the cache is cleared on the next access if the stored and current timestamps differ.

Attribute delegation

InferenceCache forwards anything it doesn't define to the wrapped endpoint, so cached_endpoint.name, cached_endpoint.details(), cached_endpoint.fast_inference(), etc. all Just Work.

API Reference

InferenceCache: Client-side caching wrapper around a Workbench Endpoint.

Wraps an Endpoint and stores inference results in a shared S3-backed DFStore keyed on a cache-key column (SMILES by default). On each inference(df) call, rows whose cache-key value is already in the cache are served from S3, and only the remaining rows are sent to the underlying endpoint. Newly computed rows are written back to the cache.

Motivating use case: the smiles-to-3d-fast-v1 feature endpoint is slow (conformer generation + FF optimization), and the same SMILES is frequently re-computed across calls.

Note: this is distinct from workbench.cached.CachedEndpoint, which caches metadata methods (summary, details, health_check). This class caches inference results.

InferenceCache

InferenceCache: Client-side caching wrapper for a Workbench Endpoint.

Common Usage
from workbench.api import Endpoint
from workbench.api.inference_cache import InferenceCache

endpoint = Endpoint("smiles-to-3d-fast-v1")
cached_endpoint = InferenceCache(endpoint, cache_key_column="smiles")

# Drop-in replacement for endpoint.inference()
result_df = cached_endpoint.inference(eval_df)

# Other endpoint methods still work via attribute delegation
print(cached_endpoint.name)
cached_endpoint.details()
Source code in src/workbench/api/inference_cache.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
class InferenceCache:
    """InferenceCache: Client-side caching wrapper for a Workbench Endpoint.

    Common Usage:
        ```python
        from workbench.api import Endpoint
        from workbench.api.inference_cache import InferenceCache

        endpoint = Endpoint("smiles-to-3d-fast-v1")
        cached_endpoint = InferenceCache(endpoint, cache_key_column="smiles")

        # Drop-in replacement for endpoint.inference()
        result_df = cached_endpoint.inference(eval_df)

        # Other endpoint methods still work via attribute delegation
        print(cached_endpoint.name)
        cached_endpoint.details()
        ```
    """

    # Rows per cache write. The endpoint is called once per chunk and the
    # cache is persisted between chunks, so this also bounds the blast radius
    # of an interrupted/failed write to one chunk worth of work.
    #
    # The actual chunk_size on each instance is set in __init__: either the
    # explicit ``chunk_size`` constructor kwarg, or — for async endpoints with
    # max_instances in their workbench_meta — derived from fleet capacity
    # (max_instances × batch_size × 2) so each chunk holds an integer number
    # of full fleet-waves. This avoids the "10 batches / 8 instances → tail"
    # utilization loss. Falls back to this class attribute (DEFAULT_CHUNK_SIZE)
    # for sync endpoints or legacy endpoints without max_instances in meta.
    chunk_size: int = DEFAULT_CHUNK_SIZE

    # Number of fleet-waves per chunk when auto-deriving chunk_size. With k
    # batches per worker, relative tail-variance scales as 1/√k, so bumping
    # k from 2 to 4 cuts tail overhead ~30% without changing batch_size (so
    # per-batch polling cost is unchanged). Crash-recovery loss is one chunk
    # = 4 fleet-waves of work, modest for any reasonable batch pipeline.
    _CHUNK_WAVES = 4

    def __init__(
        self,
        endpoint: Endpoint,
        cache_key_column: str = "smiles",
        output_key_column: Optional[str] = None,
        auto_invalidate_cache: bool = False,
        chunk_size: Optional[int] = None,
    ):
        """Initialize the InferenceCache.

        Args:
            endpoint (Endpoint): The Workbench Endpoint to wrap.
            cache_key_column (str): Name of the column whose values are used
                as the cache key (default: "smiles").
            output_key_column (Optional[str]): Name of the column in the
                endpoint's *output* that contains the original input key
                values. Some endpoints normalize/canonicalize the key column
                (e.g. canonical SMILES) and place the original value in a
                separate column (e.g. "orig_smiles"). When set, the cache
                uses this column's values as the key so future lookups with
                the original input values still hit. When None (default),
                the cache key column in the output is assumed to match the
                input unchanged.
            auto_invalidate_cache (bool): When True, automatically clear the
                cache if the endpoint has been modified since the cache was
                last written. When False (default), the existing cache is
                kept regardless of endpoint changes — the manifest is
                reseeded on first load so subsequent calls have a consistent
                baseline.
            chunk_size (Optional[int]): Rows per cache write. If ``None``
                (default), derived from the endpoint's ``max_instances`` and
                ``inference_batch_size`` to produce full fleet-waves — see
                :meth:`_derive_chunk_size`. Falls back to
                ``DEFAULT_CHUNK_SIZE`` when fleet info isn't available.
        """
        self._endpoint = endpoint
        self.cache_key_column = cache_key_column
        self.output_key_column = output_key_column
        self.cache_path = f"/workbench/inference_cache/{endpoint.name}"
        self.manifest_path = f"{self.cache_path}__meta"
        self._df_store = DFStore()
        self._cache_df: Optional[pd.DataFrame] = None  # lazy-loaded
        self._invalidation_checked = False  # per-instance, one-shot
        self._auto_invalidate_cache = auto_invalidate_cache
        # Canonical dtype map for the cache, captured on first non-empty load
        # and used to coerce subsequent appended chunks so concurrent writers
        # never produce a schema-incompatible dataset.
        self._canonical_dtypes: Optional[pd.Series] = None
        # (column, src_dtype, tgt_dtype) tuples we've already warned about —
        # keeps the coerce loop quiet when the same mismatch recurs every chunk.
        self._coerce_warned: set[tuple] = set()
        self.log = logging.getLogger("workbench")

        # Resolve chunk_size: explicit override wins; else try fleet-derivation;
        # else fall through to the class-level DEFAULT_CHUNK_SIZE.
        if chunk_size is not None:
            self.chunk_size = int(chunk_size)
        else:
            derived = self._derive_chunk_size()
            if derived is not None:
                self.chunk_size = derived

    def _derive_chunk_size(self) -> Optional[int]:
        """Derive chunk_size from the wrapped endpoint's fleet capacity.

        Returns ``capacity × batch_size × _CHUNK_WAVES`` so each chunk
        holds an integer number of full fleet-waves — preventing the
        "10 batches / 8 instances → 2-batch tail" utilization loss on async
        endpoints. ``capacity`` prefers ``effective_max_instances`` when
        present (set by MetaEndpoint to reflect the largest child fleet,
        since the meta itself deploys with ``max_instances=1`` regardless
        of downstream capacity), otherwise falls back to ``max_instances``.
        Returns ``None`` (→ caller should use DEFAULT_CHUNK_SIZE) when the
        endpoint's meta has neither, which is the case for sync endpoints
        and legacy async deploys.
        """
        try:
            meta = self._endpoint.workbench_meta() or {}
        except Exception:
            return None
        capacity = meta.get("effective_max_instances", meta.get("max_instances"))
        if capacity is None:
            return None
        # Mirror AsyncEndpointCore's own resolution: explicit meta override
        # wins, otherwise the core default.
        from workbench.core.artifacts.async_endpoint_core import _DEFAULT_BATCH_SIZE

        batch_size = int(meta.get("inference_batch_size", _DEFAULT_BATCH_SIZE))
        derived = int(capacity) * batch_size * self._CHUNK_WAVES
        self.log.info(
            f"InferenceCache[{self._endpoint.name}]: chunk_size={derived} "
            f"(capacity={capacity} × batch_size={batch_size} × "
            f"{self._CHUNK_WAVES} waves — full fleet utilization per chunk)"
        )
        return derived

    def __getattr__(self, name):
        """Delegate any unrecognized attribute access to the wrapped Endpoint."""
        # __getattr__ is only called when normal lookup fails, so this won't
        # interfere with our own attributes.
        return getattr(self._endpoint, name)

    def inference(self, eval_df: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
        """Run cached inference on ``eval_df``.

        Rows whose ``cache_key_column`` value is already in the cache are
        served from S3; the rest are sent to the underlying endpoint and the
        new results are written back to the cache. The returned DataFrame
        preserves the original row order of ``eval_df``.

        Args:
            eval_df (pd.DataFrame): DataFrame to run predictions on. Must
                contain ``self.cache_key_column``.
            **kwargs (Any): Forwarded to the wrapped ``Endpoint.inference()``
                for uncached rows.

        Returns:
            pd.DataFrame: ``eval_df`` with the endpoint's added columns
            left-joined on ``cache_key_column``.
        """
        key_col = self.cache_key_column
        if key_col not in eval_df.columns:
            raise ValueError(f"eval_df is missing required cache_key_column '{key_col}'")

        cache_df = self._load_cache()

        # Split eval rows into cache hits vs rows we still need to compute
        is_cached = eval_df[key_col].isin(cache_df[key_col])
        uncached_df = eval_df[~is_cached]
        cached_hits = cache_df[cache_df[key_col].isin(eval_df[key_col])]

        hits = len(eval_df) - len(uncached_df)
        self.log.info(f"InferenceCache[{self._endpoint.name}]: {hits}/{len(eval_df)} cache hits")

        # Run the endpoint on the uncached rows. The decorator on
        # _chunked_endpoint_inference handles chunking, per-chunk cache
        # writes, and error recovery so a single failed write doesn't
        # destroy the rest of the batch.
        new_results = pd.DataFrame()
        if not uncached_df.empty:
            to_compute = uncached_df.drop_duplicates(subset=[key_col])
            new_results = self._chunked_endpoint_inference(to_compute, **kwargs)

        # Combine cached + new into a single feature table, then left-join
        # back onto eval_df to preserve row order and any extra input columns.
        # (Filter out empty frames to dodge a pandas FutureWarning about
        # dtype inference on empty entries.)
        frames = [f for f in (cached_hits, new_results) if not f.empty]
        if not frames:
            return eval_df.copy()
        feature_table = pd.concat(frames, ignore_index=True).drop_duplicates(subset=[key_col], keep="last")
        feature_cols = [c for c in feature_table.columns if c not in eval_df.columns]
        return eval_df.merge(feature_table[[key_col] + feature_cols], on=key_col, how="left")

    # ---- cache introspection / maintenance ----
    def cache_size(self) -> int:
        """Number of rows currently in the cache."""
        return len(self._load_cache())

    def cache_info(self) -> dict:
        """Summary of the cache: path, row count, columns, manifest."""
        df = self._load_cache()
        return {
            "path": self.cache_path,
            "rows": len(df),
            "columns": list(df.columns),
            "manifest": self._load_manifest(),
        }

    def clear_cache(self) -> None:
        """Delete the cache (and manifest) from S3 and reset in-memory state."""
        if self._df_store.check(self.cache_path):
            self._df_store.delete(self.cache_path)
        if self._df_store.check(self.manifest_path):
            self._df_store.delete(self.manifest_path)
        self._cache_df = pd.DataFrame(columns=[self.cache_key_column])

    def delete_entries(self, keys: Union[Any, Iterable[Any]]) -> int:
        """Remove one or more entries from the cache by cache-key value(s).

        Use this to drop bad results that should be recomputed on the next
        ``inference()`` call.

        Args:
            keys (Union[Any, Iterable[Any]]): A single cache-key value, or an
                iterable of them.

        Returns:
            int: Number of rows removed from the cache.
        """
        if isinstance(keys, (str, bytes)) or not hasattr(keys, "__iter__"):
            keys = [keys]
        keys = list(keys)

        cache_df = self._load_cache()
        if cache_df.empty:
            return 0

        keep_mask = ~cache_df[self.cache_key_column].isin(keys)
        removed = int((~keep_mask).sum())
        if removed == 0:
            return 0

        new_cache = cache_df[keep_mask].reset_index(drop=True)
        if new_cache.empty:
            # Nothing left — delete the cache file entirely but keep the manifest
            if self._df_store.check(self.cache_path):
                self._df_store.delete(self.cache_path)
        else:
            self._df_store.upsert(self.cache_path, new_cache)
        self._cache_df = new_cache
        self.log.info(f"InferenceCache[{self._endpoint.name}]: removed {removed} entries")
        return removed

    # ---- internals ----

    def _load_cache(self) -> pd.DataFrame:
        """Lazily load the cache DataFrame from DFStore.

        If the cache doesn't yet exist, returns an empty DataFrame that
        still has ``cache_key_column`` defined, so callers can always do
        ``df[cache_key_column]`` without special-casing the empty case.

        On first call, also checks whether the endpoint has been modified
        since the cache was written and auto-invalidates if so.
        """
        if self._cache_df is None:
            if not self._invalidation_checked:
                if self._auto_invalidate_cache:
                    self._check_endpoint_changed()
                else:
                    # Skip the auto-invalidation check and reseed the manifest
                    # so the stored modified time matches the current endpoint.
                    self.log.info(
                        f"InferenceCache[{self._endpoint.name}]: auto_invalidate_cache=False, "
                        f"reseeding manifest and keeping existing cache"
                    )
                    if self._df_store.check(self.cache_path):
                        self._save_manifest()
                self._invalidation_checked = True

            df = self._read_cache_with_retry()
            if df is None:
                df = pd.DataFrame(columns=[self.cache_key_column])
            if not df.empty:
                # Seed canonical only from columns with real data; all-null
                # columns carry no dtype signal and must wait for real values.
                self._seed_canonical_from(df)
            self._cache_df = df
        return self._cache_df

    def _read_cache_with_retry(self, attempts: int = 3, backoff: float = 0.5) -> Optional[pd.DataFrame]:
        """Read the cache, tolerating transient and schema-mismatch failures.

        Handles two distinct error classes:

        - **Transient** (e.g. ``NoSuchKey`` from a concurrent overwrite or
          compaction) — bounded retries with short backoff. The cache may
          recover within a couple of seconds.
        - **Schema-mismatch** (PyArrow ``ArrowTypeError`` / ``ArrowInvalid``
          from incompatible parquet files under the dataset prefix) — not
          transient; don't retry. Log and return ``None`` so the caller
          treats the cache as empty and the affected rows recompute on the
          next inference call. Run :meth:`clear_cache` or inspect the files
          manually to resolve; :meth:`compact` also reads through the same
          path so it will not self-heal this case.
        """
        try:
            import pyarrow as pa  # deferred import; pa is transitively installed

            schema_errs: tuple = (pa.lib.ArrowTypeError, pa.lib.ArrowInvalid)
        except Exception:
            schema_errs = ()

        last_err: Optional[Exception] = None
        for i in range(attempts):
            try:
                return self._df_store.get(self.cache_path)
            except schema_errs as e:
                self.log.error(
                    f"InferenceCache[{self._endpoint.name}]: cache is schema-"
                    f"incompatible ({type(e).__name__}: {e}). Treating as empty; "
                    f"rows will recompute. Run clear_cache() to reset."
                )
                return None
            except Exception as e:
                last_err = e
                self.log.warning(
                    f"InferenceCache[{self._endpoint.name}]: cache read failed "
                    f"(attempt {i + 1}/{attempts}): {type(e).__name__}: {e}"
                )
                time.sleep(backoff * (i + 1))
        self.log.error(
            f"InferenceCache[{self._endpoint.name}]: cache read failed after "
            f"{attempts} attempts, treating as empty: {last_err}"
        )
        return None

    @chunked_with_cache_writes
    def _chunked_endpoint_inference(self, chunk: pd.DataFrame, **kwargs) -> pd.DataFrame:
        """Run the wrapped endpoint on one chunk of rows.

        The :func:`chunked_with_cache_writes` decorator handles chunking,
        per-chunk persistence via :meth:`_update_cache`, and error recovery.
        """
        return self._endpoint.inference(chunk, **kwargs)

    def _update_cache(self, new_results: pd.DataFrame) -> None:
        """Persist ``new_results`` as a new file under the cache prefix.

        Uses :meth:`DFStore.append` so concurrent writers each land a distinct
        parquet file under the dataset prefix — eliminating the delete-then-
        write race of the old overwrite-based approach. The in-memory view
        (``self._cache_df``) is updated by concat+dedup so subsequent chunks
        in this process skip rows this worker just computed, but S3 only
        receives the new slice. All-null columns are dropped from the write
        (they'd otherwise pin a spurious dtype from pandas inference and
        drift against later chunks that have real values); missing columns
        on read are restored as ``NaN``. Remaining columns are coerced to
        the canonical schema so concurrently-appended files stay
        Arrow-mergeable.
        """
        if new_results.empty:
            return

        to_write = self._drop_all_null_columns(new_results)
        to_write = self._coerce_to_canonical(to_write)

        self._df_store.append(self.cache_path, to_write)
        self._save_manifest()

        # Update the local view so this worker's later chunks see these rows.
        old_cache = self._cache_df if self._cache_df is not None else pd.DataFrame(columns=[self.cache_key_column])
        frames = [f for f in (old_cache, to_write) if not f.empty]
        self._cache_df = pd.concat(frames, ignore_index=True).drop_duplicates(
            subset=[self.cache_key_column], keep="last"
        )

    @staticmethod
    def _column_has_data(s: pd.Series) -> bool:
        """True when a column has at least one non-null value.

        Columns that are entirely null carry no usable dtype signal — pandas
        infers ``float64`` for all-``NaN``, ``object`` for all-``None``, and
        pyarrow round-trips can reshape them further (``Int64`` for nullable
        ints, etc). We treat those columns as dtype-indeterminate.
        """
        return len(s) > 0 and not s.isna().all()

    def _drop_all_null_columns(self, df: pd.DataFrame) -> pd.DataFrame:
        """Drop columns whose values are entirely null in this chunk.

        Dropped columns reappear as ``NaN`` on read thanks to Arrow's
        dataset-schema unification, so this is lossless for consumers. The
        win: we don't pin a dtype for a column based on nothing, which would
        drift against later chunks that contain real values.

        The cache-key column is preserved even if somehow all null so the
        write still has a valid key column; downstream code relies on it
        existing.
        """
        keep = [c for c in df.columns if c == self.cache_key_column or self._column_has_data(df[c])]
        if len(keep) == len(df.columns):
            return df
        dropped = [c for c in df.columns if c not in keep]
        # Only log once per (column, instance) — very chatty otherwise.
        for c in dropped:
            key = ("drop_null", c)
            if key not in self._coerce_warned:
                self._coerce_warned.add(key)
                self.log.info(
                    f"InferenceCache[{self._endpoint.name}]: column '{c}' is "
                    f"all-null in this write; dropping from parquet (will "
                    f"read back as NaN until a chunk with real values lands)."
                )
        return df[keep]

    def _seed_canonical_from(self, df: pd.DataFrame) -> None:
        """Add columns with real data to the canonical dtype map.

        Called on first load and on every write so a column's dtype gets
        pinned the first time we actually see a non-null value for it.
        Columns that stay all-null remain absent from the map — no pinning.
        """
        if self._canonical_dtypes is None:
            self._canonical_dtypes = {}
        for col in df.columns:
            if col in self._canonical_dtypes:
                continue
            if self._column_has_data(df[col]):
                self._canonical_dtypes[col] = df[col].dtype

    def _coerce_to_canonical(self, df: pd.DataFrame) -> pd.DataFrame:
        """Cast ``df`` columns to match the canonical schema when one exists.

        Types are preserved as the endpoint produces them. The canonical map
        is seeded lazily, column by column, from the first chunk in which a
        column has at least one non-null value — so an early all-``NaN``
        observation never pins a spurious dtype. Subsequent writes coerce
        to match. No widening is performed.

        - Obviously-incompatible pairs (string/object source → numeric
          target) are skipped silently: the endpoint is the source of truth
          for its own output types, and attempting the cast would just raise
          every chunk.
        - Other coerce failures log a warning **once** per
          ``(column, src, tgt)`` tuple per instance, to avoid per-chunk
          log spam when a column genuinely drifts.
        - If schema drift does poison the dataset on disk,
          :meth:`_read_cache_with_retry` falls back to an empty result so
          rows recompute; :meth:`clear_cache` resets cleanly.
        """
        # Seed / extend canonical from any newly-observed real data.
        self._seed_canonical_from(df)

        if not self._canonical_dtypes:
            return df

        out = df.copy()
        for col, target_dtype in self._canonical_dtypes.items():
            if col not in out.columns:
                continue
            src_dtype = out[col].dtype
            if src_dtype == target_dtype:
                continue

            src_is_stringish = pd.api.types.is_string_dtype(src_dtype) or pd.api.types.is_object_dtype(src_dtype)
            tgt_is_numeric = pd.api.types.is_numeric_dtype(target_dtype) and not pd.api.types.is_bool_dtype(
                target_dtype
            )
            if src_is_stringish and tgt_is_numeric:
                continue

            try:
                out[col] = out[col].astype(target_dtype)
            except Exception as e:
                key = (col, str(src_dtype), str(target_dtype))
                if key not in self._coerce_warned:
                    self._coerce_warned.add(key)
                    self.log.warning(
                        f"InferenceCache[{self._endpoint.name}]: could not coerce "
                        f"column '{col}' from {src_dtype} to {target_dtype}: {e}. "
                        f"Writes will use the endpoint's dtype; run compact() or "
                        f"clear_cache() if the dataset becomes unreadable."
                    )
        return out

    def compact(self) -> int:
        """Merge all per-chunk append files into a single deduped file.

        Append-only writes accumulate one file per ``_update_cache`` call. Over
        time this inflates S3 object count, list costs, and read latency.
        ``compact()`` reads the whole cache (as one dataset), dedups on
        ``cache_key_column``, and rewrites it via ``upsert`` (which uses
        ``mode="overwrite"``). Expected cadence: weekly / monthly as a
        maintenance op, not on the hot inference path. Do not run during
        active inference traffic — the rewrite races with concurrent
        appenders the same way any overwrite does, and can lose recent rows
        (they'll just be recomputed on the next call).

        Returns:
            int: Row count after compaction.
        """
        df = self._read_cache_with_retry()
        if df is None or df.empty:
            self.log.info(f"InferenceCache[{self._endpoint.name}]: nothing to compact")
            return 0

        before = len(df)
        df = df.drop_duplicates(subset=[self.cache_key_column], keep="last").reset_index(drop=True)
        after = len(df)

        self._df_store.upsert(self.cache_path, df)
        self._save_manifest()
        self._cache_df = df
        # Reseed canonical from real data only (all-null columns stay flexible).
        self._canonical_dtypes = None
        self._seed_canonical_from(df)

        self.log.info(f"InferenceCache[{self._endpoint.name}]: compacted {before} -> {after} rows")
        return after

    # ---- endpoint-change detection ----

    def _current_endpoint_modified(self) -> Optional[str]:
        """Read the endpoint's current 'modified' timestamp.

        Stringified so the comparison is robust to tz-aware/naive datetime
        round-tripping through parquet.
        """
        try:
            modified = self._endpoint.modified()
        except Exception as e:
            self.log.warning(
                f"InferenceCache[{self._endpoint.name}]: could not read "
                f"endpoint modified time for change detection: {e}"
            )
            return None
        return str(modified) if modified is not None else None

    def _load_manifest(self) -> Optional[dict]:
        """Load the sidecar manifest (or None if it doesn't exist)."""
        df = self._df_store.get(self.manifest_path)
        if df is None or df.empty:
            return None
        return df.iloc[0].to_dict()

    def _save_manifest(self) -> None:
        """Write the sidecar manifest capturing the endpoint's current state."""
        manifest_df = pd.DataFrame(
            [
                {
                    "endpoint_name": self._endpoint.name,
                    "endpoint_modified": self._current_endpoint_modified(),
                    "cache_key_column": self.cache_key_column,
                }
            ]
        )
        self._df_store.upsert(self.manifest_path, manifest_df)

    def _check_endpoint_changed(self) -> None:
        """Compare the stored manifest against the endpoint's current modified time.

        - If no manifest exists, seed one (first run after a clean slate).
        - If the stored and current modified times differ, warn and clear
          the cache so the next call recomputes from scratch.
        """
        manifest = self._load_manifest()
        current = self._current_endpoint_modified()

        if manifest is None:
            # No manifest yet — seed one if there's already a cache, so the
            # next check has something to compare against. (If there's no
            # cache either, the manifest will be written on first update.)
            if self._df_store.check(self.cache_path) and current is not None:
                self._save_manifest()
            return

        stored = manifest.get("endpoint_modified")
        if stored is None or current is None or stored == current:
            return

        self.log.warning(
            f"InferenceCache[{self._endpoint.name}]: endpoint was modified "
            f"since cache was written (stored={stored}, current={current}). "
            f"Auto-invalidating cache."
        )
        self.clear_cache()

__getattr__(name)

Delegate any unrecognized attribute access to the wrapped Endpoint.

Source code in src/workbench/api/inference_cache.py
def __getattr__(self, name):
    """Delegate any unrecognized attribute access to the wrapped Endpoint."""
    # __getattr__ is only called when normal lookup fails, so this won't
    # interfere with our own attributes.
    return getattr(self._endpoint, name)

__init__(endpoint, cache_key_column='smiles', output_key_column=None, auto_invalidate_cache=False, chunk_size=None)

Initialize the InferenceCache.

Parameters:

Name Type Description Default
endpoint Endpoint

The Workbench Endpoint to wrap.

required
cache_key_column str

Name of the column whose values are used as the cache key (default: "smiles").

'smiles'
output_key_column Optional[str]

Name of the column in the endpoint's output that contains the original input key values. Some endpoints normalize/canonicalize the key column (e.g. canonical SMILES) and place the original value in a separate column (e.g. "orig_smiles"). When set, the cache uses this column's values as the key so future lookups with the original input values still hit. When None (default), the cache key column in the output is assumed to match the input unchanged.

None
auto_invalidate_cache bool

When True, automatically clear the cache if the endpoint has been modified since the cache was last written. When False (default), the existing cache is kept regardless of endpoint changes — the manifest is reseeded on first load so subsequent calls have a consistent baseline.

False
chunk_size Optional[int]

Rows per cache write. If None (default), derived from the endpoint's max_instances and inference_batch_size to produce full fleet-waves — see :meth:_derive_chunk_size. Falls back to DEFAULT_CHUNK_SIZE when fleet info isn't available.

None
Source code in src/workbench/api/inference_cache.py
def __init__(
    self,
    endpoint: Endpoint,
    cache_key_column: str = "smiles",
    output_key_column: Optional[str] = None,
    auto_invalidate_cache: bool = False,
    chunk_size: Optional[int] = None,
):
    """Initialize the InferenceCache.

    Args:
        endpoint (Endpoint): The Workbench Endpoint to wrap.
        cache_key_column (str): Name of the column whose values are used
            as the cache key (default: "smiles").
        output_key_column (Optional[str]): Name of the column in the
            endpoint's *output* that contains the original input key
            values. Some endpoints normalize/canonicalize the key column
            (e.g. canonical SMILES) and place the original value in a
            separate column (e.g. "orig_smiles"). When set, the cache
            uses this column's values as the key so future lookups with
            the original input values still hit. When None (default),
            the cache key column in the output is assumed to match the
            input unchanged.
        auto_invalidate_cache (bool): When True, automatically clear the
            cache if the endpoint has been modified since the cache was
            last written. When False (default), the existing cache is
            kept regardless of endpoint changes — the manifest is
            reseeded on first load so subsequent calls have a consistent
            baseline.
        chunk_size (Optional[int]): Rows per cache write. If ``None``
            (default), derived from the endpoint's ``max_instances`` and
            ``inference_batch_size`` to produce full fleet-waves — see
            :meth:`_derive_chunk_size`. Falls back to
            ``DEFAULT_CHUNK_SIZE`` when fleet info isn't available.
    """
    self._endpoint = endpoint
    self.cache_key_column = cache_key_column
    self.output_key_column = output_key_column
    self.cache_path = f"/workbench/inference_cache/{endpoint.name}"
    self.manifest_path = f"{self.cache_path}__meta"
    self._df_store = DFStore()
    self._cache_df: Optional[pd.DataFrame] = None  # lazy-loaded
    self._invalidation_checked = False  # per-instance, one-shot
    self._auto_invalidate_cache = auto_invalidate_cache
    # Canonical dtype map for the cache, captured on first non-empty load
    # and used to coerce subsequent appended chunks so concurrent writers
    # never produce a schema-incompatible dataset.
    self._canonical_dtypes: Optional[pd.Series] = None
    # (column, src_dtype, tgt_dtype) tuples we've already warned about —
    # keeps the coerce loop quiet when the same mismatch recurs every chunk.
    self._coerce_warned: set[tuple] = set()
    self.log = logging.getLogger("workbench")

    # Resolve chunk_size: explicit override wins; else try fleet-derivation;
    # else fall through to the class-level DEFAULT_CHUNK_SIZE.
    if chunk_size is not None:
        self.chunk_size = int(chunk_size)
    else:
        derived = self._derive_chunk_size()
        if derived is not None:
            self.chunk_size = derived

cache_info()

Summary of the cache: path, row count, columns, manifest.

Source code in src/workbench/api/inference_cache.py
def cache_info(self) -> dict:
    """Summary of the cache: path, row count, columns, manifest."""
    df = self._load_cache()
    return {
        "path": self.cache_path,
        "rows": len(df),
        "columns": list(df.columns),
        "manifest": self._load_manifest(),
    }

cache_size()

Number of rows currently in the cache.

Source code in src/workbench/api/inference_cache.py
def cache_size(self) -> int:
    """Number of rows currently in the cache."""
    return len(self._load_cache())

clear_cache()

Delete the cache (and manifest) from S3 and reset in-memory state.

Source code in src/workbench/api/inference_cache.py
def clear_cache(self) -> None:
    """Delete the cache (and manifest) from S3 and reset in-memory state."""
    if self._df_store.check(self.cache_path):
        self._df_store.delete(self.cache_path)
    if self._df_store.check(self.manifest_path):
        self._df_store.delete(self.manifest_path)
    self._cache_df = pd.DataFrame(columns=[self.cache_key_column])

compact()

Merge all per-chunk append files into a single deduped file.

Append-only writes accumulate one file per _update_cache call. Over time this inflates S3 object count, list costs, and read latency. compact() reads the whole cache (as one dataset), dedups on cache_key_column, and rewrites it via upsert (which uses mode="overwrite"). Expected cadence: weekly / monthly as a maintenance op, not on the hot inference path. Do not run during active inference traffic — the rewrite races with concurrent appenders the same way any overwrite does, and can lose recent rows (they'll just be recomputed on the next call).

Returns:

Name Type Description
int int

Row count after compaction.

Source code in src/workbench/api/inference_cache.py
def compact(self) -> int:
    """Merge all per-chunk append files into a single deduped file.

    Append-only writes accumulate one file per ``_update_cache`` call. Over
    time this inflates S3 object count, list costs, and read latency.
    ``compact()`` reads the whole cache (as one dataset), dedups on
    ``cache_key_column``, and rewrites it via ``upsert`` (which uses
    ``mode="overwrite"``). Expected cadence: weekly / monthly as a
    maintenance op, not on the hot inference path. Do not run during
    active inference traffic — the rewrite races with concurrent
    appenders the same way any overwrite does, and can lose recent rows
    (they'll just be recomputed on the next call).

    Returns:
        int: Row count after compaction.
    """
    df = self._read_cache_with_retry()
    if df is None or df.empty:
        self.log.info(f"InferenceCache[{self._endpoint.name}]: nothing to compact")
        return 0

    before = len(df)
    df = df.drop_duplicates(subset=[self.cache_key_column], keep="last").reset_index(drop=True)
    after = len(df)

    self._df_store.upsert(self.cache_path, df)
    self._save_manifest()
    self._cache_df = df
    # Reseed canonical from real data only (all-null columns stay flexible).
    self._canonical_dtypes = None
    self._seed_canonical_from(df)

    self.log.info(f"InferenceCache[{self._endpoint.name}]: compacted {before} -> {after} rows")
    return after

delete_entries(keys)

Remove one or more entries from the cache by cache-key value(s).

Use this to drop bad results that should be recomputed on the next inference() call.

Parameters:

Name Type Description Default
keys Union[Any, Iterable[Any]]

A single cache-key value, or an iterable of them.

required

Returns:

Name Type Description
int int

Number of rows removed from the cache.

Source code in src/workbench/api/inference_cache.py
def delete_entries(self, keys: Union[Any, Iterable[Any]]) -> int:
    """Remove one or more entries from the cache by cache-key value(s).

    Use this to drop bad results that should be recomputed on the next
    ``inference()`` call.

    Args:
        keys (Union[Any, Iterable[Any]]): A single cache-key value, or an
            iterable of them.

    Returns:
        int: Number of rows removed from the cache.
    """
    if isinstance(keys, (str, bytes)) or not hasattr(keys, "__iter__"):
        keys = [keys]
    keys = list(keys)

    cache_df = self._load_cache()
    if cache_df.empty:
        return 0

    keep_mask = ~cache_df[self.cache_key_column].isin(keys)
    removed = int((~keep_mask).sum())
    if removed == 0:
        return 0

    new_cache = cache_df[keep_mask].reset_index(drop=True)
    if new_cache.empty:
        # Nothing left — delete the cache file entirely but keep the manifest
        if self._df_store.check(self.cache_path):
            self._df_store.delete(self.cache_path)
    else:
        self._df_store.upsert(self.cache_path, new_cache)
    self._cache_df = new_cache
    self.log.info(f"InferenceCache[{self._endpoint.name}]: removed {removed} entries")
    return removed

inference(eval_df, **kwargs)

Run cached inference on eval_df.

Rows whose cache_key_column value is already in the cache are served from S3; the rest are sent to the underlying endpoint and the new results are written back to the cache. The returned DataFrame preserves the original row order of eval_df.

Parameters:

Name Type Description Default
eval_df DataFrame

DataFrame to run predictions on. Must contain self.cache_key_column.

required
**kwargs Any

Forwarded to the wrapped Endpoint.inference() for uncached rows.

{}

Returns:

Type Description
DataFrame

pd.DataFrame: eval_df with the endpoint's added columns

DataFrame

left-joined on cache_key_column.

Source code in src/workbench/api/inference_cache.py
def inference(self, eval_df: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
    """Run cached inference on ``eval_df``.

    Rows whose ``cache_key_column`` value is already in the cache are
    served from S3; the rest are sent to the underlying endpoint and the
    new results are written back to the cache. The returned DataFrame
    preserves the original row order of ``eval_df``.

    Args:
        eval_df (pd.DataFrame): DataFrame to run predictions on. Must
            contain ``self.cache_key_column``.
        **kwargs (Any): Forwarded to the wrapped ``Endpoint.inference()``
            for uncached rows.

    Returns:
        pd.DataFrame: ``eval_df`` with the endpoint's added columns
        left-joined on ``cache_key_column``.
    """
    key_col = self.cache_key_column
    if key_col not in eval_df.columns:
        raise ValueError(f"eval_df is missing required cache_key_column '{key_col}'")

    cache_df = self._load_cache()

    # Split eval rows into cache hits vs rows we still need to compute
    is_cached = eval_df[key_col].isin(cache_df[key_col])
    uncached_df = eval_df[~is_cached]
    cached_hits = cache_df[cache_df[key_col].isin(eval_df[key_col])]

    hits = len(eval_df) - len(uncached_df)
    self.log.info(f"InferenceCache[{self._endpoint.name}]: {hits}/{len(eval_df)} cache hits")

    # Run the endpoint on the uncached rows. The decorator on
    # _chunked_endpoint_inference handles chunking, per-chunk cache
    # writes, and error recovery so a single failed write doesn't
    # destroy the rest of the batch.
    new_results = pd.DataFrame()
    if not uncached_df.empty:
        to_compute = uncached_df.drop_duplicates(subset=[key_col])
        new_results = self._chunked_endpoint_inference(to_compute, **kwargs)

    # Combine cached + new into a single feature table, then left-join
    # back onto eval_df to preserve row order and any extra input columns.
    # (Filter out empty frames to dodge a pandas FutureWarning about
    # dtype inference on empty entries.)
    frames = [f for f in (cached_hits, new_results) if not f.empty]
    if not frames:
        return eval_df.copy()
    feature_table = pd.concat(frames, ignore_index=True).drop_duplicates(subset=[key_col], keep="last")
    feature_cols = [c for c in feature_table.columns if c not in eval_df.columns]
    return eval_df.merge(feature_table[[key_col] + feature_cols], on=key_col, how="left")