1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
|
# Utilities for expression parsing
# Useful for backends which don't have any concept of expressions, such
# and pandas or PyArrow.
from __future__ import annotations
from enum import Enum, auto
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar, cast
from narwhals._utils import is_compliant_expr
from narwhals.dependencies import is_narwhals_series, is_numpy_array
from narwhals.exceptions import (
InvalidOperationError,
LengthChangingExprError,
MultiOutputExpressionError,
ShapeError,
)
if TYPE_CHECKING:
from typing_extensions import Never, TypeIs
from narwhals._compliant import CompliantExpr, CompliantFrameT
from narwhals._compliant.typing import (
AliasNames,
CompliantExprAny,
CompliantFrameAny,
CompliantNamespaceAny,
EagerNamespaceAny,
EvalNames,
)
from narwhals.expr import Expr
from narwhals.series import Series
from narwhals.typing import IntoExpr, NonNestedLiteral, _1DArray
T = TypeVar("T")
def is_expr(obj: Any) -> TypeIs[Expr]:
"""Check whether `obj` is a Narwhals Expr."""
from narwhals.expr import Expr
return isinstance(obj, Expr)
def is_series(obj: Any) -> TypeIs[Series[Any]]:
"""Check whether `obj` is a Narwhals Expr."""
from narwhals.series import Series
return isinstance(obj, Series)
def combine_evaluate_output_names(
*exprs: CompliantExpr[CompliantFrameT, Any],
) -> EvalNames[CompliantFrameT]:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
# first name of `expr1`.
if not is_compliant_expr(exprs[0]): # pragma: no cover
msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug."
raise AssertionError(msg)
def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]:
return exprs[0]._evaluate_output_names(df)[:1]
return evaluate_output_names
def combine_alias_output_names(*exprs: CompliantExprAny) -> AliasNames | None:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the
# aliasing function of `expr1` and apply it to the first output name of `expr1`.
if exprs[0]._alias_output_names is None:
return None
def alias_output_names(names: Sequence[str]) -> Sequence[str]:
return exprs[0]._alias_output_names(names)[:1] # type: ignore[misc]
return alias_output_names
def extract_compliant(
plx: CompliantNamespaceAny,
other: IntoExpr | NonNestedLiteral | _1DArray,
*,
str_as_lit: bool,
) -> CompliantExprAny | NonNestedLiteral:
if is_expr(other):
return other._to_compliant_expr(plx)
if isinstance(other, str) and not str_as_lit:
return plx.col(other)
if is_narwhals_series(other):
return other._compliant_series._to_expr()
if is_numpy_array(other):
ns = cast("EagerNamespaceAny", plx)
return ns._series.from_numpy(other, context=ns)._to_expr()
return other
def evaluate_output_names_and_aliases(
expr: CompliantExprAny, df: CompliantFrameAny, exclude: Sequence[str]
) -> tuple[Sequence[str], Sequence[str]]:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
if exclude:
assert expr._metadata is not None # noqa: S101
if expr._metadata.expansion_kind.is_multi_unnamed():
output_names, aliases = zip(
*[
(x, alias)
for x, alias in zip(output_names, aliases)
if x not in exclude
]
)
return output_names, aliases
class ExprKind(Enum):
"""Describe which kind of expression we are dealing with."""
LITERAL = auto()
"""e.g. `nw.lit(1)`"""
AGGREGATION = auto()
"""Reduces to a single value, not affected by row order, e.g. `nw.col('a').mean()`"""
ORDERABLE_AGGREGATION = auto()
"""Reduces to a single value, affected by row order, e.g. `nw.col('a').arg_max()`"""
ELEMENTWISE = auto()
"""Preserves length, can operate without context for surrounding rows, e.g. `nw.col('a').abs()`."""
ORDERABLE_WINDOW = auto()
"""Depends on the rows around it and on their order, e.g. `diff`."""
UNORDERABLE_WINDOW = auto()
"""Depends on the rows around it but not on their order, e.g. `rank`."""
FILTRATION = auto()
"""Changes length, not affected by row order, e.g. `drop_nulls`."""
ORDERABLE_FILTRATION = auto()
"""Changes length, affected by row order, e.g. `tail`."""
NARY = auto()
"""Results from the combination of multiple expressions."""
OVER = auto()
"""Results from calling `.over` on expression."""
UNKNOWN = auto()
"""Based on the information we have, we can't determine the ExprKind."""
@property
def is_scalar_like(self) -> bool:
return self in {ExprKind.LITERAL, ExprKind.AGGREGATION}
@property
def is_orderable_window(self) -> bool:
return self in {ExprKind.ORDERABLE_WINDOW, ExprKind.ORDERABLE_AGGREGATION}
@classmethod
def from_expr(cls, obj: Expr) -> ExprKind:
meta = obj._metadata
if meta.is_literal:
return ExprKind.LITERAL
if meta.is_scalar_like:
return ExprKind.AGGREGATION
if meta.is_elementwise:
return ExprKind.ELEMENTWISE
return ExprKind.UNKNOWN
@classmethod
def from_into_expr(
cls, obj: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool
) -> ExprKind:
if is_expr(obj):
return cls.from_expr(obj)
if (
is_narwhals_series(obj)
or is_numpy_array(obj)
or (isinstance(obj, str) and not str_as_lit)
):
return ExprKind.ELEMENTWISE
return ExprKind.LITERAL
def is_scalar_like(
obj: ExprKind,
) -> TypeIs[Literal[ExprKind.LITERAL, ExprKind.AGGREGATION]]:
return obj.is_scalar_like
class ExpansionKind(Enum):
"""Describe what kind of expansion the expression performs."""
SINGLE = auto()
"""e.g. `nw.col('a'), nw.sum_horizontal(nw.all())`"""
MULTI_NAMED = auto()
"""e.g. `nw.col('a', 'b')`"""
MULTI_UNNAMED = auto()
"""e.g. `nw.all()`, nw.nth(0, 1)"""
def is_multi_unnamed(self) -> bool:
return self is ExpansionKind.MULTI_UNNAMED
def is_multi_output(self) -> bool:
return self in {ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED}
def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]:
if self is ExpansionKind.MULTI_UNNAMED and other is ExpansionKind.MULTI_UNNAMED:
# e.g. nw.selectors.all() - nw.selectors.numeric().
return ExpansionKind.MULTI_UNNAMED
# Don't attempt anything more complex, keep it simple and raise in the face of ambiguity.
msg = f"Unsupported ExpansionKind combination, got {self} and {other}, please report a bug." # pragma: no cover
raise AssertionError(msg) # pragma: no cover
class ExprMetadata:
__slots__ = (
"expansion_kind",
"has_windows",
"is_elementwise",
"is_literal",
"is_scalar_like",
"last_node",
"n_orderable_ops",
"preserves_length",
)
def __init__(
self,
expansion_kind: ExpansionKind,
last_node: ExprKind,
*,
has_windows: bool = False,
n_orderable_ops: int = 0,
preserves_length: bool = True,
is_elementwise: bool = True,
is_scalar_like: bool = False,
is_literal: bool = False,
) -> None:
if is_literal:
assert is_scalar_like # noqa: S101 # debug assertion
if is_elementwise:
assert preserves_length # noqa: S101 # debug assertion
self.expansion_kind: ExpansionKind = expansion_kind
self.last_node: ExprKind = last_node
self.has_windows: bool = has_windows
self.n_orderable_ops: int = n_orderable_ops
self.is_elementwise: bool = is_elementwise
self.preserves_length: bool = preserves_length
self.is_scalar_like: bool = is_scalar_like
self.is_literal: bool = is_literal
def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover
msg = f"Cannot subclass {cls.__name__!r}"
raise TypeError(msg)
def __repr__(self) -> str: # pragma: no cover
return (
f"ExprMetadata(\n"
f" expansion_kind: {self.expansion_kind},\n"
f" last_node: {self.last_node},\n"
f" has_windows: {self.has_windows},\n"
f" n_orderable_ops: {self.n_orderable_ops},\n"
f" is_elementwise: {self.is_elementwise},\n"
f" preserves_length: {self.preserves_length},\n"
f" is_scalar_like: {self.is_scalar_like},\n"
f" is_literal: {self.is_literal},\n"
")"
)
@property
def is_filtration(self) -> bool:
return not self.preserves_length and not self.is_scalar_like
def with_aggregation(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply aggregations to scalar-like expressions."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.AGGREGATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=False,
is_elementwise=False,
is_scalar_like=True,
is_literal=False,
)
def with_orderable_aggregation(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply aggregations to scalar-like expressions."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.ORDERABLE_AGGREGATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=False,
is_elementwise=False,
is_scalar_like=True,
is_literal=False,
)
def with_elementwise_op(self) -> ExprMetadata:
return ExprMetadata(
self.expansion_kind,
ExprKind.ELEMENTWISE,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=self.preserves_length,
is_elementwise=self.is_elementwise,
is_scalar_like=self.is_scalar_like,
is_literal=self.is_literal,
)
def with_unorderable_window(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply unorderable window (`rank`, `is_unique`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.UNORDERABLE_WINDOW,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=self.preserves_length,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_orderable_window(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.ORDERABLE_WINDOW,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=self.preserves_length,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_ordered_over(self) -> ExprMetadata:
if self.has_windows:
msg = "Cannot nest `over` statements."
raise InvalidOperationError(msg)
if self.is_elementwise or self.is_filtration:
msg = (
"Cannot use `over` on expressions which are elementwise\n"
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
)
raise InvalidOperationError(msg)
n_orderable_ops = self.n_orderable_ops
if not n_orderable_ops:
msg = "Cannot use `order_by` in `over` on expression which isn't orderable."
raise InvalidOperationError(msg)
if self.last_node.is_orderable_window:
n_orderable_ops -= 1
return ExprMetadata(
self.expansion_kind,
ExprKind.OVER,
has_windows=True,
n_orderable_ops=n_orderable_ops,
preserves_length=True,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_partitioned_over(self) -> ExprMetadata:
if self.has_windows:
msg = "Cannot nest `over` statements."
raise InvalidOperationError(msg)
if self.is_elementwise or self.is_filtration:
msg = (
"Cannot use `over` on expressions which are elementwise\n"
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
)
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.OVER,
has_windows=True,
n_orderable_ops=self.n_orderable_ops,
preserves_length=True,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_filtration(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.FILTRATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=False,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_orderable_filtration(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.ORDERABLE_FILTRATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=False,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
@staticmethod
def aggregation() -> ExprMetadata:
return ExprMetadata(
ExpansionKind.SINGLE,
ExprKind.AGGREGATION,
is_elementwise=False,
preserves_length=False,
is_scalar_like=True,
)
@staticmethod
def literal() -> ExprMetadata:
return ExprMetadata(
ExpansionKind.SINGLE,
ExprKind.LITERAL,
is_elementwise=False,
preserves_length=False,
is_literal=True,
is_scalar_like=True,
)
@staticmethod
def selector_single() -> ExprMetadata:
# e.g. `nw.col('a')`, `nw.nth(0)`
return ExprMetadata(ExpansionKind.SINGLE, ExprKind.ELEMENTWISE)
@staticmethod
def selector_multi_named() -> ExprMetadata:
# e.g. `nw.col('a', 'b')`
return ExprMetadata(ExpansionKind.MULTI_NAMED, ExprKind.ELEMENTWISE)
@staticmethod
def selector_multi_unnamed() -> ExprMetadata:
# e.g. `nw.all()`
return ExprMetadata(ExpansionKind.MULTI_UNNAMED, ExprKind.ELEMENTWISE)
@classmethod
def from_binary_op(cls, lhs: Expr, rhs: IntoExpr, /) -> ExprMetadata:
# We may be able to allow multi-output rhs in the future:
# https://github.com/narwhals-dev/narwhals/issues/2244.
return combine_metadata(
lhs, rhs, str_as_lit=True, allow_multi_output=False, to_single_output=False
)
@classmethod
def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata:
return combine_metadata(
*exprs, str_as_lit=False, allow_multi_output=True, to_single_output=True
)
def combine_metadata( # noqa: C901, PLR0912
*args: IntoExpr | object | None,
str_as_lit: bool,
allow_multi_output: bool,
to_single_output: bool,
) -> ExprMetadata:
"""Combine metadata from `args`.
Arguments:
args: Arguments, maybe expressions, literals, or Series.
str_as_lit: Whether to interpret strings as literals or as column names.
allow_multi_output: Whether to allow multi-output inputs.
to_single_output: Whether the result is always single-output, regardless
of the inputs (e.g. `nw.sum_horizontal`).
"""
n_filtrations = 0
result_expansion_kind = ExpansionKind.SINGLE
result_has_windows = False
result_n_orderable_ops = 0
# result preserves length if at least one input does
result_preserves_length = False
# result is elementwise if all inputs are elementwise
result_is_not_elementwise = False
# result is scalar-like if all inputs are scalar-like
result_is_not_scalar_like = False
# result is literal if all inputs are literal
result_is_not_literal = False
for i, arg in enumerate(args): # noqa: PLR1702
if (isinstance(arg, str) and not str_as_lit) or is_series(arg):
result_preserves_length = True
result_is_not_scalar_like = True
result_is_not_literal = True
elif is_expr(arg):
metadata = arg._metadata
if metadata.expansion_kind.is_multi_output():
expansion_kind = metadata.expansion_kind
if i > 0 and not allow_multi_output:
# Left-most argument is always allowed to be multi-output.
msg = (
"Multi-output expressions (e.g. nw.col('a', 'b'), nw.all()) "
"are not supported in this context."
)
raise MultiOutputExpressionError(msg)
if not to_single_output:
if i == 0:
result_expansion_kind = expansion_kind
else:
result_expansion_kind = result_expansion_kind & expansion_kind
if metadata.has_windows:
result_has_windows = True
result_n_orderable_ops += metadata.n_orderable_ops
if metadata.preserves_length:
result_preserves_length = True
if not metadata.is_elementwise:
result_is_not_elementwise = True
if not metadata.is_scalar_like:
result_is_not_scalar_like = True
if not metadata.is_literal:
result_is_not_literal = True
if metadata.is_filtration:
n_filtrations += 1
if n_filtrations > 1:
msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation"
raise LengthChangingExprError(msg)
if result_preserves_length and n_filtrations:
msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations"
raise ShapeError(msg)
return ExprMetadata(
result_expansion_kind,
ExprKind.NARY,
has_windows=result_has_windows,
n_orderable_ops=result_n_orderable_ops,
preserves_length=result_preserves_length,
is_elementwise=not result_is_not_elementwise,
is_scalar_like=not result_is_not_scalar_like,
is_literal=not result_is_not_literal,
)
def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> None:
# Raise if any argument in `args` isn't length-preserving.
# For Series input, we don't raise (yet), we let such checks happen later,
# as this function works lazily and so can't evaluate lengths.
from narwhals.series import Series
if not all(
(is_expr(x) and x._metadata.preserves_length) or isinstance(x, (str, Series))
for x in args
):
msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'."
raise ShapeError(msg)
def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool:
# Raise if any argument in `args` isn't an aggregation or literal.
# For Series input, we don't raise (yet), we let such checks happen later,
# as this function works lazily and so can't evaluate lengths.
exprs = chain(args, kwargs.values())
return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs)
def apply_n_ary_operation(
plx: CompliantNamespaceAny,
function: Any,
*comparands: IntoExpr | NonNestedLiteral | _1DArray,
str_as_lit: bool,
) -> CompliantExprAny:
compliant_exprs = (
extract_compliant(plx, comparand, str_as_lit=str_as_lit)
for comparand in comparands
)
kinds = [
ExprKind.from_into_expr(comparand, str_as_lit=str_as_lit)
for comparand in comparands
]
broadcast = any(not kind.is_scalar_like for kind in kinds)
compliant_exprs = (
compliant_expr.broadcast(kind)
if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind)
else compliant_expr
for compliant_expr, kind in zip(compliant_exprs, kinds)
)
return function(*compliant_exprs)
|