Why Ray for Distributed ML?
Ray is an open-source distributed computing framework from Anyscale that unifies the ML lifecycle under a single Python-native runtime. Unlike Spark — which is optimized for batch ETL and SQL analytics — Ray is purpose-built for heterogeneous ML workloads: it handles CPU-bound preprocessing, GPU-bound training, hyperparameter search, and real-time inference within a single cluster, sharing resources dynamically instead of running separate infrastructure for each stage.
The key insight behind Ray is that ML teams spend most of their time gluing together incompatible distributed systems: a Spark cluster for data prep, a separate training cluster for GPU jobs, a third system for HPO, and yet another for serving. Ray replaces this zoo with a unified compute layer. If you are already running MLflow-tracked training pipelines and hitting single-machine limits, Ray is the natural next step.
Ray Core
Remote tasks and actors over any Python function. The distributed computing primitive that powers all higher-level libraries.
Ray AI Libraries
Ray Data, Ray Train, Ray Tune, and Ray Serve — high-level abstractions for each phase of the ML lifecycle built on Ray Core.
KubeRay
Kubernetes operator for Ray clusters. RayCluster, RayJob, and RayService CRDs for production deployment with autoscaling and spot instance support.
Ray Core — Remote Tasks and Actors
Ray Core provides two primitives: remote tasks (stateless functions that run on any worker) and actors (stateful objects that live on a specific worker, maintain local state, and handle method calls asynchronously). Both are decorated with @ray.remote and return futures (ObjectRef) instead of values.
import ray
import time
import numpy as np
# Initialize Ray — auto-detects local CPUs, or connects to a cluster
ray.init() # local: ray.init() | cluster: ray.init("ray://head-node:10001")
# ── Remote tasks: stateless parallel functions ────────────────────────
@ray.remote
def preprocess_shard(shard_path: str) -> np.ndarray:
"""Runs on any available worker. No shared state."""
data = np.load(shard_path)
return (data - data.mean()) / data.std()
# Submit 100 shards in parallel — returns futures immediately
shards = [f"s3://data/shard_{i:04d}.npy" for i in range(100)]
futures = [preprocess_shard.remote(path) for path in shards]
# Block on results (all 100 run in parallel across workers)
results = ray.get(futures) # list of np.ndarray
combined = np.concatenate(results, axis=0)
# ── Resource hints: reserve GPU or custom resources ───────────────────
@ray.remote(num_gpus=1, num_cpus=4)
def gpu_inference(batch: np.ndarray) -> np.ndarray:
import torch
model = torch.load("/model/checkpoint.pt")
with torch.no_grad():
return model(torch.from_numpy(batch).cuda()).cpu().numpy()
# ── Actors: stateful objects with persistent local state ──────────────
@ray.remote
class ParameterServer:
"""Central parameter store — actors serialize method calls automatically."""
def __init__(self, model_size: int):
self.params = np.zeros(model_size, dtype=np.float32)
self.update_count = 0
def push_gradient(self, grad: np.ndarray, lr: float = 0.001) -> None:
self.params -= lr * grad
self.update_count += 1
def pull_params(self) -> np.ndarray:
return self.params.copy()
def get_update_count(self) -> int:
return self.update_count
# Create a named actor — survives driver restarts with lifetime="detached"
ps = ParameterServer.options(name="param_server").remote(model_size=10_000)
# Workers push gradients asynchronously
@ray.remote
def worker_train(worker_id: int, ps_handle) -> None:
for step in range(100):
params = ray.get(ps_handle.pull_params.remote())
grad = np.random.randn(10_000).astype(np.float32) # simulated gradient
ps_handle.push_gradient.remote(grad, lr=0.001)
if step % 10 == 0:
count = ray.get(ps_handle.get_update_count.remote())
print(f"Worker {worker_id} step {step}, total updates: {count}")
# Launch 4 workers in parallel
ray.get([worker_train.remote(i, ps) for i in range(4)])
# ── Object store: zero-copy shared memory between tasks ───────────────
# Put large array in Ray's plasma object store once — tasks read it zero-copy
large_dataset = np.random.rand(10_000_000, 128).astype(np.float32)
dataset_ref = ray.put(large_dataset) # stored in shared memory
@ray.remote
def process_slice(data_ref, start: int, end: int) -> float:
data = ray.get(data_ref) # zero-copy read from plasma
return float(data[start:end].mean())
# All workers share the same plasma object — no copies, no serialization cost
chunk = 1_000_000
means = ray.get([
process_slice.remote(dataset_ref, i * chunk, (i + 1) * chunk)
for i in range(10)
])Note
Ray Data — Scalable Dataset Preprocessing
Ray Data is a distributed dataset library that handles the preprocessing gap between raw storage and model training. It reads Parquet, CSV, JSON, images, and TFRecords from S3/GCS/HDFS, applies transformations in parallel across workers, and streams batches directly into the training loop — eliminating the need for a separate Spark preprocessing job. Ray Data integrates natively with Ray Train so that data loading and training share the same cluster resources.
import ray
from ray.data import read_parquet, read_images
import numpy as np
ray.init()
# ── Read Parquet from S3 — automatic parallelism based on file count ──
ds = ray.data.read_parquet(
"s3://ml-data/features/",
columns=["user_id", "feature_vec", "label"],
parallelism=200, # number of parallel read tasks
)
print(ds.schema()) # inspects schema without full materialization
print(ds.count()) # triggers execution, returns row count
# ── Transformations — lazy, fused into a single execution plan ────────
def normalize_features(batch: dict) -> dict:
"""batch is a dict of numpy arrays — one key per column."""
features = batch["feature_vec"].astype(np.float32)
mean = features.mean(axis=0)
std = features.std(axis=0) + 1e-8
batch["feature_vec"] = (features - mean) / std
return batch
def encode_labels(batch: dict) -> dict:
label_map = {"positive": 1, "negative": 0, "neutral": 2}
batch["label_int"] = np.array(
[label_map[l] for l in batch["label"]], dtype=np.int64
)
return batch
# .map_batches() applies transform to Arrow record batches in parallel
ds = (
ds
.map_batches(normalize_features, batch_size=1024, batch_format="numpy")
.map_batches(encode_labels, batch_size=1024, batch_format="numpy")
.filter(lambda row: row["label"] != "neutral") # row-level filter
.drop_columns(["label"]) # remove string label
)
# ── Train/test split ──────────────────────────────────────────────────
train_ds, test_ds = ds.train_test_split(test_size=0.1, shuffle=True, seed=42)
print(f"Train rows: {train_ds.count()}, Test rows: {test_ds.count()}")
# ── Materialize to disk for repeated access ───────────────────────────
# Save as Parquet — preserves column types, readable by downstream tools
train_ds.write_parquet("s3://ml-data/preprocessed/train/")
test_ds.write_parquet( "s3://ml-data/preprocessed/test/")
# ── Image preprocessing pipeline ─────────────────────────────────────
from torchvision import transforms
from PIL import Image
import io
image_ds = ray.data.read_images(
"s3://ml-data/images/",
include_paths=True,
size=(224, 224), # resize all images on read
)
def augment_image(batch: dict) -> dict:
transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
augmented = []
for img_bytes in batch["image"]:
img = Image.fromarray(img_bytes)
augmented.append(transform(img).numpy())
batch["tensor"] = np.stack(augmented)
return batch
image_ds = image_ds.map_batches(
augment_image,
batch_size=64,
num_gpus=0, # augmentation runs on CPU workers
concurrency=8, # 8 parallel transform workers
)Ray Train — Distributed Training with PyTorch and Lightning
Ray Train orchestrates distributed training across multiple GPUs and nodes. It handles process group initialization, gradient synchronization, checkpoint saving and restoring, and fault tolerance. For PyTorch, it wraps DistributedDataParallel automatically — you write a single-worker training function and Ray Train replicates it across the requested worker count. For PyTorch Lightning users, RayLightningEnvironment and RayTrainReportCallback plug in with minimal code changes.
import ray
from ray.train import ScalingConfig, RunConfig, CheckpointConfig
from ray.train.torch import TorchTrainer
import ray.train.torch as ray_torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
ray.init()
# ── Define the per-worker training function ───────────────────────────
def train_func(config: dict):
"""This function runs on EVERY worker. Ray handles DDP setup."""
import torch.distributed as dist
# Hyperparameters from config dict (populated by Ray Tune or caller)
lr = config.get("lr", 1e-3)
batch_size = config.get("batch_size", 256)
epochs = config.get("epochs", 10)
hidden_dim = config.get("hidden_dim", 256)
# ── Model ─────────────────────────────────────────────────────────
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, output_dim),
)
def forward(self, x):
return self.net(x)
model = MLP(input_dim=128, hidden_dim=hidden_dim, output_dim=3)
# ray_torch.prepare_model() wraps with DDP and moves to correct GPU
model = ray_torch.prepare_model(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
criterion = nn.CrossEntropyLoss()
# ── Data loading — each worker gets its own shard via Ray Data ────
# In practice: load from Ray Data dataset passed via ray.train.get_dataset_shard()
train_dataset = ray.train.get_dataset_shard("train")
for epoch in range(epochs):
model.train()
total_loss = 0.0
num_batches = 0
for batch in train_dataset.iter_torch_batches(
batch_size=batch_size,
dtypes={"feature_vec": torch.float32, "label_int": torch.long},
):
features = batch["feature_vec"]
labels = batch["label_int"]
optimizer.zero_grad()
logits = model(features)
loss = criterion(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# Report metrics and save checkpoint — only on rank 0 by default
ray.train.report(
metrics={"loss": avg_loss, "epoch": epoch, "lr": scheduler.get_last_lr()[0]},
checkpoint=ray.train.Checkpoint.from_dict(
{"model_state": model.module.state_dict(),
"optimizer_state": optimizer.state_dict(),
"epoch": epoch}
),
)
# ── Configure distributed training ───────────────────────────────────
scaling_config = ScalingConfig(
num_workers=4, # 4 worker processes (each gets 1 GPU)
use_gpu=True,
resources_per_worker={"CPU": 4, "GPU": 1},
placement_strategy="SPREAD", # spread workers across nodes
)
run_config = RunConfig(
name="mlp_training_run",
storage_path="s3://ml-runs/ray-train/",
checkpoint_config=CheckpointConfig(
num_to_keep=3, # keep only 3 best checkpoints
checkpoint_score_attribute="loss",
checkpoint_score_order="min",
),
failure_config=ray.train.FailureConfig(max_failures=2), # retry on worker failure
)
# ── Load preprocessed Ray Data datasets ──────────────────────────────
train_ds = ray.data.read_parquet("s3://ml-data/preprocessed/train/")
test_ds = ray.data.read_parquet("s3://ml-data/preprocessed/test/")
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config={
"lr": 5e-4,
"batch_size": 512,
"epochs": 20,
"hidden_dim": 512,
},
scaling_config=scaling_config,
run_config=run_config,
datasets={"train": train_ds}, # shard automatically across workers
)
result = trainer.fit()
print(f"Best checkpoint: {result.best_checkpoints[0]}")
print(f"Metrics: {result.metrics}")Note
max_failures=3 in FailureConfig — this is enough to survive typical spot preemption patterns without manual intervention.Ray Tune — Distributed Hyperparameter Optimization
Ray Tune runs hyperparameter search across hundreds of trials in parallel, each trial running on a separate set of cluster resources. It supports population-based algorithms (ASHA, PBT), Bayesian optimization via Optuna and Ax, and grid/random search as baselines. The ASHA (Asynchronous Successive Halving) scheduler is the default for deep learning: it aggressively terminates underperforming trials early, freeing GPU hours for promising configurations.
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig
ray.init()
# ── Search space definition ───────────────────────────────────────────
search_space = {
"lr": tune.loguniform(1e-5, 1e-2), # log-uniform between 1e-5 and 1e-2
"batch_size": tune.choice([128, 256, 512, 1024]),
"hidden_dim": tune.choice([128, 256, 512]),
"dropout": tune.uniform(0.1, 0.5),
"weight_decay": tune.loguniform(1e-6, 1e-3),
}
# ── ASHA: aggressive early stopping for deep learning HPO ────────────
asha_scheduler = ASHAScheduler(
metric="loss",
mode="min",
max_t=30, # maximum epochs per trial
grace_period=5, # minimum epochs before pruning
reduction_factor=3, # keep top 1/3 of trials at each rung
)
# ── Optuna Bayesian search for smarter candidate generation ──────────
optuna_search = OptunaSearch(
metric="loss",
mode="min",
points_to_evaluate=[
# Seed with known-good configurations
{"lr": 1e-3, "batch_size": 256, "hidden_dim": 256,
"dropout": 0.3, "weight_decay": 1e-4},
],
)
# ── Trainable: wraps the TorchTrainer for Tune ───────────────────────
def trainable(config):
"""Called once per trial. config is sampled from search_space."""
trainer = TorchTrainer(
train_loop_per_worker=train_func, # same train_func as before
train_loop_config=config,
scaling_config=ScalingConfig(
num_workers=2, use_gpu=True,
resources_per_worker={"CPU": 2, "GPU": 1},
),
run_config=RunConfig(storage_path="s3://ml-runs/ray-tune/"),
)
result = trainer.fit()
# Tune picks up reported metrics from train_func via ray.train.report()
return result.metrics
tuner = tune.Tuner(
trainable,
param_space=search_space,
tune_config=tune.TuneConfig(
num_samples=50, # total trials to run
scheduler=asha_scheduler,
search_alg=optuna_search,
max_concurrent_trials=8, # max parallel trials (limited by cluster GPUs)
),
run_config=RunConfig(
name="hpo_run_v1",
storage_path="s3://ml-runs/ray-tune/",
),
)
results = tuner.fit()
# ── Analyze results ───────────────────────────────────────────────────
best_result = results.get_best_result(metric="loss", mode="min")
print(f"Best config: {best_result.config}")
print(f"Best loss: {best_result.metrics['loss']:.4f}")
print(f"Best checkpoint: {best_result.checkpoint}")
# ── Population-Based Training (PBT): mutate hyperparameters mid-training
from ray.tune.schedulers import PopulationBasedTraining
pbt_scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="loss",
mode="min",
perturbation_interval=5, # perturb every 5 epochs
hyperparam_mutations={
"lr": tune.loguniform(1e-5, 1e-2),
"dropout": tune.uniform(0.1, 0.5),
},
quantile_fraction=0.25, # top 25% survive, bottom 25% get perturbed
)Once the best hyperparameters are found with Ray Tune, tracking the final trained model in MLflow and promoting it through staging to production follows the same MLOps CI/CD champion-challenger promotion patterns — Ray Tune's best checkpoint can be loaded directly as the challenger model for AUC comparison.
Ray Serve — Production Model Serving
Ray Serve is an ML-native serving framework built on Ray. Unlike Flask or FastAPI wrappers around a model, Serve handles request batching, autoscaling, multi-model deployment graphs, and GPU resource allocation natively. Each deployment is a Python class decorated with @serve.deployment that runs as a Ray actor, giving it persistent GPU memory and request queue semantics.
import ray
from ray import serve
from ray.serve.handle import DeploymentHandle
import numpy as np
import torch
ray.init()
serve.start(detached=True, http_options={"host": "0.0.0.0", "port": 8000})
# ── Basic deployment: GPU model with request batching ─────────────────
@serve.deployment(
num_replicas=2,
ray_actor_options={"num_gpus": 1, "num_cpus": 2},
max_ongoing_requests=50, # queue depth per replica
)
class ClassifierDeployment:
def __init__(self, model_path: str):
self.model = torch.load(model_path, map_location="cuda")
self.model.eval()
print(f"Model loaded on GPU: {torch.cuda.get_device_name(0)}")
@serve.batch(max_batch_size=64, batch_wait_timeout_s=0.05)
async def handle_batch(self, requests: list[dict]) -> list[dict]:
"""Batching decorator aggregates concurrent requests automatically."""
features = torch.tensor(
[req["features"] for req in requests],
dtype=torch.float32, device="cuda"
)
with torch.no_grad():
logits = self.model(features)
probs = torch.softmax(logits, dim=-1).cpu().numpy()
return [
{"label": int(np.argmax(probs[i])),
"confidence": float(probs[i].max())}
for i in range(len(requests))
]
async def __call__(self, request) -> dict:
body = await request.json()
return await self.handle_batch(body)
# ── Autoscaling configuration ─────────────────────────────────────────
@serve.deployment(
autoscaling_config={
"min_replicas": 1,
"max_replicas": 8,
"target_ongoing_requests": 20, # scale up when queue > 20 per replica
"upscale_delay_s": 10,
"downscale_delay_s": 60,
},
ray_actor_options={"num_gpus": 1},
)
class AutoscaledClassifier:
def __init__(self):
self.model = torch.load("/models/classifier.pt", map_location="cuda")
self.model.eval()
async def __call__(self, request) -> dict:
body = await request.json()
features = torch.tensor(body["features"], dtype=torch.float32).unsqueeze(0).cuda()
with torch.no_grad():
probs = torch.softmax(self.model(features), dim=-1).cpu().numpy()[0]
return {"label": int(np.argmax(probs)), "confidence": float(probs.max())}
# ── Deployment graph: multi-model pipeline ────────────────────────────
@serve.deployment(num_replicas=2, ray_actor_options={"num_cpus": 2})
class Preprocessor:
def __init__(self):
self.mean = np.load("/models/feature_mean.npy")
self.std = np.load("/models/feature_std.npy")
def preprocess(self, raw_features: list[float]) -> list[float]:
arr = (np.array(raw_features, dtype=np.float32) - self.mean) / (self.std + 1e-8)
return arr.tolist()
@serve.deployment(num_replicas=1, ray_actor_options={"num_cpus": 1})
class Router:
"""Orchestrates multi-stage inference: preprocess → classify."""
def __init__(self, preprocessor: DeploymentHandle, classifier: DeploymentHandle):
self.preprocessor = preprocessor
self.classifier = classifier
async def __call__(self, request) -> dict:
body = await request.json()
cleaned = await self.preprocessor.preprocess.remote(body["raw_features"])
result = await self.classifier.handle_batch.remote({"features": cleaned})
return {"prediction": result, "request_id": body.get("id")}
# ── Bind deployment graph and deploy ─────────────────────────────────
preprocessor = Preprocessor.bind()
classifier = ClassifierDeployment.bind(model_path="/models/classifier.pt")
router = Router.bind(preprocessor, classifier)
serve.run(router, route_prefix="/predict")Note
@serve.batch decorator is one of its most impactful features for GPU efficiency. Without batching, each request triggers a separate forward pass — most of the GPU's compute capacity is idle. With batching, Serve holds incoming requests for up to batch_wait_timeout_s seconds, then processes them in a single forward pass. On a T4 GPU, batching typically reduces per-request latency by 3–5x and increases throughput 10–20x compared to single-request inference.KubeRay — Running Ray on Kubernetes
KubeRay is the Kubernetes operator for Ray. It introduces three CRDs: RayCluster for persistent clusters, RayJob for submitting a single job to an ephemeral cluster, and RayService for Ray Serve deployments with zero-downtime upgrades. This approach parallels how the Spark Operator manages SparkApplication CRDs — but Ray clusters are persistent (head node stays up) while Spark clusters are ephemeral per job.
# ── Install KubeRay operator ──────────────────────────────────────────
helm repo add kuberay https://ray-project.github.io/kuberay-helm/
helm repo update
helm install kuberay-operator kuberay/kuberay-operator --namespace kuberay-system --create-namespace --version 1.1.0
# ── RayCluster: persistent cluster for interactive workloads ──────────
# raycluster.yaml
apiVersion: ray.io/v1
kind: RayCluster
metadata:
name: ray-ml-cluster
namespace: ml
spec:
rayVersion: "2.10.0"
headGroupSpec:
rayStartParams:
dashboard-host: "0.0.0.0"
num-cpus: "0" # head node: no tasks, only scheduling
template:
spec:
containers:
- name: ray-head
image: rayproject/ray-ml:2.10.0-gpu
resources:
limits:
cpu: "4"
memory: "16Gi"
ports:
- containerPort: 6379 # GCS port
- containerPort: 8265 # Dashboard
- containerPort: 10001 # Client port
workerGroupSpecs:
- groupName: gpu-workers
replicas: 4
minReplicas: 1
maxReplicas: 8 # autoscale between 1–8 GPU workers
rayStartParams:
num-gpus: "1"
template:
spec:
containers:
- name: ray-worker
image: rayproject/ray-ml:2.10.0-gpu
resources:
limits:
cpu: "8"
memory: "32Gi"
nvidia.com/gpu: "1"
env:
- name: RAY_worker_register_timeout_seconds
value: "120"
nodeSelector:
node.kubernetes.io/instance-type: "g4dn.2xlarge"
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
# ── RayJob: ephemeral cluster per job run ─────────────────────────────
# rayjob.yaml
apiVersion: ray.io/v1
kind: RayJob
metadata:
name: training-run-001
namespace: ml
spec:
submissionMode: K8sJobMode
entrypoint: "python /app/train.py --config /app/config.yaml"
runtimeEnvYaml: |
pip:
- torch==2.2.0
- ray[train]==2.10.0
- mlflow==2.13.0
env_vars:
MLFLOW_TRACKING_URI: "http://mlflow.ml.svc:5000"
S3_BUCKET: "ml-data"
shutdownAfterJobFinishes: true # auto-cleanup cluster when job completes
ttlSecondsAfterFinished: 3600 # delete Job resource after 1 hour
rayClusterSpec:
rayVersion: "2.10.0"
headGroupSpec:
rayStartParams: {num-cpus: "0"}
template:
spec:
containers:
- name: ray-head
image: my-registry/ml-training:v1.2.0
resources:
limits: {cpu: "4", memory: "16Gi"}
workerGroupSpecs:
- groupName: gpu-workers
replicas: 2
rayStartParams: {num-gpus: "1"}
template:
spec:
containers:
- name: ray-worker
image: my-registry/ml-training:v1.2.0
resources:
limits: {cpu: "8", memory: "32Gi", "nvidia.com/gpu": "1"}
# ── RayService: serving with zero-downtime upgrade ────────────────────
# rayservice.yaml
apiVersion: ray.io/v1
kind: RayService
metadata:
name: classifier-service
namespace: ml
spec:
serviceUnhealthySecondThreshold: 120
deploymentUnhealthySecondThreshold: 60
serveConfigV2: |
applications:
- name: classifier
route_prefix: /predict
import_path: serve_app:router
runtime_env:
working_dir: "s3://ml-apps/serve/v1.3.0.zip"
pip: ["torch==2.2.0", "numpy==1.26.0"]
deployments:
- name: Router
num_replicas: 1
- name: ClassifierDeployment
num_replicas: 2
ray_actor_options:
num_gpus: 1
rayClusterConfig:
headGroupSpec:
rayStartParams: {num-cpus: "0"}
template:
spec:
containers:
- name: ray-head
image: rayproject/ray-ml:2.10.0-gpu
resources:
limits: {cpu: "4", memory: "16Gi"}Integrating Ray with MLflow and Feature Stores
Ray Train and Ray Tune both integrate with MLflow for experiment tracking. Every trial in a Tune run can log metrics, parameters, and checkpoints to an MLflow tracking server, making HPO results searchable alongside baseline experiments. The MLflow autolog feature works inside Ray workers as long as the tracking URI is set as an environment variable on all worker pods.
import ray
import mlflow
from ray import tune
from ray.air.integrations.mlflow import MLflowLoggerCallback, setup_mlflow
ray.init()
# ── MLflow autologging inside a Ray Train worker ──────────────────────
def train_func_with_mlflow(config: dict):
import mlflow.pytorch
# setup_mlflow() initializes MLflow on each worker with the experiment name
mlflow_config = setup_mlflow(
config,
experiment_name="ray-distributed-training",
tracking_uri=config.get("mlflow_uri", "http://mlflow:5000"),
)
mlflow.pytorch.autolog(log_every_n_epoch=1, checkpoint=False)
with mlflow.start_run(nested=True):
mlflow.log_params(config)
# ... training loop as before ...
for epoch in range(config["epochs"]):
loss = run_epoch(config) # your training step
mlflow.log_metric("train_loss", loss, step=epoch)
ray.train.report({"loss": loss})
# ── MLflow callback for Ray Tune: log all trial metrics automatically ─
tuner = tune.Tuner(
train_func_with_mlflow,
param_space={
"lr": tune.loguniform(1e-4, 1e-2),
"hidden_dim": tune.choice([128, 256]),
"epochs": tune.choice([10, 20]),
"mlflow_uri": "http://mlflow.ml.svc:5000",
},
tune_config=tune.TuneConfig(num_samples=20),
run_config=ray.train.RunConfig(
callbacks=[
MLflowLoggerCallback(
tracking_uri="http://mlflow.ml.svc:5000",
experiment_name="ray-tune-hpo",
save_artifact=True, # save trial checkpoints as MLflow artifacts
)
]
),
)
results = tuner.fit()
# ── Feature loading from Feast inside Ray Data ─────────────────────────
# Ray Data + Feast: materialize online features into a Ray dataset for training
def load_features_from_feast(entity_df_path: str):
"""Fetch historical features from Feast for model training."""
from feast import FeatureStore
import pandas as pd
store = FeatureStore(repo_path="/feast/feature_repo")
entity_df = pd.read_parquet(entity_df_path)
training_df = store.get_historical_features(
entity_df=entity_df,
features=[
"user_features:click_rate_7d",
"user_features:session_count_30d",
"item_features:popularity_score",
"item_features:avg_rating",
],
).to_df()
return training_df
# Load features via Ray remote task — runs distributed across workers
@ray.remote
def fetch_partition(partition_path: str) -> "pd.DataFrame":
return load_features_from_feast(partition_path)
partitions = [f"s3://data/entities/partition_{i}.parquet" for i in range(50)]
feature_dfs = ray.get([fetch_partition.remote(p) for p in partitions])
import pandas as pd
full_feature_df = pd.concat(feature_dfs, ignore_index=True)
ray_dataset = ray.data.from_pandas(full_feature_df)Note
get_historical_features() in Ray Data preprocessing tasks for training data generation, and Ray Serve's __init__ to load online features from the Redis store at serving time. This mirrors the pattern described in the Feast feature retrieval documentation and eliminates training-serving skew by using the same feature definitions for both.Ray Production Checklist
Pin Ray versions across all components: the Ray cluster image, Python SDK, and all AI library packages (ray[train], ray[tune], ray[serve]) must match exactly — version skew between head and workers causes silent serialization failures
Set num-cpus: '0' on the head node in KubeRay to prevent the scheduler from placing computation tasks there — the head node runs GCS and the dashboard and CPU contention degrades scheduling latency cluster-wide
Configure max_ongoing_requests per Ray Serve replica to bound memory usage — without this limit, replicas accept unlimited requests and OOM when upstream traffic spikes overwhelm model inference throughput
Use RayJob (not RayCluster + manual submission) for batch training runs — RayJob creates an ephemeral cluster, submits the job, waits for completion, then cleans up all pods, eliminating idle GPU costs between training runs
Enable Ray Serve autoscaling with conservative downscale_delay_s (120s or more) to prevent replica thrashing on bursty traffic — each replica reload reloads the model into GPU memory, taking 30–90 seconds per GPU
Store Ray Train checkpoints on S3 or GCS (not local disk) by setting storage_path in RunConfig — local disk checkpoints are lost when the head node pod restarts and prevent fault-tolerant resumption
Set RAY_worker_register_timeout_seconds to 120 in worker pod env — the default 30s is too short for nodes pulling large container images (GPU containers are typically 5–15 GB) and causes spurious worker registration failures
Configure Prometheus scraping for Ray metrics: ray-head exposes /metrics on port 8080, and each worker exposes its own metrics endpoint — use a ServiceMonitor CRD to scrape all worker pods dynamically
Test fault tolerance explicitly: kill a worker pod mid-training and verify that Ray Train resumes from the last checkpoint on the replacement worker — do this in staging before relying on it in production spot instance pools
Use RayService for Ray Serve production deployments instead of serve.run() in a RayCluster — RayService handles zero-downtime rolling upgrades by deploying the new serve config before draining traffic from the old deployment
For related patterns: fine-tuning smaller open-source models on a single GPU is covered in Fine-Tuning Open Models with LoRA and QLoRA — Ray Train scales those same techniques to multi-node clusters when your dataset or model exceeds single-GPU capacity.
Hitting single-machine limits on ML training or hyperparameter search, managing separate Spark, training, and serving clusters, or looking to consolidate your ML infrastructure on Kubernetes?
We design and implement distributed ML platforms with Ray — from Ray Core actor design and Ray Data preprocessing pipeline configuration to Ray Train multi-GPU training setup with fault-tolerant S3 checkpointing, Ray Tune HPO with ASHA and Optuna search, MLflow integration for experiment tracking across all trials, Ray Serve deployment graph design with request batching and autoscaling, KubeRay operator installation with RayCluster and RayJob CRD configuration for ephemeral training jobs, RayService rolling upgrade configuration, GPU node pool setup with spot instance toleration and Karpenter autoscaler integration, and Prometheus metrics scraping for cluster observability. Let’s talk.
Let's Talk