Skip to content

Meta Endpoints

What's a MetaEndpoint?

A MetaEndpoint is a deployed Endpoint backed by a directed acyclic graph (DAG) of other endpoints + aggregation nodes. From the caller's perspective it's just an Endpoint — endpoint.inference(df) returns a DataFrame. The DAG machinery is server-side.

MetaEndpoint lets you compose multiple deployed endpoints into a single inference target. Two canonical shapes:

  • Feature pipelines — fan out to several feature endpoints in parallel, merge the columns, optionally feed a predictor.
  • Ensembles — fan out to several predictor endpoints, aggregate their predictions (mean, weighted mean, vote, calibrated confidence weighting, …) into a single output.

The same DAG abstraction covers both.

Quick Start: Feature Pipeline

Combine the 2D and 3D-fast feature endpoints into a single endpoint that returns merged feature columns per molecule:

from workbench.api import MetaEndpoint
from workbench.utils.meta_endpoint_dag import MetaEndpointDAG
from workbench.utils.aggregation_nodes import Concat

dag = MetaEndpointDAG()
dag.add_endpoint("smiles-to-2d-v1")
dag.add_endpoint("smiles-to-3d-fast-v1")
dag.add_aggregation(Concat(name="combine"))
dag.add_edge("smiles-to-2d-v1", "combine")
dag.add_edge("smiles-to-3d-fast-v1", "combine")
dag.set_input_node("smiles-to-2d-v1", "smiles-to-3d-fast-v1")
dag.set_output_node("combine")

end = MetaEndpoint.create(
    name="smiles-to-2d-3d-features",
    dag=dag,
    description="2D RDKit/Mordred + 3D-fast features",
    tags=["meta", "features"],
)

# Use it like any other endpoint
import pandas as pd
df = pd.DataFrame({"smiles": ["CCO", "c1ccccc1"]})
result = end.inference(df)
# result has the input columns + 2D + 3D feature columns

Quick Start: Ensemble

Combine three predictor endpoints (XGBoost + PyTorch + ChemProp) into one ensemble:

from workbench.api import MetaEndpoint
from workbench.utils.meta_endpoint_dag import MetaEndpointDAG
from workbench.utils.aggregation_nodes import Mean

dag = MetaEndpointDAG()
dag.add_endpoint("logd-reg-xgb")
dag.add_endpoint("logd-reg-pytorch")
dag.add_endpoint("logd-reg-chemprop")
dag.add_aggregation(Mean(name="ensemble"))
dag.add_edge("logd-reg-xgb", "ensemble")
dag.add_edge("logd-reg-pytorch", "ensemble")
dag.add_edge("logd-reg-chemprop", "ensemble")
dag.set_input_node("logd-reg-xgb", "logd-reg-pytorch", "logd-reg-chemprop")
dag.set_output_node("ensemble")

end = MetaEndpoint.create(name="logd-meta", dag=dag, tags=["meta", "logd"])

The output has the standard prediction / prediction_std (ensemble disagreement) / confidence columns alongside whatever pass-through columns the input had.

Async Auto-Detection

If any child endpoint in the DAG is deployed as async (e.g. smiles-to-3d-full-v1), the MetaEndpoint is automatically deployed as async too — its 60-minute invocation budget needs to accommodate the slowest child. You don't specify this; MetaEndpoint.create() detects it via dag.has_async_endpoint() and chooses the deploy mode.

DAG Building Blocks

Endpoint nodes

Every endpoint node refers to a deployed Workbench endpoint by name. Endpoint nodes can be:

  • Input nodes — receive the caller's input DataFrame directly. Declared with dag.set_input_node(...).
  • Downstream endpoint nodes — take their input from a single upstream parent (e.g. a Concat aggregation feeding a predictor).
dag.add_endpoint("smiles-to-2d-v1")                       # node name = endpoint name (default)
dag.add_endpoint("smiles-to-2d-v1", node_name="left_2d")  # explicit node name (for aliasing)

Aggregation nodes

Class Use case Output
Concat Column-union of feature outputs from parallel branches All upstream columns merged
Mean Equal-weight average of predictions prediction, prediction_std, confidence
WeightedMean(weights=[…]) Static-weight average Same
Vote Majority vote for classifiers prediction (label), confidence (winner share)
ConfidenceWeighted(model_weights=…) Per-row weights from upstream confidences Same as Mean, with calibrated confidence
InverseMaeWeighted(model_weights=…) Static weights from inverse-MAE Same
ScaledConfidenceWeighted(model_weights=…) Static MAE × per-row confidence Same
CalibratedConfidenceWeighted(model_weights=…, corr_scale=…) Confidence × |conf-error correlation| Same

Edges

Edges declare data flow:

dag.add_edge("smiles-to-2d-v1", "combine")    # 2D output flows into combine
dag.add_edge("combine", "predictor")           # combined features flow into predictor

Endpoint nodes accept at most one inbound edge (one source for their input DataFrame). Aggregation nodes can have any number of inbound edges.

Validation

Call dag.validate() to fail loud on misconfiguration before any inference round-trips. Checks include cycle detection, dangling endpoint nodes, and reachability from input to output.

Row Alignment

The walker injects a synthetic __dag_row_id column at the start of every run() and strips it before returning. Aggregation nodes use it as the join key so callers do not need to supply any id column on their input data — and any id-like column they do supply just flows through as a regular pass-through column.

This gives MetaEndpoints a clean contract: any DataFrame with the columns your input-node endpoints expect is a valid input.

How It Works

Creation flow

When you call MetaEndpoint.create(name, dag, ...):

  1. Validatedag.validate() fails loud on cycles, dangling nodes, etc.
  2. Resolve async flags — looks up each child's workbench_meta to record per-endpoint async status.
  3. Lineage anchor — backtraces the first input endpoint to a FeatureSet/target/feature_list (Workbench Models need to point at a FeatureSet).
  4. Build + register the Model — runs the standard FeatureSet.to_model() flow, passing the DAG dict + region + bucket as custom_args. The meta_endpoint.template substitutes those placeholders, the SageMaker training job persists meta_endpoint_config.json as the model artifact, and the model package is registered with the standard inference image.
  5. Deploymodel.to_endpoint(...) deploys an Endpoint, async if any child is async (with max_instances=1 and 5-minute idle drain). inference_batch_size is auto-set to the minimum across DAG children.

Inference flow

When the deployed MetaEndpoint receives a request:

  1. The container deserializes the DAG from the model artifact.
  2. The walker traverses nodes in topological order.
  3. Endpoint nodes call fast_inference (sync child) or async_inference (async child) via workbench_bridges — the transport is decided per child by the async flags captured at deploy time.
  4. Aggregation nodes apply their combination logic on the upstream outputs.
  5. The output node's DataFrame is returned to the caller (with the synthetic __dag_row_id stripped).

Failure policy is fail-fast: any exception in any node propagates out and the request fails. (Future: per-node failure policies for partial-result aggregation.)

CLI: Ensemble Simulator

Before committing to an ensemble shape, the ensemble_sim CLI lets you evaluate aggregation strategies offline against captured cross-fold predictions:

ensemble_sim logd-reg-xgb logd-reg-pytorch logd-reg-chemprop

This reports per-strategy MAE / RMSE / R² so you can pick the aggregation node that performs best before building the DAG. The same simulator will also be wired into MetaEndpoint.create() for auto-tuning when a DAG includes a strategy-tunable aggregation node (planned).

API Reference

MetaEndpoint: An Endpoint backed by a directed acyclic graph (DAG) of child endpoints and aggregation nodes.

A MetaEndpoint behaves identically to a regular Endpoint at runtime — callers do endpoint.inference(df) and get a DataFrame back. The DAG machinery is server-side: the deployed container loads the serialized DAG, dispatches each child invocation to fast_inference (sync) or async_inference (async), and runs aggregation nodes locally.

Common usage::

from workbench.api import MetaEndpoint
from workbench.utils.meta_endpoint_dag import MetaEndpointDAG
from workbench.utils.aggregation_nodes import Concat

dag = MetaEndpointDAG()
dag.add_endpoint("smiles-to-2d-v1")
dag.add_endpoint("smiles-to-3d-fast-v1")
dag.add_aggregation(Concat(name="combine"))
dag.add_edge("smiles-to-2d-v1", "combine")
dag.add_edge("smiles-to-3d-fast-v1", "combine")
dag.set_input_node("smiles-to-2d-v1", "smiles-to-3d-fast-v1")
dag.set_output_node("combine")

end = MetaEndpoint.create(name="my-features-meta", dag=dag)
# Input does not need any id column — the DAG handles row alignment internally.
df = end.inference(input_df)

If any child endpoint in the DAG is async (e.g. smiles-to-3d-full-v1), the MetaEndpoint is automatically deployed as async too — its invocation budget needs to accommodate the slowest child.

MetaEndpoint

Bases: Endpoint

Endpoint backed by a :class:MetaEndpointDAG.

Constructor wraps an existing deployed MetaEndpoint by name, identical to :class:Endpoint. Use :meth:create to build and deploy a new one from a DAG.

Source code in src/workbench/api/meta_endpoint.py
class MetaEndpoint(Endpoint):
    """Endpoint backed by a :class:`MetaEndpointDAG`.

    Constructor wraps an existing deployed MetaEndpoint by name, identical
    to :class:`Endpoint`. Use :meth:`create` to build and deploy a new one
    from a DAG.
    """

    @classmethod
    def create(
        cls,
        name: str,
        dag: MetaEndpointDAG,
        description: str | None = None,
        tags: list[str] | None = None,
    ) -> "MetaEndpoint":
        """Build, register, and deploy a MetaEndpoint from a DAG.

        Steps:
          1. Validate the DAG; populate per-endpoint async flags.
          2. Backtrace lineage from a primary endpoint to satisfy
             Workbench's Model machinery (FeatureSet, target, features).
          3. Run the standard ``FeatureSet.to_model()`` flow, passing the
             DAG / region / bucket as ``custom_args`` so the meta-endpoint
             template fills them in at training time.
          4. Set DAG-specific ``workbench_meta`` keys on the resulting Model.
          5. Deploy the endpoint (async if any DAG child is async).

        Args:
            name: Endpoint / Model name.
            dag: A :class:`MetaEndpointDAG` describing the data flow.
            description: Optional description for the registered model.
            tags: Optional list of Workbench tags.

        Returns:
            The deployed MetaEndpoint, ready for ``.inference()``.
        """
        Artifact.is_name_valid(name, delimiter="-", lower_case=False)

        log.important(f"Validating DAG for MetaEndpoint '{name}'...")
        dag.validate()
        dag.populate_async_flags()
        is_async = dag.has_async_endpoint()
        log.important(
            f"DAG: {len(dag._endpoints)} endpoints, {len(dag._aggregations)} aggregation nodes "
            f"({'async' if is_async else 'sync'} deployment)"
        )

        # Backtrace lineage from a primary endpoint to satisfy Workbench Model
        # machinery (every Model needs a FeatureSet to hang off of).
        feature_list, feature_set_name, target_column = cls._derive_lineage(dag)

        # Build the model via the standard FeatureSet → Model flow. The
        # meta-endpoint template's `{{dag}}`, `{{aws_region}}`, `{{s3_bucket}}`
        # placeholders are filled from custom_args.
        aws_clamp = AWSAccountClamp()
        sm_session = aws_clamp.sagemaker_session()
        workbench_bucket = ConfigManager().get_config("WORKBENCH_BUCKET")

        feature_set = FeatureSet(feature_set_name)
        feature_set.to_model(
            name=name,
            model_type=cls._derive_model_type(dag),
            model_framework=ModelFramework.META,
            tags=tags or [name],
            description=description or f"MetaEndpoint DAG over: {', '.join(dag._endpoints.values())}",
            target_column=target_column,
            feature_list=feature_list,
            custom_args={
                "dag": dag.to_dict(),
                "aws_region": sm_session.boto_region_name,
                "s3_bucket": workbench_bucket,
            },
        )

        # Append DAG-specific workbench_meta on top of what FeaturesToModel
        # already set (model_type, framework, features, target, training view).
        output_model = ModelCore(name)
        output_model.upsert_workbench_meta({"endpoints": list(dag._endpoints.values())})
        output_model.upsert_workbench_meta({"meta_endpoint_dag": dag.to_dict()})

        # Deploy. MetaEndpoint containers are thin orchestrators — actual
        # compute happens in the child endpoints, which scale on their own
        # backlog. One meta instance can already drive 100s of concurrent
        # child calls (async_inference uses a 64-thread worker pool
        # internally), so additional meta instances don't help. Async deploy
        # is therefore 0→1 with idle drain; sync is fixed 1.
        log.important(f"Deploying MetaEndpoint '{name}' ({'async' if is_async else 'sync'})...")
        model = Model(name)
        if is_async:
            endpoint = model.to_endpoint(
                tags=tags or [name],
                async_endpoint=True,
                max_instances=1,
                scale_in_idle_minutes=5,
            )
        else:
            endpoint = model.to_endpoint(
                tags=tags or [name],
                async_endpoint=False,
            )

        # Auto-derive inference_batch_size from the smallest tolerance among
        # children — chunks the meta receives from SageMaker get fanned out
        # as-is to every child, so the meta's chunk size shouldn't exceed
        # the smallest child's batch size.
        min_batch = dag.min_child_batch_size()
        endpoint.upsert_workbench_meta({"inference_batch_size": min_batch})
        log.important(f"Set inference_batch_size={min_batch} (min across DAG children)")

        # Publish the largest child fleet size as effective_max_instances so
        # callers (e.g. InferenceCache) can size their work units to fill
        # downstream child capacity rather than the meta's own
        # max_instances=1 (which only describes the orchestrator layer).
        effective_max = dag.max_child_max_instances()
        endpoint.upsert_workbench_meta({"effective_max_instances": effective_max})
        log.important(f"Set effective_max_instances={effective_max} (max across DAG children)")

        log.important(f"MetaEndpoint '{name}' created successfully!")
        return cls(name)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    @classmethod
    def _derive_model_type(cls, dag: MetaEndpointDAG) -> ModelType:
        """Pick the most accurate :class:`ModelType` for the DAG's output.

        - Output node is a terminal endpoint → borrow that endpoint's declared
          type (e.g., a downstream predictor endpoint contributes its own type).
        - Output node is :class:`~workbench.utils.aggregation_nodes.Concat` →
          ``TRANSFORMER`` (column-union of feature outputs).
        - Output node is :class:`~workbench.utils.aggregation_nodes.Vote` →
          ``CLASSIFIER`` (majority vote of class labels).
        - Output node is any other prediction aggregator → ``REGRESSOR``.
        """
        from workbench.utils.aggregation_nodes import Concat, Vote

        output_name = dag._output_node
        if output_name in dag._endpoints:
            ep_name = dag._endpoints[output_name]
            try:
                return Model(ep_name).model_type
            except Exception:
                return ModelType.REGRESSOR

        agg = dag._aggregations[output_name]
        if isinstance(agg, Concat):
            return ModelType.TRANSFORMER
        if isinstance(agg, Vote):
            return ModelType.CLASSIFIER
        return ModelType.REGRESSOR

    @classmethod
    def _derive_lineage(cls, dag: MetaEndpointDAG) -> tuple[list[str], str, str | None]:
        """Backtrace from the first input endpoint to find a FeatureSet + lineage.

        Workbench Models need to trace back to a FeatureSet. For DAG-based
        MetaEndpoints there isn't a single canonical FeatureSet, so we use
        the first input node's lineage as a representative anchor.

        Returns ``(feature_list, feature_set_name, target_column)``.
        ``target_column`` may be ``None`` for pure feature-pipeline DAGs
        whose primary endpoint is a feature endpoint.
        """
        if not dag._input_nodes:
            raise ValueError("DAG has no input nodes — cannot derive lineage")
        primary_endpoint_name = dag._endpoints[dag._input_nodes[0]]

        ep = Endpoint(primary_endpoint_name)
        if not ep.exists():
            raise ValueError(f"Primary endpoint '{primary_endpoint_name}' does not exist")

        primary_model = Model(ep.get_input())
        feature_list = primary_model.features() or list(dag.input_columns())
        feature_set_name = primary_model.get_input()
        target_column = primary_model.target()

        log.info(
            f"Lineage anchor: {primary_endpoint_name} -> {primary_model.name} -> {feature_set_name} "
            f"(target: {target_column})"
        )
        return feature_list, feature_set_name, target_column

    def get_dag(self) -> MetaEndpointDAG:
        """Reconstruct the MetaEndpointDAG from this endpoint's stored metadata."""
        meta = self.workbench_meta() or {}
        dag_dict = meta.get("meta_endpoint_dag")
        if not dag_dict:
            raise ValueError(
                f"MetaEndpoint '{self.name}' has no DAG in workbench_meta. Recreate via MetaEndpoint.create()."
            )
        return MetaEndpointDAG.from_dict(dag_dict)

create(name, dag, description=None, tags=None) classmethod

Build, register, and deploy a MetaEndpoint from a DAG.

Steps
  1. Validate the DAG; populate per-endpoint async flags.
  2. Backtrace lineage from a primary endpoint to satisfy Workbench's Model machinery (FeatureSet, target, features).
  3. Run the standard FeatureSet.to_model() flow, passing the DAG / region / bucket as custom_args so the meta-endpoint template fills them in at training time.
  4. Set DAG-specific workbench_meta keys on the resulting Model.
  5. Deploy the endpoint (async if any DAG child is async).

Parameters:

Name Type Description Default
name str

Endpoint / Model name.

required
dag MetaEndpointDAG

A :class:MetaEndpointDAG describing the data flow.

required
description str | None

Optional description for the registered model.

None
tags list[str] | None

Optional list of Workbench tags.

None

Returns:

Type Description
'MetaEndpoint'

The deployed MetaEndpoint, ready for .inference().

Source code in src/workbench/api/meta_endpoint.py
@classmethod
def create(
    cls,
    name: str,
    dag: MetaEndpointDAG,
    description: str | None = None,
    tags: list[str] | None = None,
) -> "MetaEndpoint":
    """Build, register, and deploy a MetaEndpoint from a DAG.

    Steps:
      1. Validate the DAG; populate per-endpoint async flags.
      2. Backtrace lineage from a primary endpoint to satisfy
         Workbench's Model machinery (FeatureSet, target, features).
      3. Run the standard ``FeatureSet.to_model()`` flow, passing the
         DAG / region / bucket as ``custom_args`` so the meta-endpoint
         template fills them in at training time.
      4. Set DAG-specific ``workbench_meta`` keys on the resulting Model.
      5. Deploy the endpoint (async if any DAG child is async).

    Args:
        name: Endpoint / Model name.
        dag: A :class:`MetaEndpointDAG` describing the data flow.
        description: Optional description for the registered model.
        tags: Optional list of Workbench tags.

    Returns:
        The deployed MetaEndpoint, ready for ``.inference()``.
    """
    Artifact.is_name_valid(name, delimiter="-", lower_case=False)

    log.important(f"Validating DAG for MetaEndpoint '{name}'...")
    dag.validate()
    dag.populate_async_flags()
    is_async = dag.has_async_endpoint()
    log.important(
        f"DAG: {len(dag._endpoints)} endpoints, {len(dag._aggregations)} aggregation nodes "
        f"({'async' if is_async else 'sync'} deployment)"
    )

    # Backtrace lineage from a primary endpoint to satisfy Workbench Model
    # machinery (every Model needs a FeatureSet to hang off of).
    feature_list, feature_set_name, target_column = cls._derive_lineage(dag)

    # Build the model via the standard FeatureSet → Model flow. The
    # meta-endpoint template's `{{dag}}`, `{{aws_region}}`, `{{s3_bucket}}`
    # placeholders are filled from custom_args.
    aws_clamp = AWSAccountClamp()
    sm_session = aws_clamp.sagemaker_session()
    workbench_bucket = ConfigManager().get_config("WORKBENCH_BUCKET")

    feature_set = FeatureSet(feature_set_name)
    feature_set.to_model(
        name=name,
        model_type=cls._derive_model_type(dag),
        model_framework=ModelFramework.META,
        tags=tags or [name],
        description=description or f"MetaEndpoint DAG over: {', '.join(dag._endpoints.values())}",
        target_column=target_column,
        feature_list=feature_list,
        custom_args={
            "dag": dag.to_dict(),
            "aws_region": sm_session.boto_region_name,
            "s3_bucket": workbench_bucket,
        },
    )

    # Append DAG-specific workbench_meta on top of what FeaturesToModel
    # already set (model_type, framework, features, target, training view).
    output_model = ModelCore(name)
    output_model.upsert_workbench_meta({"endpoints": list(dag._endpoints.values())})
    output_model.upsert_workbench_meta({"meta_endpoint_dag": dag.to_dict()})

    # Deploy. MetaEndpoint containers are thin orchestrators — actual
    # compute happens in the child endpoints, which scale on their own
    # backlog. One meta instance can already drive 100s of concurrent
    # child calls (async_inference uses a 64-thread worker pool
    # internally), so additional meta instances don't help. Async deploy
    # is therefore 0→1 with idle drain; sync is fixed 1.
    log.important(f"Deploying MetaEndpoint '{name}' ({'async' if is_async else 'sync'})...")
    model = Model(name)
    if is_async:
        endpoint = model.to_endpoint(
            tags=tags or [name],
            async_endpoint=True,
            max_instances=1,
            scale_in_idle_minutes=5,
        )
    else:
        endpoint = model.to_endpoint(
            tags=tags or [name],
            async_endpoint=False,
        )

    # Auto-derive inference_batch_size from the smallest tolerance among
    # children — chunks the meta receives from SageMaker get fanned out
    # as-is to every child, so the meta's chunk size shouldn't exceed
    # the smallest child's batch size.
    min_batch = dag.min_child_batch_size()
    endpoint.upsert_workbench_meta({"inference_batch_size": min_batch})
    log.important(f"Set inference_batch_size={min_batch} (min across DAG children)")

    # Publish the largest child fleet size as effective_max_instances so
    # callers (e.g. InferenceCache) can size their work units to fill
    # downstream child capacity rather than the meta's own
    # max_instances=1 (which only describes the orchestrator layer).
    effective_max = dag.max_child_max_instances()
    endpoint.upsert_workbench_meta({"effective_max_instances": effective_max})
    log.important(f"Set effective_max_instances={effective_max} (max across DAG children)")

    log.important(f"MetaEndpoint '{name}' created successfully!")
    return cls(name)

get_dag()

Reconstruct the MetaEndpointDAG from this endpoint's stored metadata.

Source code in src/workbench/api/meta_endpoint.py
def get_dag(self) -> MetaEndpointDAG:
    """Reconstruct the MetaEndpointDAG from this endpoint's stored metadata."""
    meta = self.workbench_meta() or {}
    dag_dict = meta.get("meta_endpoint_dag")
    if not dag_dict:
        raise ValueError(
            f"MetaEndpoint '{self.name}' has no DAG in workbench_meta. Recreate via MetaEndpoint.create()."
        )
    return MetaEndpointDAG.from_dict(dag_dict)

MetaEndpointDAG — a directed acyclic graph of endpoints and aggregation nodes describing an inference-time data flow.

A DAG has two kinds of nodes:

  • Endpoint nodes — references to deployed Workbench Endpoint instances by name. The DAG defers actual Endpoint instantiation until execution / column-contract resolution.

  • Aggregation nodes — instances of :class:AggregationNode subclasses that combine outputs from upstream nodes.

DAG construction is explicit::

dag = MetaEndpointDAG()
dag.add_endpoint("smiles-to-2d-v1")
dag.add_endpoint("smiles-to-3d-fast-v1")
dag.add_aggregation(Concat(name="combine"))
dag.add_edge("smiles-to-2d-v1", "combine")
dag.add_edge("smiles-to-3d-fast-v1", "combine")
dag.set_input_node("smiles-to-2d-v1", "smiles-to-3d-fast-v1")
dag.set_output_node("combine")
dag.validate()

Row-alignment across parallel branches: the walker injects a synthetic :data:DAG_ROW_ID column at the start of every run() and strips it before returning. Aggregation nodes use it as the join key, so callers do not need to supply (or care about) any id column on their input data.

Validation runs at construction time so misconfigured DAGs fail loud before any inference round-trips.

MetaEndpointDAG

A typed DAG of endpoints + aggregation nodes.

The DAG joins parallel branches using an internal synthetic row id (:data:DAG_ROW_ID) injected by :meth:run — callers don't need to supply any id column on their input.

Source code in src/workbench/utils/meta_endpoint_dag.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
class MetaEndpointDAG:
    """A typed DAG of endpoints + aggregation nodes.

    The DAG joins parallel branches using an internal synthetic row id
    (:data:`DAG_ROW_ID`) injected by :meth:`run` — callers don't need to
    supply any id column on their input.
    """

    def __init__(self):
        self._endpoints: Dict[str, str] = {}  # node_name → endpoint_name
        self._endpoint_async_flags: Dict[str, bool] = {}  # populated by populate_async_flags()
        self._aggregations: Dict[str, AggregationNode] = {}
        self._edges: List[tuple[str, str]] = []  # (from_node, to_node)
        self._input_nodes: List[str] = []
        self._output_node: Optional[str] = None

    # ------------------------------------------------------------------
    # Construction
    # ------------------------------------------------------------------

    def add_endpoint(self, endpoint_name: str, node_name: Optional[str] = None) -> str:
        """Add an endpoint reference to the DAG.

        Args:
            endpoint_name: Name of a deployed Workbench endpoint.
            node_name: Optional unique node name (defaults to ``endpoint_name``).

        Returns:
            The node name (so callers can chain).
        """
        node = node_name or endpoint_name
        if node in self._endpoints or node in self._aggregations:
            raise ValueError(f"Node '{node}' already exists in this DAG")
        self._endpoints[node] = endpoint_name
        return node

    def add_aggregation(self, node: AggregationNode) -> str:
        """Add an :class:`AggregationNode` to the DAG.

        The node's ``name`` must be unique across the DAG.
        """
        if node.name in self._endpoints or node.name in self._aggregations:
            raise ValueError(f"Node '{node.name}' already exists in this DAG")
        self._aggregations[node.name] = node
        return node.name

    def add_edge(self, from_node: str, to_node: str) -> None:
        """Declare data flow from ``from_node`` to ``to_node``.

        Endpoint nodes accept at most one inbound edge (their input DataFrame
        comes from a single upstream producer). Aggregation nodes can have
        any number of inbound edges.
        """
        if from_node not in self._all_nodes():
            raise ValueError(f"Edge from unknown node '{from_node}'")
        if to_node not in self._all_nodes():
            raise ValueError(f"Edge to unknown node '{to_node}'")
        if to_node in self._endpoints and self._parents_of(to_node):
            raise ValueError(
                f"Endpoint node '{to_node}' already has an upstream parent "
                f"('{self._parents_of(to_node)[0]}'); endpoints take input "
                f"from at most one source."
            )
        self._edges.append((from_node, to_node))

    def set_input_node(self, *nodes: str) -> None:
        """Declare which nodes receive the DAG's input DataFrame directly."""
        for n in nodes:
            if n not in self._endpoints:
                raise ValueError(f"Input nodes must be endpoint nodes; '{n}' is not")
        self._input_nodes = list(nodes)

    def set_output_node(self, node: str) -> None:
        """Declare the terminal node whose output is the DAG's output."""
        if node not in self._all_nodes():
            raise ValueError(f"Unknown output node '{node}'")
        self._output_node = node

    # ------------------------------------------------------------------
    # Inspection
    # ------------------------------------------------------------------

    def _all_nodes(self) -> List[str]:
        return list(self._endpoints.keys()) + list(self._aggregations.keys())

    def _parents_of(self, node: str) -> List[str]:
        return [src for src, dst in self._edges if dst == node]

    def topological_order(self) -> List[str]:
        """Return nodes in topological order (parents before children).

        Raises:
            ValueError: If the DAG contains a cycle.
        """
        in_degree = {n: 0 for n in self._all_nodes()}
        for _, dst in self._edges:
            in_degree[dst] += 1

        ready = [n for n, deg in in_degree.items() if deg == 0]
        order: List[str] = []
        while ready:
            node = ready.pop(0)
            order.append(node)
            for src, dst in self._edges:
                if src == node:
                    in_degree[dst] -= 1
                    if in_degree[dst] == 0:
                        ready.append(dst)

        if len(order) != len(in_degree):
            raise ValueError("DAG contains a cycle")
        return order

    # ------------------------------------------------------------------
    # Column contract
    # ------------------------------------------------------------------

    def input_columns(self) -> List[str]:
        """Union of input columns required by every node that receives the
        caller's input directly.

        Used by :class:`MetaEndpoint` as a fallback when deriving the
        feature list during lineage anchoring.
        """
        from workbench.api import Endpoint

        if not self._input_nodes:
            raise ValueError("DAG has no input nodes — call set_input_node() first")

        seen = set()
        cols: List[str] = []
        for node in self._input_nodes:
            for c in Endpoint(self._endpoints[node]).input_columns():
                if c not in seen:
                    seen.add(c)
                    cols.append(c)
        return cols

    # ------------------------------------------------------------------
    # Validation
    # ------------------------------------------------------------------

    def validate(self) -> "MetaEndpointDAG":
        """Validate the DAG. Returns self for chaining; raises on failure.

        Checks:
          - At least one input node and exactly one output node declared
          - No cycles
          - Aggregation nodes have at least one parent
          - Endpoint nodes are either input nodes (zero parents) or have
            exactly one upstream parent — never both
          - The output node is reachable from the input nodes
        """
        if not self._input_nodes:
            raise ValueError("DAG has no input nodes")
        if self._output_node is None:
            raise ValueError("DAG has no output node")

        order = self.topological_order()  # raises on cycle

        for ep_node in self._endpoints:
            parents = self._parents_of(ep_node)
            is_input = ep_node in self._input_nodes
            if is_input and parents:
                raise ValueError(
                    f"Endpoint node '{ep_node}' is declared as an input node but has "
                    f"upstream parents {parents}; pick one or the other."
                )
            if not is_input and not parents:
                raise ValueError(
                    f"Endpoint node '{ep_node}' has no upstream parent and is not "
                    f"declared as an input node — it has no source for its input DataFrame."
                )

        for name in self._aggregations:
            if not self._parents_of(name):
                raise ValueError(f"Aggregation node '{name}' has no upstream parents")

        reachable = set(self._input_nodes)
        for node in order:
            if node in reachable:
                for src, dst in self._edges:
                    if src == node:
                        reachable.add(dst)
        if self._output_node not in reachable:
            raise ValueError(f"Output node '{self._output_node}' is not reachable from input nodes {self._input_nodes}")

        return self

    # ------------------------------------------------------------------
    # Execution (client-side walker)
    # ------------------------------------------------------------------

    def run(
        self,
        input_df: pd.DataFrame,
        endpoint_invoker: Optional[EndpointInvoker] = None,
    ) -> pd.DataFrame:
        """Execute the DAG against ``input_df`` and return the output node's DataFrame.

        The walker injects a synthetic :data:`DAG_ROW_ID` column at entry
        (used internally to align rows across parallel branches) and
        strips it before returning. Callers don't need to supply any id
        column.

        Walks nodes in topological order. Endpoint nodes call
        :meth:`Endpoint.inference` on either the caller's ``input_df`` (input
        nodes) or their upstream parent's cached output. Aggregation nodes
        receive the cached outputs of all their parents and apply their
        combination logic.

        Failure policy is fail-fast: any exception in any node propagates
        out and the DAG run aborts.

        Args:
            input_df: DataFrame supplied by the caller. Must contain the
                columns required by every input-node endpoint. No id
                column is required.
            endpoint_invoker: Optional callable ``(endpoint_name, df) -> df``
                used to invoke endpoint nodes. Defaults to using the full
                Workbench ``Endpoint`` API class — appropriate for client-side
                use. Pass a ``fast_inference``-backed invoker when running
                inside a deployed SageMaker container where the full
                Workbench config isn't available.

        Returns:
            The DataFrame at the DAG's output node, with the synthetic
            :data:`DAG_ROW_ID` column removed.
        """
        if self._output_node is None:
            raise ValueError("DAG has no output node — call set_output_node() first")
        if DAG_ROW_ID in input_df.columns:
            raise ValueError(
                f"input_df already contains the reserved column '{DAG_ROW_ID}'. " f"Remove it before calling run()."
            )

        # Inject the synthetic row id. Endpoints will pass this through as an
        # unknown input column; aggregation nodes use it as their join key.
        input_df = input_df.copy()
        input_df[DAG_ROW_ID] = range(len(input_df))

        outputs: Dict[str, pd.DataFrame] = {}
        for node in self.topological_order():
            if node in self._endpoints:
                outputs[node] = self._run_endpoint(node, input_df, outputs, endpoint_invoker)
            else:
                outputs[node] = self._run_aggregation(node, outputs)

        result = outputs[self._output_node]
        if DAG_ROW_ID in result.columns:
            result = result.drop(columns=[DAG_ROW_ID])
        return result

    def _run_endpoint(
        self,
        node: str,
        input_df: pd.DataFrame,
        outputs: Dict[str, pd.DataFrame],
        endpoint_invoker: Optional[EndpointInvoker],
    ) -> pd.DataFrame:
        """Execute a single endpoint node.

        Source DataFrame is the caller's input for input nodes, or the
        single upstream parent's output otherwise. The full DataFrame is
        passed to ``endpoint.inference()`` — metadata columns
        (project_id, owner, etc.) flow through alongside the endpoint's
        added columns, matching standard Workbench inference behavior.

        The walker-injected :data:`DAG_ROW_ID` column must survive the
        endpoint round-trip so downstream aggregation nodes can join on
        it. If an endpoint silently strips unknown input columns, this
        will fail loudly — better than misaligned rows.
        """
        endpoint_name = self._endpoints[node]
        parents = self._parents_of(node)
        source_df = input_df if not parents else outputs[parents[0]]

        if endpoint_invoker is not None:
            result = endpoint_invoker(endpoint_name, source_df)
        else:
            from workbench.api import Endpoint

            result = Endpoint(endpoint_name).inference(source_df)

        if DAG_ROW_ID not in result.columns:
            raise RuntimeError(
                f"Endpoint '{endpoint_name}' dropped the walker-injected '{DAG_ROW_ID}' "
                f"column from its output. The DAG can't align rows across branches "
                f"without it. Endpoints must pass unknown input columns through to "
                f"their output."
            )
        return result

    def _run_aggregation(self, node: str, outputs: Dict[str, pd.DataFrame]) -> pd.DataFrame:
        """Execute a single aggregation node."""
        agg = self._aggregations[node]
        upstream = [outputs[p] for p in self._parents_of(node)]
        return agg.apply(upstream)

    # ------------------------------------------------------------------
    # Serialization (model artifact + workbench_meta storage)
    # ------------------------------------------------------------------

    def to_dict(self) -> dict:
        """Serialize the DAG topology to a JSON-friendly dict.

        Aggregation nodes are serialized by class name + constructor kwargs;
        deserialization (:meth:`from_dict`) requires the same class to be
        importable.

        Per-endpoint ``is_async`` flags are included only if
        :meth:`populate_async_flags` has been called. The deployed
        inference container relies on these flags to dispatch invocations
        to ``fast_inference`` or ``async_inference``.
        """
        return {
            "endpoints": dict(self._endpoints),
            "endpoint_async": dict(self._endpoint_async_flags),
            "aggregations": [_serialize_aggregation(a) for a in self._aggregations.values()],
            "edges": [list(e) for e in self._edges],
            "input_nodes": list(self._input_nodes),
            "output_node": self._output_node,
        }

    def to_json(self) -> str:
        return json.dumps(self.to_dict(), indent=2)

    def populate_async_flags(self) -> None:
        """Look up each endpoint's async flag via ``workbench_meta`` and store it.

        Flags are keyed by endpoint name (not node name) so the deployed
        invoker can dispatch directly on the value passed by the walker.

        Called by :meth:`MetaEndpoint.create` before serializing the DAG
        for deployment. Hits AWS once per unique endpoint name, so isolated
        as an explicit step rather than running implicitly in :meth:`to_dict`.
        """
        from workbench.api import Endpoint

        for endpoint_name in set(self._endpoints.values()):
            meta = Endpoint(endpoint_name).workbench_meta() or {}
            self._endpoint_async_flags[endpoint_name] = bool(meta.get("async_endpoint"))

    def has_async_endpoint(self) -> bool:
        """Return True if any endpoint in the DAG is deployed as async.

        Used by :meth:`MetaEndpoint.create` to decide whether the meta
        endpoint itself must be deployed as async. Lazily calls
        :meth:`populate_async_flags` if not yet populated.
        """
        if not self._endpoint_async_flags and self._endpoints:
            self.populate_async_flags()
        return any(self._endpoint_async_flags.values())

    def min_child_batch_size(self) -> int:
        """Minimum ``inference_batch_size`` across all child endpoints.

        :meth:`MetaEndpoint.create` uses this to set the meta endpoint's
        own ``inference_batch_size`` — chunks the meta receives from
        SageMaker get fanned out as-is to every child, so the meta's
        chunk size shouldn't exceed the smallest tolerance among children.

        Smaller chunks at the meta level also mean smaller failure blast
        radius (one bad row fails one small chunk, not a large one).

        Returns:
            int: Minimum batch size; ``100`` if the DAG has no endpoints
            (the standard sync default).
        """
        from workbench.api import Endpoint

        if not self._endpoints:
            return 100
        sizes = [Endpoint(ep_name).inference_batch_size() for ep_name in set(self._endpoints.values())]
        return min(sizes)

    def max_child_max_instances(self) -> int:
        """Maximum ``max_instances`` across all child endpoints.

        :class:`MetaEndpoint.create` uses this to publish an
        ``effective_max_instances`` hint into the meta endpoint's
        ``workbench_meta``. The meta itself deploys with ``max_instances=1``
        (it's a thin orchestrator), but downstream tooling like
        :class:`InferenceCache` sizes its work units to fill fleet capacity —
        and the relevant capacity is the child fleets, not the meta's
        single orchestrator instance.

        Returns:
            int: Maximum ``max_instances`` seen on any child endpoint;
            ``1`` if no child has it set or the DAG has no endpoints.
        """
        from workbench.api import Endpoint

        seen: List[int] = []
        for ep_name in set(self._endpoints.values()):
            meta = Endpoint(ep_name).workbench_meta() or {}
            if meta.get("max_instances") is not None:
                seen.append(int(meta["max_instances"]))
        return max(seen) if seen else 1

    @classmethod
    def from_dict(cls, data: dict) -> "MetaEndpointDAG":
        dag = cls()
        for node_name, endpoint_name in data.get("endpoints", {}).items():
            dag.add_endpoint(endpoint_name, node_name=node_name)
        dag._endpoint_async_flags = dict(data.get("endpoint_async", {}))
        for agg_data in data.get("aggregations", []):
            dag.add_aggregation(_deserialize_aggregation(agg_data))
        for src, dst in data.get("edges", []):
            dag.add_edge(src, dst)
        if data.get("input_nodes"):
            dag.set_input_node(*data["input_nodes"])
        if data.get("output_node"):
            dag.set_output_node(data["output_node"])
        return dag

    @classmethod
    def from_json(cls, payload: str) -> "MetaEndpointDAG":
        return cls.from_dict(json.loads(payload))

add_aggregation(node)

Add an :class:AggregationNode to the DAG.

The node's name must be unique across the DAG.

Source code in src/workbench/utils/meta_endpoint_dag.py
def add_aggregation(self, node: AggregationNode) -> str:
    """Add an :class:`AggregationNode` to the DAG.

    The node's ``name`` must be unique across the DAG.
    """
    if node.name in self._endpoints or node.name in self._aggregations:
        raise ValueError(f"Node '{node.name}' already exists in this DAG")
    self._aggregations[node.name] = node
    return node.name

add_edge(from_node, to_node)

Declare data flow from from_node to to_node.

Endpoint nodes accept at most one inbound edge (their input DataFrame comes from a single upstream producer). Aggregation nodes can have any number of inbound edges.

Source code in src/workbench/utils/meta_endpoint_dag.py
def add_edge(self, from_node: str, to_node: str) -> None:
    """Declare data flow from ``from_node`` to ``to_node``.

    Endpoint nodes accept at most one inbound edge (their input DataFrame
    comes from a single upstream producer). Aggregation nodes can have
    any number of inbound edges.
    """
    if from_node not in self._all_nodes():
        raise ValueError(f"Edge from unknown node '{from_node}'")
    if to_node not in self._all_nodes():
        raise ValueError(f"Edge to unknown node '{to_node}'")
    if to_node in self._endpoints and self._parents_of(to_node):
        raise ValueError(
            f"Endpoint node '{to_node}' already has an upstream parent "
            f"('{self._parents_of(to_node)[0]}'); endpoints take input "
            f"from at most one source."
        )
    self._edges.append((from_node, to_node))

add_endpoint(endpoint_name, node_name=None)

Add an endpoint reference to the DAG.

Parameters:

Name Type Description Default
endpoint_name str

Name of a deployed Workbench endpoint.

required
node_name Optional[str]

Optional unique node name (defaults to endpoint_name).

None

Returns:

Type Description
str

The node name (so callers can chain).

Source code in src/workbench/utils/meta_endpoint_dag.py
def add_endpoint(self, endpoint_name: str, node_name: Optional[str] = None) -> str:
    """Add an endpoint reference to the DAG.

    Args:
        endpoint_name: Name of a deployed Workbench endpoint.
        node_name: Optional unique node name (defaults to ``endpoint_name``).

    Returns:
        The node name (so callers can chain).
    """
    node = node_name or endpoint_name
    if node in self._endpoints or node in self._aggregations:
        raise ValueError(f"Node '{node}' already exists in this DAG")
    self._endpoints[node] = endpoint_name
    return node

has_async_endpoint()

Return True if any endpoint in the DAG is deployed as async.

Used by :meth:MetaEndpoint.create to decide whether the meta endpoint itself must be deployed as async. Lazily calls :meth:populate_async_flags if not yet populated.

Source code in src/workbench/utils/meta_endpoint_dag.py
def has_async_endpoint(self) -> bool:
    """Return True if any endpoint in the DAG is deployed as async.

    Used by :meth:`MetaEndpoint.create` to decide whether the meta
    endpoint itself must be deployed as async. Lazily calls
    :meth:`populate_async_flags` if not yet populated.
    """
    if not self._endpoint_async_flags and self._endpoints:
        self.populate_async_flags()
    return any(self._endpoint_async_flags.values())

input_columns()

Union of input columns required by every node that receives the caller's input directly.

Used by :class:MetaEndpoint as a fallback when deriving the feature list during lineage anchoring.

Source code in src/workbench/utils/meta_endpoint_dag.py
def input_columns(self) -> List[str]:
    """Union of input columns required by every node that receives the
    caller's input directly.

    Used by :class:`MetaEndpoint` as a fallback when deriving the
    feature list during lineage anchoring.
    """
    from workbench.api import Endpoint

    if not self._input_nodes:
        raise ValueError("DAG has no input nodes — call set_input_node() first")

    seen = set()
    cols: List[str] = []
    for node in self._input_nodes:
        for c in Endpoint(self._endpoints[node]).input_columns():
            if c not in seen:
                seen.add(c)
                cols.append(c)
    return cols

max_child_max_instances()

Maximum max_instances across all child endpoints.

:class:MetaEndpoint.create uses this to publish an effective_max_instances hint into the meta endpoint's workbench_meta. The meta itself deploys with max_instances=1 (it's a thin orchestrator), but downstream tooling like :class:InferenceCache sizes its work units to fill fleet capacity — and the relevant capacity is the child fleets, not the meta's single orchestrator instance.

Returns:

Name Type Description
int int

Maximum max_instances seen on any child endpoint;

int

1 if no child has it set or the DAG has no endpoints.

Source code in src/workbench/utils/meta_endpoint_dag.py
def max_child_max_instances(self) -> int:
    """Maximum ``max_instances`` across all child endpoints.

    :class:`MetaEndpoint.create` uses this to publish an
    ``effective_max_instances`` hint into the meta endpoint's
    ``workbench_meta``. The meta itself deploys with ``max_instances=1``
    (it's a thin orchestrator), but downstream tooling like
    :class:`InferenceCache` sizes its work units to fill fleet capacity —
    and the relevant capacity is the child fleets, not the meta's
    single orchestrator instance.

    Returns:
        int: Maximum ``max_instances`` seen on any child endpoint;
        ``1`` if no child has it set or the DAG has no endpoints.
    """
    from workbench.api import Endpoint

    seen: List[int] = []
    for ep_name in set(self._endpoints.values()):
        meta = Endpoint(ep_name).workbench_meta() or {}
        if meta.get("max_instances") is not None:
            seen.append(int(meta["max_instances"]))
    return max(seen) if seen else 1

min_child_batch_size()

Minimum inference_batch_size across all child endpoints.

:meth:MetaEndpoint.create uses this to set the meta endpoint's own inference_batch_size — chunks the meta receives from SageMaker get fanned out as-is to every child, so the meta's chunk size shouldn't exceed the smallest tolerance among children.

Smaller chunks at the meta level also mean smaller failure blast radius (one bad row fails one small chunk, not a large one).

Returns:

Name Type Description
int int

Minimum batch size; 100 if the DAG has no endpoints

int

(the standard sync default).

Source code in src/workbench/utils/meta_endpoint_dag.py
def min_child_batch_size(self) -> int:
    """Minimum ``inference_batch_size`` across all child endpoints.

    :meth:`MetaEndpoint.create` uses this to set the meta endpoint's
    own ``inference_batch_size`` — chunks the meta receives from
    SageMaker get fanned out as-is to every child, so the meta's
    chunk size shouldn't exceed the smallest tolerance among children.

    Smaller chunks at the meta level also mean smaller failure blast
    radius (one bad row fails one small chunk, not a large one).

    Returns:
        int: Minimum batch size; ``100`` if the DAG has no endpoints
        (the standard sync default).
    """
    from workbench.api import Endpoint

    if not self._endpoints:
        return 100
    sizes = [Endpoint(ep_name).inference_batch_size() for ep_name in set(self._endpoints.values())]
    return min(sizes)

populate_async_flags()

Look up each endpoint's async flag via workbench_meta and store it.

Flags are keyed by endpoint name (not node name) so the deployed invoker can dispatch directly on the value passed by the walker.

Called by :meth:MetaEndpoint.create before serializing the DAG for deployment. Hits AWS once per unique endpoint name, so isolated as an explicit step rather than running implicitly in :meth:to_dict.

Source code in src/workbench/utils/meta_endpoint_dag.py
def populate_async_flags(self) -> None:
    """Look up each endpoint's async flag via ``workbench_meta`` and store it.

    Flags are keyed by endpoint name (not node name) so the deployed
    invoker can dispatch directly on the value passed by the walker.

    Called by :meth:`MetaEndpoint.create` before serializing the DAG
    for deployment. Hits AWS once per unique endpoint name, so isolated
    as an explicit step rather than running implicitly in :meth:`to_dict`.
    """
    from workbench.api import Endpoint

    for endpoint_name in set(self._endpoints.values()):
        meta = Endpoint(endpoint_name).workbench_meta() or {}
        self._endpoint_async_flags[endpoint_name] = bool(meta.get("async_endpoint"))

run(input_df, endpoint_invoker=None)

Execute the DAG against input_df and return the output node's DataFrame.

The walker injects a synthetic :data:DAG_ROW_ID column at entry (used internally to align rows across parallel branches) and strips it before returning. Callers don't need to supply any id column.

Walks nodes in topological order. Endpoint nodes call :meth:Endpoint.inference on either the caller's input_df (input nodes) or their upstream parent's cached output. Aggregation nodes receive the cached outputs of all their parents and apply their combination logic.

Failure policy is fail-fast: any exception in any node propagates out and the DAG run aborts.

Parameters:

Name Type Description Default
input_df DataFrame

DataFrame supplied by the caller. Must contain the columns required by every input-node endpoint. No id column is required.

required
endpoint_invoker Optional[EndpointInvoker]

Optional callable (endpoint_name, df) -> df used to invoke endpoint nodes. Defaults to using the full Workbench Endpoint API class — appropriate for client-side use. Pass a fast_inference-backed invoker when running inside a deployed SageMaker container where the full Workbench config isn't available.

None

Returns:

Type Description
DataFrame

The DataFrame at the DAG's output node, with the synthetic

DataFrame

data:DAG_ROW_ID column removed.

Source code in src/workbench/utils/meta_endpoint_dag.py
def run(
    self,
    input_df: pd.DataFrame,
    endpoint_invoker: Optional[EndpointInvoker] = None,
) -> pd.DataFrame:
    """Execute the DAG against ``input_df`` and return the output node's DataFrame.

    The walker injects a synthetic :data:`DAG_ROW_ID` column at entry
    (used internally to align rows across parallel branches) and
    strips it before returning. Callers don't need to supply any id
    column.

    Walks nodes in topological order. Endpoint nodes call
    :meth:`Endpoint.inference` on either the caller's ``input_df`` (input
    nodes) or their upstream parent's cached output. Aggregation nodes
    receive the cached outputs of all their parents and apply their
    combination logic.

    Failure policy is fail-fast: any exception in any node propagates
    out and the DAG run aborts.

    Args:
        input_df: DataFrame supplied by the caller. Must contain the
            columns required by every input-node endpoint. No id
            column is required.
        endpoint_invoker: Optional callable ``(endpoint_name, df) -> df``
            used to invoke endpoint nodes. Defaults to using the full
            Workbench ``Endpoint`` API class — appropriate for client-side
            use. Pass a ``fast_inference``-backed invoker when running
            inside a deployed SageMaker container where the full
            Workbench config isn't available.

    Returns:
        The DataFrame at the DAG's output node, with the synthetic
        :data:`DAG_ROW_ID` column removed.
    """
    if self._output_node is None:
        raise ValueError("DAG has no output node — call set_output_node() first")
    if DAG_ROW_ID in input_df.columns:
        raise ValueError(
            f"input_df already contains the reserved column '{DAG_ROW_ID}'. " f"Remove it before calling run()."
        )

    # Inject the synthetic row id. Endpoints will pass this through as an
    # unknown input column; aggregation nodes use it as their join key.
    input_df = input_df.copy()
    input_df[DAG_ROW_ID] = range(len(input_df))

    outputs: Dict[str, pd.DataFrame] = {}
    for node in self.topological_order():
        if node in self._endpoints:
            outputs[node] = self._run_endpoint(node, input_df, outputs, endpoint_invoker)
        else:
            outputs[node] = self._run_aggregation(node, outputs)

    result = outputs[self._output_node]
    if DAG_ROW_ID in result.columns:
        result = result.drop(columns=[DAG_ROW_ID])
    return result

set_input_node(*nodes)

Declare which nodes receive the DAG's input DataFrame directly.

Source code in src/workbench/utils/meta_endpoint_dag.py
def set_input_node(self, *nodes: str) -> None:
    """Declare which nodes receive the DAG's input DataFrame directly."""
    for n in nodes:
        if n not in self._endpoints:
            raise ValueError(f"Input nodes must be endpoint nodes; '{n}' is not")
    self._input_nodes = list(nodes)

set_output_node(node)

Declare the terminal node whose output is the DAG's output.

Source code in src/workbench/utils/meta_endpoint_dag.py
def set_output_node(self, node: str) -> None:
    """Declare the terminal node whose output is the DAG's output."""
    if node not in self._all_nodes():
        raise ValueError(f"Unknown output node '{node}'")
    self._output_node = node

to_dict()

Serialize the DAG topology to a JSON-friendly dict.

Aggregation nodes are serialized by class name + constructor kwargs; deserialization (:meth:from_dict) requires the same class to be importable.

Per-endpoint is_async flags are included only if :meth:populate_async_flags has been called. The deployed inference container relies on these flags to dispatch invocations to fast_inference or async_inference.

Source code in src/workbench/utils/meta_endpoint_dag.py
def to_dict(self) -> dict:
    """Serialize the DAG topology to a JSON-friendly dict.

    Aggregation nodes are serialized by class name + constructor kwargs;
    deserialization (:meth:`from_dict`) requires the same class to be
    importable.

    Per-endpoint ``is_async`` flags are included only if
    :meth:`populate_async_flags` has been called. The deployed
    inference container relies on these flags to dispatch invocations
    to ``fast_inference`` or ``async_inference``.
    """
    return {
        "endpoints": dict(self._endpoints),
        "endpoint_async": dict(self._endpoint_async_flags),
        "aggregations": [_serialize_aggregation(a) for a in self._aggregations.values()],
        "edges": [list(e) for e in self._edges],
        "input_nodes": list(self._input_nodes),
        "output_node": self._output_node,
    }

topological_order()

Return nodes in topological order (parents before children).

Raises:

Type Description
ValueError

If the DAG contains a cycle.

Source code in src/workbench/utils/meta_endpoint_dag.py
def topological_order(self) -> List[str]:
    """Return nodes in topological order (parents before children).

    Raises:
        ValueError: If the DAG contains a cycle.
    """
    in_degree = {n: 0 for n in self._all_nodes()}
    for _, dst in self._edges:
        in_degree[dst] += 1

    ready = [n for n, deg in in_degree.items() if deg == 0]
    order: List[str] = []
    while ready:
        node = ready.pop(0)
        order.append(node)
        for src, dst in self._edges:
            if src == node:
                in_degree[dst] -= 1
                if in_degree[dst] == 0:
                    ready.append(dst)

    if len(order) != len(in_degree):
        raise ValueError("DAG contains a cycle")
    return order

validate()

Validate the DAG. Returns self for chaining; raises on failure.

Checks
  • At least one input node and exactly one output node declared
  • No cycles
  • Aggregation nodes have at least one parent
  • Endpoint nodes are either input nodes (zero parents) or have exactly one upstream parent — never both
  • The output node is reachable from the input nodes
Source code in src/workbench/utils/meta_endpoint_dag.py
def validate(self) -> "MetaEndpointDAG":
    """Validate the DAG. Returns self for chaining; raises on failure.

    Checks:
      - At least one input node and exactly one output node declared
      - No cycles
      - Aggregation nodes have at least one parent
      - Endpoint nodes are either input nodes (zero parents) or have
        exactly one upstream parent — never both
      - The output node is reachable from the input nodes
    """
    if not self._input_nodes:
        raise ValueError("DAG has no input nodes")
    if self._output_node is None:
        raise ValueError("DAG has no output node")

    order = self.topological_order()  # raises on cycle

    for ep_node in self._endpoints:
        parents = self._parents_of(ep_node)
        is_input = ep_node in self._input_nodes
        if is_input and parents:
            raise ValueError(
                f"Endpoint node '{ep_node}' is declared as an input node but has "
                f"upstream parents {parents}; pick one or the other."
            )
        if not is_input and not parents:
            raise ValueError(
                f"Endpoint node '{ep_node}' has no upstream parent and is not "
                f"declared as an input node — it has no source for its input DataFrame."
            )

    for name in self._aggregations:
        if not self._parents_of(name):
            raise ValueError(f"Aggregation node '{name}' has no upstream parents")

    reachable = set(self._input_nodes)
    for node in order:
        if node in reachable:
            for src, dst in self._edges:
                if src == node:
                    reachable.add(dst)
    if self._output_node not in reachable:
        raise ValueError(f"Output node '{self._output_node}' is not reachable from input nodes {self._input_nodes}")

    return self

Aggregation nodes for MetaEndpointDAG.

An aggregation node combines outputs from one or more upstream nodes (Endpoint or other AggregationNode instances) into a single DataFrame. Two broad categories:

  • Column-union aggregators (Concat): join feature outputs from parallel feature endpoints into a single wide row per id — used for feature-pipeline DAGs (e.g. [2D] + [3D] → Concat).

  • Prediction aggregators (Mean, WeightedMean, Vote, plus the ensemble-strategy ports ConfidenceWeighted, InverseMaeWeighted, ScaledConfidenceWeighted, CalibratedConfidenceWeighted): combine prediction columns from multiple predictor endpoints into a single ensemble prediction with confidence — used for ensemble combination.

Each node declares its input/output column contract so the DAG can be validated statically before any inference runs.

AggregationNode

Base class for DAG aggregation nodes.

Subclasses implement apply() to combine upstream DataFrames and declare output_columns() for static DAG validation.

Aggregation nodes always join across upstream branches using the walker-injected :data:DAG_ROW_ID column, so they don't need to know anything about the caller's id conventions (or whether the caller has any).

Source code in src/workbench/utils/aggregation_nodes.py
class AggregationNode:
    """Base class for DAG aggregation nodes.

    Subclasses implement ``apply()`` to combine upstream DataFrames and
    declare ``output_columns()`` for static DAG validation.

    Aggregation nodes always join across upstream branches using the
    walker-injected :data:`DAG_ROW_ID` column, so they don't need to
    know anything about the caller's id conventions (or whether the
    caller has any).
    """

    def __init__(self, name: str):
        self.name = name

    def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
        """Combine upstream DataFrames into one. Subclasses must override."""
        raise NotImplementedError

    def input_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
        """The columns this node expects across all upstream outputs.

        Default: union of all upstream output columns. Subclasses can
        narrow this if they only consume specific columns.
        """
        seen = set()
        cols: List[str] = []
        for upstream in upstream_outputs:
            for c in upstream:
                if c not in seen:
                    seen.add(c)
                    cols.append(c)
        return cols

    def output_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
        """The columns this node emits. Subclasses must override."""
        raise NotImplementedError

apply(upstream)

Combine upstream DataFrames into one. Subclasses must override.

Source code in src/workbench/utils/aggregation_nodes.py
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
    """Combine upstream DataFrames into one. Subclasses must override."""
    raise NotImplementedError

input_columns(upstream_outputs)

The columns this node expects across all upstream outputs.

Default: union of all upstream output columns. Subclasses can narrow this if they only consume specific columns.

Source code in src/workbench/utils/aggregation_nodes.py
def input_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
    """The columns this node expects across all upstream outputs.

    Default: union of all upstream output columns. Subclasses can
    narrow this if they only consume specific columns.
    """
    seen = set()
    cols: List[str] = []
    for upstream in upstream_outputs:
        for c in upstream:
            if c not in seen:
                seen.add(c)
                cols.append(c)
    return cols

output_columns(upstream_outputs)

The columns this node emits. Subclasses must override.

Source code in src/workbench/utils/aggregation_nodes.py
def output_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
    """The columns this node emits. Subclasses must override."""
    raise NotImplementedError

CalibratedConfidenceWeighted

Bases: _StrategyAggregator

Per-row weights = confidence × |conf-error correlation| (normalized).

Rewards models whose confidence actually predicts accuracy.

Source code in src/workbench/utils/aggregation_nodes.py
class CalibratedConfidenceWeighted(_StrategyAggregator):
    """Per-row weights = ``confidence × |conf-error correlation|`` (normalized).

    Rewards models whose confidence actually predicts accuracy.
    """

    def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
        self._check_arity(upstream)
        ids, preds, confs = self._stack(upstream)
        calibrated = confs * self.corr_scale
        weights = conf_weights_with_fallback(calibrated, self.model_weights)
        return self._build_output(
            upstream,
            ids,
            prediction=(preds * weights).sum(axis=1),
            prediction_std=preds.std(axis=1),
            confidence=ensemble_confidence(preds, confs, self.corr_scale, self.model_weights, self.optimal_alpha),
        )

Concat

Bases: AggregationNode

Column-union aggregator. Joins upstream DataFrames on the walker's synthetic row id.

Use for feature-pipeline DAGs where parallel feature endpoints contribute disjoint feature column sets that need to be merged into a single wide row.

Source code in src/workbench/utils/aggregation_nodes.py
class Concat(AggregationNode):
    """Column-union aggregator. Joins upstream DataFrames on the walker's
    synthetic row id.

    Use for feature-pipeline DAGs where parallel feature endpoints
    contribute disjoint feature column sets that need to be merged into a
    single wide row.
    """

    def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
        if not upstream:
            raise ValueError(f"Concat[{self.name}]: requires at least one upstream DataFrame")

        out = upstream[0]
        for df in upstream[1:]:
            new_cols = [c for c in df.columns if c == DAG_ROW_ID or c not in out.columns]
            out = out.merge(df[new_cols], on=DAG_ROW_ID, how="inner")
        return out

    def output_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
        seen = set()
        cols: List[str] = []
        for upstream in upstream_outputs:
            for c in upstream:
                if c not in seen:
                    seen.add(c)
                    cols.append(c)
        return cols

ConfidenceWeighted

Bases: _StrategyAggregator

Per-row weights = upstream confidences (normalized).

Falls back to static model_weights when row confidences sum to ~0.

Source code in src/workbench/utils/aggregation_nodes.py
class ConfidenceWeighted(_StrategyAggregator):
    """Per-row weights = upstream confidences (normalized).

    Falls back to static ``model_weights`` when row confidences sum to ~0.
    """

    def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
        self._check_arity(upstream)
        ids, preds, confs = self._stack(upstream)
        weights = conf_weights_with_fallback(confs, self.model_weights)
        return self._build_output(
            upstream,
            ids,
            prediction=(preds * weights).sum(axis=1),
            prediction_std=preds.std(axis=1),
            confidence=ensemble_confidence(preds, confs, self.corr_scale, self.model_weights, self.optimal_alpha),
        )

InverseMaeWeighted

Bases: _StrategyAggregator

Static per-model weights from inverse-MAE.

The caller passes the inverse-MAE-derived weights directly via model_weights. Identical to WeightedMean for the prediction column, but additionally computes calibrated ensemble confidence.

Source code in src/workbench/utils/aggregation_nodes.py
class InverseMaeWeighted(_StrategyAggregator):
    """Static per-model weights from inverse-MAE.

    The caller passes the inverse-MAE-derived weights directly via
    ``model_weights``. Identical to ``WeightedMean`` for the prediction
    column, but additionally computes calibrated ensemble confidence.
    """

    def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
        self._check_arity(upstream)
        ids, preds, confs = self._stack(upstream)
        return self._build_output(
            upstream,
            ids,
            prediction=(preds * self.model_weights).sum(axis=1),
            prediction_std=preds.std(axis=1),
            confidence=ensemble_confidence(preds, confs, self.corr_scale, self.model_weights, self.optimal_alpha),
        )

Mean

Bases: _PredictionAggregator

Simple equal-weight mean of predictions.

Source code in src/workbench/utils/aggregation_nodes.py
class Mean(_PredictionAggregator):
    """Simple equal-weight mean of predictions."""

    def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
        ids, preds, confs = self._stack(upstream)
        return self._build_output(
            upstream,
            ids,
            prediction=preds.mean(axis=1),
            prediction_std=preds.std(axis=1),
            confidence=confs.mean(axis=1),
        )

ScaledConfidenceWeighted

Bases: _StrategyAggregator

Per-row weights = model_weights × confidence (normalized).

Often the top performer in practice — combines static MAE-derived weighting with per-row confidence scaling.

Source code in src/workbench/utils/aggregation_nodes.py
class ScaledConfidenceWeighted(_StrategyAggregator):
    """Per-row weights = ``model_weights × confidence`` (normalized).

    Often the top performer in practice — combines static MAE-derived
    weighting with per-row confidence scaling.
    """

    def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
        self._check_arity(upstream)
        ids, preds, confs = self._stack(upstream)
        scaled = confs * self.model_weights
        weights = conf_weights_with_fallback(scaled, self.model_weights)
        return self._build_output(
            upstream,
            ids,
            prediction=(preds * weights).sum(axis=1),
            prediction_std=preds.std(axis=1),
            confidence=ensemble_confidence(preds, confs, self.corr_scale, self.model_weights, self.optimal_alpha),
        )

Vote

Bases: _PredictionAggregator

Majority-vote aggregator for classifier predictions.

Expects each upstream's prediction column to hold class labels (string or int). Output prediction is the most common label per row; prediction_std is 0 (placeholder for contract symmetry); confidence is the fraction of upstream models that voted for the winning label.

Source code in src/workbench/utils/aggregation_nodes.py
class Vote(_PredictionAggregator):
    """Majority-vote aggregator for classifier predictions.

    Expects each upstream's ``prediction`` column to hold class labels
    (string or int). Output ``prediction`` is the most common label per
    row; ``prediction_std`` is 0 (placeholder for contract symmetry);
    ``confidence`` is the fraction of upstream models that voted for the
    winning label.
    """

    def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
        if not upstream:
            raise ValueError(f"Vote[{self.name}]: requires at least one upstream DataFrame")

        ids = upstream[0][[DAG_ROW_ID]].copy()
        for df in upstream[1:]:
            ids = ids.merge(df[[DAG_ROW_ID]], on=DAG_ROW_ID, how="inner")

        labels = pd.concat(
            [
                ids.merge(df[[DAG_ROW_ID, "prediction"]], on=DAG_ROW_ID)["prediction"].rename(f"_p{i}")
                for i, df in enumerate(upstream)
            ],
            axis=1,
        )

        modes = labels.mode(axis=1)[0]
        winner_share = (labels.eq(modes, axis=0)).sum(axis=1) / labels.shape[1]

        return self._build_output(
            upstream,
            ids,
            prediction=modes.to_numpy(),
            prediction_std=0.0,
            confidence=winner_share.to_numpy(),
        )

WeightedMean

Bases: _PredictionAggregator

Static-weight mean — caller supplies one weight per upstream.

Source code in src/workbench/utils/aggregation_nodes.py
class WeightedMean(_PredictionAggregator):
    """Static-weight mean — caller supplies one weight per upstream."""

    def __init__(self, name: str, weights: List[float]):
        super().__init__(name)
        if not weights:
            raise ValueError("WeightedMean: weights must be a non-empty list")
        w = np.asarray(weights, dtype=np.float64)
        if (w < 0).any():
            raise ValueError("WeightedMean: weights must be non-negative")
        if w.sum() <= 0:
            raise ValueError("WeightedMean: at least one weight must be positive")
        self.weights = w / w.sum()

    def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
        if len(upstream) != len(self.weights):
            raise ValueError(
                f"WeightedMean[{self.name}]: got {len(upstream)} upstream frames " f"but {len(self.weights)} weights"
            )
        ids, preds, confs = self._stack(upstream)
        return self._build_output(
            upstream,
            ids,
            prediction=(preds * self.weights).sum(axis=1),
            prediction_std=preds.std(axis=1),
            confidence=(confs * self.weights).sum(axis=1),
        )

Questions?

The SuperCowPowers team is happy to answer any questions you may have about AWS and Workbench.