import numpy as np
import torch
Sampling words by frequency
We want to sample words from a vocabulary with a probability that is proportional to their counts (absolute frequencies) in some given text. That is, if we have two words \(w_1\) and \(w_2\), where \(w_2\) appears \(k\) times as often as \(w_1\), then the expected number of times we sample \(w_1\) should be \(k\) times higher than the expected number of times we sample \(w_2\).
Sampling recipe
Imagine all the words in the vocabulary covering a line marked with numbers between 0 and the sum of all word frequencies, where each word covers an interval whose size equals its frequency. To sample a word, we choose a random point on that line, and return that word whose interval includes this chosen point. In doing so, we will sample words with a probability that is proportional to its frequency.
Example
We illustrate the sampling recipe with a concrete example.
Here is a list of counts for words in a ten-word vocabulary:
= np.array([14507, 5014, 4602, 4529, 4000, 3219, 3010, 2958, 2225, 1271]) counts
To implement the sampling recipe, we need the cumulative sums of these counts. We can get them with the function torch.cumsum()
.
= torch.cumsum(torch.from_numpy(counts), dim=0)
cumulative_sums cumulative_sums
tensor([14507, 19521, 24123, 28652, 32652, 35871, 38881, 41839, 44064, 45335])
To choose a random point on the counts line, we sample a random number between 0 and 1 and multiply it with the sum of all counts, which is the last entry in the list of cumulative sums. Here we choose \(5\) such points.
= torch.rand(5) * cumulative_sums[-1]
random_points random_points
tensor([ 4442.1890, 10967.1973, 7999.8721, 9633.9512, 44469.0352])
To return the word whose interval on the counts line includes a chosen point, we use the function torch.searchsorted()
. This function takes a sorted sequence and tensor of values and finds the indices from the sorted sequence such that, if the corresponding values were inserted before the indices, the order of the corresponding dimension within the sorted sequence would be preserved.
torch.searchsorted(cumulative_sums, random_points)
tensor([0, 0, 0, 0, 9])
Good luck with the lab!