pivot_table 源码解析

所以为什么pivot_table会快呢?

首先concat在不断拼接的过程耗时确实是越来越长的,这个底层暂时没有了解

但是pivot_table是怎么做长表宽表的转换呢?

其实pivot_table是一个透视表,他之所有能做长宽表的数据转换还是有一定局限性的

源码:https://github.com/pandas-dev/pandas/blob/v1.0.4/pandas/core/reshape/pivot.py#L25-L186

from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union

import numpy as np

from pandas.util._decorators import Appender, Substitution

from pandas.core.dtypes.cast import maybe_downcast_to_dtype
from pandas.core.dtypes.common import is_integer_dtype, is_list_like, is_scalar
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries

import pandas.core.common as com
from pandas.core.frame import _shared_docs
from pandas.core.groupby import Grouper
#from pandas.core.indexes.api import Index, MultiIndex, get_objs_combined_axis
from pandas.core.reshape.concat import concat
from pandas.core.reshape.util import cartesian_product
from pandas.core.series import Series

if TYPE_CHECKING:
    from pandas import DataFrame

def pivot_table(
    data,
    values=None,
    index=None,
    columns=None,
    aggfunc="mean",
    fill_value=None,
    margins=False,
    dropna=True,
    margins_name="All",
    observed=False,
) -> "DataFrame":
    #index = _convert_by(index)
    #columns = _convert_by(columns)

    if isinstance(aggfunc, list):
        print("juge is instance")
        pieces: List[DataFrame] = []
        keys = []
        for func in aggfunc:
            table = pivot_table(
                data,
                values=values,
                index=index,
                columns=columns,
                fill_value=fill_value,
                aggfunc=func,
                margins=margins,
                dropna=dropna,
                margins_name=margins_name,
                observed=observed,
            )
            pieces.append(table)
            keys.append(getattr(func, "__name__", func))

        return concat(pieces, keys=keys, axis=1)

    keys = index + columns
    #print(keys)
    
    # 无论输入的是如果输入的是字符,转换成列表 ’values‘转换成['values']
    values_passed = values is not None
    print(values_passed)
    if values_passed:
        if is_list_like(values):
            values_multi = True
            values = list(values)
        else:
            values_multi = False
            values = [values]

        # GH14938 Make sure value labels are in data
        for i in values:
            if i not in data:
                raise KeyError(i)

        to_filter = []
        for x in keys + values:
            print(x)
            if isinstance(x, Grouper):
                x = x.key
            try:
                if x in data:
                    to_filter.append(x)
            except TypeError:
                pass
        if len(to_filter) < len(data.columns):
            #缩小数据
            data = data[to_filter]

    else:
        values = data.columns
        for key in keys:
            try:
                values = values.drop(key)
            except (TypeError, ValueError, KeyError):
                pass
        values = list(values)

    grouped = data.groupby(keys, observed=observed)
    agged = grouped.agg(aggfunc)
    #print(agged)
    if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
        agged = agged.dropna(how="all")
        #print(agged)

        # gh-21133
        # we want to down cast if
        # the original values are ints
        # as we grouped with a NaN value
        # and then dropped, coercing to floats
        for v in values:
#             print(v in data)
#             print(data[v])
#             print(is_integer_dtype(data[v]))
#             print(v in agged)
#             print(agged[v])
#             print(not is_integer_dtype(agged[v]))
            if (
                v in data
                and is_integer_dtype(data[v])
                and v in agged
                and not is_integer_dtype(agged[v])
            ):
                agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)
                #print(agged[v])

    table = agged
    print(table)
    #print(table.index.nlevels)
    #判断有几个索引
    if table.index.nlevels > 1:
        # Related GH #17123
        # If index_names are integers, determine whether the integers refer
        # to the level position or name.
        index_names = agged.index.names[: len(index)]
        #取出data中的index,后面的是columns
        #print(index_names)
        to_unstack = []
        for i in range(len(index), len(keys)):
            name = agged.index.names[i]
            if name is None or name in index_names:
                to_unstack.append(i)
            else:
                to_unstack.append(name)
        print(to_unstack)
        table = agged.unstack(to_unstack)
        print(table)

    if not dropna:
        if table.index.nlevels > 1:
            m = MultiIndex.from_arrays(
                cartesian_product(table.index.levels), names=table.index.names
            )
            table = table.reindex(m, axis=0)

        if table.columns.nlevels > 1:
            m = MultiIndex.from_arrays(
                cartesian_product(table.columns.levels), names=table.columns.names
            )
            table = table.reindex(m, axis=1)

    if isinstance(table, ABCDataFrame):
        table = table.sort_index(axis=1)

    if fill_value is not None:
        _table = table.fillna(fill_value, downcast="infer")
        assert _table is not None  # needed for mypy
        table = _table

    if margins:
        if dropna:
            data = data[data.notna().all(axis=1)]
        table = _add_margins(
            table,
            data,
            values,
            rows=index,
            cols=columns,
            aggfunc=aggfunc,
            observed=dropna,
            margins_name=margins_name,
            fill_value=fill_value,
        )

    # discard the top level
    if (
        values_passed
        and not values_multi
        and not table.empty
        and (table.columns.nlevels > 1)
    ):
        table = table[values[0]]

    if len(index) == 0 and len(columns) > 0:
        table = table.T

    # GH 15193 Make sure empty columns are removed if dropna=True
    if isinstance(table, ABCDataFrame) and dropna:
        table = table.dropna(how="all", axis=1)

    return table

代码是我自己调试的时候打印了中间过程一步一步看的

pivot_table的核心在于透视表,基于groupby 和 unstack来处理

所以整体上没有concat的过程,效率确实高于筛选再进行concat (我竟然没有想到用groupby unstack也是太蠢了)

但是pivot_table这套逻辑做长宽表转换有一定局限

数据的主键一定要对应一个值,如果有多个值或者多个值其中有空值都会导致groupby的聚合函数最后的结果不是我们想要的

存在多个值得时候可以自己灵活添加一个键让联合主键唯一也是能够利用pivot_table来处理的

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章