numpy provides an indirect way to compute the indices of the smallest (or largest) values of an array using: numpy.argpartition.
There is also a proposal to provide a higher level API, namely (arg)topk in numpy:
This PR relies on numpy.argpartition internally but it can probably later be optimized to avoid allocating a result array of the size of the input array when k is small.
Here is a quick review of some available implementations in related libraries:
- torch.topk (no such thing as
torch.argpartition)
- returns a tuple of values and indices
- jax.lax.top_k
- returns a tuple of values and indices
- apparently it is quite slow on GPU
- dask.array.topk
- returns only the values, I did not find a way to get the indices :(
- cupy.argpartition but internally computes a full
cupy.argsort which makes it very inefficient for large arrays and small k: O(nlog(n)) instead of O(n).
Motivation: (arg)topk is needed by popular baseline data-science workloads (e.g. k-nearest neighbors classification in scikit-learn) and is surprisingly non trivial to implement efficiently. For instance on GPUs, the fastest implementations are based on some kind of partial radix sort while CPU implementations would use more traditional partial sorting algorithms (as implemented in std:partial_sort or std::nth_element).
numpy provides an indirect way to compute the indices of the smallest (or largest) values of an array using: numpy.argpartition.
There is also a proposal to provide a higher level API, namely (arg)topk in numpy:
This PR relies on
numpy.argpartitioninternally but it can probably later be optimized to avoid allocating a result array of the size of the input array whenkis small.Here is a quick review of some available implementations in related libraries:
torch.argpartition)cupy.argsortwhich makes it very inefficient for large arrays and smallk: O(nlog(n)) instead of O(n).Motivation: (arg)topk is needed by popular baseline data-science workloads (e.g. k-nearest neighbors classification in scikit-learn) and is surprisingly non trivial to implement efficiently. For instance on GPUs, the fastest implementations are based on some kind of partial radix sort while CPU implementations would use more traditional partial sorting algorithms (as implemented in
std:partial_sortorstd::nth_element).