Contrasting Multiple Representations with the Multi-Marginal Matching Gap

1Apple, 2The Hebrew University of Jerusalem
ICML 2024

*Indicates Equal Contribution

Abstract

Learning meaningful representations of complex objects that can be seen through multiple ($k\geq 3$) views or modalities is a core task in machine learning. Existing methods extend the InfoNCE loss, typically considered for paired views ($k=2$), either by instantiating $\tfrac12k(k-1)$ InfoNCE pairs, or by using reduced embeddings, following a one vs. average-of-rest strategy. We propose the multi-marginal matching gap ($\rm{M3G}$), a radically different loss that borrows tools from multi-marginal optimal transport (MM-OT) theory. Given $n$ points, each seen as a $k$-tuple of embeddings, our loss contrasts the cost of matching these $n$ ground-truth $k$-tuples, with the MM-OT polymatching cost, which seeks $n$ optimally arranged $k$-tuples chosen within $n\times k$ vectors. While the exponential complexity $O(n^k$) of the MM-OT problem may seem daunting, our experiments show that the MM-Sinkhorn algorithm can be run for $k=3\sim 6$ views. Additionally, and thanks to Danskin's theorem, getting the gradient of the $\rm{M3G}$ loss does not require a backward pass. Our experiments demonstrate performance improvements over multiview extensions of InfoNCE, for both self-supervised and multimodal tasks.

Abstract visualization of M3G

Intuition for $\rm{M3G}$

Abstract visualization of M3G

BibTeX



@InProceedings{piran24m3g,
  title = 	 {Contrasting Multiple Representations with the Multi-Marginal Matching Gap},
  author =       {Piran, Zoe and Klein, Michal and Thornton, James and Cuturi, Marco},
  booktitle = 	 {Proceedings of the 41st International Conference on Machine Learning},
  pages = 	 {40827--40842},
  year = 	 {2024},
  editor = 	 {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
  volume = 	 {235},
  series = 	 {Proceedings of Machine Learning Research},
  month = 	 {21--27 Jul},
  publisher =    {PMLR},
  pdf = 	 {https://raw.githubusercontent.com/mlresearch/v235/main/assets/piran24a/piran24a.pdf},
  url = 	 {https://proceedings.mlr.press/v235/piran24a.html},
  abstract = 	 {Learning meaningful representations of complex objects that can be seen through multiple ($k\geq 3$) views or modalities is a core task in machine learning. Existing methods use losses originally intended for paired views, and extend them to $k$ views, either by instantiating $\tfrac12k(k-1)$ loss-pairs, or by using reduced embeddings, following a one vs. average-of-rest strategy. We propose the multi-marginal matching gap (M3G), a loss that borrows tools from multi-marginal optimal transport (MM-OT) theory to simultaneously incorporate all $k$ views. Given a batch of $n$ points, each seen as a $k$-tuple of views subsequently transformed into $k$ embeddings, our loss contrasts the cost of matching these $n$ ground-truth $k$-tuples with the MM-OT polymatching cost, which seeks $n$ optimally arranged $k$-tuples chosen within these $n\times k$ vectors. While the exponential complexity $O(n^k$) of the MM-OT problem may seem daunting, we show in experiments that a suitable generalization of the Sinkhorn algorithm for that problem can scale to, e.g., $k=3\sim 6$ views using mini-batches of size $64 \sim128$. Our experiments demonstrate improved performance over multiview extensions of pairwise losses, for both self-supervised and multimodal tasks.}
}