Stéphan Tulkens

NLP Person

Fast way to find the rank of an item in pytorch

I recently read an interesting paper by Gehrmann et al. in which the rank of the predictions of a language model is used as a feature vector to distinguish machine-generated from regular text. In implementing this method in pytorch, I ran into an interesting problem that I solved in a really slow way, and subsequently made faster. This blog post shows you how not to do it, and how you can make it faster.

First, some preliminaries. A language model (LM) is a statistical mechanism that tries to predict a token, usually a word, given some context. As a statistical mechanism, it thus defines a probability distribution over a fixed vocabulary V. The context is usually other tokens from V, and usually consists of the preceding tokens. So, at a given time-step, a LM produces a vector of probabilities over the entire vocabulary. Given the true targets, we can then assess the quality of the LM by, for example, calculating the cross entropy.

But, as said before, instead of calculating the loss, we want to find out the rank of the true word. A naive solution would be to use argsort and where. Given a B * V tensor, and B targets (B is the batch size), this would lead to the following code.

import torch

# Make up some random numbers with batch size 32, |V|, 50K
torch.manual_seed(44)
probas = torch.rand(32, 50000)
probas /= probas.sum(1)[:, None]
targets = torch.randint(0, 50000, (32,))

def argsort_solution(x, targets):
    # Argsort by first axis
    # Sort by - because we're more interested in larger items.
    sort = torch.argsort(-x, 1)
    return torch.where(sort == targets[:, None])[1]

testing this with timeit:

%timeit argsort_solution(logits, targets)
117 ms ± 93.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

This takes about 117ms on my macbook. That sounds fine, right? But actually, when implementing this in my LM, this took up a substantial inference time. Think about it: if you want to do anything with these ranks, you need to perform this calculation for each token in your corpus. This means that, for a 1M token corpus, you will spend about 62 minutes on the calculation of the ranks alone.

As it turns out, almost all time is spent in the argsort function. So, in order to make a faster alternative, we should avoid the use of argsort.

There is a really easy way to solve this. First, assuming that we rank items from large to small, observe that the rank of an item is simply the number of items that are larger than the item we’re interested in. Given that we already know the item we’re interested in, i.e., we know the word the model is trying to predict, retrieving its value is trivial. So, instead of sorting and then finding our item in this list, we can just get the value of the item we’re trying to predict, and count how how many items get a larger value. This is exactly the rank of our predicted item. Implementing this in pytorch is trivial due to broadcast semantics, as shown below.

def get_rank(x, indices):
   vals = x[range(len(x)), indices]
   return (x > vals[:, None]).long().sum(1)

Short, right? But how much faster is this? As it turns out, it is a whole lot faster:

%timeit get_rank(logits, targets)
3.4 ms ± 84.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

It also gives the same answer as the conventional mechanism:

a = argsort_solution(logits, targets)
b = get_rank(logits, targets)
assert all(a == b)
print(a, b)

I must admit I was surprised when I saw this. Usually, using out of the box tensor functions is the best solution, and more efficient than anything you can come up with yourself. But in this specific case, it was much more efficient.