The hardware and bandwidth for this mirror is donated by dogado GmbH, the Webhosting and Full Service-Cloud Provider. Check out our Wordpress Tutorial.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]dogado.de.

Optimal Tensor Transport

Koki Tsuyuzaki

Laboratory for Bioinformatics Research, RIKEN Center for Biosystems Dynamics Research
k.t.the-answer@hotmail.co.jp

2026-05-08

Introduction

In this vignette, we consider optimal tensor transport (OTT), which is an extension of OT to be able to handle tensors of any order by learning possibly multiple transport plans.

Here, we reproduce the experiments in the original paper (Kerdoncuff 2022). For the details of the methodology, see the original paper.

library("otTensor")

.show_matrix <- function(mat, main = ""){
    mat_rev <- apply(mat, 2, rev)
    mat_rev <- t(mat_rev)

    row_index <- 1:ncol(mat_rev)
    col_index <- 1:nrow(mat_rev)

    # grayscale
    image(mat_rev, col = gray((0:255)/255), xaxt = "n", yaxt = "n",
        xlab = "", ylab = "", axes = FALSE, main = main)
}

OTT_1 (OT)

D <- 1 A <- 1 Is <- c(4) Ks <- c(7) f <- c(1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { arrX[i1] <- i1 } for (k1 in 1:Ks[1]) { arrY[k1] <- k1 }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) plot(arrX, type=“h”, col=“black”, main=“arrX”) plot(arrY, type=“h”, col=“black”, main=“arrY”)

OTT_12 (Co-OT)

D <- 2 A <- 2 Is <- c(4, 5) Ks <- c(7, 8) f <- c(1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)

par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)

OTT_11 (GW)

D <- 2 A <- 1 Is <- c(4, 4) Ks <- c(6, 6) f <- c(1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)

OTT_111 (triplets)

D <- 3 A <- 1 Is <- c(4, 4, 4) Ks <- c(6, 6, 6) f <- c(1, 1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)

OTT_123 (triCo-OT)

D <- 3 A <- 3 Is <- c(4, 5, 6) Ks <- c(7, 8, 9) f <- c(1, 2, 3) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)

par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX[,,2], main=“arrX[,,2]”) .show_matrix(arrY[,,2], main=“arrY[,,2]”)

par(mfrow=c(3, 2)) plot(ps[[3]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[3]]”) plot(qs[[3]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[3]]”) .show_matrix(out$Ts[[3]], main=“Ts[[3]]”) .show_matrix(arrX[,,3], main=“arrX[,,3]”) .show_matrix(arrY[,,3], main=“arrY[,,3]”)

OTT_112 (GW Collection)

D <- 3 A <- 2 Is <- c(4, 4, 5) Ks <- c(6, 6, 7) f <- c(1, 1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)

par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX[,,2], main=“arrX[,,2]”) .show_matrix(arrY[,,2], main=“arrY[,,2]”)

Session Information

## R version 3.6.3 (2020-02-29)
## Platform: x86_64-conda-linux-gnu (64-bit)
## Running under: Rocky Linux 9.5 (Blue Onyx)
## 
## Matrix products: default
## BLAS:   /home/koki/miniconda3/lib/libblas.so.3.9.0
## LAPACK: /home/koki/miniconda3/lib/liblapack.so.3.9.0
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=C              
##  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] rTensor_1.4.8   otTensor_0.99.0
## 
## loaded via a namespace (and not attached):
##  [1] digest_0.6.31   R6_2.5.1        jsonlite_1.8.4  evaluate_0.20  
##  [5] highr_0.10      rlang_0.4.11    jquerylib_0.1.4 bslib_0.3.1    
##  [9] rmarkdown_2.11  tools_3.6.3     xfun_0.38       yaml_2.3.7     
## [13] fastmap_1.1.1   compiler_3.6.3  htmltools_0.5.5 knitr_1.42     
## [17] sass_0.4.0

References

Kerdoncuff, T. et al. 2022. “Optimal Tensor Transport.” Proceedings of the AAAI Conference on Artificial Intelligence 36(7): 7124–32.

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.
Health stats visible at Monitor.