Remote GPU Runtime | Phase 2

Cloud GPUs that feel operationally local

Sahasra is Pravina Lab's JAX-first runtime for sending compiled graphs to remote workers, keeping tensors remote when possible, and making placement decisions visible instead of magical.

Control Plane View

Scheduler state that stays legible

Placement

Least-loaded

Healthy workers are ranked by active executions, stored tensor bytes, and tensor count.

Lifecycle

Session affinity

Once a session lands, it stays pinned. Worker loss is explicit instead of hidden.

GET /workers

[
  {
    "id": "worker-a",
    "status": "healthy",
    "active_executions": 0,
    "tensor_count": 0,
    "stored_tensor_bytes": 0
  },
  {
    "id": "worker-b",
    "status": "healthy",
    "active_executions": 1,
    "tensor_count": 150,
    "stored_tensor_bytes": 5089200
  }
]
Today's target is scheduler correctness and runtime visibility, not distributed training abstractions. Sahasra is building the control plane first so the hard behavior is understandable before scale is added.

For the developer

Cloud GPU access that does not get in your way

Most GPU runtimes ask you to rewrite your training loop, manage infrastructure, or accept a black box. Sahasra adds remote execution on top of the JAX code you already wrote — you stay in control of the code, Sahasra handles the placement.

Your JAX code stays the same

Add three lines — a client, a session, and an execute call. Your @jax.jit functions are untouched. Sahasra lowers the graph locally and ships it to the remote worker.

Estimate before you commit

Call estimate_jitted() first. You get back whether the workload is feasible on your target GPU class before burning any real compute.

Tensors stay on the worker

Multi-step training reuses remote tensor handles. Parameters do not round-trip every step — they stay resident on the worker until you close the session.

Failures are never silent

If your bound worker goes offline, Sahasra marks the session failed with a clear error. No mysterious state, no hidden migrations, no wondering what happened.

Without Sahasra
import jax
import jax.numpy as jnp

@jax.jit
def train_step(x):
    return jnp.tanh(x @ x)

x = jnp.ones((1024, 1024), dtype=jnp.float32)
y = train_step(x)
print(y.shape)
With Sahasra
import jax
import jax.numpy as jnp
from sahasra import RemoteJaxSession, SahasraClient

client = SahasraClient("http://your-control-plane")
session = RemoteJaxSession.create(client, gpu_class="g5")

@jax.jit
def train_step(x):
    return jnp.tanh(x @ x)

x = jnp.ones((1024, 1024), dtype=jnp.float32)
execution = session.execute_jitted(train_step, sample_inputs=(x,))
print(execution.runtime_mode)  # jax_export_local

The function body does not change. Sahasra wraps the execution path, not your model logic.

Validated results — run locally end to end

100

Remote training steps

across 4 workers in one run

93.75%

Step accuracy

on a real JAX MLP neural-net demo

4 workers

Zero tensor leak

all sessions cleaned to zero on close

~5 lines

To switch

from local JAX to running remotely

Why Sahasra

A runtime story built around visibility, not hand-waving

JAX-first remote execution

Lower local jitted functions, ship graph payloads through the control plane, and execute exported JAX programs behind a standalone worker boundary.

Visible scheduler state

Workers self-register, heartbeat their live load, and expose tensor pressure, active executions, and last-seen health through the control plane.

Remote tensor reuse

Repeated steps can feed remote tensor handles back into later executions so parameters stay resident instead of bouncing across the wire.

Explicit failure semantics

Sessions stay bound to one worker. If that worker goes offline, Sahasra fails clearly instead of silently migrating state and hiding runtime risk.

Runtime Flow

Control plane loop

Sahasra is deliberately opinionated: keep scheduling rules simple, keep worker state observable, and keep failures explicit.

01

Create a session

The client requests a JAX session, and the control plane binds it to the lowest-load healthy worker.

02

Lower the graph locally

Your local machine lowers a jitted JAX function and packages StableHLO-style metadata plus serialized export artifacts when available.

03

Execute through the worker

The API forwards the graph to the bound worker, which runs it locally and stores outputs in the worker tensor store.

04

Reuse or close cleanly

Subsequent steps can reuse remote tensor handles. Closing the session deletes worker-owned tensors and returns the runtime to zero.

Phase 2 outcome

Minimal coordinator, real behavior

  • Workers self-register and continuously report heartbeat load state.
  • The API chooses the least-loaded healthy worker, preferring idle workers first.
  • Bound sessions fail clearly if their worker disappears; there is no silent tensor migration.
  • Closing a session deletes worker-owned tensors and returns runtime pressure to zero.

Roadmap

Built to evolve without rewriting the contract

Now: scheduler correctness, heartbeat health, session affinity, and local multi-worker validation.
Next: EC2-hosted GPU execution without changing the session semantics already proven locally.
Later: broader runtime portability beyond JAX export artifacts, without sacrificing observability.
Developer Flow

Drop-in JAX code, remote runtime semantics

The developer experience is intentionally direct: create a session, estimate a jitted function, execute it, and optionally keep the resulting tensors remote for the next step.

What you get right now

Session creation, graph estimation, remote execution requests, remote tensor handles, worker heartbeats, least-loaded placement, cleanup, and offline-worker failure behavior.

What it is not pretending to be yet

Sahasra is not claiming cross-worker tensor migration or finished distributed training. The current value is a clean, inspectable runtime contract that already works with real local workers.

JAX SDK sample
import jax
import jax.numpy as jnp

from sahasra import RemoteJaxSession, SahasraClient

client = SahasraClient("http://127.0.0.1:8000")
session = RemoteJaxSession.create(client, gpu_class="g5")

@jax.jit
def train_step(x):
    return jnp.tanh(x @ x)

x = jnp.ones((1024, 1024), dtype=jnp.float32)

estimate = session.estimate_jitted(train_step, sample_inputs=(x,))
execution = session.execute_jitted(train_step, sample_inputs=(x,))

print(estimate.feasible, execution.runtime_mode, len(execution.output_tensors))

Contact

Need a runtime layer that stays understandable under load?

Sahasra is being built as infrastructure for serious ML iteration: visible placement, remote execution, explicit failure, and a clean path from local development to real workers.

Talk to Pravina Lab