1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
|
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, Iterable, Iterator, Sequence, TypeVar
from narwhals._expression_parsing import all_exprs_are_scalar_like
from narwhals._utils import flatten, tupleify
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import DataFrameT
if TYPE_CHECKING:
from narwhals._compliant.typing import CompliantExprAny
from narwhals.dataframe import LazyFrame
from narwhals.expr import Expr
LazyFrameT = TypeVar("LazyFrameT", bound="LazyFrame[Any]")
class GroupBy(Generic[DataFrameT]):
def __init__(
self,
df: DataFrameT,
keys: Sequence[str] | Sequence[CompliantExprAny],
/,
*,
drop_null_keys: bool,
) -> None:
self._df: DataFrameT = df
self._keys = keys
self._grouped = self._df._compliant_frame.group_by(
self._keys, drop_null_keys=drop_null_keys
)
def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT:
"""Compute aggregations for each group of a group by operation.
Arguments:
aggs: Aggregations to compute for each group of the group by operation,
specified as positional arguments.
named_aggs: Additional aggregations, specified as keyword arguments.
Returns:
A new Dataframe.
Examples:
Group by one column or by multiple columns and call `agg` to compute
the grouped sum of another column.
>>> import pandas as pd
>>> import narwhals as nw
>>> df_native = pd.DataFrame(
... {
... "a": ["a", "b", "a", "b", "c"],
... "b": [1, 2, 1, 3, 3],
... "c": [5, 4, 3, 2, 1],
... }
... )
>>> df = nw.from_native(df_native)
>>>
>>> df.group_by("a").agg(nw.col("b").sum()).sort("a")
┌──────────────────┐
|Narwhals DataFrame|
|------------------|
| a b |
| 0 a 2 |
| 1 b 5 |
| 2 c 3 |
└──────────────────┘
>>>
>>> df.group_by("a", "b").agg(nw.col("c").sum()).sort("a", "b").to_native()
a b c
0 a 1 8
1 b 2 4
2 b 3 2
3 c 3 1
"""
flat_aggs = tuple(flatten(aggs))
if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
msg = (
"Found expression which does not aggregate.\n\n"
"All expressions passed to GroupBy.agg must aggregate.\n"
"For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
"but `df.group_by('a').agg(nw.col('b'))` is not."
)
raise InvalidOperationError(msg)
plx = self._df.__narwhals_namespace__()
compliant_aggs = (
*(x._to_compliant_expr(plx) for x in flat_aggs),
*(
value.alias(key)._to_compliant_expr(plx)
for key, value in named_aggs.items()
),
)
return self._df._with_compliant(self._grouped.agg(*compliant_aggs))
def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]:
yield from (
(tupleify(key), self._df._with_compliant(df))
for (key, df) in self._grouped.__iter__()
)
class LazyGroupBy(Generic[LazyFrameT]):
def __init__(
self,
df: LazyFrameT,
keys: Sequence[str] | Sequence[CompliantExprAny],
/,
*,
drop_null_keys: bool,
) -> None:
self._df: LazyFrameT = df
self._keys = keys
self._grouped = self._df._compliant_frame.group_by(
self._keys, drop_null_keys=drop_null_keys
)
def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT:
"""Compute aggregations for each group of a group by operation.
Arguments:
aggs: Aggregations to compute for each group of the group by operation,
specified as positional arguments.
named_aggs: Additional aggregations, specified as keyword arguments.
Returns:
A new LazyFrame.
Examples:
Group by one column or by multiple columns and call `agg` to compute
the grouped sum of another column.
>>> import polars as pl
>>> import narwhals as nw
>>> from narwhals.typing import IntoFrameT
>>> lf_native = pl.LazyFrame(
... {
... "a": ["a", "b", "a", "b", "c"],
... "b": [1, 2, 1, 3, 3],
... "c": [5, 4, 3, 2, 1],
... }
... )
>>> lf = nw.from_native(lf_native)
>>>
>>> nw.to_native(lf.group_by("a").agg(nw.col("b").sum()).sort("a")).collect()
shape: (3, 2)
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ a ┆ 2 │
│ b ┆ 5 │
│ c ┆ 3 │
└─────┴─────┘
>>>
>>> lf.group_by("a", "b").agg(nw.sum("c")).sort("a", "b").collect()
┌───────────────────┐
|Narwhals DataFrame |
|-------------------|
|shape: (4, 3) |
|┌─────┬─────┬─────┐|
|│ a ┆ b ┆ c │|
|│ --- ┆ --- ┆ --- │|
|│ str ┆ i64 ┆ i64 │|
|╞═════╪═════╪═════╡|
|│ a ┆ 1 ┆ 8 │|
|│ b ┆ 2 ┆ 4 │|
|│ b ┆ 3 ┆ 2 │|
|│ c ┆ 3 ┆ 1 │|
|└─────┴─────┴─────┘|
└───────────────────┘
"""
flat_aggs = tuple(flatten(aggs))
if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
msg = (
"Found expression which does not aggregate.\n\n"
"All expressions passed to GroupBy.agg must aggregate.\n"
"For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
"but `df.group_by('a').agg(nw.col('b'))` is not."
)
raise InvalidOperationError(msg)
plx = self._df.__narwhals_namespace__()
compliant_aggs = (
*(x._to_compliant_expr(plx) for x in flat_aggs),
*(
value.alias(key)._to_compliant_expr(plx)
for key, value in named_aggs.items()
),
)
return self._df._with_compliant(self._grouped.agg(*compliant_aggs))
|