Convertir la colonne Spark DataFrame en liste Python

104

Je travaille sur un dataframe avec deux colonnes, mvv et count.

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

Je voudrais obtenir deux listes contenant les valeurs mvv et la valeur de comptage. Quelque chose comme

mvv = [1,2,3,4]
count = [5,9,3,1]

Donc, j'ai essayé le code suivant: La première ligne devrait renvoyer une liste de lignes python. Je voulais voir la première valeur:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

Mais j'obtiens un message d'erreur avec la deuxième ligne:

AttributeError: getInt

a.moussa
la source
Comme Spark 2.3, ce code est le plus rapide et moins susceptible de causer des exceptions OutOfMemory: list(df.select('mvv').toPandas()['mvv']). Arrow a été intégré à PySpark, ce qui a considérablement accéléré toPandas. N'utilisez pas les autres approches si vous utilisez Spark 2.3+. Voir ma réponse pour plus de détails sur l'analyse comparative.
Pouvoirs le

Réponses:

140

Voyez, pourquoi cette façon que vous faites ne fonctionne pas. Tout d'abord, vous essayez d'obtenir un entier à partir d'un type de ligne , la sortie de votre collecte est comme ceci:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

Si vous prenez quelque chose comme ça:

>>> firstvalue = mvv_list[0].mvv
Out: 1

Vous obtiendrez la mvvvaleur. Si vous voulez toutes les informations du tableau, vous pouvez prendre quelque chose comme ceci:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

Mais si vous essayez la même chose pour l'autre colonne, vous obtenez:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

Cela se produit car il counts'agit d'une méthode intégrée. Et la colonne porte le même nom que count. Une solution de contournement pour ce faire est de modifier le nom de la colonne de counten _count:

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

Mais cette solution de contournement n'est pas nécessaire, car vous pouvez accéder à la colonne à l'aide de la syntaxe du dictionnaire:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

Et cela fonctionnera enfin!

Thiago Baldim
la source
cela fonctionne très bien pour la première colonne, mais cela ne fonctionne pas pour le nombre de colonnes que je pense à cause de (le nombre de fonctions d'étincelle)
a.moussa
Pouvez-vous ajouter ce que vous faites avec le décompte? Ajoutez ici dans les commentaires.
Thiago Baldim
merci pour votre réponse Donc cette ligne fonctionne mvv_list = [int (i.mvv) pour i dans mvv_count.select ('mvv'). collect ()] mais pas celle-ci count_list = [int (i.count) pour i dans mvv_count .select ('count'). collect ()] renvoie une syntaxe invalide
a.moussa
Vous n'avez pas besoin d'ajouter cette select('count')utilisation comme ceci: count_list = [int(i.count) for i in mvv_list.collect()]j'ajouterai l'exemple à la réponse.
Thiago Baldim
1
@ a.moussa [i.['count'] for i in mvv_list.collect()]travaille pour rendre explicite l'utilisation de la colonne nommée 'count' et non la countfonction
user989762
103

Suivre une ligne donne la liste que vous voulez.

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()
Néo
la source
3
En termes de performances, cette solution est beaucoup plus rapide que votre solution mvv_list = [int (i.mvv) for i in mvv_count.select ('mvv'). Collect ()]
Chanaka Fernando
C'est de loin la meilleure solution que j'ai vue. Merci.
hui chen
22

Cela vous donnera tous les éléments sous forme de liste.

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)
Muhammad Raihan Muhaimin
la source
1
Il s'agit de la solution la plus rapide et la plus efficace pour Spark 2.3+. Voir les résultats de l'analyse comparative dans ma réponse.
Pouvoirs le
16

Le code suivant vous aidera

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()
Itachi
la source
3
Cela devrait être la réponse acceptée. la raison en est que vous restez dans un contexte d'étincelle tout au long du processus et que vous collectez à la fin plutôt que de sortir du contexte d'étincelle plus tôt, ce qui peut entraîner une collecte plus importante en fonction de ce que vous faites.
AntiPawn79
15

Sur mes données, j'ai ces points de repère:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0,52 seconde

>>> [row[col] for row in data.collect()]

0,271 seconde

>>> list(data.select(col).toPandas()[col])

0,427 seconde

le résultat est le même

lumineux
la source
1
Si vous utilisez à la toLocalIteratorplace, collectcela devrait même être plus efficace en mémoire[row[col] for row in data.toLocalIterator()]
oglop
6

Si vous obtenez l'erreur ci-dessous:

AttributeError: l'objet 'list' n'a pas d'attribut 'collect'

Ce code résoudra vos problèmes:

mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]
anirban sen
la source
J'ai eu cette erreur aussi et cette solution a résolu le problème. Mais pourquoi ai-je eu l'erreur? (Beaucoup d'autres ne semblent pas comprendre cela!)
bikashg
2

J'ai effectué une analyse comparative et list(mvv_count_df.select('mvv').toPandas()['mvv'])c'est la méthode la plus rapide. Je suis très surpris.

J'ai exécuté les différentes approches sur 100 mille / 100 millions d'ensembles de données de lignes en utilisant un cluster i3.xlarge à 5 nœuds (chaque nœud a 30,5 Go de RAM et 4 cœurs) avec Spark 2.4.5. Les données ont été uniformément réparties sur 20 fichiers Parquet compressés avec une seule colonne.

Voici les résultats de l'analyse comparative (durées d'exécution en secondes):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

Règles d'or à suivre lors de la collecte de données sur le nœud du pilote:

  • Essayez de résoudre le problème avec d'autres approches. La collecte de données vers le nœud de pilote est coûteuse, n'utilise pas la puissance du cluster Spark et doit être évitée autant que possible.
  • Collectez le moins de lignes possible. Agréger, dédupliquer, filtrer et élaguer les colonnes avant de collecter les données. Envoyez le moins de données possible au nœud du pilote.

toPandas a été considérablement amélioré dans Spark 2.3 . Ce n'est probablement pas la meilleure approche si vous utilisez une version Spark antérieure à 2.3.

Voir ici pour plus de détails / résultats d'analyse comparative.

Pouvoirs
la source
2

Une solution possible consiste à utiliser la collect_list()fonction de pyspark.sql.functions. Cela agrégera toutes les valeurs de colonne dans un tableau pyspark qui est converti en une liste python lors de la collecte:

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 
phgui
la source