EmbeddingMetric
EmbeddingMetric 跟踪对比学习(InfoNCE)训练中的嵌入质量,报告锚点-正样本余弦相似度和批内负样本相似度。
使用方法
from twinkle.metric import EmbeddingMetric
metric = EmbeddingMetric(device_mesh=device_mesh, process_group=process_group)
# 训练中
metric.accumulate(inputs, outputs)
# 日志间隔时
results = metric.calculate()
# results: {'pos_sim': '0.8523', 'neg_sim': '0.2134', 'loss': '0.3412', ...}
输出指标
| 指标 | 说明 |
|---|---|
pos_sim | 锚点与正样本的平均余弦相似度 |
pos_sim_min | 批内最小正样本相似度 |
pos_sim_max | 批内最大正样本相似度 |
neg_sim | 锚点与其他正样本(批内负样本)的平均相似度 |
loss | 平均对比损失值 |
grad_norm | 梯度范数 |
此指标与
InfonceLoss配合使用,适用于嵌入/检索模型训练。