Cloud GPUs that feel operationally local
Sahasra is Pravina Lab's JAX-first runtime for sending compiled graphs to remote workers, keeping tensors remote when possible, and making placement decisions visible instead of magical.
Control Plane View
Scheduler state that stays legible
Placement
Least-loaded
Healthy workers are ranked by active executions, stored tensor bytes, and tensor count.
Lifecycle
Session affinity
Once a session lands, it stays pinned. Worker loss is explicit instead of hidden.
GET /workers
[
{
"id": "worker-a",
"status": "healthy",
"active_executions": 0,
"tensor_count": 0,
"stored_tensor_bytes": 0
},
{
"id": "worker-b",
"status": "healthy",
"active_executions": 1,
"tensor_count": 150,
"stored_tensor_bytes": 5089200
}
]For the developer
Cloud GPU access that does not get in your way
Most GPU runtimes ask you to rewrite your training loop, manage infrastructure, or accept a black box. Sahasra adds remote execution on top of the JAX code you already wrote — you stay in control of the code, Sahasra handles the placement.
Your JAX code stays the same
Add three lines — a client, a session, and an execute call. Your @jax.jit functions are untouched. Sahasra lowers the graph locally and ships it to the remote worker.
Estimate before you commit
Call estimate_jitted() first. You get back whether the workload is feasible on your target GPU class before burning any real compute.
Tensors stay on the worker
Multi-step training reuses remote tensor handles. Parameters do not round-trip every step — they stay resident on the worker until you close the session.
Failures are never silent
If your bound worker goes offline, Sahasra marks the session failed with a clear error. No mysterious state, no hidden migrations, no wondering what happened.
import jax
import jax.numpy as jnp
@jax.jit
def train_step(x):
return jnp.tanh(x @ x)
x = jnp.ones((1024, 1024), dtype=jnp.float32)
y = train_step(x)
print(y.shape)import jax
import jax.numpy as jnp
from sahasra import RemoteJaxSession, SahasraClient
client = SahasraClient("http://your-control-plane")
session = RemoteJaxSession.create(client, gpu_class="g5")
@jax.jit
def train_step(x):
return jnp.tanh(x @ x)
x = jnp.ones((1024, 1024), dtype=jnp.float32)
execution = session.execute_jitted(train_step, sample_inputs=(x,))
print(execution.runtime_mode) # jax_export_localThe function body does not change. Sahasra wraps the execution path, not your model logic.
Validated results — run locally end to end
100
Remote training steps
across 4 workers in one run
93.75%
Step accuracy
on a real JAX MLP neural-net demo
4 workers
Zero tensor leak
all sessions cleaned to zero on close
~5 lines
To switch
from local JAX to running remotely
Why Sahasra
A runtime story built around visibility, not hand-waving
JAX-first remote execution
Lower local jitted functions, ship graph payloads through the control plane, and execute exported JAX programs behind a standalone worker boundary.
Visible scheduler state
Workers self-register, heartbeat their live load, and expose tensor pressure, active executions, and last-seen health through the control plane.
Remote tensor reuse
Repeated steps can feed remote tensor handles back into later executions so parameters stay resident instead of bouncing across the wire.
Explicit failure semantics
Sessions stay bound to one worker. If that worker goes offline, Sahasra fails clearly instead of silently migrating state and hiding runtime risk.
Control plane loop
Sahasra is deliberately opinionated: keep scheduling rules simple, keep worker state observable, and keep failures explicit.
Create a session
The client requests a JAX session, and the control plane binds it to the lowest-load healthy worker.
Lower the graph locally
Your local machine lowers a jitted JAX function and packages StableHLO-style metadata plus serialized export artifacts when available.
Execute through the worker
The API forwards the graph to the bound worker, which runs it locally and stores outputs in the worker tensor store.
Reuse or close cleanly
Subsequent steps can reuse remote tensor handles. Closing the session deletes worker-owned tensors and returns the runtime to zero.
Phase 2 outcome
Minimal coordinator, real behavior
- Workers self-register and continuously report heartbeat load state.
- The API chooses the least-loaded healthy worker, preferring idle workers first.
- Bound sessions fail clearly if their worker disappears; there is no silent tensor migration.
- Closing a session deletes worker-owned tensors and returns runtime pressure to zero.
Roadmap
Built to evolve without rewriting the contract
Drop-in JAX code, remote runtime semantics
The developer experience is intentionally direct: create a session, estimate a jitted function, execute it, and optionally keep the resulting tensors remote for the next step.
What you get right now
Session creation, graph estimation, remote execution requests, remote tensor handles, worker heartbeats, least-loaded placement, cleanup, and offline-worker failure behavior.
What it is not pretending to be yet
Sahasra is not claiming cross-worker tensor migration or finished distributed training. The current value is a clean, inspectable runtime contract that already works with real local workers.
import jax
import jax.numpy as jnp
from sahasra import RemoteJaxSession, SahasraClient
client = SahasraClient("http://127.0.0.1:8000")
session = RemoteJaxSession.create(client, gpu_class="g5")
@jax.jit
def train_step(x):
return jnp.tanh(x @ x)
x = jnp.ones((1024, 1024), dtype=jnp.float32)
estimate = session.estimate_jitted(train_step, sample_inputs=(x,))
execution = session.execute_jitted(train_step, sample_inputs=(x,))
print(estimate.feasible, execution.runtime_mode, len(execution.output_tensors))Contact
Need a runtime layer that stays understandable under load?
Sahasra is being built as infrastructure for serious ML iteration: visible placement, remote execution, explicit failure, and a clean path from local development to real workers.