Batched compute with run_many#

Feedback wanted

run_many, _run_batch, and BatchProtocolError live on main only — not in the latest PyPI release. To try the examples below, install from git: pip install "git+https://github.com/facebookresearch/exca".

The underlying semantics (per-item caching, _run_batch, item_uid) are settling. Comments via GitHub issues are welcome.

What run_many adds#

Step.run(value) runs the step on a single input. Use run_many to run the same configuration over many inputs, with one cache entry per (step, input) pair:

step = Multiply(coeff=2.0, infra={"backend": "Cached", "folder": cache})

# Single input
step.run(5.0)                       # 10.0

# Many inputs — same configuration, one cache entry per input
for r in step.run_many([1.0, 2.0, 3.0]):
    print(r)                        # 2.0, then 4.0, then 6.0

step.run_many(...) returns a StepItems iterator that yields one result per input, in input order — iterate it rather than list()-ing it so you never hold the whole result set in memory. (Without a backend, results compute lazily as you pull; with one, see Distribution across workers.) Re-running with overlapping inputs reuses the cache entries from the previous call. run(value) is sugar over run_many([value]), returning the single result.

Per-input identity: item_uid#

By default, cache keys come from the input value via Exca’s uid machinery (exca.confdict.UidMaker). For inputs the default can’t key reliably — typically arrays or other large / unhashable objects — override item_uid(value) on the step to return a stable string. A content hash is a safe default:

import hashlib

import numpy as np


class Embed(steps.Step):
    def item_uid(self, value: np.ndarray) -> str:
        return hashlib.sha256(value.tobytes()).hexdigest()

    def _run(self, value: np.ndarray) -> np.ndarray:
        return embed(value)

item_uid is consulted once at chain entry on the caller-provided value, and the result is propagated unchanged to every sub-step. This is what makes downstream cache lookups lazy: a downstream cache hit can short-circuit the whole pipeline without running upstream steps. Per-input identity that depends on an upstream’s output is therefore not supported (yet).

Long item_uids are truncated to 256 characters (Step._ITEM_UID_MAX_LENGTH) to keep on-disk paths sane; truncation preserves identity via a hashed middle section.

Vectorised compute: _run_batch#

Override _run_batch instead of _run when per-input cost is dominated by setup that should amortise across the batch (model load, GPU transfer, …):

class Embed(steps.Step):
    model_path: str

    def _run_batch(self, values):
        model = load_model(self.model_path)        # loaded once per call
        for v in values:
            yield model(v)                         # in order, 1 per input

_run_batch must yield exactly one result per input, in order. The framework validates this and raises BatchProtocolError on under- or over-yield. A partial-batch error annotates the exception with the uids consumed-but-not-yielded so you can see which items were in flight when it raised.

A single-value call and a batched call share the same cache: if step.run(v) writes uid X, next(iter(step.run_many([v]))) reads X back.

Distribution across workers#

When infra.backend is Slurm, LocalProcess, ProcessPool, or ThreadPool, the backend splits items across workers. The distribution is set by max_jobs and min_items_per_job:

step = Embed(
    model_path="...",
    infra={
        "backend": "Slurm",
        "folder": cache,
        "max_jobs": 16,
        "min_items_per_job": 4,
        "gpus_per_node": 1,
    },
)
for embedding in step.run_many(paths):           # M items → up to 16 jobs
    print(embedding.shape)                         # read back in input order

Each worker runs _run_batch on its sub-batch and writes results to the shared cache. With a distributed backend the driver waits for all jobs, then the returned iterator reads results from the cache one at a time (never the full set in memory, never round-tripped through the job pickle). Execution order within a batch is non-deterministic; output order matches input order.

What’s stable#

Pinned by tests — safe to rely on:

  • One cache entry per (step, uid). No fan-out, no filtering, no reordering.

  • Output ordering preserved. A StepItems iterator yields in input order, even if execution was unordered.

  • Duplicate uids preserved. If you pass [v, v, v], you get three results; the cache is hit once on disk.

  • Single-value and batched calls share cache entries.

  • _run_batch is streaming with fail-fast caching. Partial results are cached up to the point of failure.

  • Errors are cached and re-raised on next call until cleared or retry-d (same as the single-input case).

  • Chain dispatches per step. Each step in a chain submits its own batch through its own backend. Without a chain-level infra, sub-step budgets (e.g. max_jobs) are independent; with one, the whole chain becomes a single job.