As presented in my post Multimodal reranking I decided to give a try to DPO to fine tune a multimodal embedding model, which had been trained using a contrastive loss function. My approach was mainly inspired in how LLM are fine tuned using DPO. The reason why taking this path is that there is no need for a reward model, losses are calculating taking directly the preferences in the dataset.
The model fit take a triplet a text query and a pair of pictures, along with the two pictures there are also two scores for each picture from 0 to 100 that indicating how well the picture matches the text query (as I wrote in the other post, scores were generated by a LLM). For example:
"image";"query";"score"
"1268119946.jpg";"A large living room with lots of light and a wooden table in the middle";"70"
"1221416997.jpg";"A large living room with lots of light and a wooden table in the middle";"85"
The main idea is that model eventually would be able to prefer 1268119946.jpg over 1221416997.jpg, if not loss back-propagation would correct it. In more formal terms, loss is trickier than that.
Thus, given a preferred image and a less‑preferred image for query ,
where denotes similarities from the frozen policy, , and . KL regards the Kullback-Leibler divergence for the output distribution in original and fine tuned models.
Lets turn it into a piece of code
For the layperson what it means, loss would be greater if current model prefers the wrong picture (and its bigger if the greater the size of that error, regarded as gap in the code below), and we try to stabilize our model by a regularization term (the KL term). Both parts of the loss piece can be adjusted with the and hyperparameters.
def dpo_loss(s_pos, s_neg, s_pos_ref, s_neg_ref, gap, beta, kl_lambda):
# Compute preference probability from model and reference
pi = torch.sigmoid(beta * (s_pos - s_neg))
ref = torch.sigmoid(beta * (s_pos_ref - s_neg_ref))
# DPO loss term: negative log likelihood of preference
dpo = -torch.log(pi + 1e-12)
# KL divergence term for regularization
kl = kl_lambda * (ref * torch.log((ref + 1e-12) / (pi + 1e-12))
+ (1 - ref) * torch.log((1 - ref + 1e-12)/(1 - pi + 1e-12)))
loss = (dpo + kl).mean()
return loss
Training the model is quite straightforward, note that we will use two CLIP models one frozen (model_ref) and the one being tuned (model):
# Load original model (used a reference)
model, preprocess, _ = open_clip.create_model_and_transforms( 'xlm-roberta-large-ViT-H-14', pretrained='frozen_laion5b_s13b_b90k')
model_ref = open_clip.create_model('xlm-roberta-large-ViT-H-14', pretrained='frozen_laion5b_s13b_b90k')
...
# -------------------------------
# Training Loop
# -------------------------------
print(f"Starting {loss_function.upper()} fine-tuning...")
last_avg_loss = 1e+10
patience_steps = 0
loss_fn = dpo_loss if args.mode == "dpo" else grpo_loss
for epoch in range(args.epochs):
gap_values = []
start_time = time.time()
epoch_loss = 0.0
n_batches = 0
pbar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
for text_batch, pos_imgs, neg_imgs, gap_weights in pbar:
# Tokenize the batch of queries
tokens = tokenizer(list(text_batch)).to(device)
pos_imgs = torch.stack(pos_imgs).to(device)
neg_imgs = torch.stack(neg_imgs).to(device)
gap_weights = torch.tensor(gap_weights, dtype=torch.float32).to(device)
# Forward pass: compute similarities for policy model
s_pos = clip_similarity(model, tokens, pos_imgs)
s_neg = clip_similarity(model, tokens, neg_imgs)
with torch.no_grad():
s_pos_ref = clip_similarity(model_ref, tokens, pos_imgs)
s_neg_ref = clip_similarity(model_ref, tokens, neg_imgs)
loss = loss_fn(s_pos, s_neg, s_pos_ref, s_neg_ref, gap_weights,
beta=args.beta, kl_lambda=args.kl_lambda)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
n_batches += 1
avg_loss = epoch_loss / n_batches
pbar.set_postfix(loss=f"{loss.item():.4f}", avg=f"{avg_loss:.4f}")
end_time = time.time()
loss_reduction = last_avg_loss - avg_loss
print(f"Epoch {epoch+1}: Average Loss = {avg_loss:.4f}, Loss Reduction = {loss_reduction:.6f}")
add_log_entry(model_file_log, epoch, epoch_loss, data_loader, sample_size, start_time, end_time)