# Optimizing Contrastive/Rank/Triplet Loss in Tensorflow for Neural Information Retrieval

*ex*Senior Research Scientist @ Zalando Research

## Background

Recently, my team is applying deep neural networks to improve the search experience of customers. Researchers often call this type of application *neural information retrieval*. The input to the model is a full-text query and a set of documents. A search query typically contains a few terms, while a document, depending on the scenario, may contain hundreds of sentences, a set of images, or structured key-value pairs. The goal is to use DNN to rank the documents in response to the given query. The output of the model is a list, in which the preferred search results should be at the top, and the irrelevant results should be at the bottom.

In this article, I will explain the key to success of such models, namely the **contrastive/rank loss**, and then show you how to implement it in an efficient and generic way.

## Contrastive/Rank Loss

To achieve the goal, let’s first define some metric function $g$ between the query $q$ and document $d$ to represent similarity. Denoting a relevant document as $d^{+}$ and an irrelevant one as $d^{-}$. Consequently, ranking relevant items *before* the irrelevant ones means letting $g(q, d^{+}) > g(q, d^{-})$. In other words, we don’t even care about the absolute value of $g(q, d^{+})$ of $g(q, d^{-})$, whether $g(q, d^{+})$ is 7.3 or -10.5 does not bother us too much. We *only* care about the relative distances between positive and negative pairs. In fact, larger difference is better for us, as a clearer separation between positive and negative pairs can enhance the generalization ability.

That being said, minimizing $g(q, d^{-}) - g(q, d^{+})$ seems to be an obvious objective. However, this objective function is very problematic. First, it’s unbounded hence won’t converge. The optimizer could just simultaneously shrink $g(q, d^{+})$ and $g(q, d^{-})$ and make the loss arbitrarily small. So how about minimizing: $\max(0,g(q, d^{-}) - g(q, d^{+}))$? Now the loss function is bounded below (at zero) and the network can ignore large negative values. However, when $g(q, d^{-}) = g(q, d^{+})$ the loss would be zero, and the optimizer will not give any gradient. This is bad because the learning will stop at there. To solve it, we add a margin threshold $\epsilon$ to enforce $g(q, d^{+})$ to be at least $g(q, d^{-}) + \epsilon$. Finally we reach to the following objective function:

$$\min \max\left(0, \epsilon + g(q, d^{-}) - g(q, d^{+})\right).$$

One may notice that it is basically a hinge loss. In fact, we could use any loss function besides the hinge loss, e.g. logistic loss, exponential loss. As for the metric, we also have plenty of options, e.g. cosine, $\ell_1$/$\ell_2$-norm. We could even parametrize the metric function with a multi-layer perceptron and learn it from the data.

Formally, given a training set of $N$ triplets $\left\{ q_i, \{d^{+}_{i,j},\ldots\}, \{d^{-}_{i,k},\ldots\} \right\}$, the final loss function of the model has the following general form:

$$\min \sum_{i=1}^{N}\sum_{j=1}^{|d_{i}^{+}|}\sum_{k=1}^{|d_{i}^{-}|}\,w_{i,j}\,\ell\left(g(q_i, d_{i,j}^{+}), g(q_i, d_{i,k}^{-})\right),$$

where $w_{i,j}$ is the weight of the positive query-document pair. In practice, this could be the click-through rate, or log number of clicks mined from the query-log. $|d_{i}^{+}|$ and $|d_{i}^{-}|$ are the number of positive and negative documents associated with query $i$, respectively. Negative documents can be obtained via random sampling. For functions $\ell$ and $g$, the options are:

**Loss function $\ell$**: logistic, exponential, hinge loss, etc.**Metric function $g$**: cosine similarity, euclidean distance (i.e. $\ell_2$-norm), MLP, etc.

The network has the following structure:

Query and document encoder transform the input to a vector representation, correspondingly. As an example, one can use RNN, LSTM or GRU to encode query and CNN or VAE to encode document/images. The training data is fed from the left. In the next section, I will focus on the implementation of metric layer and loss layer. The encoder and negative sampling are also important to the model performance, but I leave them as a story for another day.

## Metric Layer Implementation

The code is fairly straightforward. But one needs to pay attention to the dimension of those tensors.

1 | with tf.variable_scope('Metric_Layer'): |

In the network, variable `query`

is the output from query-encoder, which is in the size of `[batch_size, num_hidden]`

. The document-encoder’s ouput has different size: `d_pos_norm`

is a tensor of `[batch_size, num_pos, num_hidden]`

and `d_neg_norm`

is `[batch_size, num_neg, num_hidden]`

. Therefore, we need to use `tf.expand_dims`

on `query`

to make them size-compatible. For cosine and $\ell_2$ cases, it is unnecessary to tile the tensor as the multiplication and subtraction operations support broadcasting. The figure below explains the data flow when using MLP as the metric.

In the above code, the MLP is implemented as a 64-32-16 feedforward network with softplus activation on each layer.

## Loss Layer Implementation

To compute the loss, we must unify the size of `metric_p`

and `metric_n`

. This can be again achieved by using the expand and tile operations on the corresponding axis, as described below. Note that how I `expand_dims`

and `tile`

on different axis for `metric_p`

and `metric_n`

, so that the difference `delta`

can be computed in one-shot.

1 | with tf.variable_scope('Loss_layer'): |

The idea behind is also quite straightforward. First we compute the difference of each triplet (query, positive document, negative document). Then, we feed `delta`

to the loss function and aggregate over all negative documents, via `tf.reduce_sum(..., axis=2)`

. Finally we rescale the loss of each query-(postive) document pair by `weight`

and reduce them into a scalar.

## Summary

Besides the search, contrastive/rank loss enjoys a wide range of application. One may see it also appears in question-answer, context-aware chatbot and people re-identification tasks. In this article, I showed that such loss generally contains two parts, i.e. the metric and the loss. By leveraging `tf.expand_dims`

, `tf.tile`

and the broadcasting feature of arithmetic operators in Tensorflow, it is fairly straightforward to implement it correctly.