Multi-Agent DAG Orchestration

Building on Example 01, this example demonstrates a multi-agent pipeline with a directed acyclic graph (DAG) topology. Multiple agents are wired together so upstream results automatically flow into downstream agents’ context. It also shows function nodes (pure Python execution without LLM) and all three tool registration styles.

Prerequisites: Read 01_single_agent.py first.

What you’ll learn:

  • How to build a multi-agent DAG with multiple upstream dependencies per node

  • How to wire results from upstream agents into downstream agent context

  • How to create function nodes (pure Python, no LLM) using PipelineNode(fn=...)

  • All three tool registration styles: direct callable, async decorator, sync decorator

  • How downstream agents automatically receive upstream results

Architecture:

Four agents form a linear pipeline (planner → runner → analyzer → reporter), plus a pure Python function node to save the final report.

Main Code

Below is the complete example:

Listing 28 02_multi_agent_dag.py: Multi-agent DAG with function nodes
  1"""02 — Multi-Agent DAG + Function Nodes + Tool Registration Styles.
  2
  3**Prerequisites:** Read ``01_single_agent.py`` first.
  4
  5**What's new:**
  6
  7* **Multi-agent DAG** — four agents wired as a directed acyclic graph
  8  (planner → runner → analyzer → reporter)
  9* **Function node** — ``save_report`` runs as plain Python in Dragon Batch
 10  (no LLM), demonstrating ``PipelineNode(fn=...)``
 11* **All three tool registration styles:**
 12
 13  1. ``registry.register(fn)`` — direct callable (sync)
 14  2. ``@registry.tool`` decorator — inline async tool
 15  3. ``@registry.tool`` decorator — inline sync tool
 16
 17* **Async tools** — ``analyze_convergence`` is an async function
 18* **:param: annotations** — auto-extracted into JSON schemas for the LLM
 19
 20Architecture::
 21
 22    planner_agent ──► runner_agent ──► analyzer_agent ──┐
 23         │                │                             ├──► reporter_agent
 24         └────────────────┴─────────────────────────────┘         │
 25                                                          save_report (fn)
 26
 27Usage::
 28
 29    dragon 02_multi_agent_dag.py
 30"""
 31
 32import asyncio
 33import json
 34import math
 35import os
 36import time
 37from typing import Any
 38
 39import dragon
 40import multiprocessing as mp
 41
 42from dragon.ai.agent.core import create_sub_agent
 43from dragon.ai.agent.config import (
 44    AgentConfig,
 45    OrchestratorConfig,
 46    Pipeline,
 47    PipelineNode,
 48    TaskResult,
 49    TaskStatus,
 50    DISPATCH_ID_KEY,
 51    RESULT_KEY,
 52    STATUS_KEY,
 53)
 54from dragon.ai.agent.tools import ToolRegistry
 55from dragon.ai.agent.orchestrator import DAGOrchestrator
 56from dragon.data.ddict import DDict
 57from dragon.infrastructure.policy import Policy
 58from dragon.native.event import Event
 59from dragon.native.machine import Node, System
 60from dragon.native.process import Process
 61from dragon.native.queue import Queue
 62from dragon.workflows.batch import Batch
 63
 64from dragon.ai.inference.config import (
 65    BatchingConfig,
 66    HardwareConfig,
 67    InferenceConfig,
 68    ModelConfig,
 69)
 70from dragon.ai.inference.inference_utils import Inference
 71
 72# ---------------------------------------------------------------------------
 73# Style 1: Import existing sync tools, register via registry.register(fn)
 74# ---------------------------------------------------------------------------
 75from tools import propose_experiment
 76from tools.runner import (
 77    launch_experiment,
 78    check_progress,
 79    collect_results,
 80    cleanup_experiment_state,
 81)
 82
 83
 84# ===========================================================================
 85# User-configurable constants
 86# ===========================================================================
 87
 88MODEL_NAME = "/path/to/your/model"
 89HF_TOKEN = ""
 90
 91
 92# ===========================================================================
 93# Inference Pipeline Configuration
 94# ===========================================================================
 95
 96INFERENCE_CONFIG = InferenceConfig(
 97    model=ModelConfig(
 98        model_name=MODEL_NAME,
 99        hf_token=HF_TOKEN,
100        tp_size=2,
101        max_tokens=8192,
102        max_model_len=32768,
103    ),
104    hardware=HardwareConfig(
105        num_nodes=1,
106        num_gpus=2,
107        num_inf_workers_per_cpu=1,
108    ),
109    batching=BatchingConfig(
110        batch_wait_seconds=0.1,
111        max_batch_size=32,
112    ),
113)
114
115
116# ===========================================================================
117# Tool registries — demonstrating all registration styles
118# ===========================================================================
119
120# --- Style 1: registry.register(fn) — sync tools from tools/ package ------
121planner_registry = ToolRegistry()
122planner_registry.register(propose_experiment)
123
124runner_registry = ToolRegistry()
125runner_registry.register(launch_experiment)
126runner_registry.register(check_progress)
127runner_registry.register(collect_results)
128
129# --- Style 2: @registry.tool — inline async tool --------------------------
130# The decorator detects async callables automatically.  :param: annotations
131# are extracted into the JSON schema the LLM receives.
132analyzer_registry = ToolRegistry()
133
134
135@analyzer_registry.tool
136async def analyze_convergence(results: list) -> dict:
137    """Analyse Monte Carlo convergence from experiment results.
138
139    Computes convergence rate (slope of log-log fit) and determines
140    whether the experiment meets the theoretical 1/sqrt(n) rate.
141
142    :param results: List of dicts with keys 'n_samples' and 'absolute_error'.
143    """
144    await asyncio.sleep(0.01)  # simulate async I/O
145
146    if not results or len(results) < 2:
147        return {"error": "Need at least 2 data points."}
148
149    sorted_r = sorted(results, key=lambda r: r["n_samples"])
150    errors = [r["absolute_error"] for r in sorted_r]
151    samples = [r["n_samples"] for r in sorted_r]
152
153    log_n = [math.log10(n) for n in samples]
154    log_e = [math.log10(max(e, 1e-12)) for e in errors]
155    slope = (log_e[-1] - log_e[0]) / (log_n[-1] - log_n[0])
156
157    return {
158        "convergence_rate": round(slope, 4),
159        "expected_rate": -0.5,
160        "best_error": min(errors),
161        "worst_error": max(errors),
162        "n_experiments": len(results),
163        "meets_theory": abs(slope - (-0.5)) < 0.15,
164    }
165
166
167# --- Style 3: @registry.tool — inline sync tool ---------------------------
168reporter_registry = ToolRegistry()
169
170
171@reporter_registry.tool
172def format_results_table(
173    results: list,
174    convergence_info: dict,
175    plan_summary: str,
176) -> dict:
177    """Format experiment results into a Markdown table for the final report.
178
179    :param results: List of result dicts from collect_results.
180    :param convergence_info: Dict from analyze_convergence.
181    :param plan_summary: One-line summary of the experiment plan.
182    """
183    lines = [
184        f"## Monte Carlo π Estimation — {plan_summary}",
185        "",
186        "| Samples | π Estimate | Abs Error | Wall Time (s) | Node |",
187        "|--------:|-----------:|----------:|--------------:|------|",
188    ]
189    for r in sorted(results, key=lambda x: x.get("n_samples", 0)):
190        lines.append(
191            f"| {r.get('n_samples', '?'):>7,} "
192            f"| {r.get('pi_estimate', 0):.6f} "
193            f"| {r.get('absolute_error', 0):.6f} "
194            f"| {r.get('wall_time_s', 0):>13.3f} "
195            f"| {r.get('hostname', '?')} |"
196        )
197    conv = convergence_info or {}
198    lines += [
199        "",
200        f"**Convergence rate:** {conv.get('convergence_rate', 'N/A')} "
201        f"(expected ≈ {conv.get('expected_rate', -0.5)})",
202        f"**Meets theoretical rate:** {conv.get('meets_theory', 'N/A')}",
203    ]
204    return {"markdown_table": "\n".join(lines)}
205
206
207# ===========================================================================
208# Function node: save_report
209#
210# A PipelineNode with fn= runs as a plain Python function in Dragon Batch.
211# No LLM, no agent queue.  Receives TaskResult tokens from upstream nodes,
212# reads from DDict, writes files, returns a TaskResult.
213# ===========================================================================
214
215REPORT_DIR = os.environ.get("DRAGON_REPORT_DIR", os.getcwd())
216
217
218def save_report(*upstreams: TaskResult) -> TaskResult:
219    """Write the reporter's output to disk as Markdown and JSON."""
220    upstream = upstreams[0]
221    task_id = upstream.task_id
222    serialized_ddict = upstream.serialized_ddict
223
224    print(f"\n[save_report] Function node started (task_id={task_id[:8]}...)",
225          flush=True)
226
227    ddict = DDict.attach(serialized_ddict)
228    try:
229        reporter_dispatch_key = DISPATCH_ID_KEY.format(
230            task_id=task_id, agent_id="reporter_agent"
231        )
232        reporter_dispatch_id = ddict[reporter_dispatch_key]
233        reporter_result_key = RESULT_KEY.format(
234            task_id=task_id,
235            agent_id="reporter_agent",
236            dispatch_id=reporter_dispatch_id,
237        )
238        reporter_result = ddict[reporter_result_key]
239
240        report_text = (
241            reporter_result.get("response", str(reporter_result))
242            if isinstance(reporter_result, dict)
243            else str(reporter_result)
244        )
245
246        md_path = os.path.join(REPORT_DIR, "monte_carlo_report.md")
247        with open(md_path, "w") as f:
248            f.write(report_text)
249        print(f"[save_report] Written: {md_path}", flush=True)
250
251        json_path = os.path.join(REPORT_DIR, "monte_carlo_report.json")
252        artifact = {
253            "task_id": task_id,
254            "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S%z"),
255            "report": report_text,
256            "source_agent": "reporter_agent",
257        }
258        with open(json_path, "w") as f:
259            json.dump(artifact, f, indent=2)
260        print(f"[save_report] Written: {json_path}", flush=True)
261
262        own_dispatch_id = f"fn-save-report-{task_id[:8]}"
263        ddict[DISPATCH_ID_KEY.format(task_id=task_id, agent_id="save_report")] = own_dispatch_id
264        ddict[RESULT_KEY.format(task_id=task_id, agent_id="save_report", dispatch_id=own_dispatch_id)] = {
265            "response": f"Report saved to:\n  - {md_path}\n  - {json_path}"
266        }
267        ddict[STATUS_KEY.format(task_id=task_id, agent_id="save_report", dispatch_id=own_dispatch_id)] = TaskStatus.DONE
268    finally:
269        ddict.detach()
270
271    print("[save_report] Function node complete.\n", flush=True)
272    return TaskResult(
273        task_id=task_id,
274        agent_id="save_report",
275        status=TaskStatus.DONE,
276        serialized_ddict=serialized_ddict,
277    )
278
279
280# ===========================================================================
281# DAG pipeline
282# ===========================================================================
283
284pipeline = Pipeline(nodes=[
285    PipelineNode(
286        agent_id="planner_agent",
287        task_description=(
288            "You are a scientific experiment planner.  The user wants to study "
289            "Monte Carlo convergence for estimating π.\n\n"
290            "Propose an experiment plan by calling propose_experiment with:\n"
291            "  - description, sample_sizes, convergence_target, methodology\n\n"
292            "Report the approved plan verbatim as your final answer."
293        ),
294        depends_on=[],
295    ),
296    PipelineNode(
297        agent_id="runner_agent",
298        task_description=(
299            "You manage parallel Monte Carlo simulations on an HPC cluster.\n\n"
300            "Tools:\n"
301            "  1. launch_experiment(sample_sizes, seeds) — launches ALL "
302            "simulations in parallel.\n"
303            "  2. check_progress() — shows done/running/pending status.\n"
304            "  3. collect_results() — call ONLY when all_done=true.\n\n"
305            "STRICT workflow:\n"
306            "  1. Call launch_experiment.\n"
307            "  2. Call check_progress until all_done=true.\n"
308            "  3. Call collect_results — REQUIRED before final answer.\n"
309            "  4. Report collect_results output verbatim."
310        ),
311        depends_on=["planner_agent"],
312    ),
313    PipelineNode(
314        agent_id="analyzer_agent",
315        task_description=(
316            "Call analyze_convergence with a list of dicts (keys: "
317            "'n_samples', 'absolute_error').  Report all metrics verbatim."
318        ),
319        depends_on=["runner_agent"],
320    ),
321    PipelineNode(
322        agent_id="reporter_agent",
323        task_description=(
324            "Write a structured report.\n\n"
325            "STRICT workflow:\n"
326            "  1. Call format_results_table to get a Markdown table.\n"
327            "  2. COPY-PASTE the actual table into your final answer.\n"
328            "  3. Include sections: Plan, Results Table, Parallel Execution,\n"
329            "     Convergence Analysis, Quality Assessment, Recommendations.\n\n"
330            "Never invent data — use only data from upstream agents."
331        ),
332        depends_on=["planner_agent", "runner_agent", "analyzer_agent"],
333    ),
334    # Function node — no LLM, runs as plain Python
335    PipelineNode(
336        agent_id="save_report",
337        fn=save_report,
338        depends_on=["reporter_agent"],
339    ),
340])
341
342
343# ===========================================================================
344# Helpers
345# ===========================================================================
346
347def _make_agent_kwargs(agent_id, name, role, registry, inference_queue,
348                       max_tool_call_iterations=20):
349    return {
350        "config": AgentConfig(
351            agent_id=agent_id, name=name, role=role,
352            inference_queue=inference_queue,
353            max_tool_call_iterations=max_tool_call_iterations,
354        ),
355        "tool_registry": registry,
356        "shutdown_event": Event(),
357        "reply_queue": Queue(),
358    }
359
360
361def _start_agents(specs, policies):
362    procs = []
363    for spec, policy in zip(specs, policies):
364        p = Process(target=create_sub_agent, kwargs=spec, policy=policy)
365        p.start()
366        procs.append(p)
367    queues = {}
368    for spec in specs:
369        aid = spec["config"].agent_id
370        queues[aid] = spec["reply_queue"].get()
371        print(f"[startup] Agent '{aid}' ready.", flush=True)
372    return procs, queues
373
374
375# ===========================================================================
376# Main
377# ===========================================================================
378
379async def main():
380    input_queue = Queue()
381
382    print("[startup] Initializing inference pipeline...", flush=True)
383
384    inference_pipeline = None
385    try:
386        inference_pipeline = Inference(INFERENCE_CONFIG, input_queue)
387        inference_pipeline.initialize()
388    except Exception as exc:
389        import traceback
390        print(f"\n[FATAL] Inference pipeline failed to initialize: {exc}", flush=True)
391        traceback.print_exc()
392        if inference_pipeline is not None:
393            inference_pipeline.destroy()
394        return
395    print("[startup] Inference pipeline ready.\n", flush=True)
396
397    my_alloc = System()
398    node_list = my_alloc.nodes
399    compute_host = (
400        Node(node_list[1]).hostname if len(node_list) > 1
401        else Node(node_list[0]).hostname
402    )
403    compute_policy = Policy(
404        placement=Policy.Placement.HOST_NAME,
405        host_name=compute_host,
406    )
407
408    procs, agent_specs = [], []
409    try:
410        agent_specs = [
411            _make_agent_kwargs(
412                "planner_agent", "Experiment Planner",
413                "You are an experiment planner for Monte Carlo convergence "
414                "studies on an HPC cluster.",
415                planner_registry, input_queue,
416            ),
417            _make_agent_kwargs(
418                "runner_agent", "Parallel Simulation Runner",
419                "You manage parallel Monte Carlo simulations.\n"
420                "You MUST call all three tools in order: "
421                "launch_experiment → check_progress → collect_results.\n"
422                "NEVER give a final answer without calling collect_results first.",
423                runner_registry, input_queue,
424                max_tool_call_iterations=60,
425            ),
426            _make_agent_kwargs(
427                "analyzer_agent", "Convergence Analyzer",
428                "Analyse convergence.  Call analyze_convergence with "
429                "n_samples + absolute_error dicts.  Report verbatim.",
430                analyzer_registry, input_queue,
431            ),
432            _make_agent_kwargs(
433                "reporter_agent", "Report Writer",
434                "Write a structured report with Markdown tables.  "
435                "Always include tool output verbatim — "
436                "never use placeholder variables.",
437                reporter_registry, input_queue,
438            ),
439        ]
440        policies = [None, compute_policy, None, None]
441
442        procs, queues = _start_agents(agent_specs, policies)
443        for spec in agent_specs:
444            spec["config"].input_queue = queues[spec["config"].agent_id]
445
446        orchestrator = DAGOrchestrator(
447            config=OrchestratorConfig(
448                agents=[s["config"] for s in agent_specs],
449                poll_interval=0.5,
450                poll_timeout=14400.0,
451            ),
452            pipeline=pipeline,
453        )
454
455        user_input = (
456            "Design and run a Monte Carlo convergence experiment for π "
457            "and save the report to disk."
458        )
459
460        batch = Batch()
461        try:
462            print("=" * 60, flush=True)
463            print("Dragon AI — 02 Multi-Agent DAG + Function Node", flush=True)
464            print("=" * 60, flush=True)
465            print(f"Request: {user_input}\n", flush=True)
466
467            result = orchestrator.run(
468                user_input=user_input,
469                batch=batch,
470            )
471
472            print("\n" + "=" * 60, flush=True)
473            print("FINAL RESULT", flush=True)
474            print("=" * 60, flush=True)
475            print(result, flush=True)
476
477        except Exception as exc:
478            import traceback
479            print(f"\n[error] Pipeline failed: {exc}", flush=True)
480            traceback.print_exc()
481        finally:
482            orchestrator.destroy()
483            batch.join()
484            batch.destroy()
485
486    except Exception as exc:
487        import traceback
488        print(f"\n[error] Fatal: {exc}", flush=True)
489        traceback.print_exc()
490    finally:
491        cleanup_experiment_state()
492        for spec in agent_specs:
493            try:
494                spec["shutdown_event"].set()
495            except Exception:
496                pass
497        for p in procs:
498            try:
499                p.join()
500            except Exception:
501                pass
502        print("\n[teardown] All agents stopped.", flush=True)
503        try:
504            inference_pipeline.destroy()
505        except Exception:
506            pass
507        print("[teardown] Inference pipeline stopped.", flush=True)
508
509
510if __name__ == "__main__":
511    mp.set_start_method("dragon")
512    asyncio.run(main())

Key Concepts

DAG Topology:

Each PipelineNode can list multiple upstream agent IDs. The framework waits for all upstream nodes to complete, then passes their results to the downstream agent’s context automatically.

Function Nodes:

A PipelineNode(fn=...) runs pure Python (no LLM). Useful for:

  • Aggregating/reporting results

  • Data transformation

  • Writing to disk or database

  • Any non-agentic computation

All Three Registration Styles:

  1. Direct callable: registry.register(existing_function)

  2. Async decorator: @registry.tool async def ...

  3. Sync decorator: @registry.tool def ...

All generate JSON schemas automatically from parameter annotations.

Installation

See Example 01 (same dependencies).

System Description

Tested on HPE Cray EX:

  • 1–2 nodes sufficient

  • Main vLLM backend: 1–2 GPUs

  • Agents: any available CPU

How to Run

dragon 02_multi_agent_dag.py

Example output:

$ dragon 02_multi_agent_dag.py
Planner: Proposing strategy...
Runner: Executing strategy...
Analyzer: Reviewing execution...
Reporter: Generating final report...
Report saved to: report.txt

Next Steps

  • 03 — Human-in-the-Loop (add approval gates before tool execution)