Numpy première occurrence de valeur supérieure à la valeur existante

145

J'ai un tableau 1D dans numpy et je veux trouver la position de l'index où une valeur dépasse la valeur dans le tableau numpy.

Par exemple

aa = range(-10,10)

Trouvez la position à l' aaendroit où la valeur 5est dépassée.

user308827
la source
2
Il faut savoir s'il n'y a pas de solution (puisque par exemple la réponse argmax ne fonctionnera pas dans ce cas (max de (0,0,0,0) = 0) comme l'a commenté
ambrus

Réponses:

201

C'est un peu plus rapide (et c'est plus joli)

np.argmax(aa>5)

Puisque argmaxs'arrêtera à la première True("En cas d'occurrences multiples des valeurs maximales, les indices correspondant à la première occurrence sont retournés.") Et n'enregistre pas une autre liste.

In [2]: N = 10000

In [3]: aa = np.arange(-N,N)

In [4]: timeit np.argmax(aa>N/2)
100000 loops, best of 3: 52.3 us per loop

In [5]: timeit np.where(aa>N/2)[0][0]
10000 loops, best of 3: 141 us per loop

In [6]: timeit np.nonzero(aa>N/2)[0][0]
10000 loops, best of 3: 142 us per loop
Askewchan
la source
103
Juste un mot d'avertissement: s'il n'y a pas de valeur True dans son tableau d'entrée, np.argmax renverra volontiers 0 (ce qui n'est pas ce que vous voulez dans ce cas).
ambrus
8
Les résultats sont corrects, mais je trouve l'explication un peu suspecte. argmaxne semble pas s'arrêter au premier True. (Cela peut être testé en créant des tableaux booléens avec un seul Trueà différentes positions.) La vitesse s'explique probablement par le fait qu'il argmaxn'est pas nécessaire de créer une liste de sortie.
DrV
1
Je pense que vous avez raison, @DrV. Mon explication visait à savoir pourquoi cela donne le résultat correct malgré l'intention originale de ne pas vraiment chercher un maximum, pas pourquoi il est plus rapide car je ne peux pas prétendre comprendre les détails intérieurs de argmax.
askewchan
1
@George, j'ai peur de ne pas savoir pourquoi exactement. Je peux seulement dire que c'est plus rapide dans l'exemple particulier que j'ai montré, donc je ne le considérerais généralement pas plus rapidement sans (i) savoir pourquoi (voir le commentaire de @ DrV) ou (ii) tester plus de cas (par exemple, si aaest trié, comme dans la réponse de @ Michael).
askewchan
3
@DrV, je viens de courir argmaxsur des tableaux booléens de 10 millions d'éléments avec un seul Trueà différentes positions en utilisant NumPy 1.11.2, et la position de l' élément Trueimporté. Donc la 1.11.2 argmaxsemble "court-circuiter" sur les tableaux booléens.
Ulrich Stern
96

étant donné le contenu trié de votre tableau, il existe une méthode encore plus rapide: searchsorted .

import time
N = 10000
aa = np.arange(-N,N)
%timeit np.searchsorted(aa, N/2)+1
%timeit np.argmax(aa>N/2)
%timeit np.where(aa>N/2)[0][0]
%timeit np.nonzero(aa>N/2)[0][0]

# Output
100000 loops, best of 3: 5.97 µs per loop
10000 loops, best of 3: 46.3 µs per loop
10000 loops, best of 3: 154 µs per loop
10000 loops, best of 3: 154 µs per loop
MichaelKaisers
la source
19
C'est vraiment la meilleure réponse en supposant que le tableau est trié (ce qui n'est pas réellement spécifié dans la question). Vous pouvez éviter les embarrassants +1avecnp.searchsorted(..., side='right')
askewchan
3
Je pense que l' sideargument ne fait une différence que s'il y a des valeurs répétées dans le tableau trié. Cela ne change pas la signification de l'index renvoyé, qui est toujours l'index dans lequel vous pouvez insérer la valeur de la requête, en déplaçant toutes les entrées suivantes vers la droite et en conservant un tableau trié.
Gus
@Gus, sidea un effet lorsque la même valeur est à la fois dans le tableau trié et inséré, quelles que soient les valeurs répétées dans l'un ou l'autre. Les valeurs répétées dans le tableau trié exagèrent simplement l'effet (la différence entre les côtés est le nombre de fois où la valeur insérée apparaît dans le tableau trié). side ne change le sens de l'indice retourné, même si elle ne change pas le tableau résultant de l' insertion des valeurs dans le tableau trié à ces indices. Une distinction subtile mais importante; en fait, cette réponse donne le mauvais index si ce N/2n'est pas le cas aa.
askewchan
Comme indiqué dans le commentaire ci-dessus, cette réponse est désactivée si elle N/2ne l'est pas aa. La forme correcte serait np.searchsorted(aa, N/2, side='right')(sans le +1). Dans le cas contraire, les deux formes donnent le même indice. Considérez le cas de test comme Nétant impair (et N/2.0pour forcer float si vous utilisez python 2).
askewchan
21

Cela m'intéressait également et j'ai comparé toutes les réponses suggérées avec perfplot . (Avertissement: je suis l'auteur de perfplot.)

Si vous savez que le tableau que vous regardez est déjà trié , alors

numpy.searchsorted(a, alpha)

est pour toi. C'est une opération à temps constant, c'est-à-dire que la vitesse ne dépend pas de la taille du tableau. Vous ne pouvez pas aller plus vite que cela.

Si vous ne savez rien de votre baie, vous ne vous trompez pas avec

numpy.argmax(a > alpha)

Déjà trié:

entrez la description de l'image ici

Non trié:

entrez la description de l'image ici

Code pour reproduire le tracé:

import numpy
import perfplot


alpha = 0.5

def argmax(data):
    return numpy.argmax(data > alpha)

def where(data):
    return numpy.where(data > alpha)[0][0]

def nonzero(data):
    return numpy.nonzero(data > alpha)[0][0]

def searchsorted(data):
    return numpy.searchsorted(data, alpha)

out = perfplot.show(
    # setup=numpy.random.rand,
    setup=lambda n: numpy.sort(numpy.random.rand(n)),
    kernels=[
        argmax, where,
        nonzero,
        searchsorted
        ],
    n_range=[2**k for k in range(2, 20)],
    logx=True,
    logy=True,
    xlabel='len(array)'
    )
Nico Schlömer
la source
4
np.searchsortedn'est pas à temps constant. C'est en fait O(log(n)). Mais votre scénario de test compare en fait le meilleur des cas searchsorted(qui est O(1)).
MSeifert
@MSeifert De quel type de tableau d'entrée / alpha avez-vous besoin pour voir O (log (n))?
Nico Schlömer
1
L'obtention de l'élément à l'index sqrt (longueur) a conduit à de très mauvaises performances. J'ai également écrit une réponse ici, y compris ce point de repère.
MSeifert
Je doute que searchsorted(ou n'importe quel algorithme) puisse battre la O(log(n))recherche binaire de données triées uniformément distribuées. EDIT: searchsorted est une recherche binaire.
Mateen Ulhaq
16
In [34]: a=np.arange(-10,10)

In [35]: a
Out[35]:
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
         3,   4,   5,   6,   7,   8,   9])

In [36]: np.where(a>5)
Out[36]: (array([16, 17, 18, 19]),)

In [37]: np.where(a>5)[0][0]
Out[37]: 16
Moj
la source
8

Tableaux qui ont un pas constant entre les éléments

Dans le cas d'un rangeou de tout autre tableau à croissance linéaire, vous pouvez simplement calculer l'index par programme, pas besoin d'itérer du tout sur le tableau:

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('no value greater than {}'.format(val))
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    # For linearly decreasing arrays or constant arrays we only need to check
    # the first element, because if that does not satisfy the condition
    # no other element will.
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

On pourrait probablement améliorer cela un peu. Je me suis assuré que cela fonctionnait correctement pour quelques exemples de tableaux et de valeurs, mais cela ne signifie pas qu'il ne pourrait pas y avoir d'erreurs, d'autant plus qu'il utilise des flottants ...

>>> import numpy as np
>>> first_index_calculate_range_like(5, np.arange(-10, 10))
16
>>> np.arange(-10, 10)[16]  # double check
6

>>> first_index_calculate_range_like(4.8, np.arange(-10, 10))
15

Étant donné qu'il peut calculer la position sans aucune itération, ce sera un temps constant ( O(1)) et peut probablement battre toutes les autres approches mentionnées. Cependant, cela nécessite une étape constante dans le tableau, sinon cela produira des résultats erronés.

Solution générale utilisant numba

Une approche plus générale consisterait à utiliser une fonction numba:

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

Cela fonctionnera pour n'importe quel tableau, mais il doit itérer sur le tableau, donc dans le cas moyen, ce sera O(n):

>>> first_index_numba(4.8, np.arange(-10, 10))
15
>>> first_index_numba(5, np.arange(-10, 10))
16

Référence

Même si Nico Schlömer a déjà fourni quelques points de repère, j'ai pensé qu'il pourrait être utile d'inclure mes nouvelles solutions et de tester différentes «valeurs».

La configuration du test:

import numpy as np
import math
import numba as nb

def first_index_using_argmax(val, arr):
    return np.argmax(arr > val)

def first_index_using_where(val, arr):
    return np.where(arr > val)[0][0]

def first_index_using_nonzero(val, arr):
    return np.nonzero(arr > val)[0][0]

def first_index_using_searchsorted(val, arr):
    return np.searchsorted(arr, val) + 1

def first_index_using_min(val, arr):
    return np.min(np.where(arr > val))

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('empty array')
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

funcs = [
    first_index_using_argmax, 
    first_index_using_min, 
    first_index_using_nonzero,
    first_index_calculate_range_like, 
    first_index_numba, 
    first_index_using_searchsorted, 
    first_index_using_where
]

from simple_benchmark import benchmark, MultiArgument

et les graphiques ont été générés en utilisant:

%matplotlib notebook
b.plot()

l'élément est au début

b = benchmark(
    funcs,
    {2**i: MultiArgument([0, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

entrez la description de l'image ici

La fonction numba fonctionne mieux, suivie de la fonction de calcul et de la fonction triée par recherche. Les autres solutions fonctionnent bien moins bien.

l'article est à la fin

b = benchmark(
    funcs,
    {2**i: MultiArgument([2**i-2, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

entrez la description de l'image ici

Pour les petits tableaux, la fonction numba est incroyablement rapide, mais pour les grands tableaux, elle est surperformée par la fonction de calcul et la fonction triée par recherche.

l'élément est à sqrt (len)

b = benchmark(
    funcs,
    {2**i: MultiArgument([np.sqrt(2**i), np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

entrez la description de l'image ici

C'est plus intéressant. Encore une fois, numba et la fonction de calcul fonctionnent très bien, mais cela déclenche en fait le pire des cas de recherche triée qui ne fonctionne vraiment pas bien dans ce cas.

Comparaison des fonctions lorsqu'aucune valeur ne satisfait la condition

Un autre point intéressant est le comportement de ces fonctions s'il n'y a pas de valeur dont l'index doit être retourné:

arr = np.ones(100)
value = 2

for func in funcs:
    print(func.__name__)
    try:
        print('-->', func(value, arr))
    except Exception as e:
        print('-->', e)

Avec ce résultat:

first_index_using_argmax
--> 0
first_index_using_min
--> zero-size array to reduction operation minimum which has no identity
first_index_using_nonzero
--> index 0 is out of bounds for axis 0 with size 0
first_index_calculate_range_like
--> no value greater than 2
first_index_numba
--> -1
first_index_using_searchsorted
--> 101
first_index_using_where
--> index 0 is out of bounds for axis 0 with size 0

Searchsorted, argmax et numba renvoient simplement une valeur incorrecte. Cependant searchsortedetnumba retournez un index qui n'est pas un index valide pour le tableau.

Les fonctions where, min, nonzeroet calculatejettent une exception. Cependant, seule l'exception pour calculatedit réellement quelque chose d'utile.

Cela signifie qu'il faut en fait encapsuler ces appels dans une fonction wrapper appropriée qui intercepte les exceptions ou les valeurs de retour non valides et les gère de manière appropriée, du moins si vous n'êtes pas sûr que la valeur puisse être dans le tableau.


Remarque: les searchsortedoptions de calcul et de calcul ne fonctionnent que dans des conditions spéciales. La fonction "calculer" nécessite une étape constante et la recherche triée nécessite que le tableau soit trié. Donc, ceux-ci pourraient être utiles dans les bonnes circonstances, mais ne sont pas des solutions générales à ce problème. Si vous avez affaire à triées listes Python vous pouvez jeter un oeil à la bissectrice module au lieu d'utiliser Numpys searchsorted.

MSeifert
la source
3

Je voudrais proposer

np.min(np.append(np.where(aa>5)[0],np.inf))

Cela retournera le plus petit index où la condition est remplie, tout en retournant l'infini si la condition n'est jamais remplie (et whererenvoie un tableau vide).

Mfeldt
la source
1

J'irais avec

i = np.min(np.where(V >= x))

Vest vector (tableau 1d), xest la valeur et iest l'index résultant.

sivic
la source