What's a MetaEndpoint?
A MetaEndpoint is a deployed Endpoint backed by a directed acyclic graph (DAG) of other endpoints + aggregation nodes. From the caller's perspective it's just an Endpoint — endpoint.inference(df) returns a DataFrame. The DAG machinery is server-side.
MetaEndpoint lets you compose multiple deployed endpoints into a single inference target. Two canonical shapes:
- Feature pipelines — fan out to several feature endpoints in parallel, merge the columns, optionally feed a predictor.
- Ensembles — fan out to several predictor endpoints, aggregate their predictions (mean, weighted mean, vote, calibrated confidence weighting, …) into a single output.
The same DAG abstraction covers both.
Quick Start: Feature Pipeline
Combine the 2D and 3D-fast feature endpoints into a single endpoint that returns merged feature columns per molecule:
from workbench.api import MetaEndpoint
from workbench.utils.meta_endpoint_dag import MetaEndpointDAG
from workbench.utils.aggregation_nodes import Concat
dag = MetaEndpointDAG()
dag.add_endpoint("smiles-to-2d-v1")
dag.add_endpoint("smiles-to-3d-fast-v1")
dag.add_aggregation(Concat(name="combine"))
dag.add_edge("smiles-to-2d-v1", "combine")
dag.add_edge("smiles-to-3d-fast-v1", "combine")
dag.set_input_node("smiles-to-2d-v1", "smiles-to-3d-fast-v1")
dag.set_output_node("combine")
end = MetaEndpoint.create(
name="smiles-to-2d-3d-features",
dag=dag,
description="2D RDKit/Mordred + 3D-fast features",
tags=["meta", "features"],
)
# Use it like any other endpoint
import pandas as pd
df = pd.DataFrame({"smiles": ["CCO", "c1ccccc1"]})
result = end.inference(df)
# result has the input columns + 2D + 3D feature columns
Quick Start: Ensemble
Combine three predictor endpoints (XGBoost + PyTorch + ChemProp) into one ensemble:
from workbench.api import MetaEndpoint
from workbench.utils.meta_endpoint_dag import MetaEndpointDAG
from workbench.utils.aggregation_nodes import Mean
dag = MetaEndpointDAG()
dag.add_endpoint("logd-reg-xgb")
dag.add_endpoint("logd-reg-pytorch")
dag.add_endpoint("logd-reg-chemprop")
dag.add_aggregation(Mean(name="ensemble"))
dag.add_edge("logd-reg-xgb", "ensemble")
dag.add_edge("logd-reg-pytorch", "ensemble")
dag.add_edge("logd-reg-chemprop", "ensemble")
dag.set_input_node("logd-reg-xgb", "logd-reg-pytorch", "logd-reg-chemprop")
dag.set_output_node("ensemble")
end = MetaEndpoint.create(name="logd-meta", dag=dag, tags=["meta", "logd"])
The output has the standard prediction / prediction_std (ensemble disagreement) / confidence columns alongside whatever pass-through columns the input had.
Async Auto-Detection
If any child endpoint in the DAG is deployed as async (e.g. smiles-to-3d-full-v1), the MetaEndpoint is automatically deployed as async too — its 60-minute invocation budget needs to accommodate the slowest child. You don't specify this; MetaEndpoint.create() detects it via dag.has_async_endpoint() and chooses the deploy mode.
DAG Building Blocks
Endpoint nodes
Every endpoint node refers to a deployed Workbench endpoint by name. Endpoint nodes can be:
- Input nodes — receive the caller's input DataFrame directly. Declared with
dag.set_input_node(...).
- Downstream endpoint nodes — take their input from a single upstream parent (e.g. a
Concat aggregation feeding a predictor).
dag.add_endpoint("smiles-to-2d-v1") # node name = endpoint name (default)
dag.add_endpoint("smiles-to-2d-v1", node_name="left_2d") # explicit node name (for aliasing)
Aggregation nodes
| Class |
Use case |
Output |
Concat |
Column-union of feature outputs from parallel branches |
All upstream columns merged |
Mean |
Equal-weight average of predictions |
prediction, prediction_std, confidence |
WeightedMean(weights=[…]) |
Static-weight average |
Same |
Vote |
Majority vote for classifiers |
prediction (label), confidence (winner share) |
ConfidenceWeighted(model_weights=…) |
Per-row weights from upstream confidences |
Same as Mean, with calibrated confidence |
InverseMaeWeighted(model_weights=…) |
Static weights from inverse-MAE |
Same |
ScaledConfidenceWeighted(model_weights=…) |
Static MAE × per-row confidence |
Same |
CalibratedConfidenceWeighted(model_weights=…, corr_scale=…) |
Confidence × |conf-error correlation| |
Same |
Edges
Edges declare data flow:
dag.add_edge("smiles-to-2d-v1", "combine") # 2D output flows into combine
dag.add_edge("combine", "predictor") # combined features flow into predictor
Endpoint nodes accept at most one inbound edge (one source for their input DataFrame). Aggregation nodes can have any number of inbound edges.
Validation
Call dag.validate() to fail loud on misconfiguration before any inference round-trips. Checks include cycle detection, dangling endpoint nodes, and reachability from input to output.
Row Alignment
The walker injects a synthetic __dag_row_id column at the start of every run() and strips it before returning. Aggregation nodes use it as the join key so callers do not need to supply any id column on their input data — and any id-like column they do supply just flows through as a regular pass-through column.
This gives MetaEndpoints a clean contract: any DataFrame with the columns your input-node endpoints expect is a valid input.
How It Works
Creation flow
When you call MetaEndpoint.create(name, dag, ...):
- Validate —
dag.validate() fails loud on cycles, dangling nodes, etc.
- Resolve async flags — looks up each child's
workbench_meta to record per-endpoint async status.
- Lineage anchor — backtraces the first input endpoint to a FeatureSet/target/feature_list (Workbench Models need to point at a FeatureSet).
- Build + register the Model — runs the standard
FeatureSet.to_model() flow, passing the DAG dict + region + bucket as custom_args. The meta_endpoint.template substitutes those placeholders, the SageMaker training job persists meta_endpoint_config.json as the model artifact, and the model package is registered with the standard inference image.
- Deploy —
model.to_endpoint(...) deploys an Endpoint, async if any child is async (with max_instances=1 and 5-minute idle drain). inference_batch_size is auto-set to the minimum across DAG children.
Inference flow
When the deployed MetaEndpoint receives a request:
- The container deserializes the DAG from the model artifact.
- The walker traverses nodes in topological order.
- Endpoint nodes call
fast_inference (sync child) or async_inference (async child) via workbench_bridges — the transport is decided per child by the async flags captured at deploy time.
- Aggregation nodes apply their combination logic on the upstream outputs.
- The output node's DataFrame is returned to the caller (with the synthetic
__dag_row_id stripped).
Failure policy is fail-fast: any exception in any node propagates out and the request fails. (Future: per-node failure policies for partial-result aggregation.)
CLI: Ensemble Simulator
Before committing to an ensemble shape, the ensemble_sim CLI lets you evaluate aggregation strategies offline against captured cross-fold predictions:
ensemble_sim logd-reg-xgb logd-reg-pytorch logd-reg-chemprop
This reports per-strategy MAE / RMSE / R² so you can pick the aggregation node that performs best before building the DAG. The same simulator will also be wired into MetaEndpoint.create() for auto-tuning when a DAG includes a strategy-tunable aggregation node (planned).
API Reference
MetaEndpoint: An Endpoint backed by a directed acyclic graph (DAG) of
child endpoints and aggregation nodes.
A MetaEndpoint behaves identically to a regular Endpoint at runtime —
callers do endpoint.inference(df) and get a DataFrame back. The DAG
machinery is server-side: the deployed container loads the serialized
DAG, dispatches each child invocation to fast_inference (sync) or
async_inference (async), and runs aggregation nodes locally.
Common usage::
from workbench.api import MetaEndpoint
from workbench.utils.meta_endpoint_dag import MetaEndpointDAG
from workbench.utils.aggregation_nodes import Concat
dag = MetaEndpointDAG()
dag.add_endpoint("smiles-to-2d-v1")
dag.add_endpoint("smiles-to-3d-fast-v1")
dag.add_aggregation(Concat(name="combine"))
dag.add_edge("smiles-to-2d-v1", "combine")
dag.add_edge("smiles-to-3d-fast-v1", "combine")
dag.set_input_node("smiles-to-2d-v1", "smiles-to-3d-fast-v1")
dag.set_output_node("combine")
end = MetaEndpoint.create(name="my-features-meta", dag=dag)
# Input does not need any id column — the DAG handles row alignment internally.
df = end.inference(input_df)
If any child endpoint in the DAG is async (e.g. smiles-to-3d-full-v1),
the MetaEndpoint is automatically deployed as async too — its invocation
budget needs to accommodate the slowest child.
Bases: Endpoint
Endpoint backed by a :class:MetaEndpointDAG.
Constructor wraps an existing deployed MetaEndpoint by name, identical
to :class:Endpoint. Use :meth:create to build and deploy a new one
from a DAG.
Source code in src/workbench/api/meta_endpoint.py
| class MetaEndpoint(Endpoint):
"""Endpoint backed by a :class:`MetaEndpointDAG`.
Constructor wraps an existing deployed MetaEndpoint by name, identical
to :class:`Endpoint`. Use :meth:`create` to build and deploy a new one
from a DAG.
"""
@classmethod
def create(
cls,
name: str,
dag: MetaEndpointDAG,
description: str | None = None,
tags: list[str] | None = None,
) -> "MetaEndpoint":
"""Build, register, and deploy a MetaEndpoint from a DAG.
Steps:
1. Validate the DAG; populate per-endpoint async flags.
2. Backtrace lineage from a primary endpoint to satisfy
Workbench's Model machinery (FeatureSet, target, features).
3. Run the standard ``FeatureSet.to_model()`` flow, passing the
DAG / region / bucket as ``custom_args`` so the meta-endpoint
template fills them in at training time.
4. Set DAG-specific ``workbench_meta`` keys on the resulting Model.
5. Deploy the endpoint (async if any DAG child is async).
Args:
name: Endpoint / Model name.
dag: A :class:`MetaEndpointDAG` describing the data flow.
description: Optional description for the registered model.
tags: Optional list of Workbench tags.
Returns:
The deployed MetaEndpoint, ready for ``.inference()``.
"""
Artifact.is_name_valid(name, delimiter="-", lower_case=False)
log.important(f"Validating DAG for MetaEndpoint '{name}'...")
dag.validate()
dag.populate_async_flags()
is_async = dag.has_async_endpoint()
log.important(
f"DAG: {len(dag._endpoints)} endpoints, {len(dag._aggregations)} aggregation nodes "
f"({'async' if is_async else 'sync'} deployment)"
)
# Backtrace lineage from a primary endpoint to satisfy Workbench Model
# machinery (every Model needs a FeatureSet to hang off of).
feature_list, feature_set_name, target_column = cls._derive_lineage(dag)
# Build the model via the standard FeatureSet → Model flow. The
# meta-endpoint template's `{{dag}}`, `{{aws_region}}`, `{{s3_bucket}}`
# placeholders are filled from custom_args.
aws_clamp = AWSAccountClamp()
sm_session = aws_clamp.sagemaker_session()
workbench_bucket = ConfigManager().get_config("WORKBENCH_BUCKET")
feature_set = FeatureSet(feature_set_name)
feature_set.to_model(
name=name,
model_type=cls._derive_model_type(dag),
model_framework=ModelFramework.META,
tags=tags or [name],
description=description or f"MetaEndpoint DAG over: {', '.join(dag._endpoints.values())}",
target_column=target_column,
feature_list=feature_list,
custom_args={
"dag": dag.to_dict(),
"aws_region": sm_session.boto_region_name,
"s3_bucket": workbench_bucket,
},
)
# Append DAG-specific workbench_meta on top of what FeaturesToModel
# already set (model_type, framework, features, target, training view).
output_model = ModelCore(name)
output_model.upsert_workbench_meta({"endpoints": list(dag._endpoints.values())})
output_model.upsert_workbench_meta({"meta_endpoint_dag": dag.to_dict()})
# Deploy. MetaEndpoint containers are thin orchestrators — actual
# compute happens in the child endpoints, which scale on their own
# backlog. One meta instance can already drive 100s of concurrent
# child calls (async_inference uses a 64-thread worker pool
# internally), so additional meta instances don't help. Async deploy
# is therefore 0→1 with idle drain; sync is fixed 1.
log.important(f"Deploying MetaEndpoint '{name}' ({'async' if is_async else 'sync'})...")
model = Model(name)
if is_async:
endpoint = model.to_endpoint(
tags=tags or [name],
async_endpoint=True,
max_instances=1,
scale_in_idle_minutes=5,
)
else:
endpoint = model.to_endpoint(
tags=tags or [name],
async_endpoint=False,
)
# Auto-derive inference_batch_size from the smallest tolerance among
# children — chunks the meta receives from SageMaker get fanned out
# as-is to every child, so the meta's chunk size shouldn't exceed
# the smallest child's batch size.
min_batch = dag.min_child_batch_size()
endpoint.upsert_workbench_meta({"inference_batch_size": min_batch})
log.important(f"Set inference_batch_size={min_batch} (min across DAG children)")
# Publish the largest child fleet size as effective_max_instances so
# callers (e.g. InferenceCache) can size their work units to fill
# downstream child capacity rather than the meta's own
# max_instances=1 (which only describes the orchestrator layer).
effective_max = dag.max_child_max_instances()
endpoint.upsert_workbench_meta({"effective_max_instances": effective_max})
log.important(f"Set effective_max_instances={effective_max} (max across DAG children)")
log.important(f"MetaEndpoint '{name}' created successfully!")
return cls(name)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@classmethod
def _derive_model_type(cls, dag: MetaEndpointDAG) -> ModelType:
"""Pick the most accurate :class:`ModelType` for the DAG's output.
- Output node is a terminal endpoint → borrow that endpoint's declared
type (e.g., a downstream predictor endpoint contributes its own type).
- Output node is :class:`~workbench.utils.aggregation_nodes.Concat` →
``TRANSFORMER`` (column-union of feature outputs).
- Output node is :class:`~workbench.utils.aggregation_nodes.Vote` →
``CLASSIFIER`` (majority vote of class labels).
- Output node is any other prediction aggregator → ``REGRESSOR``.
"""
from workbench.utils.aggregation_nodes import Concat, Vote
output_name = dag._output_node
if output_name in dag._endpoints:
ep_name = dag._endpoints[output_name]
try:
return Model(ep_name).model_type
except Exception:
return ModelType.REGRESSOR
agg = dag._aggregations[output_name]
if isinstance(agg, Concat):
return ModelType.TRANSFORMER
if isinstance(agg, Vote):
return ModelType.CLASSIFIER
return ModelType.REGRESSOR
@classmethod
def _derive_lineage(cls, dag: MetaEndpointDAG) -> tuple[list[str], str, str | None]:
"""Backtrace from the first input endpoint to find a FeatureSet + lineage.
Workbench Models need to trace back to a FeatureSet. For DAG-based
MetaEndpoints there isn't a single canonical FeatureSet, so we use
the first input node's lineage as a representative anchor.
Returns ``(feature_list, feature_set_name, target_column)``.
``target_column`` may be ``None`` for pure feature-pipeline DAGs
whose primary endpoint is a feature endpoint.
"""
if not dag._input_nodes:
raise ValueError("DAG has no input nodes — cannot derive lineage")
primary_endpoint_name = dag._endpoints[dag._input_nodes[0]]
ep = Endpoint(primary_endpoint_name)
if not ep.exists():
raise ValueError(f"Primary endpoint '{primary_endpoint_name}' does not exist")
primary_model = Model(ep.get_input())
feature_list = primary_model.features() or list(dag.input_columns())
feature_set_name = primary_model.get_input()
target_column = primary_model.target()
log.info(
f"Lineage anchor: {primary_endpoint_name} -> {primary_model.name} -> {feature_set_name} "
f"(target: {target_column})"
)
return feature_list, feature_set_name, target_column
def get_dag(self) -> MetaEndpointDAG:
"""Reconstruct the MetaEndpointDAG from this endpoint's stored metadata."""
meta = self.workbench_meta() or {}
dag_dict = meta.get("meta_endpoint_dag")
if not dag_dict:
raise ValueError(
f"MetaEndpoint '{self.name}' has no DAG in workbench_meta. Recreate via MetaEndpoint.create()."
)
return MetaEndpointDAG.from_dict(dag_dict)
|
Build, register, and deploy a MetaEndpoint from a DAG.
Steps
- Validate the DAG; populate per-endpoint async flags.
- Backtrace lineage from a primary endpoint to satisfy
Workbench's Model machinery (FeatureSet, target, features).
- Run the standard
FeatureSet.to_model() flow, passing the
DAG / region / bucket as custom_args so the meta-endpoint
template fills them in at training time.
- Set DAG-specific
workbench_meta keys on the resulting Model.
- Deploy the endpoint (async if any DAG child is async).
Parameters:
| Name |
Type |
Description |
Default |
name
|
str
|
|
required
|
dag
|
MetaEndpointDAG
|
A :class:MetaEndpointDAG describing the data flow.
|
required
|
description
|
str | None
|
Optional description for the registered model.
|
None
|
tags
|
list[str] | None
|
Optional list of Workbench tags.
|
None
|
Returns:
| Type |
Description |
'MetaEndpoint'
|
The deployed MetaEndpoint, ready for .inference().
|
Source code in src/workbench/api/meta_endpoint.py
| @classmethod
def create(
cls,
name: str,
dag: MetaEndpointDAG,
description: str | None = None,
tags: list[str] | None = None,
) -> "MetaEndpoint":
"""Build, register, and deploy a MetaEndpoint from a DAG.
Steps:
1. Validate the DAG; populate per-endpoint async flags.
2. Backtrace lineage from a primary endpoint to satisfy
Workbench's Model machinery (FeatureSet, target, features).
3. Run the standard ``FeatureSet.to_model()`` flow, passing the
DAG / region / bucket as ``custom_args`` so the meta-endpoint
template fills them in at training time.
4. Set DAG-specific ``workbench_meta`` keys on the resulting Model.
5. Deploy the endpoint (async if any DAG child is async).
Args:
name: Endpoint / Model name.
dag: A :class:`MetaEndpointDAG` describing the data flow.
description: Optional description for the registered model.
tags: Optional list of Workbench tags.
Returns:
The deployed MetaEndpoint, ready for ``.inference()``.
"""
Artifact.is_name_valid(name, delimiter="-", lower_case=False)
log.important(f"Validating DAG for MetaEndpoint '{name}'...")
dag.validate()
dag.populate_async_flags()
is_async = dag.has_async_endpoint()
log.important(
f"DAG: {len(dag._endpoints)} endpoints, {len(dag._aggregations)} aggregation nodes "
f"({'async' if is_async else 'sync'} deployment)"
)
# Backtrace lineage from a primary endpoint to satisfy Workbench Model
# machinery (every Model needs a FeatureSet to hang off of).
feature_list, feature_set_name, target_column = cls._derive_lineage(dag)
# Build the model via the standard FeatureSet → Model flow. The
# meta-endpoint template's `{{dag}}`, `{{aws_region}}`, `{{s3_bucket}}`
# placeholders are filled from custom_args.
aws_clamp = AWSAccountClamp()
sm_session = aws_clamp.sagemaker_session()
workbench_bucket = ConfigManager().get_config("WORKBENCH_BUCKET")
feature_set = FeatureSet(feature_set_name)
feature_set.to_model(
name=name,
model_type=cls._derive_model_type(dag),
model_framework=ModelFramework.META,
tags=tags or [name],
description=description or f"MetaEndpoint DAG over: {', '.join(dag._endpoints.values())}",
target_column=target_column,
feature_list=feature_list,
custom_args={
"dag": dag.to_dict(),
"aws_region": sm_session.boto_region_name,
"s3_bucket": workbench_bucket,
},
)
# Append DAG-specific workbench_meta on top of what FeaturesToModel
# already set (model_type, framework, features, target, training view).
output_model = ModelCore(name)
output_model.upsert_workbench_meta({"endpoints": list(dag._endpoints.values())})
output_model.upsert_workbench_meta({"meta_endpoint_dag": dag.to_dict()})
# Deploy. MetaEndpoint containers are thin orchestrators — actual
# compute happens in the child endpoints, which scale on their own
# backlog. One meta instance can already drive 100s of concurrent
# child calls (async_inference uses a 64-thread worker pool
# internally), so additional meta instances don't help. Async deploy
# is therefore 0→1 with idle drain; sync is fixed 1.
log.important(f"Deploying MetaEndpoint '{name}' ({'async' if is_async else 'sync'})...")
model = Model(name)
if is_async:
endpoint = model.to_endpoint(
tags=tags or [name],
async_endpoint=True,
max_instances=1,
scale_in_idle_minutes=5,
)
else:
endpoint = model.to_endpoint(
tags=tags or [name],
async_endpoint=False,
)
# Auto-derive inference_batch_size from the smallest tolerance among
# children — chunks the meta receives from SageMaker get fanned out
# as-is to every child, so the meta's chunk size shouldn't exceed
# the smallest child's batch size.
min_batch = dag.min_child_batch_size()
endpoint.upsert_workbench_meta({"inference_batch_size": min_batch})
log.important(f"Set inference_batch_size={min_batch} (min across DAG children)")
# Publish the largest child fleet size as effective_max_instances so
# callers (e.g. InferenceCache) can size their work units to fill
# downstream child capacity rather than the meta's own
# max_instances=1 (which only describes the orchestrator layer).
effective_max = dag.max_child_max_instances()
endpoint.upsert_workbench_meta({"effective_max_instances": effective_max})
log.important(f"Set effective_max_instances={effective_max} (max across DAG children)")
log.important(f"MetaEndpoint '{name}' created successfully!")
return cls(name)
|
Reconstruct the MetaEndpointDAG from this endpoint's stored metadata.
Source code in src/workbench/api/meta_endpoint.py
| def get_dag(self) -> MetaEndpointDAG:
"""Reconstruct the MetaEndpointDAG from this endpoint's stored metadata."""
meta = self.workbench_meta() or {}
dag_dict = meta.get("meta_endpoint_dag")
if not dag_dict:
raise ValueError(
f"MetaEndpoint '{self.name}' has no DAG in workbench_meta. Recreate via MetaEndpoint.create()."
)
return MetaEndpointDAG.from_dict(dag_dict)
|
MetaEndpointDAG — a directed acyclic graph of endpoints and
aggregation nodes describing an inference-time data flow.
A DAG has two kinds of nodes:
-
Endpoint nodes — references to deployed Workbench Endpoint
instances by name. The DAG defers actual Endpoint instantiation
until execution / column-contract resolution.
-
Aggregation nodes — instances of :class:AggregationNode subclasses
that combine outputs from upstream nodes.
DAG construction is explicit::
dag = MetaEndpointDAG()
dag.add_endpoint("smiles-to-2d-v1")
dag.add_endpoint("smiles-to-3d-fast-v1")
dag.add_aggregation(Concat(name="combine"))
dag.add_edge("smiles-to-2d-v1", "combine")
dag.add_edge("smiles-to-3d-fast-v1", "combine")
dag.set_input_node("smiles-to-2d-v1", "smiles-to-3d-fast-v1")
dag.set_output_node("combine")
dag.validate()
Row-alignment across parallel branches: the walker injects a synthetic
:data:DAG_ROW_ID column at the start of every run() and strips it
before returning. Aggregation nodes use it as the join key, so callers
do not need to supply (or care about) any id column on their input data.
Validation runs at construction time so misconfigured DAGs fail loud
before any inference round-trips.
A typed DAG of endpoints + aggregation nodes.
The DAG joins parallel branches using an internal synthetic row id
(:data:DAG_ROW_ID) injected by :meth:run — callers don't need to
supply any id column on their input.
Source code in src/workbench/utils/meta_endpoint_dag.py
| class MetaEndpointDAG:
"""A typed DAG of endpoints + aggregation nodes.
The DAG joins parallel branches using an internal synthetic row id
(:data:`DAG_ROW_ID`) injected by :meth:`run` — callers don't need to
supply any id column on their input.
"""
def __init__(self):
self._endpoints: Dict[str, str] = {} # node_name → endpoint_name
self._endpoint_async_flags: Dict[str, bool] = {} # populated by populate_async_flags()
self._aggregations: Dict[str, AggregationNode] = {}
self._edges: List[tuple[str, str]] = [] # (from_node, to_node)
self._input_nodes: List[str] = []
self._output_node: Optional[str] = None
# ------------------------------------------------------------------
# Construction
# ------------------------------------------------------------------
def add_endpoint(self, endpoint_name: str, node_name: Optional[str] = None) -> str:
"""Add an endpoint reference to the DAG.
Args:
endpoint_name: Name of a deployed Workbench endpoint.
node_name: Optional unique node name (defaults to ``endpoint_name``).
Returns:
The node name (so callers can chain).
"""
node = node_name or endpoint_name
if node in self._endpoints or node in self._aggregations:
raise ValueError(f"Node '{node}' already exists in this DAG")
self._endpoints[node] = endpoint_name
return node
def add_aggregation(self, node: AggregationNode) -> str:
"""Add an :class:`AggregationNode` to the DAG.
The node's ``name`` must be unique across the DAG.
"""
if node.name in self._endpoints or node.name in self._aggregations:
raise ValueError(f"Node '{node.name}' already exists in this DAG")
self._aggregations[node.name] = node
return node.name
def add_edge(self, from_node: str, to_node: str) -> None:
"""Declare data flow from ``from_node`` to ``to_node``.
Endpoint nodes accept at most one inbound edge (their input DataFrame
comes from a single upstream producer). Aggregation nodes can have
any number of inbound edges.
"""
if from_node not in self._all_nodes():
raise ValueError(f"Edge from unknown node '{from_node}'")
if to_node not in self._all_nodes():
raise ValueError(f"Edge to unknown node '{to_node}'")
if to_node in self._endpoints and self._parents_of(to_node):
raise ValueError(
f"Endpoint node '{to_node}' already has an upstream parent "
f"('{self._parents_of(to_node)[0]}'); endpoints take input "
f"from at most one source."
)
self._edges.append((from_node, to_node))
def set_input_node(self, *nodes: str) -> None:
"""Declare which nodes receive the DAG's input DataFrame directly."""
for n in nodes:
if n not in self._endpoints:
raise ValueError(f"Input nodes must be endpoint nodes; '{n}' is not")
self._input_nodes = list(nodes)
def set_output_node(self, node: str) -> None:
"""Declare the terminal node whose output is the DAG's output."""
if node not in self._all_nodes():
raise ValueError(f"Unknown output node '{node}'")
self._output_node = node
# ------------------------------------------------------------------
# Inspection
# ------------------------------------------------------------------
def _all_nodes(self) -> List[str]:
return list(self._endpoints.keys()) + list(self._aggregations.keys())
def _parents_of(self, node: str) -> List[str]:
return [src for src, dst in self._edges if dst == node]
def topological_order(self) -> List[str]:
"""Return nodes in topological order (parents before children).
Raises:
ValueError: If the DAG contains a cycle.
"""
in_degree = {n: 0 for n in self._all_nodes()}
for _, dst in self._edges:
in_degree[dst] += 1
ready = [n for n, deg in in_degree.items() if deg == 0]
order: List[str] = []
while ready:
node = ready.pop(0)
order.append(node)
for src, dst in self._edges:
if src == node:
in_degree[dst] -= 1
if in_degree[dst] == 0:
ready.append(dst)
if len(order) != len(in_degree):
raise ValueError("DAG contains a cycle")
return order
# ------------------------------------------------------------------
# Column contract
# ------------------------------------------------------------------
def input_columns(self) -> List[str]:
"""Union of input columns required by every node that receives the
caller's input directly.
Used by :class:`MetaEndpoint` as a fallback when deriving the
feature list during lineage anchoring.
"""
from workbench.api import Endpoint
if not self._input_nodes:
raise ValueError("DAG has no input nodes — call set_input_node() first")
seen = set()
cols: List[str] = []
for node in self._input_nodes:
for c in Endpoint(self._endpoints[node]).input_columns():
if c not in seen:
seen.add(c)
cols.append(c)
return cols
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
def validate(self) -> "MetaEndpointDAG":
"""Validate the DAG. Returns self for chaining; raises on failure.
Checks:
- At least one input node and exactly one output node declared
- No cycles
- Aggregation nodes have at least one parent
- Endpoint nodes are either input nodes (zero parents) or have
exactly one upstream parent — never both
- The output node is reachable from the input nodes
"""
if not self._input_nodes:
raise ValueError("DAG has no input nodes")
if self._output_node is None:
raise ValueError("DAG has no output node")
order = self.topological_order() # raises on cycle
for ep_node in self._endpoints:
parents = self._parents_of(ep_node)
is_input = ep_node in self._input_nodes
if is_input and parents:
raise ValueError(
f"Endpoint node '{ep_node}' is declared as an input node but has "
f"upstream parents {parents}; pick one or the other."
)
if not is_input and not parents:
raise ValueError(
f"Endpoint node '{ep_node}' has no upstream parent and is not "
f"declared as an input node — it has no source for its input DataFrame."
)
for name in self._aggregations:
if not self._parents_of(name):
raise ValueError(f"Aggregation node '{name}' has no upstream parents")
reachable = set(self._input_nodes)
for node in order:
if node in reachable:
for src, dst in self._edges:
if src == node:
reachable.add(dst)
if self._output_node not in reachable:
raise ValueError(f"Output node '{self._output_node}' is not reachable from input nodes {self._input_nodes}")
return self
# ------------------------------------------------------------------
# Execution (client-side walker)
# ------------------------------------------------------------------
def run(
self,
input_df: pd.DataFrame,
endpoint_invoker: Optional[EndpointInvoker] = None,
) -> pd.DataFrame:
"""Execute the DAG against ``input_df`` and return the output node's DataFrame.
The walker injects a synthetic :data:`DAG_ROW_ID` column at entry
(used internally to align rows across parallel branches) and
strips it before returning. Callers don't need to supply any id
column.
Walks nodes in topological order. Endpoint nodes call
:meth:`Endpoint.inference` on either the caller's ``input_df`` (input
nodes) or their upstream parent's cached output. Aggregation nodes
receive the cached outputs of all their parents and apply their
combination logic.
Failure policy is fail-fast: any exception in any node propagates
out and the DAG run aborts.
Args:
input_df: DataFrame supplied by the caller. Must contain the
columns required by every input-node endpoint. No id
column is required.
endpoint_invoker: Optional callable ``(endpoint_name, df) -> df``
used to invoke endpoint nodes. Defaults to using the full
Workbench ``Endpoint`` API class — appropriate for client-side
use. Pass a ``fast_inference``-backed invoker when running
inside a deployed SageMaker container where the full
Workbench config isn't available.
Returns:
The DataFrame at the DAG's output node, with the synthetic
:data:`DAG_ROW_ID` column removed.
"""
if self._output_node is None:
raise ValueError("DAG has no output node — call set_output_node() first")
if DAG_ROW_ID in input_df.columns:
raise ValueError(
f"input_df already contains the reserved column '{DAG_ROW_ID}'. " f"Remove it before calling run()."
)
# Inject the synthetic row id. Endpoints will pass this through as an
# unknown input column; aggregation nodes use it as their join key.
input_df = input_df.copy()
input_df[DAG_ROW_ID] = range(len(input_df))
outputs: Dict[str, pd.DataFrame] = {}
for node in self.topological_order():
if node in self._endpoints:
outputs[node] = self._run_endpoint(node, input_df, outputs, endpoint_invoker)
else:
outputs[node] = self._run_aggregation(node, outputs)
result = outputs[self._output_node]
if DAG_ROW_ID in result.columns:
result = result.drop(columns=[DAG_ROW_ID])
return result
def _run_endpoint(
self,
node: str,
input_df: pd.DataFrame,
outputs: Dict[str, pd.DataFrame],
endpoint_invoker: Optional[EndpointInvoker],
) -> pd.DataFrame:
"""Execute a single endpoint node.
Source DataFrame is the caller's input for input nodes, or the
single upstream parent's output otherwise. The full DataFrame is
passed to ``endpoint.inference()`` — metadata columns
(project_id, owner, etc.) flow through alongside the endpoint's
added columns, matching standard Workbench inference behavior.
The walker-injected :data:`DAG_ROW_ID` column must survive the
endpoint round-trip so downstream aggregation nodes can join on
it. If an endpoint silently strips unknown input columns, this
will fail loudly — better than misaligned rows.
"""
endpoint_name = self._endpoints[node]
parents = self._parents_of(node)
source_df = input_df if not parents else outputs[parents[0]]
if endpoint_invoker is not None:
result = endpoint_invoker(endpoint_name, source_df)
else:
from workbench.api import Endpoint
result = Endpoint(endpoint_name).inference(source_df)
if DAG_ROW_ID not in result.columns:
raise RuntimeError(
f"Endpoint '{endpoint_name}' dropped the walker-injected '{DAG_ROW_ID}' "
f"column from its output. The DAG can't align rows across branches "
f"without it. Endpoints must pass unknown input columns through to "
f"their output."
)
return result
def _run_aggregation(self, node: str, outputs: Dict[str, pd.DataFrame]) -> pd.DataFrame:
"""Execute a single aggregation node."""
agg = self._aggregations[node]
upstream = [outputs[p] for p in self._parents_of(node)]
return agg.apply(upstream)
# ------------------------------------------------------------------
# Serialization (model artifact + workbench_meta storage)
# ------------------------------------------------------------------
def to_dict(self) -> dict:
"""Serialize the DAG topology to a JSON-friendly dict.
Aggregation nodes are serialized by class name + constructor kwargs;
deserialization (:meth:`from_dict`) requires the same class to be
importable.
Per-endpoint ``is_async`` flags are included only if
:meth:`populate_async_flags` has been called. The deployed
inference container relies on these flags to dispatch invocations
to ``fast_inference`` or ``async_inference``.
"""
return {
"endpoints": dict(self._endpoints),
"endpoint_async": dict(self._endpoint_async_flags),
"aggregations": [_serialize_aggregation(a) for a in self._aggregations.values()],
"edges": [list(e) for e in self._edges],
"input_nodes": list(self._input_nodes),
"output_node": self._output_node,
}
def to_json(self) -> str:
return json.dumps(self.to_dict(), indent=2)
def populate_async_flags(self) -> None:
"""Look up each endpoint's async flag via ``workbench_meta`` and store it.
Flags are keyed by endpoint name (not node name) so the deployed
invoker can dispatch directly on the value passed by the walker.
Called by :meth:`MetaEndpoint.create` before serializing the DAG
for deployment. Hits AWS once per unique endpoint name, so isolated
as an explicit step rather than running implicitly in :meth:`to_dict`.
"""
from workbench.api import Endpoint
for endpoint_name in set(self._endpoints.values()):
meta = Endpoint(endpoint_name).workbench_meta() or {}
self._endpoint_async_flags[endpoint_name] = bool(meta.get("async_endpoint"))
def has_async_endpoint(self) -> bool:
"""Return True if any endpoint in the DAG is deployed as async.
Used by :meth:`MetaEndpoint.create` to decide whether the meta
endpoint itself must be deployed as async. Lazily calls
:meth:`populate_async_flags` if not yet populated.
"""
if not self._endpoint_async_flags and self._endpoints:
self.populate_async_flags()
return any(self._endpoint_async_flags.values())
def min_child_batch_size(self) -> int:
"""Minimum ``inference_batch_size`` across all child endpoints.
:meth:`MetaEndpoint.create` uses this to set the meta endpoint's
own ``inference_batch_size`` — chunks the meta receives from
SageMaker get fanned out as-is to every child, so the meta's
chunk size shouldn't exceed the smallest tolerance among children.
Smaller chunks at the meta level also mean smaller failure blast
radius (one bad row fails one small chunk, not a large one).
Returns:
int: Minimum batch size; ``100`` if the DAG has no endpoints
(the standard sync default).
"""
from workbench.api import Endpoint
if not self._endpoints:
return 100
sizes = [Endpoint(ep_name).inference_batch_size() for ep_name in set(self._endpoints.values())]
return min(sizes)
def max_child_max_instances(self) -> int:
"""Maximum ``max_instances`` across all child endpoints.
:class:`MetaEndpoint.create` uses this to publish an
``effective_max_instances`` hint into the meta endpoint's
``workbench_meta``. The meta itself deploys with ``max_instances=1``
(it's a thin orchestrator), but downstream tooling like
:class:`InferenceCache` sizes its work units to fill fleet capacity —
and the relevant capacity is the child fleets, not the meta's
single orchestrator instance.
Returns:
int: Maximum ``max_instances`` seen on any child endpoint;
``1`` if no child has it set or the DAG has no endpoints.
"""
from workbench.api import Endpoint
seen: List[int] = []
for ep_name in set(self._endpoints.values()):
meta = Endpoint(ep_name).workbench_meta() or {}
if meta.get("max_instances") is not None:
seen.append(int(meta["max_instances"]))
return max(seen) if seen else 1
@classmethod
def from_dict(cls, data: dict) -> "MetaEndpointDAG":
dag = cls()
for node_name, endpoint_name in data.get("endpoints", {}).items():
dag.add_endpoint(endpoint_name, node_name=node_name)
dag._endpoint_async_flags = dict(data.get("endpoint_async", {}))
for agg_data in data.get("aggregations", []):
dag.add_aggregation(_deserialize_aggregation(agg_data))
for src, dst in data.get("edges", []):
dag.add_edge(src, dst)
if data.get("input_nodes"):
dag.set_input_node(*data["input_nodes"])
if data.get("output_node"):
dag.set_output_node(data["output_node"])
return dag
@classmethod
def from_json(cls, payload: str) -> "MetaEndpointDAG":
return cls.from_dict(json.loads(payload))
|
Add an :class:AggregationNode to the DAG.
The node's name must be unique across the DAG.
Source code in src/workbench/utils/meta_endpoint_dag.py
| def add_aggregation(self, node: AggregationNode) -> str:
"""Add an :class:`AggregationNode` to the DAG.
The node's ``name`` must be unique across the DAG.
"""
if node.name in self._endpoints or node.name in self._aggregations:
raise ValueError(f"Node '{node.name}' already exists in this DAG")
self._aggregations[node.name] = node
return node.name
|
Declare data flow from from_node to to_node.
Endpoint nodes accept at most one inbound edge (their input DataFrame
comes from a single upstream producer). Aggregation nodes can have
any number of inbound edges.
Source code in src/workbench/utils/meta_endpoint_dag.py
| def add_edge(self, from_node: str, to_node: str) -> None:
"""Declare data flow from ``from_node`` to ``to_node``.
Endpoint nodes accept at most one inbound edge (their input DataFrame
comes from a single upstream producer). Aggregation nodes can have
any number of inbound edges.
"""
if from_node not in self._all_nodes():
raise ValueError(f"Edge from unknown node '{from_node}'")
if to_node not in self._all_nodes():
raise ValueError(f"Edge to unknown node '{to_node}'")
if to_node in self._endpoints and self._parents_of(to_node):
raise ValueError(
f"Endpoint node '{to_node}' already has an upstream parent "
f"('{self._parents_of(to_node)[0]}'); endpoints take input "
f"from at most one source."
)
self._edges.append((from_node, to_node))
|
Add an endpoint reference to the DAG.
Parameters:
| Name |
Type |
Description |
Default |
endpoint_name
|
str
|
Name of a deployed Workbench endpoint.
|
required
|
node_name
|
Optional[str]
|
Optional unique node name (defaults to endpoint_name).
|
None
|
Returns:
| Type |
Description |
str
|
The node name (so callers can chain).
|
Source code in src/workbench/utils/meta_endpoint_dag.py
| def add_endpoint(self, endpoint_name: str, node_name: Optional[str] = None) -> str:
"""Add an endpoint reference to the DAG.
Args:
endpoint_name: Name of a deployed Workbench endpoint.
node_name: Optional unique node name (defaults to ``endpoint_name``).
Returns:
The node name (so callers can chain).
"""
node = node_name or endpoint_name
if node in self._endpoints or node in self._aggregations:
raise ValueError(f"Node '{node}' already exists in this DAG")
self._endpoints[node] = endpoint_name
return node
|
Return True if any endpoint in the DAG is deployed as async.
Used by :meth:MetaEndpoint.create to decide whether the meta
endpoint itself must be deployed as async. Lazily calls
:meth:populate_async_flags if not yet populated.
Source code in src/workbench/utils/meta_endpoint_dag.py
| def has_async_endpoint(self) -> bool:
"""Return True if any endpoint in the DAG is deployed as async.
Used by :meth:`MetaEndpoint.create` to decide whether the meta
endpoint itself must be deployed as async. Lazily calls
:meth:`populate_async_flags` if not yet populated.
"""
if not self._endpoint_async_flags and self._endpoints:
self.populate_async_flags()
return any(self._endpoint_async_flags.values())
|
Union of input columns required by every node that receives the
caller's input directly.
Used by :class:MetaEndpoint as a fallback when deriving the
feature list during lineage anchoring.
Source code in src/workbench/utils/meta_endpoint_dag.py
| def input_columns(self) -> List[str]:
"""Union of input columns required by every node that receives the
caller's input directly.
Used by :class:`MetaEndpoint` as a fallback when deriving the
feature list during lineage anchoring.
"""
from workbench.api import Endpoint
if not self._input_nodes:
raise ValueError("DAG has no input nodes — call set_input_node() first")
seen = set()
cols: List[str] = []
for node in self._input_nodes:
for c in Endpoint(self._endpoints[node]).input_columns():
if c not in seen:
seen.add(c)
cols.append(c)
return cols
|
Maximum max_instances across all child endpoints.
:class:MetaEndpoint.create uses this to publish an
effective_max_instances hint into the meta endpoint's
workbench_meta. The meta itself deploys with max_instances=1
(it's a thin orchestrator), but downstream tooling like
:class:InferenceCache sizes its work units to fill fleet capacity —
and the relevant capacity is the child fleets, not the meta's
single orchestrator instance.
Returns:
| Name | Type |
Description |
int |
int
|
Maximum max_instances seen on any child endpoint;
|
|
int
|
1 if no child has it set or the DAG has no endpoints.
|
Source code in src/workbench/utils/meta_endpoint_dag.py
| def max_child_max_instances(self) -> int:
"""Maximum ``max_instances`` across all child endpoints.
:class:`MetaEndpoint.create` uses this to publish an
``effective_max_instances`` hint into the meta endpoint's
``workbench_meta``. The meta itself deploys with ``max_instances=1``
(it's a thin orchestrator), but downstream tooling like
:class:`InferenceCache` sizes its work units to fill fleet capacity —
and the relevant capacity is the child fleets, not the meta's
single orchestrator instance.
Returns:
int: Maximum ``max_instances`` seen on any child endpoint;
``1`` if no child has it set or the DAG has no endpoints.
"""
from workbench.api import Endpoint
seen: List[int] = []
for ep_name in set(self._endpoints.values()):
meta = Endpoint(ep_name).workbench_meta() or {}
if meta.get("max_instances") is not None:
seen.append(int(meta["max_instances"]))
return max(seen) if seen else 1
|
Minimum inference_batch_size across all child endpoints.
:meth:MetaEndpoint.create uses this to set the meta endpoint's
own inference_batch_size — chunks the meta receives from
SageMaker get fanned out as-is to every child, so the meta's
chunk size shouldn't exceed the smallest tolerance among children.
Smaller chunks at the meta level also mean smaller failure blast
radius (one bad row fails one small chunk, not a large one).
Returns:
| Name | Type |
Description |
int |
int
|
Minimum batch size; 100 if the DAG has no endpoints
|
|
int
|
(the standard sync default).
|
Source code in src/workbench/utils/meta_endpoint_dag.py
| def min_child_batch_size(self) -> int:
"""Minimum ``inference_batch_size`` across all child endpoints.
:meth:`MetaEndpoint.create` uses this to set the meta endpoint's
own ``inference_batch_size`` — chunks the meta receives from
SageMaker get fanned out as-is to every child, so the meta's
chunk size shouldn't exceed the smallest tolerance among children.
Smaller chunks at the meta level also mean smaller failure blast
radius (one bad row fails one small chunk, not a large one).
Returns:
int: Minimum batch size; ``100`` if the DAG has no endpoints
(the standard sync default).
"""
from workbench.api import Endpoint
if not self._endpoints:
return 100
sizes = [Endpoint(ep_name).inference_batch_size() for ep_name in set(self._endpoints.values())]
return min(sizes)
|
Look up each endpoint's async flag via workbench_meta and store it.
Flags are keyed by endpoint name (not node name) so the deployed
invoker can dispatch directly on the value passed by the walker.
Called by :meth:MetaEndpoint.create before serializing the DAG
for deployment. Hits AWS once per unique endpoint name, so isolated
as an explicit step rather than running implicitly in :meth:to_dict.
Source code in src/workbench/utils/meta_endpoint_dag.py
| def populate_async_flags(self) -> None:
"""Look up each endpoint's async flag via ``workbench_meta`` and store it.
Flags are keyed by endpoint name (not node name) so the deployed
invoker can dispatch directly on the value passed by the walker.
Called by :meth:`MetaEndpoint.create` before serializing the DAG
for deployment. Hits AWS once per unique endpoint name, so isolated
as an explicit step rather than running implicitly in :meth:`to_dict`.
"""
from workbench.api import Endpoint
for endpoint_name in set(self._endpoints.values()):
meta = Endpoint(endpoint_name).workbench_meta() or {}
self._endpoint_async_flags[endpoint_name] = bool(meta.get("async_endpoint"))
|
Execute the DAG against input_df and return the output node's DataFrame.
The walker injects a synthetic :data:DAG_ROW_ID column at entry
(used internally to align rows across parallel branches) and
strips it before returning. Callers don't need to supply any id
column.
Walks nodes in topological order. Endpoint nodes call
:meth:Endpoint.inference on either the caller's input_df (input
nodes) or their upstream parent's cached output. Aggregation nodes
receive the cached outputs of all their parents and apply their
combination logic.
Failure policy is fail-fast: any exception in any node propagates
out and the DAG run aborts.
Parameters:
| Name |
Type |
Description |
Default |
input_df
|
DataFrame
|
DataFrame supplied by the caller. Must contain the
columns required by every input-node endpoint. No id
column is required.
|
required
|
endpoint_invoker
|
Optional[EndpointInvoker]
|
Optional callable (endpoint_name, df) -> df
used to invoke endpoint nodes. Defaults to using the full
Workbench Endpoint API class — appropriate for client-side
use. Pass a fast_inference-backed invoker when running
inside a deployed SageMaker container where the full
Workbench config isn't available.
|
None
|
Returns:
| Type |
Description |
DataFrame
|
The DataFrame at the DAG's output node, with the synthetic
|
DataFrame
|
data:DAG_ROW_ID column removed.
|
Source code in src/workbench/utils/meta_endpoint_dag.py
| def run(
self,
input_df: pd.DataFrame,
endpoint_invoker: Optional[EndpointInvoker] = None,
) -> pd.DataFrame:
"""Execute the DAG against ``input_df`` and return the output node's DataFrame.
The walker injects a synthetic :data:`DAG_ROW_ID` column at entry
(used internally to align rows across parallel branches) and
strips it before returning. Callers don't need to supply any id
column.
Walks nodes in topological order. Endpoint nodes call
:meth:`Endpoint.inference` on either the caller's ``input_df`` (input
nodes) or their upstream parent's cached output. Aggregation nodes
receive the cached outputs of all their parents and apply their
combination logic.
Failure policy is fail-fast: any exception in any node propagates
out and the DAG run aborts.
Args:
input_df: DataFrame supplied by the caller. Must contain the
columns required by every input-node endpoint. No id
column is required.
endpoint_invoker: Optional callable ``(endpoint_name, df) -> df``
used to invoke endpoint nodes. Defaults to using the full
Workbench ``Endpoint`` API class — appropriate for client-side
use. Pass a ``fast_inference``-backed invoker when running
inside a deployed SageMaker container where the full
Workbench config isn't available.
Returns:
The DataFrame at the DAG's output node, with the synthetic
:data:`DAG_ROW_ID` column removed.
"""
if self._output_node is None:
raise ValueError("DAG has no output node — call set_output_node() first")
if DAG_ROW_ID in input_df.columns:
raise ValueError(
f"input_df already contains the reserved column '{DAG_ROW_ID}'. " f"Remove it before calling run()."
)
# Inject the synthetic row id. Endpoints will pass this through as an
# unknown input column; aggregation nodes use it as their join key.
input_df = input_df.copy()
input_df[DAG_ROW_ID] = range(len(input_df))
outputs: Dict[str, pd.DataFrame] = {}
for node in self.topological_order():
if node in self._endpoints:
outputs[node] = self._run_endpoint(node, input_df, outputs, endpoint_invoker)
else:
outputs[node] = self._run_aggregation(node, outputs)
result = outputs[self._output_node]
if DAG_ROW_ID in result.columns:
result = result.drop(columns=[DAG_ROW_ID])
return result
|
Declare which nodes receive the DAG's input DataFrame directly.
Source code in src/workbench/utils/meta_endpoint_dag.py
| def set_input_node(self, *nodes: str) -> None:
"""Declare which nodes receive the DAG's input DataFrame directly."""
for n in nodes:
if n not in self._endpoints:
raise ValueError(f"Input nodes must be endpoint nodes; '{n}' is not")
self._input_nodes = list(nodes)
|
Declare the terminal node whose output is the DAG's output.
Source code in src/workbench/utils/meta_endpoint_dag.py
| def set_output_node(self, node: str) -> None:
"""Declare the terminal node whose output is the DAG's output."""
if node not in self._all_nodes():
raise ValueError(f"Unknown output node '{node}'")
self._output_node = node
|
Serialize the DAG topology to a JSON-friendly dict.
Aggregation nodes are serialized by class name + constructor kwargs;
deserialization (:meth:from_dict) requires the same class to be
importable.
Per-endpoint is_async flags are included only if
:meth:populate_async_flags has been called. The deployed
inference container relies on these flags to dispatch invocations
to fast_inference or async_inference.
Source code in src/workbench/utils/meta_endpoint_dag.py
| def to_dict(self) -> dict:
"""Serialize the DAG topology to a JSON-friendly dict.
Aggregation nodes are serialized by class name + constructor kwargs;
deserialization (:meth:`from_dict`) requires the same class to be
importable.
Per-endpoint ``is_async`` flags are included only if
:meth:`populate_async_flags` has been called. The deployed
inference container relies on these flags to dispatch invocations
to ``fast_inference`` or ``async_inference``.
"""
return {
"endpoints": dict(self._endpoints),
"endpoint_async": dict(self._endpoint_async_flags),
"aggregations": [_serialize_aggregation(a) for a in self._aggregations.values()],
"edges": [list(e) for e in self._edges],
"input_nodes": list(self._input_nodes),
"output_node": self._output_node,
}
|
Return nodes in topological order (parents before children).
Raises:
| Type |
Description |
ValueError
|
If the DAG contains a cycle.
|
Source code in src/workbench/utils/meta_endpoint_dag.py
| def topological_order(self) -> List[str]:
"""Return nodes in topological order (parents before children).
Raises:
ValueError: If the DAG contains a cycle.
"""
in_degree = {n: 0 for n in self._all_nodes()}
for _, dst in self._edges:
in_degree[dst] += 1
ready = [n for n, deg in in_degree.items() if deg == 0]
order: List[str] = []
while ready:
node = ready.pop(0)
order.append(node)
for src, dst in self._edges:
if src == node:
in_degree[dst] -= 1
if in_degree[dst] == 0:
ready.append(dst)
if len(order) != len(in_degree):
raise ValueError("DAG contains a cycle")
return order
|
Validate the DAG. Returns self for chaining; raises on failure.
Checks
- At least one input node and exactly one output node declared
- No cycles
- Aggregation nodes have at least one parent
- Endpoint nodes are either input nodes (zero parents) or have
exactly one upstream parent — never both
- The output node is reachable from the input nodes
Source code in src/workbench/utils/meta_endpoint_dag.py
| def validate(self) -> "MetaEndpointDAG":
"""Validate the DAG. Returns self for chaining; raises on failure.
Checks:
- At least one input node and exactly one output node declared
- No cycles
- Aggregation nodes have at least one parent
- Endpoint nodes are either input nodes (zero parents) or have
exactly one upstream parent — never both
- The output node is reachable from the input nodes
"""
if not self._input_nodes:
raise ValueError("DAG has no input nodes")
if self._output_node is None:
raise ValueError("DAG has no output node")
order = self.topological_order() # raises on cycle
for ep_node in self._endpoints:
parents = self._parents_of(ep_node)
is_input = ep_node in self._input_nodes
if is_input and parents:
raise ValueError(
f"Endpoint node '{ep_node}' is declared as an input node but has "
f"upstream parents {parents}; pick one or the other."
)
if not is_input and not parents:
raise ValueError(
f"Endpoint node '{ep_node}' has no upstream parent and is not "
f"declared as an input node — it has no source for its input DataFrame."
)
for name in self._aggregations:
if not self._parents_of(name):
raise ValueError(f"Aggregation node '{name}' has no upstream parents")
reachable = set(self._input_nodes)
for node in order:
if node in reachable:
for src, dst in self._edges:
if src == node:
reachable.add(dst)
if self._output_node not in reachable:
raise ValueError(f"Output node '{self._output_node}' is not reachable from input nodes {self._input_nodes}")
return self
|
Aggregation nodes for MetaEndpointDAG.
An aggregation node combines outputs from one or more upstream nodes
(Endpoint or other AggregationNode instances) into a single DataFrame.
Two broad categories:
-
Column-union aggregators (Concat): join feature outputs from
parallel feature endpoints into a single wide row per id — used for
feature-pipeline DAGs (e.g. [2D] + [3D] → Concat).
-
Prediction aggregators (Mean, WeightedMean, Vote, plus the
ensemble-strategy ports ConfidenceWeighted, InverseMaeWeighted,
ScaledConfidenceWeighted, CalibratedConfidenceWeighted): combine
prediction columns from multiple predictor endpoints into a single
ensemble prediction with confidence — used for ensemble combination.
Each node declares its input/output column contract so the DAG can be
validated statically before any inference runs.
AggregationNode
Base class for DAG aggregation nodes.
Subclasses implement apply() to combine upstream DataFrames and
declare output_columns() for static DAG validation.
Aggregation nodes always join across upstream branches using the
walker-injected :data:DAG_ROW_ID column, so they don't need to
know anything about the caller's id conventions (or whether the
caller has any).
Source code in src/workbench/utils/aggregation_nodes.py
| class AggregationNode:
"""Base class for DAG aggregation nodes.
Subclasses implement ``apply()`` to combine upstream DataFrames and
declare ``output_columns()`` for static DAG validation.
Aggregation nodes always join across upstream branches using the
walker-injected :data:`DAG_ROW_ID` column, so they don't need to
know anything about the caller's id conventions (or whether the
caller has any).
"""
def __init__(self, name: str):
self.name = name
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
"""Combine upstream DataFrames into one. Subclasses must override."""
raise NotImplementedError
def input_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
"""The columns this node expects across all upstream outputs.
Default: union of all upstream output columns. Subclasses can
narrow this if they only consume specific columns.
"""
seen = set()
cols: List[str] = []
for upstream in upstream_outputs:
for c in upstream:
if c not in seen:
seen.add(c)
cols.append(c)
return cols
def output_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
"""The columns this node emits. Subclasses must override."""
raise NotImplementedError
|
apply(upstream)
Combine upstream DataFrames into one. Subclasses must override.
Source code in src/workbench/utils/aggregation_nodes.py
| def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
"""Combine upstream DataFrames into one. Subclasses must override."""
raise NotImplementedError
|
The columns this node expects across all upstream outputs.
Default: union of all upstream output columns. Subclasses can
narrow this if they only consume specific columns.
Source code in src/workbench/utils/aggregation_nodes.py
| def input_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
"""The columns this node expects across all upstream outputs.
Default: union of all upstream output columns. Subclasses can
narrow this if they only consume specific columns.
"""
seen = set()
cols: List[str] = []
for upstream in upstream_outputs:
for c in upstream:
if c not in seen:
seen.add(c)
cols.append(c)
return cols
|
output_columns(upstream_outputs)
The columns this node emits. Subclasses must override.
Source code in src/workbench/utils/aggregation_nodes.py
| def output_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
"""The columns this node emits. Subclasses must override."""
raise NotImplementedError
|
CalibratedConfidenceWeighted
Bases: _StrategyAggregator
Per-row weights = confidence × |conf-error correlation| (normalized).
Rewards models whose confidence actually predicts accuracy.
Source code in src/workbench/utils/aggregation_nodes.py
| class CalibratedConfidenceWeighted(_StrategyAggregator):
"""Per-row weights = ``confidence × |conf-error correlation|`` (normalized).
Rewards models whose confidence actually predicts accuracy.
"""
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
self._check_arity(upstream)
ids, preds, confs = self._stack(upstream)
calibrated = confs * self.corr_scale
weights = conf_weights_with_fallback(calibrated, self.model_weights)
return self._build_output(
upstream,
ids,
prediction=(preds * weights).sum(axis=1),
prediction_std=preds.std(axis=1),
confidence=ensemble_confidence(preds, confs, self.corr_scale, self.model_weights, self.optimal_alpha),
)
|
Concat
Bases: AggregationNode
Column-union aggregator. Joins upstream DataFrames on the walker's
synthetic row id.
Use for feature-pipeline DAGs where parallel feature endpoints
contribute disjoint feature column sets that need to be merged into a
single wide row.
Source code in src/workbench/utils/aggregation_nodes.py
| class Concat(AggregationNode):
"""Column-union aggregator. Joins upstream DataFrames on the walker's
synthetic row id.
Use for feature-pipeline DAGs where parallel feature endpoints
contribute disjoint feature column sets that need to be merged into a
single wide row.
"""
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
if not upstream:
raise ValueError(f"Concat[{self.name}]: requires at least one upstream DataFrame")
out = upstream[0]
for df in upstream[1:]:
new_cols = [c for c in df.columns if c == DAG_ROW_ID or c not in out.columns]
out = out.merge(df[new_cols], on=DAG_ROW_ID, how="inner")
return out
def output_columns(self, upstream_outputs: List[List[str]]) -> List[str]:
seen = set()
cols: List[str] = []
for upstream in upstream_outputs:
for c in upstream:
if c not in seen:
seen.add(c)
cols.append(c)
return cols
|
ConfidenceWeighted
Bases: _StrategyAggregator
Per-row weights = upstream confidences (normalized).
Falls back to static model_weights when row confidences sum to ~0.
Source code in src/workbench/utils/aggregation_nodes.py
| class ConfidenceWeighted(_StrategyAggregator):
"""Per-row weights = upstream confidences (normalized).
Falls back to static ``model_weights`` when row confidences sum to ~0.
"""
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
self._check_arity(upstream)
ids, preds, confs = self._stack(upstream)
weights = conf_weights_with_fallback(confs, self.model_weights)
return self._build_output(
upstream,
ids,
prediction=(preds * weights).sum(axis=1),
prediction_std=preds.std(axis=1),
confidence=ensemble_confidence(preds, confs, self.corr_scale, self.model_weights, self.optimal_alpha),
)
|
InverseMaeWeighted
Bases: _StrategyAggregator
Static per-model weights from inverse-MAE.
The caller passes the inverse-MAE-derived weights directly via
model_weights. Identical to WeightedMean for the prediction
column, but additionally computes calibrated ensemble confidence.
Source code in src/workbench/utils/aggregation_nodes.py
| class InverseMaeWeighted(_StrategyAggregator):
"""Static per-model weights from inverse-MAE.
The caller passes the inverse-MAE-derived weights directly via
``model_weights``. Identical to ``WeightedMean`` for the prediction
column, but additionally computes calibrated ensemble confidence.
"""
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
self._check_arity(upstream)
ids, preds, confs = self._stack(upstream)
return self._build_output(
upstream,
ids,
prediction=(preds * self.model_weights).sum(axis=1),
prediction_std=preds.std(axis=1),
confidence=ensemble_confidence(preds, confs, self.corr_scale, self.model_weights, self.optimal_alpha),
)
|
Mean
Bases: _PredictionAggregator
Simple equal-weight mean of predictions.
Source code in src/workbench/utils/aggregation_nodes.py
| class Mean(_PredictionAggregator):
"""Simple equal-weight mean of predictions."""
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
ids, preds, confs = self._stack(upstream)
return self._build_output(
upstream,
ids,
prediction=preds.mean(axis=1),
prediction_std=preds.std(axis=1),
confidence=confs.mean(axis=1),
)
|
ScaledConfidenceWeighted
Bases: _StrategyAggregator
Per-row weights = model_weights × confidence (normalized).
Often the top performer in practice — combines static MAE-derived
weighting with per-row confidence scaling.
Source code in src/workbench/utils/aggregation_nodes.py
| class ScaledConfidenceWeighted(_StrategyAggregator):
"""Per-row weights = ``model_weights × confidence`` (normalized).
Often the top performer in practice — combines static MAE-derived
weighting with per-row confidence scaling.
"""
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
self._check_arity(upstream)
ids, preds, confs = self._stack(upstream)
scaled = confs * self.model_weights
weights = conf_weights_with_fallback(scaled, self.model_weights)
return self._build_output(
upstream,
ids,
prediction=(preds * weights).sum(axis=1),
prediction_std=preds.std(axis=1),
confidence=ensemble_confidence(preds, confs, self.corr_scale, self.model_weights, self.optimal_alpha),
)
|
Vote
Bases: _PredictionAggregator
Majority-vote aggregator for classifier predictions.
Expects each upstream's prediction column to hold class labels
(string or int). Output prediction is the most common label per
row; prediction_std is 0 (placeholder for contract symmetry);
confidence is the fraction of upstream models that voted for the
winning label.
Source code in src/workbench/utils/aggregation_nodes.py
| class Vote(_PredictionAggregator):
"""Majority-vote aggregator for classifier predictions.
Expects each upstream's ``prediction`` column to hold class labels
(string or int). Output ``prediction`` is the most common label per
row; ``prediction_std`` is 0 (placeholder for contract symmetry);
``confidence`` is the fraction of upstream models that voted for the
winning label.
"""
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
if not upstream:
raise ValueError(f"Vote[{self.name}]: requires at least one upstream DataFrame")
ids = upstream[0][[DAG_ROW_ID]].copy()
for df in upstream[1:]:
ids = ids.merge(df[[DAG_ROW_ID]], on=DAG_ROW_ID, how="inner")
labels = pd.concat(
[
ids.merge(df[[DAG_ROW_ID, "prediction"]], on=DAG_ROW_ID)["prediction"].rename(f"_p{i}")
for i, df in enumerate(upstream)
],
axis=1,
)
modes = labels.mode(axis=1)[0]
winner_share = (labels.eq(modes, axis=0)).sum(axis=1) / labels.shape[1]
return self._build_output(
upstream,
ids,
prediction=modes.to_numpy(),
prediction_std=0.0,
confidence=winner_share.to_numpy(),
)
|
WeightedMean
Bases: _PredictionAggregator
Static-weight mean — caller supplies one weight per upstream.
Source code in src/workbench/utils/aggregation_nodes.py
| class WeightedMean(_PredictionAggregator):
"""Static-weight mean — caller supplies one weight per upstream."""
def __init__(self, name: str, weights: List[float]):
super().__init__(name)
if not weights:
raise ValueError("WeightedMean: weights must be a non-empty list")
w = np.asarray(weights, dtype=np.float64)
if (w < 0).any():
raise ValueError("WeightedMean: weights must be non-negative")
if w.sum() <= 0:
raise ValueError("WeightedMean: at least one weight must be positive")
self.weights = w / w.sum()
def apply(self, upstream: List[pd.DataFrame]) -> pd.DataFrame:
if len(upstream) != len(self.weights):
raise ValueError(
f"WeightedMean[{self.name}]: got {len(upstream)} upstream frames " f"but {len(self.weights)} weights"
)
ids, preds, confs = self._stack(upstream)
return self._build_output(
upstream,
ids,
prediction=(preds * self.weights).sum(axis=1),
prediction_std=preds.std(axis=1),
confidence=(confs * self.weights).sum(axis=1),
)
|
Questions?

The SuperCowPowers team is happy to answer any questions you may have about AWS and Workbench.