Setter de propriété pour la sous-classe de Pandas DataFrame

9

J'essaie de mettre en place une sous-classe pd.DataFramequi a deux arguments requis lors de l'initialisation ( groupet timestamp_col). Je veux exécuter la validation sur ces arguments groupet timestamp_colj'ai donc une méthode de définition pour chacune des propriétés. Tout cela fonctionne jusqu'à ce que j'essaye de le set_index()faire TypeError: 'NoneType' object is not iterable. Il semble qu'aucun argument ne soit transmis à ma fonction de définition dans test_set_indexet test_assignment_with_indexed_obj. Si j'ajoute if g == None: returnà ma fonction setter, je peux passer les cas de test mais je ne pense pas que ce soit la bonne solution.

Comment dois-je implémenter la validation des propriétés pour ces arguments requis?

Voici ma classe:

import pandas as pd
import numpy as np


class HistDollarGains(pd.DataFrame):
    @property
    def _constructor(self):
        return HistDollarGains._internal_ctor

    _metadata = ["group", "timestamp_col", "_group", "_timestamp_col"]

    @classmethod
    def _internal_ctor(cls, *args, **kwargs):
        kwargs["group"] = None
        kwargs["timestamp_col"] = None
        return cls(*args, **kwargs)

    def __init__(
        self,
        data,
        group,
        timestamp_col,
        index=None,
        columns=None,
        dtype=None,
        copy=True,
    ):
        super(HistDollarGains, self).__init__(
            data=data, index=index, columns=columns, dtype=dtype, copy=copy
        )

        self.group = group
        self.timestamp_col = timestamp_col

    @property
    def group(self):
        return self._group

    @group.setter
    def group(self, g):
        if g == None:
            return

        if isinstance(g, str):
            group_list = [g]
        else:
            group_list = g

        if not set(group_list).issubset(self.columns):
            raise ValueError("Data does not contain " + '[' + ', '.join(group_list) + ']')
        self._group = group_list

    @property
    def timestamp_col(self):
        return self._timestamp_col

    @timestamp_col.setter
    def timestamp_col(self, t):
        if t == None:
            return
        if not t in self.columns:
            raise ValueError("Data does not contain " + '[' + t + ']')
        self._timestamp_col = t

Voici mes cas de test:

import pytest

import pandas as pd
import numpy as np

from myclass import *


@pytest.fixture(scope="module")
def sample():
    samp = pd.DataFrame(
        [
            {"timestamp": "2020-01-01", "group": "a", "dollar_gains": 100},
            {"timestamp": "2020-01-01", "group": "b", "dollar_gains": 100},
            {"timestamp": "2020-01-01", "group": "c", "dollar_gains": 110},
            {"timestamp": "2020-01-01", "group": "a", "dollar_gains": 110},
            {"timestamp": "2020-01-01", "group": "b", "dollar_gains": 90},
            {"timestamp": "2020-01-01", "group": "d", "dollar_gains": 100},
        ]
    )

    return samp

@pytest.fixture(scope="module")
def sample_obj(sample):
    return HistDollarGains(sample, "group", "timestamp")

def test_constructor_without_args(sample):
    with pytest.raises(TypeError):
        HistDollarGains(sample)


def test_constructor_with_string_group(sample):
    hist_dg = HistDollarGains(sample, "group", "timestamp")
    assert hist_dg.group == ["group"]
    assert hist_dg.timestamp_col == "timestamp"


def test_constructor_with_list_group(sample):
    hist_dg = HistDollarGains(sample, ["group", "timestamp"], "timestamp")

def test_constructor_with_invalid_group(sample):
    with pytest.raises(ValueError):
        HistDollarGains(sample, "invalid_group", np.random.choice(sample.columns))

def test_constructor_with_invalid_timestamp(sample):
    with pytest.raises(ValueError):
        HistDollarGains(sample, np.random.choice(sample.columns), "invalid_timestamp")

def test_assignment_with_indexed_obj(sample_obj):
    b = sample_obj.set_index(sample_obj.group + [sample_obj.timestamp_col])

def test_set_index(sample_obj):
    # print(isinstance(a, pd.DataFrame))
    assert sample_obj.set_index(sample_obj.group + [sample_obj.timestamp_col]).index.names == ['group', 'timestamp']
cpage
la source
1
Si la Nonevaleur de la grouppropriété n'est pas valide , ne devriez-vous pas lever un ValueError?
chepner
1
Vous avez raison, il Nones'agit d'une valeur non valide, c'est pourquoi je n'aime pas l'instruction if. Mais en ajoutant que None, il réussit les tests. Je cherche comment résoudre ce problème sans l'instruction None if.
cpage
2
Le passeur doit relever a ValueError. Le problème est de comprendre ce qui essaie de définir l' groupattribut Noneen premier lieu.
chepner
@chepner oui, exactement.
cpage
Peut-être que le package Pandas Flavor peut vous aider.
Mykola Zotko le

Réponses:

3

La set_index()méthode appellera en self.copy()interne pour créer une copie de votre objet DataFrame (voir le code source ici ), à l'intérieur duquel elle utilise votre méthode constructeur personnalisée _internal_ctor(), pour créer le nouvel objet ( source ). Notez qu'il self._constructor()est identique à self._internal_ctor(), qui est une méthode interne commune à presque toutes les classes pandas pour créer de nouvelles instances lors d'opérations telles que la copie profonde ou le découpage. Votre problème provient en fait de cette fonction:

class HistDollarGains(pd.DataFrame):
    ...
    @classmethod
    def _internal_ctor(cls, *args, **kwargs):
        kwargs["group"]         = None
        kwargs["timestamp_col"] = None
        return cls(*args, **kwargs) # this is equivalent to calling
                                    # HistDollarGains(data, group=None, timestamp_col=None)

Je suppose que vous avez copié ce code à partir du problème github . Les lignes kwargs["**"] = Noneindiquent explicitement au constructeur de définir Noneà la fois groupet timestamp_col. Enfin, le setter / validateur obtient Nonela nouvelle valeur et déclenche une erreur.

Par conséquent, vous devez définir une valeur acceptable sur groupet timestamp_col.

    @classmethod
    def _internal_ctor(cls, *args, **kwargs):
        kwargs["group"]         = []
        kwargs["timestamp_col"] = 'timestamp' # or whatever name that makes your validator happy
        return cls(*args, **kwargs)

Ensuite, vous pouvez supprimer les if g == None: returnlignes dans le validateur.

gdlmx
la source