No description
  • Python 99.9%
  • TeX 0.1%
Find a file
2026-04-11 00:03:56 +02:00
.forgejo/workflows Isolate loss for use by other packages 2026-04-10 01:17:25 +02:00
.github/workflows Add workflow update for github 2026-04-07 18:36:07 +02:00
.theia Add README 2026-02-04 21:25:28 +01:00
.vscode Add README 2026-02-04 21:25:28 +01:00
d_cat Isolate loss for use by other packages 2026-04-10 01:17:25 +02:00
images Add README 2026-02-04 21:25:28 +01:00
tests Isolate loss for use by other packages 2026-04-10 01:17:25 +02:00
.gitignore Add all the code 2026-03-31 21:21:52 +02:00
.python-version Add README 2026-02-04 21:25:28 +01:00
dcat.bib Add README 2026-02-04 21:25:28 +01:00
LICENSE Add license 2026-04-07 18:40:38 +02:00
pyproject.toml Isolate loss for use by other packages 2026-04-10 01:17:25 +02:00
README.md Update README 2026-04-11 00:03:56 +02:00
uv.lock Add workflow update for github 2026-04-07 18:36:07 +02:00

D-CAT 😼: Decoupled Cross-Attention Knowledge Transfer between Sensor Modalities for Unimodal Inference

This paper was accepted to ICRA 2026.

PDF, cite us, website

Train with Multiple Sensors, Deploy with One

D-CAT

Youve built a high-performing model using camera, audio, and IMU data, great! But when its time to deploy, your industrial partners hit a roadblock: scaling to 60,000 sensors isnt feasible due to cost and complexity.

What if you could train with multimodal data but deploy with just one sensor, slashing costs and complexity without sacrificing performance?

Meet D-CAT: A framework for multimodal training with decoupled inference thanks to a novel cross-attention loss. Train robustly, deploy efficiently.

Installation

Install with uv:

uv sync --all-extras

Using the Cross-Attention Loss in Your Own Model

The cross_attention_loss function enables cross-modal knowledge transfer by aligning attention scores between different modalities. This is particularly useful for training robust unimodal models from multimodal supervision.

How to Use

The loss function compares attention patterns between two modalities by:

  1. Computing attention scores: K^T @ V where K=keys, V=values
  2. Normalizing the scores
  3. Computing the Frobenius norm between aligned attention patterns

Import the Loss Function

from d_cat.cross_attention_loss import cross_attention_loss

Example Usage

Assuming you have two models for different modalities (e.g., audio and image) that output key, query, value tensors:

import torch
from your_model import YourAudioModel, YourImageModel

# Initialize models
audio_model = YourAudioModel()
image_model = YourImageModel()

# Get data (batch_size, seq_len, hidden_dim)
audio_input = torch.randn(32, 100, 64)
image_input = torch.randn(32, 16, 16, 64)  # flattened to (32, 256, 64)

# Forward pass to get KQV tuples (key, query, value)
audio_logits, (k_audio, q_audio, v_audio) = audio_model(audio_input)
image_logits, (k_image, q_image, v_image) = image_model(image_input)

# Compute loss - optional mask for better alignment
# Mask can be used to only align samples where one modality's predictions are correct
good_predictions = torch.ones(32, dtype=torch.bool)  # example: all correct
audio_loss = cross_attention_loss(
    (k_audio, q_audio, v_audio),
    (k_image, q_image, v_image),
    mask=good_predictions
)

# Combine with your main loss
total_loss = classification_loss + 0.1 * audio_loss  # weigh the cross-attention loss

Key Parameters:

  • modality_a_kqv: Tuple of (keys, queries, values) for modality A
  • modality_b_kqv: Tuple of (keys, queries, values) for modality B
  • mask: Optional boolean tensor (shape: [batch_size]) to select which samples to align. Only samples where mask=True are used in the loss computation.

Integration Tips:

  1. Use the loss during training to align modality-specific attention mechanisms
  2. Apply the loss conditionally (e.g., only when one modality performs well)
  3. Tune the weight (cross_attention_weight) to balance between alignment and task-specific objectives
  4. Normalization ensures the loss is dimension-agnostic across different model architectures

Run The Experiments

D-CAT provides complete training pipelines organized by dataset and modality transfer. The entry points are in the scripts/ directory.

Available Datasets

D-CAT supports the following standard datasets for multimodal learning:

Dataset Modalities Scripts Directory
VGGSound Audio + Image scripts/vggsound/
UESTC Image + IMU scripts/uestc/
Dryad Audio + IMU scripts/dryad/

Running Experiments

Each dataset has dedicated training scripts:

VGGSound Experiments

Audio-to-Image Transfer:

python scripts/vggsound/train_vgg_audio_to_image_model.py \
    --data_path /your/vggsound/data/path \
    --audio_model_path mlflow://path/to/audio/mlflow/model \
    --model_param.batch_size 32 \
    --model_param.learning_rate 0.001 \
    --model_param.max_epoch 50

Image-to-Audio Transfer:

python scripts/vggsound/train_vgg_image_to_audio_model.py \
    --data_path /your/vggsound/data/path \
    --image_model_path mlflow://path/to/image/mlflow/model \
    --model_param.batch_size 32 \
    --model_param.learning_rate 0.001 \
    --model_param.max_epoch 50

UESTC Dataset

IMU-to-Image Transfer:

python scripts/uestc/UESTC_imu_kqv_training.py \
    --data_root /path/to/UESTC \
    --batch_size 64 \
    --epochs 100 \
    --lr 0.01

Dryad Dataset

Audio+IMU multimodal training:

python scripts/dryad/train_dryad_audio_kqv_model.py \
    --data_root /path/to/Dryad \
    --batch_size 256 \
    --epochs 200 \
    --lr 0.0001

Command Line Parameters

Each script accepts configuration through command-line arguments. Use --help to see full options:

python scripts/vggsound/train_vgg_audio_to_image_model.py -- --help

Dataset Structure

D-CAT expects datasets structured with separate folders for each modality and splits:

vggsound/
├── original_data/
│   ├── mp4/           # video frames (for image modality)
│   └── wav/           # audio files (for audio modality)

Reproducing Paper Results

To reproduce ICRA 2026 paper results:

# Multi-modal training on Dryad dataset
python scripts/dryad/train_dryad_audio_kqv_model.py \
    --data_root /data/Dryad \
    --audio_len_per_window 0.5 \
    --batch_size 256 \
    --epochs 200 \
    --lr 0.0001

After training, you'll have:

  • A multimodal model that can make predictions using audio, IMU, or both
  • A unimodal image model optimized for deployment (via positive transfer)
  • Performance metrics logged via MLflow and TensorBoard

Contribute

We welcome contributions!

The code in this repo is formatted and linted using RUFF. The code should follow the PEP8 and be type annotated.