Vérification rapide de NaN dans NumPy

120

Je recherche le moyen le plus rapide de vérifier l'occurrence de NaN ( np.nan) dans un tableau NumPy X. np.isnan(X)est hors de question, car il construit un tableau booléen de forme X.shape, ce qui est potentiellement gigantesque.

J'ai essayé np.nan in X, mais cela ne semble pas fonctionner parce que np.nan != np.nan. Existe-t-il un moyen rapide et efficace de faire cela?

(À ceux qui demanderaient «à quel point c'est gigantesque»: je ne peux pas le dire. C'est la validation d'entrée pour le code de la bibliothèque.)

Fred Foo
la source
la validation de l'entrée utilisateur ne fonctionne-t-elle pas dans ce scénario? Comme pour NaN avant l'insertion
Woot4Moo
@ Woot4Moo: non, la bibliothèque prend scipy.sparseen entrée des tableaux ou des matrices NumPy .
Fred Foo
2
Si vous faites beaucoup cela, j'ai entendu de bonnes choses à propos de Bottleneck ( pypi.python.org/pypi/Bottleneck )
mat

Réponses:

161

La solution de Ray est bonne. Cependant, sur ma machine, il est environ 2,5 fois plus rapide à utiliser numpy.sumà la place de numpy.min:

In [13]: %timeit np.isnan(np.min(x))
1000 loops, best of 3: 244 us per loop

In [14]: %timeit np.isnan(np.sum(x))
10000 loops, best of 3: 97.3 us per loop

Contrairement à min, sumne nécessite pas de branchement, ce qui sur le matériel moderne a tendance à être assez coûteux. C'est probablement la raison pour laquelle sumc'est plus rapide.

edit Le test ci-dessus a été effectué avec un seul NaN en plein milieu du tableau.

Il est intéressant de noter que minc'est plus lent en présence de NaN qu'en leur absence. Il semble également devenir plus lent à mesure que les NaN se rapprochent du début du tableau. D'un autre côté, sumle débit semble constant, qu'il y ait ou non des NaN et où ils se trouvent:

In [40]: x = np.random.rand(100000)

In [41]: %timeit np.isnan(np.min(x))
10000 loops, best of 3: 153 us per loop

In [42]: %timeit np.isnan(np.sum(x))
10000 loops, best of 3: 95.9 us per loop

In [43]: x[50000] = np.nan

In [44]: %timeit np.isnan(np.min(x))
1000 loops, best of 3: 239 us per loop

In [45]: %timeit np.isnan(np.sum(x))
10000 loops, best of 3: 95.8 us per loop

In [46]: x[0] = np.nan

In [47]: %timeit np.isnan(np.min(x))
1000 loops, best of 3: 326 us per loop

In [48]: %timeit np.isnan(np.sum(x))
10000 loops, best of 3: 95.9 us per loop
NPE
la source
1
np.minest plus rapide lorsque le tableau ne contient pas de NaN, ce qui est mon entrée attendue. Mais j'ai quand même décidé d'accepter celui-ci, car il attrape infet neginfaussi.
Fred Foo
2
Cela ne les prises infou -infsi l'entrée contient à la fois, et il a des problèmes si l'entrée contient de grandes valeurs , mais finies qui débordent lorsqu'ils sont additionnés.
user2357112 prend en charge Monica
4
min et max n'ont pas besoin de créer des branches pour les données en virgule flottante sur les puces x86 compatibles sse. Donc, à partir de numpy 1,8 min ne sera pas plus lent que la somme, sur mon phénomène amd, c'est encore 20% plus rapide.
jtaylor le
1
Sur mon Intel Core i5, avec numpy 1.9.2 sur OSX, np.sumest toujours environ 30% plus rapide que np.min.
Matthew Brett
np.isnan(x).any(0)est légèrement plus rapide que np.sumet np.minsur ma machine, bien qu'il puisse y avoir une mise en cache indésirable.
jsignell
28

Je pense np.isnan(np.min(X))que tu devrais faire ce que tu veux.

Rayon
la source
Hmmm ... c'est toujours O (n) alors que ça pourrait être O (1) (pour certains tableaux).
user48956
17

Même s'il existe une réponse acceptée, j'aimerais démontrer ce qui suit (avec Python 2.7.2 et Numpy 1.6.0 sur Vista):

In []: x= rand(1e5)
In []: %timeit isnan(x.min())
10000 loops, best of 3: 200 us per loop
In []: %timeit isnan(x.sum())
10000 loops, best of 3: 169 us per loop
In []: %timeit isnan(dot(x, x))
10000 loops, best of 3: 134 us per loop

In []: x[5e4]= NaN
In []: %timeit isnan(x.min())
100 loops, best of 3: 4.47 ms per loop
In []: %timeit isnan(x.sum())
100 loops, best of 3: 6.44 ms per loop
In []: %timeit isnan(dot(x, x))
10000 loops, best of 3: 138 us per loop

Ainsi, la manière vraiment efficace pourrait dépendre fortement du système d'exploitation. Quoi qu'il en soit, dot(.)basé semble être le plus stable.

manger
la source
1
Je soupçonne que cela ne dépend pas tant du système d'exploitation que de l'implémentation BLAS sous-jacente et du compilateur C. Merci, mais un produit scalaire est juste un peu plus susceptible de déborder lorsqu'il xcontient de grandes valeurs, et je souhaite également vérifier les inf.
Fred Foo
1
Eh bien, vous pouvez toujours faire le produit scalaire avec ceux et les utiliser isfinite(.). Je voulais juste souligner l'énorme écart de performance. Merci
manger le
La même chose sur ma machine.
kawing-chiu
1
Intelligent, non? Comme le suggère Fred Foo , les gains d'efficacité de l'approche basée sur le produit scalaire sont presque certainement dus à une installation NumPy locale liée à une implémentation BLAS optimisée comme ATLAS, MKL ou OpenBLAS. C'est le cas d'Anaconda, par exemple. Compte tenu de cela, ce produit scalaire sera parallélisé sur tous les cœurs disponibles. La même chose ne peut pas en dire autant pour les min- ou les sumapproches à base, qui vont limités à un seul noyau. Ergo, cet écart de performance.
Cecil Curry
16

Il existe deux approches générales ici:

  • Vérifiez chaque élément du tableau nanet prenez any.
  • Appliquez une opération cumulative qui préserve nans (comme sum) et vérifiez son résultat.

Alors que la première approche est certainement la plus propre, l'optimisation lourde de certaines des opérations cumulatives (en particulier celles qui sont exécutées dans BLAS, par exemple dot) peut les rendre assez rapides. Notez que dot, comme certaines autres opérations BLAS, sont multithread sous certaines conditions. Ceci explique la différence de vitesse entre les différentes machines.

entrez la description de l'image ici

import numpy
import perfplot


def min(a):
    return numpy.isnan(numpy.min(a))


def sum(a):
    return numpy.isnan(numpy.sum(a))


def dot(a):
    return numpy.isnan(numpy.dot(a, a))


def any(a):
    return numpy.any(numpy.isnan(a))


def einsum(a):
    return numpy.isnan(numpy.einsum("i->", a))


perfplot.show(
    setup=lambda n: numpy.random.rand(n),
    kernels=[min, sum, dot, any, einsum],
    n_range=[2 ** k for k in range(20)],
    logx=True,
    logy=True,
    xlabel="len(a)",
)
Nico Schlömer
la source
4
  1. utiliser .any ()

    if numpy.isnan(myarray).any()

  2. numpy.isfinite peut-être mieux que isnan pour vérifier

    if not np.isfinite(prop).all()

woso
la source
3

Si vous êtes à l'aise avec il permet de créer une fonction de court-circuit rapide (s'arrête dès qu'un NaN est trouvé):

import numba as nb
import math

@nb.njit
def anynan(array):
    array = array.ravel()
    for i in range(array.size):
        if math.isnan(array[i]):
            return True
    return False

S'il n'y a pas, NaNla fonction pourrait en fait être plus lente que np.min, je pense que c'est parce que np.minutilise le multitraitement pour les grands tableaux:

import numpy as np
array = np.random.random(2000000)

%timeit anynan(array)          # 100 loops, best of 3: 2.21 ms per loop
%timeit np.isnan(array.sum())  # 100 loops, best of 3: 4.45 ms per loop
%timeit np.isnan(array.min())  # 1000 loops, best of 3: 1.64 ms per loop

Mais au cas où il y aurait un NaN dans le tableau, surtout si sa position est à des indices bas, alors c'est beaucoup plus rapide:

array = np.random.random(2000000)
array[100] = np.nan

%timeit anynan(array)          # 1000000 loops, best of 3: 1.93 µs per loop
%timeit np.isnan(array.sum())  # 100 loops, best of 3: 4.57 ms per loop
%timeit np.isnan(array.min())  # 1000 loops, best of 3: 1.65 ms per loop

Des résultats similaires peuvent être obtenus avec Cython ou une extension C, ceux-ci sont un peu plus compliqués (ou facilement disponibles en tant que bottleneck.anynan) mais font finalement la même chose que ma anynanfonction.

MSeifert
la source
1

Liée à ceci est la question de savoir comment trouver la première occurrence de NaN. C'est le moyen le plus rapide de gérer ce que je connais:

index = next((i for (i,n) in enumerate(iterable) if n!=n), None)
vitiral
la source