pivot_table 源碼解析

所以爲什麼pivot_table會快呢?

首先concat在不斷拼接的過程耗時確實是越來越長的,這個底層暫時沒有了解

但是pivot_table是怎麼做長表寬表的轉換呢?

其實pivot_table是一個透視表,他之所有能做長寬表的數據轉換還是有一定侷限性的

源碼:https://github.com/pandas-dev/pandas/blob/v1.0.4/pandas/core/reshape/pivot.py#L25-L186

from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union

import numpy as np

from pandas.util._decorators import Appender, Substitution

from pandas.core.dtypes.cast import maybe_downcast_to_dtype
from pandas.core.dtypes.common import is_integer_dtype, is_list_like, is_scalar
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries

import pandas.core.common as com
from pandas.core.frame import _shared_docs
from pandas.core.groupby import Grouper
#from pandas.core.indexes.api import Index, MultiIndex, get_objs_combined_axis
from pandas.core.reshape.concat import concat
from pandas.core.reshape.util import cartesian_product
from pandas.core.series import Series

if TYPE_CHECKING:
    from pandas import DataFrame

def pivot_table(
    data,
    values=None,
    index=None,
    columns=None,
    aggfunc="mean",
    fill_value=None,
    margins=False,
    dropna=True,
    margins_name="All",
    observed=False,
) -> "DataFrame":
    #index = _convert_by(index)
    #columns = _convert_by(columns)

    if isinstance(aggfunc, list):
        print("juge is instance")
        pieces: List[DataFrame] = []
        keys = []
        for func in aggfunc:
            table = pivot_table(
                data,
                values=values,
                index=index,
                columns=columns,
                fill_value=fill_value,
                aggfunc=func,
                margins=margins,
                dropna=dropna,
                margins_name=margins_name,
                observed=observed,
            )
            pieces.append(table)
            keys.append(getattr(func, "__name__", func))

        return concat(pieces, keys=keys, axis=1)

    keys = index + columns
    #print(keys)
    
    # 無論輸入的是如果輸入的是字符,轉換成列表 ’values‘轉換成['values']
    values_passed = values is not None
    print(values_passed)
    if values_passed:
        if is_list_like(values):
            values_multi = True
            values = list(values)
        else:
            values_multi = False
            values = [values]

        # GH14938 Make sure value labels are in data
        for i in values:
            if i not in data:
                raise KeyError(i)

        to_filter = []
        for x in keys + values:
            print(x)
            if isinstance(x, Grouper):
                x = x.key
            try:
                if x in data:
                    to_filter.append(x)
            except TypeError:
                pass
        if len(to_filter) < len(data.columns):
            #縮小數據
            data = data[to_filter]

    else:
        values = data.columns
        for key in keys:
            try:
                values = values.drop(key)
            except (TypeError, ValueError, KeyError):
                pass
        values = list(values)

    grouped = data.groupby(keys, observed=observed)
    agged = grouped.agg(aggfunc)
    #print(agged)
    if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
        agged = agged.dropna(how="all")
        #print(agged)

        # gh-21133
        # we want to down cast if
        # the original values are ints
        # as we grouped with a NaN value
        # and then dropped, coercing to floats
        for v in values:
#             print(v in data)
#             print(data[v])
#             print(is_integer_dtype(data[v]))
#             print(v in agged)
#             print(agged[v])
#             print(not is_integer_dtype(agged[v]))
            if (
                v in data
                and is_integer_dtype(data[v])
                and v in agged
                and not is_integer_dtype(agged[v])
            ):
                agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)
                #print(agged[v])

    table = agged
    print(table)
    #print(table.index.nlevels)
    #判斷有幾個索引
    if table.index.nlevels > 1:
        # Related GH #17123
        # If index_names are integers, determine whether the integers refer
        # to the level position or name.
        index_names = agged.index.names[: len(index)]
        #取出data中的index,後面的是columns
        #print(index_names)
        to_unstack = []
        for i in range(len(index), len(keys)):
            name = agged.index.names[i]
            if name is None or name in index_names:
                to_unstack.append(i)
            else:
                to_unstack.append(name)
        print(to_unstack)
        table = agged.unstack(to_unstack)
        print(table)

    if not dropna:
        if table.index.nlevels > 1:
            m = MultiIndex.from_arrays(
                cartesian_product(table.index.levels), names=table.index.names
            )
            table = table.reindex(m, axis=0)

        if table.columns.nlevels > 1:
            m = MultiIndex.from_arrays(
                cartesian_product(table.columns.levels), names=table.columns.names
            )
            table = table.reindex(m, axis=1)

    if isinstance(table, ABCDataFrame):
        table = table.sort_index(axis=1)

    if fill_value is not None:
        _table = table.fillna(fill_value, downcast="infer")
        assert _table is not None  # needed for mypy
        table = _table

    if margins:
        if dropna:
            data = data[data.notna().all(axis=1)]
        table = _add_margins(
            table,
            data,
            values,
            rows=index,
            cols=columns,
            aggfunc=aggfunc,
            observed=dropna,
            margins_name=margins_name,
            fill_value=fill_value,
        )

    # discard the top level
    if (
        values_passed
        and not values_multi
        and not table.empty
        and (table.columns.nlevels > 1)
    ):
        table = table[values[0]]

    if len(index) == 0 and len(columns) > 0:
        table = table.T

    # GH 15193 Make sure empty columns are removed if dropna=True
    if isinstance(table, ABCDataFrame) and dropna:
        table = table.dropna(how="all", axis=1)

    return table

代碼是我自己調試的時候打印了中間過程一步一步看的

pivot_table的核心在於透視表,基於groupby 和 unstack來處理

所以整體上沒有concat的過程,效率確實高於篩選再進行concat (我竟然沒有想到用groupby unstack也是太蠢了)

但是pivot_table這套邏輯做長寬錶轉換有一定侷限

數據的主鍵一定要對應一個值,如果有多個值或者多個值其中有空值都會導致groupby的聚合函數最後的結果不是我們想要的

存在多個值得時候可以自己靈活添加一個鍵讓聯合主鍵唯一也是能夠利用pivot_table來處理的

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章