• Articles
  • Fine-Tuning Sparse Embeddings for E-Commerce Search | Part 2: Training SPLADE on Modal
Back to Practical Examples

Fine-Tuning Sparse Embeddings for E-Commerce Search | Part 2: Training SPLADE on Modal

Thierry Damiba

·

March 09, 2026

Fine-Tuning Sparse Embeddings for E-Commerce Search | Part 2: Training SPLADE on Modal

This is Part 2 of a 5-part series on fine-tuning sparse embeddings for e-commerce search. In Part 1, we covered why sparse embeddings beat BM25 for e-commerce. Now we build the training pipeline.

Series:


In the last article we made the case for sparse embeddings in e-commerce search. Now we write the code. All source code is available in the GitHub repo, and you can try the fine-tuned models on HuggingFace. Want to skip straight to fine-tuning on your own data? See the sparse-finetune CLI. By the end of this piece, you’ll have a SPLADE model trained on Amazon’s ESCI dataset, running on Modal’s serverless GPUs, with checkpoints saved to persistent storage.

The Dataset: Amazon ESCI

We use Amazon’s ESCI dataset (Shopping Queries Dataset), released for KDD Cup 2022. It’s one of the most realistic e-commerce search benchmarks available:

  • 1.2M+ query-product pairs with human-annotated relevance labels
  • Four relevance grades: Exact (E), Substitute (S), Complement (C), Irrelevant (I)
  • Rich product metadata: titles, descriptions, bullet points, brands

The graded relevance is what makes ESCI interesting:

ESCI relevance gradient from Exact to Irrelevant

For training, we use Exact and Substitute pairs as positives. This teaches the model that both the exact product and reasonable alternatives are relevant, matching how real shoppers think.

Loading the Data

from datasets import load_dataset
from src.data.text_builder import build_product_text

def load_esci_training_data(max_samples=None):
    """Load ESCI dataset as anchor-positive pairs for contrastive training."""
    dataset = load_dataset("tasksource/esci", split="train")

    pairs = []
    for row in dataset:
        if row["relevance_label"] not in ("E", "S"):
            continue

        query = row["query"]
        product_text = build_product_text(
            title=row["product_title"],
            brand=row.get("product_brand", ""),
            description=row.get("product_description", ""),
            bullets=row.get("product_bullet_point", []),
        )
        pairs.append({"anchor": query, "positive": product_text})

        if max_samples and len(pairs) >= max_samples:
            break

    return pairs

Product Text Formatting

How you format product text matters for sparse embeddings. Unlike dense models that capture broad semantic meaning, SPLADE is lexically grounded: the specific tokens in your text determine which vocabulary dimensions activate:

def build_product_text(title, brand="", description="", bullets=None, max_length=512):
    """Consistent product text formatting for SPLADE."""
    parts = []

    # Brand in brackets makes it a distinct signal
    if brand:
        parts.append(f"[{brand}]")

    parts.append(title)

    # Pipe separators help the model distinguish sections
    if description:
        parts.append(f"| {description[:200]}")

    if bullets:
        parts.append(f"| {' | '.join(bullets[:3])}")

    text = " ".join(parts)
    return text[:max_length]

# Example output:
# "[Sony] WH-1000XM5 Wireless Headphones | Industry-leading noise
#  cancellation | 30hr battery | Hi-Res Audio"

The bracket notation for brands, pipe separators between sections, and character limits are deliberate. They preserve lexical signals that SPLADE can learn from: brand names, product attributes, and key features remain as distinct tokens rather than blurring into a wall of text.

Training stack: Modal for GPU compute, Sentence Transformers for training, Qdrant for evaluation

Setting Up the Modal App

Modal gives us serverless GPUs. No provisioning, no idle hardware, pay-per-second billing. Here’s the app configuration:

import modal

app = modal.App("esci-sparse-encoder")

# Persistent storage for checkpoints and datasets
checkpoint_volume = modal.Volume.from_name(
    "esci-sparse-checkpoints", create_if_missing=True
)
dataset_volume = modal.Volume.from_name(
    "esci-datasets", create_if_missing=True
)

# Docker image with dependencies
image = (
    modal.Image.debian_slim(python_version="3.11")
    .pip_install(
        "sentence-transformers>=5.0.0",
        "torch>=2.2.0",
        "transformers>=4.45.0",
        "datasets>=2.20.0",
        "qdrant-client>=1.12.0",
        "accelerate>=0.30.0",
    )
)

Two things matter here:

Persistent volumes. Training runs can take hours. If your SSH connection drops or a container restarts, you don’t want to lose checkpoints. Modal volumes persist data across runs. Mount them at a path and write to them like a local filesystem.

Detached runs. For long training jobs, launch with --detach and walk away:

# Start training and disconnect
uv run modal run --detach modal_app.py --mode train

# Come back later, check your checkpoints
uv run modal volume ls esci-sparse-checkpoints /checkpoints/

No S3 uploads, no checkpoint management code, no lost training runs.

Creating the SPLADE Model

Sentence Transformers v5 introduced SparseEncoder, making SPLADE training straightforward. The model has two components:

  1. MLMTransformer: A transformer with a masked language model head that outputs logits over the full vocabulary
  2. SpladePooling: Max-pools the token-level logits and applies ReLU + log saturation
from sentence_transformers import SparseEncoder
from sentence_transformers.sparse_encoder.models import (
    MLMTransformer,
    SpladePooling,
)

def create_sparse_encoder(base_model="distilbert/distilbert-base-uncased"):
    """Create a SPLADE model from a base transformer."""

    # MLM transformer outputs logits over vocabulary
    mlm = MLMTransformer(base_model)

    # SPLADE pooling: max over tokens, ReLU activation
    pooling = SpladePooling(pooling_strategy="max")

    return SparseEncoder(modules=[mlm, pooling])

We start from DistilBERT rather than a pre-trained SPLADE checkpoint (like naver/splade-v3). This is a deliberate choice. We want to measure how much domain-specific fine-tuning helps when starting from a general language model, not from a model already trained on web search data.

The Training Function

Here’s the core training logic, decorated as a Modal function:

@app.function(
    image=image,
    gpu="A100",
    volumes={
        "/checkpoints": checkpoint_volume,
        "/datasets": dataset_volume,
    },
    timeout=3600 * 6,
)
def train_sparse_encoder(config: dict):
    from sentence_transformers import SparseEncoder
    from sentence_transformers.sparse_encoder import SparseEncoderTrainer
    from sentence_transformers.training_args import SparseEncoderTrainingArguments
    from sentence_transformers.losses import SpladeLoss, SparseMultipleNegativesRankingLoss

    # Create model
    model = create_sparse_encoder(config["base_model"])

    # Load ESCI dataset (anchor-positive pairs)
    train_dataset = load_esci_training_data(
        max_samples=config.get("max_samples")
    )

    # SPLADE loss combines contrastive learning with sparsity regularization
    loss = SpladeLoss(
        model=model,
        loss=SparseMultipleNegativesRankingLoss(model=model),
        query_regularizer_weight=float(config.get("query_regularizer_weight", 5e-5)),
        document_regularizer_weight=float(config.get("document_regularizer_weight", 3e-5)),
    )

    # Training arguments
    args = SparseEncoderTrainingArguments(
        output_dir=f"/checkpoints/{config['run_name']}",
        num_train_epochs=config.get("num_epochs", 1),
        per_device_train_batch_size=config.get("batch_size", 32),
        learning_rate=float(config.get("learning_rate", 2e-5)),
        warmup_ratio=0.1,
        fp16=True,
        save_steps=1000,
        logging_steps=100,
    )

    # Train
    trainer = SparseEncoderTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=loss,
    )
    trainer.train()

    # Save final model
    model.save_pretrained(f"/checkpoints/{config['run_name']}/final")

    return f"/checkpoints/{config['run_name']}/final"

Understanding SpladeLoss

SpladeLoss wraps two objectives:

Contrastive loss (SparseMultipleNegativesRankingLoss): Given a batch of (query, product) pairs, treat other products in the batch as negatives. Push relevant query-product pairs together, push irrelevant ones apart. This is the same in-batch negative approach used for dense embedding training, and it works because most random products are irrelevant to a given query.

Sparsity regularization: Penalizes dense outputs to maintain efficiency. Without it, the model would activate all 30,000 vocabulary dimensions for every input. That’s technically optimal for matching but useless for retrieval speed and storage.

The regularization weights control this tradeoff:

ParameterValueEffect
query_regularizer_weight5e-5Higher = sparser queries
document_regularizer_weight3e-5Higher = sparser documents

The sweet spot is 100-300 active terms per vector. Too high regularization produces nearly empty vectors (fast but low recall). Too low produces thousands of terms (slow, huge index).

Document regularization is lower than query regularization because product descriptions need more terms to capture all relevant attributes. A product listing for headphones should activate terms like “audio”, “wireless”, “bluetooth”, “noise”, “canceling” - more than the 3-4 words in a typical query.

Configuration via YAML

We keep hyperparameters in YAML files for easy experimentation:

# configs/splade_standard.yaml
run_name: splade_standard
base_model: distilbert/distilbert-base-uncased
architecture: splade
batch_size: 32
learning_rate: 2e-5
num_epochs: 1
query_regularizer_weight: 5e-5
document_regularizer_weight: 3e-5
max_samples: 100000

100K samples trains in about 6 minutes on an A100 and costs less than $1 on Modal. The full 1.2M dataset with multiple epochs takes a few hours, still cheap compared to reserved GPU instances.

Parallel Hyperparameter Sweeps

One of Modal’s strengths is embarrassingly parallel workloads. Hyperparameter sweeps are a natural fit. spawn() launches one GPU per configuration:

@app.function(gpu="A100")
def train_single_experiment(config: dict):
    """Train one configuration."""
    model = create_sparse_encoder(config["base_model"])
    # ... training code ...
    return {"config": config, "ndcg": evaluate(model)}

@app.local_entrypoint()
def run_hyperparameter_sweep():
    """Launch all experiments in parallel."""
    configs = [
        {"learning_rate": 1e-5, "regularizer_weight": 3e-5},
        {"learning_rate": 2e-5, "regularizer_weight": 3e-5},
        {"learning_rate": 2e-5, "regularizer_weight": 5e-5},
        {"learning_rate": 5e-5, "regularizer_weight": 5e-5},
        # ... more configurations ...
    ]

    # Launch all experiments simultaneously
    handles = [train_single_experiment.spawn(c) for c in configs]

    # Collect results as they complete
    results = [h.get() for h in handles]
    best = max(results, key=lambda r: r["ndcg"])
    print(f"Best config: {best}")

A 24-experiment sweep finishes in the time of a single training run. Each experiment gets its own A100. You pay only for the compute time actually used, not for idle GPUs waiting in a queue.

What NOT to Do: The Inference-Free SPLADE Trap

We tried replacing the query-side transformer with a static embedding lookup to save latency. The idea is appealing: queries are short, so why run a full transformer?

# DON'T DO THIS (for e-commerce)
router = Router.for_query_document(
    query_modules=[
        SparseStaticEmbedding(tokenizer=mlm.tokenizer)  # Fast but weak
    ],
    document_modules=[
        mlm,
        SpladePooling(pooling_strategy="max"),
    ],
)

The results were disastrous:

ArchitecturenDCG@10
Standard SPLADE (contextual)0.389
Inference-Free (static)0.065

That’s 6x worse without contextual encoding.

The static embedding completely failed because e-commerce queries are highly contextual. “Apple” means different things in “apple iphone” vs “apple fruit”. The static embedding can’t disambiguate. It looks up “apple” and returns the same vector regardless of context.

The transformer is the bottleneck at ~15ms per query, but 15ms is perfectly acceptable for search. Don’t prematurely optimize away the component that makes the model work.

Modal detached training with persistent volumes

Running Training

With everything in place, launch training:

# Quick test run (100K samples)
uv run modal run modal_app.py \
    --config-path configs/splade_standard.yaml \
    --mode train

# Full dataset, detached
uv run modal run --detach modal_app.py \
    --config-path configs/splade_standard.yaml \
    --mode train

The model checkpoint gets saved to the persistent volume at /checkpoints/splade_standard/final. We’ve also published the trained model on HuggingFace as splade-ecommerce-esci so you can skip training and use it directly. In the next article, we’ll load this model, index products into Qdrant, and run retrieval benchmarks to see exactly how much we’ve improved over BM25.

Key Takeaways

  • ESCI’s graded relevance (Exact, Substitute, Complement, Irrelevant) teaches the model nuanced matching, not just binary relevant/not-relevant.
  • Product text formatting matters for sparse models. Keep lexical signals distinct with structured formatting.
  • SpladeLoss balances two objectives: contrastive learning for relevance and regularization for sparsity. The regularization weights are the main knob to tune.
  • Modal’s persistent volumes solve the checkpoint management problem. Detached runs survive SSH drops.
  • Don’t skip the query transformer. The 15ms of latency buys you a 6x quality improvement over static embeddings.

Next: Part 3 - Evaluation, Hard Negatives, and Results

Was this page useful?

Thank you for your feedback! 🙏

We are sorry to hear that. 😔 You can edit this page on GitHub, or create a GitHub issue.