diff options
| author | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
|---|---|---|
| committer | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
| commit | 5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e (patch) | |
| tree | 8dacb0f195df1c0788d36dd0064f6bbaa3143ede /venv/lib/python3.8/site-packages/narwhals/_polars | |
| parent | b832d364da8c2efe09e3f75828caf73c50d01ce3 (diff) | |
add code for analysis of data
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_polars')
8 files changed, 2604 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/__init__.py b/venv/lib/python3.8/site-packages/narwhals/_polars/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/__init__.py diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/dataframe.py b/venv/lib/python3.8/site-packages/narwhals/_polars/dataframe.py new file mode 100644 index 0000000..9a270ff --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/dataframe.py @@ -0,0 +1,770 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Iterator, + Literal, + Mapping, + Sequence, + Sized, + cast, + overload, +) + +import polars as pl + +from narwhals._polars.namespace import PolarsNamespace +from narwhals._polars.series import PolarsSeries +from narwhals._polars.utils import ( + catch_polars_exception, + extract_args_kwargs, + native_to_narwhals_dtype, +) +from narwhals._utils import ( + Implementation, + _into_arrow_table, + check_columns_exist, + convert_str_slice_to_int_slice, + is_compliant_series, + is_index_selector, + is_range, + is_sequence_like, + is_slice_index, + is_slice_none, + parse_columns_to_drop, + parse_version, + requires, + validate_backend_version, +) +from narwhals.dependencies import is_numpy_array_1d +from narwhals.exceptions import ColumnNotFoundError + +if TYPE_CHECKING: + from types import ModuleType + from typing import Callable, TypeVar + + import pandas as pd + import pyarrow as pa + from typing_extensions import Self, TypeAlias, TypeIs + + from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny + from narwhals._polars.expr import PolarsExpr + from narwhals._polars.group_by import PolarsGroupBy, PolarsLazyGroupBy + from narwhals._translate import IntoArrowTable + from narwhals._utils import Version, _FullContext + from narwhals.dataframe import DataFrame, LazyFrame + from narwhals.dtypes import DType + from narwhals.schema import Schema + from narwhals.typing import ( + JoinStrategy, + MultiColSelector, + MultiIndexSelector, + PivotAgg, + SingleIndexSelector, + _2DArray, + ) + + T = TypeVar("T") + R = TypeVar("R") + +Method: TypeAlias = "Callable[..., R]" +"""Generic alias representing all methods implemented via `__getattr__`. + +Where `R` is the return type. +""" + +# DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly. +INHERITED_METHODS = frozenset( + [ + "clone", + "drop_nulls", + "estimated_size", + "explode", + "filter", + "gather_every", + "head", + "is_unique", + "item", + "iter_rows", + "join_asof", + "rename", + "row", + "rows", + "sample", + "select", + "sort", + "tail", + "to_arrow", + "to_pandas", + "unique", + "with_columns", + "write_csv", + "write_parquet", + ] +) + + +class PolarsDataFrame: + clone: Method[Self] + collect: Method[CompliantDataFrameAny] + drop_nulls: Method[Self] + estimated_size: Method[int | float] + explode: Method[Self] + filter: Method[Self] + gather_every: Method[Self] + item: Method[Any] + iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]] + is_unique: Method[PolarsSeries] + join_asof: Method[Self] + rename: Method[Self] + row: Method[tuple[Any, ...]] + rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]] + sample: Method[Self] + select: Method[Self] + sort: Method[Self] + to_arrow: Method[pa.Table] + to_pandas: Method[pd.DataFrame] + unique: Method[Self] + with_columns: Method[Self] + # NOTE: `write_csv` requires an `@overload` for `str | None` + # Can't do that here 😟 + write_csv: Method[Any] + write_parquet: Method[None] + + # CompliantDataFrame + _evaluate_aliases: Any + + def __init__( + self, df: pl.DataFrame, *, backend_version: tuple[int, ...], version: Version + ) -> None: + self._native_frame = df + self._backend_version = backend_version + self._implementation = Implementation.POLARS + self._version = version + validate_backend_version(self._implementation, self._backend_version) + + @classmethod + def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: + if context._backend_version >= (1, 3): + native = pl.DataFrame(data) + else: + native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context))) + return cls.from_native(native, context=context) + + @classmethod + def from_dict( + cls, + data: Mapping[str, Any], + /, + *, + context: _FullContext, + schema: Mapping[str, DType] | Schema | None, + ) -> Self: + from narwhals.schema import Schema + + pl_schema = Schema(schema).to_polars() if schema is not None else schema + return cls.from_native(pl.from_dict(data, pl_schema), context=context) + + @staticmethod + def _is_native(obj: pl.DataFrame | Any) -> TypeIs[pl.DataFrame]: + return isinstance(obj, pl.DataFrame) + + @classmethod + def from_native(cls, data: pl.DataFrame, /, *, context: _FullContext) -> Self: + return cls( + data, backend_version=context._backend_version, version=context._version + ) + + @classmethod + def from_numpy( + cls, + data: _2DArray, + /, + *, + context: _FullContext, # NOTE: Maybe only `Implementation`? + schema: Mapping[str, DType] | Schema | Sequence[str] | None, + ) -> Self: + from narwhals.schema import Schema + + pl_schema = ( + Schema(schema).to_polars() + if isinstance(schema, (Mapping, Schema)) + else schema + ) + return cls.from_native(pl.from_numpy(data, pl_schema), context=context) + + def to_narwhals(self) -> DataFrame[pl.DataFrame]: + return self._version.dataframe(self, level="full") + + @property + def native(self) -> pl.DataFrame: + return self._native_frame + + def __repr__(self) -> str: # pragma: no cover + return "PolarsDataFrame" + + def __narwhals_dataframe__(self) -> Self: + return self + + def __narwhals_namespace__(self) -> PolarsNamespace: + return PolarsNamespace( + backend_version=self._backend_version, version=self._version + ) + + def __native_namespace__(self) -> ModuleType: + if self._implementation is Implementation.POLARS: + return self._implementation.to_native_namespace() + + msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover + raise AssertionError(msg) + + def _with_version(self, version: Version) -> Self: + return self.__class__( + self.native, backend_version=self._backend_version, version=version + ) + + def _with_native(self, df: pl.DataFrame) -> Self: + return self.__class__( + df, backend_version=self._backend_version, version=self._version + ) + + @overload + def _from_native_object(self, obj: pl.Series) -> PolarsSeries: ... + + @overload + def _from_native_object(self, obj: pl.DataFrame) -> Self: ... + + @overload + def _from_native_object(self, obj: T) -> T: ... + + def _from_native_object( + self, obj: pl.Series | pl.DataFrame | T + ) -> Self | PolarsSeries | T: + if isinstance(obj, pl.Series): + return PolarsSeries.from_native(obj, context=self) + if self._is_native(obj): + return self._with_native(obj) + # scalar + return obj + + def __len__(self) -> int: + return len(self.native) + + def head(self, n: int) -> Self: + return self._with_native(self.native.head(n)) + + def tail(self, n: int) -> Self: + return self._with_native(self.native.tail(n)) + + def __getattr__(self, attr: str) -> Any: + if attr not in INHERITED_METHODS: # pragma: no cover + msg = f"{self.__class__.__name__} has not attribute '{attr}'." + raise AttributeError(msg) + + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + try: + return self._from_native_object(getattr(self.native, attr)(*pos, **kwds)) + except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover + msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?" + raise ColumnNotFoundError(msg) from e + except Exception as e: # noqa: BLE001 + raise catch_polars_exception(e, self._backend_version) from None + + return func + + def __array__( + self, dtype: Any | None = None, *, copy: bool | None = None + ) -> _2DArray: + if self._backend_version < (0, 20, 28) and copy is not None: + msg = "`copy` in `__array__` is only supported for 'polars>=0.20.28'" + raise NotImplementedError(msg) + if self._backend_version < (0, 20, 28): + return self.native.__array__(dtype) + return self.native.__array__(dtype) + + def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: + return self.native.to_numpy() + + def collect_schema(self) -> dict[str, DType]: + if self._backend_version < (1,): + return { + name: native_to_narwhals_dtype( + dtype, self._version, self._backend_version + ) + for name, dtype in self.native.schema.items() + } + else: + collected_schema = self.native.collect_schema() + return { + name: native_to_narwhals_dtype( + dtype, self._version, self._backend_version + ) + for name, dtype in collected_schema.items() + } + + @property + def shape(self) -> tuple[int, int]: + return self.native.shape + + def __getitem__( # noqa: C901, PLR0912 + self, + item: tuple[ + SingleIndexSelector | MultiIndexSelector[PolarsSeries], + MultiColSelector[PolarsSeries], + ], + ) -> Any: + rows, columns = item + if self._backend_version > (0, 20, 30): + rows_native = rows.native if is_compliant_series(rows) else rows + columns_native = columns.native if is_compliant_series(columns) else columns + selector = rows_native, columns_native + selected = self.native.__getitem__(selector) # type: ignore[index] + return self._from_native_object(selected) + else: # pragma: no cover + # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum + # Polars version we support + # This mostly mirrors the logic in `EagerDataFrame.__getitem__`. + rows = list(rows) if isinstance(rows, tuple) else rows + columns = list(columns) if isinstance(columns, tuple) else columns + if is_numpy_array_1d(columns): + columns = columns.tolist() + + native = self.native + if not is_slice_none(columns): + if isinstance(columns, Sized) and len(columns) == 0: + return self.select() + if is_index_selector(columns): + if is_slice_index(columns) or is_range(columns): + native = native.select( + self.columns[slice(columns.start, columns.stop, columns.step)] + ) + elif is_compliant_series(columns): + native = native[:, columns.native.to_list()] + else: + native = native[:, columns] + elif isinstance(columns, slice): + native = native.select( + self.columns[ + slice(*convert_str_slice_to_int_slice(columns, self.columns)) + ] + ) + elif is_compliant_series(columns): + native = native.select(columns.native.to_list()) + elif is_sequence_like(columns): + native = native.select(columns) + else: + msg = f"Unreachable code, got unexpected type: {type(columns)}" + raise AssertionError(msg) + + if not is_slice_none(rows): + if isinstance(rows, int): + native = native[[rows], :] + elif isinstance(rows, (slice, range)): + native = native[rows, :] + elif is_compliant_series(rows): + native = native[rows.native, :] + elif is_sequence_like(rows): + native = native[rows, :] + else: + msg = f"Unreachable code, got unexpected type: {type(rows)}" + raise AssertionError(msg) + + return self._with_native(native) + + def simple_select(self, *column_names: str) -> Self: + return self._with_native(self.native.select(*column_names)) + + def aggregate(self, *exprs: Any) -> Self: + return self.select(*exprs) + + def get_column(self, name: str) -> PolarsSeries: + return PolarsSeries.from_native(self.native.get_column(name), context=self) + + def iter_columns(self) -> Iterator[PolarsSeries]: + for series in self.native.iter_columns(): + yield PolarsSeries.from_native(series, context=self) + + @property + def columns(self) -> list[str]: + return self.native.columns + + @property + def schema(self) -> dict[str, DType]: + return { + name: native_to_narwhals_dtype(dtype, self._version, self._backend_version) + for name, dtype in self.native.schema.items() + } + + def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny: + if backend is None or backend is Implementation.POLARS: + return PolarsLazyFrame.from_native(self.native.lazy(), context=self) + elif backend is Implementation.DUCKDB: + import duckdb # ignore-banned-import + + from narwhals._duckdb.dataframe import DuckDBLazyFrame + + # NOTE: (F841) is a false positive + df = self.native # noqa: F841 + return DuckDBLazyFrame( + duckdb.table("df"), + backend_version=parse_version(duckdb), + version=self._version, + ) + elif backend is Implementation.DASK: + import dask # ignore-banned-import + import dask.dataframe as dd # ignore-banned-import + + from narwhals._dask.dataframe import DaskLazyFrame + + return DaskLazyFrame( + dd.from_pandas(self.native.to_pandas()), + backend_version=parse_version(dask), + version=self._version, + ) + raise AssertionError # pragma: no cover + + @overload + def to_dict(self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ... + + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + + def to_dict( + self, *, as_series: bool + ) -> dict[str, PolarsSeries] | dict[str, list[Any]]: + if as_series: + return { + name: PolarsSeries.from_native(col, context=self) + for name, col in self.native.to_dict().items() + } + else: + return self.native.to_dict(as_series=False) + + def group_by( + self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool + ) -> PolarsGroupBy: + from narwhals._polars.group_by import PolarsGroupBy + + return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys) + + def with_row_index(self, name: str) -> Self: + if self._backend_version < (0, 20, 4): + return self._with_native(self.native.with_row_count(name)) + return self._with_native(self.native.with_row_index(name)) + + def drop(self, columns: Sequence[str], *, strict: bool) -> Self: + to_drop = parse_columns_to_drop(self, columns, strict=strict) + return self._with_native(self.native.drop(to_drop)) + + def unpivot( + self, + on: Sequence[str] | None, + index: Sequence[str] | None, + variable_name: str, + value_name: str, + ) -> Self: + if self._backend_version < (1, 0, 0): + return self._with_native( + self.native.melt( + id_vars=index, + value_vars=on, + variable_name=variable_name, + value_name=value_name, + ) + ) + return self._with_native( + self.native.unpivot( + on=on, index=index, variable_name=variable_name, value_name=value_name + ) + ) + + @requires.backend_version((1,)) + def pivot( + self, + on: Sequence[str], + *, + index: Sequence[str] | None, + values: Sequence[str] | None, + aggregate_function: PivotAgg | None, + sort_columns: bool, + separator: str, + ) -> Self: + try: + result = self.native.pivot( + on, + index=index, + values=values, + aggregate_function=aggregate_function, + sort_columns=sort_columns, + separator=separator, + ) + except Exception as e: # noqa: BLE001 + raise catch_polars_exception(e, self._backend_version) from None + return self._from_native_object(result) + + def to_polars(self) -> pl.DataFrame: + return self.native + + def join( + self, + other: Self, + *, + how: JoinStrategy, + left_on: Sequence[str] | None, + right_on: Sequence[str] | None, + suffix: str, + ) -> Self: + how_native = ( + "outer" if (self._backend_version < (0, 20, 29) and how == "full") else how + ) + try: + return self._with_native( + self.native.join( + other=other.native, + how=how_native, # type: ignore[arg-type] + left_on=left_on, + right_on=right_on, + suffix=suffix, + ) + ) + except Exception as e: # noqa: BLE001 + raise catch_polars_exception(e, self._backend_version) from None + + def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: + return check_columns_exist(subset, available=self.columns) + + +class PolarsLazyFrame: + drop_nulls: Method[Self] + explode: Method[Self] + filter: Method[Self] + gather_every: Method[Self] + head: Method[Self] + join_asof: Method[Self] + rename: Method[Self] + select: Method[Self] + sort: Method[Self] + tail: Method[Self] + unique: Method[Self] + with_columns: Method[Self] + + # CompliantLazyFrame + _evaluate_expr: Any + _evaluate_window_expr: Any + _evaluate_aliases: Any + + def __init__( + self, df: pl.LazyFrame, *, backend_version: tuple[int, ...], version: Version + ) -> None: + self._native_frame = df + self._backend_version = backend_version + self._implementation = Implementation.POLARS + self._version = version + validate_backend_version(self._implementation, self._backend_version) + + @staticmethod + def _is_native(obj: pl.LazyFrame | Any) -> TypeIs[pl.LazyFrame]: + return isinstance(obj, pl.LazyFrame) + + @classmethod + def from_native(cls, data: pl.LazyFrame, /, *, context: _FullContext) -> Self: + return cls( + data, backend_version=context._backend_version, version=context._version + ) + + def to_narwhals(self) -> LazyFrame[pl.LazyFrame]: + return self._version.lazyframe(self, level="lazy") + + def __repr__(self) -> str: # pragma: no cover + return "PolarsLazyFrame" + + def __narwhals_lazyframe__(self) -> Self: + return self + + def __narwhals_namespace__(self) -> PolarsNamespace: + return PolarsNamespace( + backend_version=self._backend_version, version=self._version + ) + + def __native_namespace__(self) -> ModuleType: + if self._implementation is Implementation.POLARS: + return self._implementation.to_native_namespace() + + msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover + raise AssertionError(msg) + + def _with_native(self, df: pl.LazyFrame) -> Self: + return self.__class__( + df, backend_version=self._backend_version, version=self._version + ) + + def _with_version(self, version: Version) -> Self: + return self.__class__( + self.native, backend_version=self._backend_version, version=version + ) + + def __getattr__(self, attr: str) -> Any: + if attr not in INHERITED_METHODS: # pragma: no cover + msg = f"{self.__class__.__name__} has not attribute '{attr}'." + raise AttributeError(msg) + + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + try: + return self._with_native(getattr(self.native, attr)(*pos, **kwds)) + except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover + raise ColumnNotFoundError(str(e)) from e + + return func + + def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover + yield from self.collect(self._implementation).iter_columns() + + @property + def native(self) -> pl.LazyFrame: + return self._native_frame + + @property + def columns(self) -> list[str]: + return self.native.columns + + @property + def schema(self) -> dict[str, DType]: + schema = self.native.schema + return { + name: native_to_narwhals_dtype(dtype, self._version, self._backend_version) + for name, dtype in schema.items() + } + + def collect_schema(self) -> dict[str, DType]: + if self._backend_version < (1,): + return { + name: native_to_narwhals_dtype( + dtype, self._version, self._backend_version + ) + for name, dtype in self.native.schema.items() + } + else: + try: + collected_schema = self.native.collect_schema() + except Exception as e: # noqa: BLE001 + raise catch_polars_exception(e, self._backend_version) from None + return { + name: native_to_narwhals_dtype( + dtype, self._version, self._backend_version + ) + for name, dtype in collected_schema.items() + } + + def collect( + self, backend: Implementation | None, **kwargs: Any + ) -> CompliantDataFrameAny: + try: + result = self.native.collect(**kwargs) + except Exception as e: # noqa: BLE001 + raise catch_polars_exception(e, self._backend_version) from None + + if backend is None or backend is Implementation.POLARS: + return PolarsDataFrame.from_native(result, context=self) + + if backend is Implementation.PANDAS: + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + result.to_pandas(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd), + version=self._version, + validate_column_names=False, + ) + + if backend is Implementation.PYARROW: + import pyarrow as pa # ignore-banned-import + + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + result.to_arrow(), + backend_version=parse_version(pa), + version=self._version, + validate_column_names=False, + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover + + def group_by( + self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool + ) -> PolarsLazyGroupBy: + from narwhals._polars.group_by import PolarsLazyGroupBy + + return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) + + def with_row_index(self, name: str) -> Self: + if self._backend_version < (0, 20, 4): + return self._with_native(self.native.with_row_count(name)) + return self._with_native(self.native.with_row_index(name)) + + def drop(self, columns: Sequence[str], *, strict: bool) -> Self: + if self._backend_version < (1, 0, 0): + return self._with_native(self.native.drop(columns)) + return self._with_native(self.native.drop(columns, strict=strict)) + + def unpivot( + self, + on: Sequence[str] | None, + index: Sequence[str] | None, + variable_name: str, + value_name: str, + ) -> Self: + if self._backend_version < (1, 0, 0): + return self._with_native( + self.native.melt( + id_vars=index, + value_vars=on, + variable_name=variable_name, + value_name=value_name, + ) + ) + return self._with_native( + self.native.unpivot( + on=on, index=index, variable_name=variable_name, value_name=value_name + ) + ) + + def simple_select(self, *column_names: str) -> Self: + return self._with_native(self.native.select(*column_names)) + + def aggregate(self, *exprs: Any) -> Self: + return self.select(*exprs) + + def join( + self, + other: Self, + *, + how: JoinStrategy, + left_on: Sequence[str] | None, + right_on: Sequence[str] | None, + suffix: str, + ) -> Self: + how_native = ( + "outer" if (self._backend_version < (0, 20, 29) and how == "full") else how + ) + return self._with_native( + self.native.join( + other=other.native, + how=how_native, # type: ignore[arg-type] + left_on=left_on, + right_on=right_on, + suffix=suffix, + ) + ) + + def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: + return check_columns_exist( # pragma: no cover + subset, available=self.columns + ) diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py b/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py new file mode 100644 index 0000000..eb5b5f2 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py @@ -0,0 +1,415 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Sequence + +import polars as pl + +from narwhals._duration import parse_interval_string +from narwhals._polars.utils import ( + extract_args_kwargs, + extract_native, + narwhals_to_native_dtype, +) +from narwhals._utils import Implementation, requires + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._expression_parsing import ExprKind, ExprMetadata + from narwhals._polars.dataframe import Method + from narwhals._polars.namespace import PolarsNamespace + from narwhals._utils import Version + from narwhals.typing import IntoDType + + +class PolarsExpr: + def __init__( + self, expr: pl.Expr, version: Version, backend_version: tuple[int, ...] + ) -> None: + self._native_expr = expr + self._implementation = Implementation.POLARS + self._version = version + self._backend_version = backend_version + self._metadata: ExprMetadata | None = None + + @property + def native(self) -> pl.Expr: + return self._native_expr + + def __repr__(self) -> str: # pragma: no cover + return "PolarsExpr" + + def _with_native(self, expr: pl.Expr) -> Self: + return self.__class__(expr, self._version, self._backend_version) + + @classmethod + def _from_series(cls, series: Any) -> Self: + return cls(series.native, series._version, series._backend_version) + + def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + # Let Polars do its thing. + return self + + def __getattr__(self, attr: str) -> Any: + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._with_native(getattr(self.native, attr)(*pos, **kwds)) + + return func + + def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]: + name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples" + return {name: min_samples} + + def cast(self, dtype: IntoDType) -> Self: + dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) + return self._with_native(self.native.cast(dtype_pl)) + + def ewm_mean( + self, + *, + com: float | None, + span: float | None, + half_life: float | None, + alpha: float | None, + adjust: bool, + min_samples: int, + ignore_nulls: bool, + ) -> Self: + native = self.native.ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + ignore_nulls=ignore_nulls, + **self._renamed_min_periods(min_samples), + ) + if self._backend_version < (1,): # pragma: no cover + native = pl.when(~self.native.is_null()).then(native).otherwise(None) + return self._with_native(native) + + def is_nan(self) -> Self: + if self._backend_version >= (1, 18): + native = self.native.is_nan() + else: # pragma: no cover + native = pl.when(self.native.is_not_null()).then(self.native.is_nan()) + return self._with_native(native) + + def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: + if self._backend_version < (1, 9): + if order_by: + msg = "`order_by` in Polars requires version 1.10 or greater" + raise NotImplementedError(msg) + native = self.native.over(partition_by or pl.lit(1)) + else: + native = self.native.over( + partition_by or pl.lit(1), order_by=order_by or None + ) + return self._with_native(native) + + @requires.backend_version((1,)) + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + kwds = self._renamed_min_periods(min_samples) + native = self.native.rolling_var( + window_size=window_size, center=center, ddof=ddof, **kwds + ) + return self._with_native(native) + + @requires.backend_version((1,)) + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + kwds = self._renamed_min_periods(min_samples) + native = self.native.rolling_std( + window_size=window_size, center=center, ddof=ddof, **kwds + ) + return self._with_native(native) + + def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: + kwds = self._renamed_min_periods(min_samples) + native = self.native.rolling_sum(window_size=window_size, center=center, **kwds) + return self._with_native(native) + + def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: + kwds = self._renamed_min_periods(min_samples) + native = self.native.rolling_mean(window_size=window_size, center=center, **kwds) + return self._with_native(native) + + def map_batches( + self, function: Callable[[Any], Any], return_dtype: IntoDType | None + ) -> Self: + return_dtype_pl = ( + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) + if return_dtype + else None + ) + native = self.native.map_batches(function, return_dtype_pl) + return self._with_native(native) + + @requires.backend_version((1,)) + def replace_strict( + self, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any], + *, + return_dtype: IntoDType | None, + ) -> Self: + return_dtype_pl = ( + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) + if return_dtype + else None + ) + native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl) + return self._with_native(native) + + def __eq__(self, other: object) -> Self: # type: ignore[override] + return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator] + + def __ne__(self, other: object) -> Self: # type: ignore[override] + return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator] + + def __ge__(self, other: Any) -> Self: + return self._with_native(self.native.__ge__(extract_native(other))) + + def __gt__(self, other: Any) -> Self: + return self._with_native(self.native.__gt__(extract_native(other))) + + def __le__(self, other: Any) -> Self: + return self._with_native(self.native.__le__(extract_native(other))) + + def __lt__(self, other: Any) -> Self: + return self._with_native(self.native.__lt__(extract_native(other))) + + def __and__(self, other: PolarsExpr | bool | Any) -> Self: + return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator] + + def __or__(self, other: PolarsExpr | bool | Any) -> Self: + return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator] + + def __add__(self, other: Any) -> Self: + return self._with_native(self.native.__add__(extract_native(other))) + + def __sub__(self, other: Any) -> Self: + return self._with_native(self.native.__sub__(extract_native(other))) + + def __mul__(self, other: Any) -> Self: + return self._with_native(self.native.__mul__(extract_native(other))) + + def __pow__(self, other: Any) -> Self: + return self._with_native(self.native.__pow__(extract_native(other))) + + def __truediv__(self, other: Any) -> Self: + return self._with_native(self.native.__truediv__(extract_native(other))) + + def __floordiv__(self, other: Any) -> Self: + return self._with_native(self.native.__floordiv__(extract_native(other))) + + def __mod__(self, other: Any) -> Self: + return self._with_native(self.native.__mod__(extract_native(other))) + + def __invert__(self) -> Self: + return self._with_native(self.native.__invert__()) + + def cum_count(self, *, reverse: bool) -> Self: + if self._backend_version < (0, 20, 4): + result = (~self.native.is_null()).cum_sum(reverse=reverse) + else: + result = self.native.cum_count(reverse=reverse) + return self._with_native(result) + + def __narwhals_expr__(self) -> None: ... + def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover + from narwhals._polars.namespace import PolarsNamespace + + return PolarsNamespace( + backend_version=self._backend_version, version=self._version + ) + + @property + def dt(self) -> PolarsExprDateTimeNamespace: + return PolarsExprDateTimeNamespace(self) + + @property + def str(self) -> PolarsExprStringNamespace: + return PolarsExprStringNamespace(self) + + @property + def cat(self) -> PolarsExprCatNamespace: + return PolarsExprCatNamespace(self) + + @property + def name(self) -> PolarsExprNameNamespace: + return PolarsExprNameNamespace(self) + + @property + def list(self) -> PolarsExprListNamespace: + return PolarsExprListNamespace(self) + + @property + def struct(self) -> PolarsExprStructNamespace: + return PolarsExprStructNamespace(self) + + # CompliantExpr + _alias_output_names: Any + _evaluate_aliases: Any + _evaluate_output_names: Any + _is_multi_output_unnamed: Any + __call__: Any + from_column_names: Any + from_column_indices: Any + _eval_names_indices: Any + + # Polars + abs: Method[Self] + all: Method[Self] + any: Method[Self] + alias: Method[Self] + arg_max: Method[Self] + arg_min: Method[Self] + arg_true: Method[Self] + clip: Method[Self] + count: Method[Self] + cum_max: Method[Self] + cum_min: Method[Self] + cum_prod: Method[Self] + cum_sum: Method[Self] + diff: Method[Self] + drop_nulls: Method[Self] + exp: Method[Self] + fill_null: Method[Self] + gather_every: Method[Self] + head: Method[Self] + is_finite: Method[Self] + is_first_distinct: Method[Self] + is_in: Method[Self] + is_last_distinct: Method[Self] + is_null: Method[Self] + is_unique: Method[Self] + len: Method[Self] + log: Method[Self] + max: Method[Self] + mean: Method[Self] + median: Method[Self] + min: Method[Self] + mode: Method[Self] + n_unique: Method[Self] + null_count: Method[Self] + quantile: Method[Self] + rank: Method[Self] + round: Method[Self] + sample: Method[Self] + shift: Method[Self] + skew: Method[Self] + std: Method[Self] + sum: Method[Self] + sort: Method[Self] + tail: Method[Self] + unique: Method[Self] + var: Method[Self] + + +class PolarsExprDateTimeNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._compliant_expr = expr + + def truncate(self, every: str) -> PolarsExpr: + parse_interval_string(every) # Ensure consistent error message is raised. + return self._compliant_expr._with_native( + self._compliant_expr.native.dt.truncate(every) + ) + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_expr._with_native( + getattr(self._compliant_expr.native.dt, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprStringNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._compliant_expr = expr + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_expr._with_native( + getattr(self._compliant_expr.native.str, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprCatNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._compliant_expr = expr + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_expr._with_native( + getattr(self._compliant_expr.native.cat, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprNameNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._compliant_expr = expr + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_expr._with_native( + getattr(self._compliant_expr.native.name, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprListNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._expr = expr + + def len(self) -> PolarsExpr: + native_expr = self._expr._native_expr + native_result = native_expr.list.len() + + if self._expr._backend_version < (1, 16): # pragma: no cover + native_result = ( + pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32()) + ) + elif self._expr._backend_version < (1, 17): # pragma: no cover + native_result = native_result.cast(pl.UInt32()) + + return self._expr._with_native(native_result) + + # TODO(FBruzzesi): Remove `pragma: no cover` once other namespace methods are added + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._expr._with_native( + getattr(self._expr.native.list, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprStructNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._expr = expr + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._expr._with_native( + getattr(self._expr.native.struct, attr)(*pos, **kwds) + ) + + return func diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/group_by.py b/venv/lib/python3.8/site-packages/narwhals/_polars/group_by.py new file mode 100644 index 0000000..e29c3e2 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/group_by.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterator, Sequence, cast + +from narwhals._utils import is_sequence_of + +if TYPE_CHECKING: + from polars.dataframe.group_by import GroupBy as NativeGroupBy + from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy + + from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame + from narwhals._polars.expr import PolarsExpr + + +class PolarsGroupBy: + _compliant_frame: PolarsDataFrame + _grouped: NativeGroupBy + _drop_null_keys: bool + _output_names: Sequence[str] + + @property + def compliant(self) -> PolarsDataFrame: + return self._compliant_frame + + def __init__( + self, + df: PolarsDataFrame, + keys: Sequence[PolarsExpr] | Sequence[str], + /, + *, + drop_null_keys: bool, + ) -> None: + self._keys = list(keys) + self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df + self._grouped = ( + self.compliant.native.group_by(keys) + if is_sequence_of(keys, str) + else self.compliant.native.group_by(arg.native for arg in keys) + ) + + def agg(self, *aggs: PolarsExpr) -> PolarsDataFrame: + agg_result = self._grouped.agg(arg.native for arg in aggs) + return self.compliant._with_native(agg_result) + + def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]: + for key, df in self._grouped: + yield tuple(cast("str", key)), self.compliant._with_native(df) + + +class PolarsLazyGroupBy: + _compliant_frame: PolarsLazyFrame + _grouped: NativeLazyGroupBy + _drop_null_keys: bool + _output_names: Sequence[str] + + @property + def compliant(self) -> PolarsLazyFrame: + return self._compliant_frame + + def __init__( + self, + df: PolarsLazyFrame, + keys: Sequence[PolarsExpr] | Sequence[str], + /, + *, + drop_null_keys: bool, + ) -> None: + self._keys = list(keys) + self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df + self._grouped = ( + self.compliant.native.group_by(keys) + if is_sequence_of(keys, str) + else self.compliant.native.group_by(arg.native for arg in keys) + ) + + def agg(self, *aggs: PolarsExpr) -> PolarsLazyFrame: + agg_result = self._grouped.agg(arg.native for arg in aggs) + return self.compliant._with_native(agg_result) diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_polars/namespace.py new file mode 100644 index 0000000..4dec34c --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/namespace.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import operator +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Literal, + Mapping, + Sequence, + cast, + overload, +) + +import polars as pl + +from narwhals._polars.expr import PolarsExpr +from narwhals._polars.series import PolarsSeries +from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype +from narwhals._utils import Implementation, requires +from narwhals.dependencies import is_numpy_array_2d +from narwhals.dtypes import DType + +if TYPE_CHECKING: + from datetime import timezone + + from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen + from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame + from narwhals._polars.typing import FrameT + from narwhals._utils import Version, _FullContext + from narwhals.schema import Schema + from narwhals.typing import Into1DArray, IntoDType, TimeUnit, _2DArray + + +class PolarsNamespace: + all: Method[PolarsExpr] + col: Method[PolarsExpr] + exclude: Method[PolarsExpr] + all_horizontal: Method[PolarsExpr] + any_horizontal: Method[PolarsExpr] + sum_horizontal: Method[PolarsExpr] + min_horizontal: Method[PolarsExpr] + max_horizontal: Method[PolarsExpr] + + # NOTE: `pyright` accepts, `mypy` doesn't highlight the issue + # error: Type argument "PolarsExpr" of "CompliantWhen" must be a subtype of "CompliantExpr[Any, Any]" + when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]] # type: ignore[type-var] + + def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None: + self._backend_version = backend_version + self._implementation = Implementation.POLARS + self._version = version + + def __getattr__(self, attr: str) -> Any: + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._expr( + getattr(pl, attr)(*pos, **kwds), + version=self._version, + backend_version=self._backend_version, + ) + + return func + + @property + def _dataframe(self) -> type[PolarsDataFrame]: + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame + + @property + def _lazyframe(self) -> type[PolarsLazyFrame]: + from narwhals._polars.dataframe import PolarsLazyFrame + + return PolarsLazyFrame + + @property + def _expr(self) -> type[PolarsExpr]: + return PolarsExpr + + @property + def _series(self) -> type[PolarsSeries]: + return PolarsSeries + + @overload + def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ... + @overload + def from_native(self, data: pl.LazyFrame, /) -> PolarsLazyFrame: ... + @overload + def from_native(self, data: pl.Series, /) -> PolarsSeries: ... + def from_native( + self, data: pl.DataFrame | pl.LazyFrame | pl.Series | Any, / + ) -> PolarsDataFrame | PolarsLazyFrame | PolarsSeries: + if self._dataframe._is_native(data): + return self._dataframe.from_native(data, context=self) + elif self._series._is_native(data): + return self._series.from_native(data, context=self) + elif self._lazyframe._is_native(data): + return self._lazyframe.from_native(data, context=self) + else: # pragma: no cover + msg = f"Unsupported type: {type(data).__name__!r}" + raise TypeError(msg) + + @overload + def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> PolarsSeries: ... + + @overload + def from_numpy( + self, + data: _2DArray, + /, + schema: Mapping[str, DType] | Schema | Sequence[str] | None, + ) -> PolarsDataFrame: ... + + def from_numpy( + self, + data: Into1DArray | _2DArray, + /, + schema: Mapping[str, DType] | Schema | Sequence[str] | None = None, + ) -> PolarsDataFrame | PolarsSeries: + if is_numpy_array_2d(data): + return self._dataframe.from_numpy(data, schema=schema, context=self) + return self._series.from_numpy(data, context=self) # pragma: no cover + + @requires.backend_version( + (1, 0, 0), "Please use `col` for columns selection instead." + ) + def nth(self, *indices: int) -> PolarsExpr: + return self._expr( + pl.nth(*indices), version=self._version, backend_version=self._backend_version + ) + + def len(self) -> PolarsExpr: + if self._backend_version < (0, 20, 5): + return self._expr( + pl.count().alias("len"), + version=self._version, + backend_version=self._backend_version, + ) + return self._expr( + pl.len(), version=self._version, backend_version=self._backend_version + ) + + def concat( + self, + items: Iterable[FrameT], + *, + how: Literal["vertical", "horizontal", "diagonal"], + ) -> PolarsDataFrame | PolarsLazyFrame: + result = pl.concat((item.native for item in items), how=how) + if isinstance(result, pl.DataFrame): + return self._dataframe( + result, backend_version=self._backend_version, version=self._version + ) + return self._lazyframe.from_native(result, context=self) + + def lit(self, value: Any, dtype: IntoDType | None) -> PolarsExpr: + if dtype is not None: + return self._expr( + pl.lit( + value, + dtype=narwhals_to_native_dtype( + dtype, self._version, self._backend_version + ), + ), + version=self._version, + backend_version=self._backend_version, + ) + return self._expr( + pl.lit(value), version=self._version, backend_version=self._backend_version + ) + + def mean_horizontal(self, *exprs: PolarsExpr) -> PolarsExpr: + if self._backend_version < (0, 20, 8): + return self._expr( + pl.sum_horizontal(e._native_expr for e in exprs) + / pl.sum_horizontal(1 - e.is_null()._native_expr for e in exprs), + version=self._version, + backend_version=self._backend_version, + ) + + return self._expr( + pl.mean_horizontal(e._native_expr for e in exprs), + version=self._version, + backend_version=self._backend_version, + ) + + def concat_str( + self, *exprs: PolarsExpr, separator: str, ignore_nulls: bool + ) -> PolarsExpr: + pl_exprs: list[pl.Expr] = [expr._native_expr for expr in exprs] + + if self._backend_version < (0, 20, 6): + null_mask = [expr.is_null() for expr in pl_exprs] + sep = pl.lit(separator) + + if not ignore_nulls: + null_mask_result = pl.any_horizontal(*null_mask) + output_expr = pl.reduce( + lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value] + pl_exprs, + ) + result = pl.when(~null_mask_result).then(output_expr) + else: + init_value, *values = [ + pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String())) + for expr, nm in zip(pl_exprs, null_mask) + ] + separators = [ + pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1] + ] + + result = pl.fold( # type: ignore[assignment] + acc=init_value, + function=operator.add, + exprs=[s + v for s, v in zip(separators, values)], + ) + + return self._expr( + result, version=self._version, backend_version=self._backend_version + ) + + return self._expr( + pl.concat_str(pl_exprs, separator=separator, ignore_nulls=ignore_nulls), + version=self._version, + backend_version=self._backend_version, + ) + + # NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`) + # 1. Others have lots of private stuff for code reuse + # i. None of that is useful here + # 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr` + @property + def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]: + return cast( + "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", + PolarsSelectorNamespace(self), + ) + + +class PolarsSelectorNamespace: + def __init__(self, context: _FullContext, /) -> None: + self._implementation = context._implementation + self._backend_version = context._backend_version + self._version = context._version + + def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr: + native_dtypes = [ + narwhals_to_native_dtype( + dtype, self._version, self._backend_version + ).__class__ + if isinstance(dtype, type) and issubclass(dtype, DType) + else narwhals_to_native_dtype(dtype, self._version, self._backend_version) + for dtype in dtypes + ] + return PolarsExpr( + pl.selectors.by_dtype(native_dtypes), + version=self._version, + backend_version=self._backend_version, + ) + + def matches(self, pattern: str) -> PolarsExpr: + return PolarsExpr( + pl.selectors.matches(pattern=pattern), + version=self._version, + backend_version=self._backend_version, + ) + + def numeric(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.numeric(), + version=self._version, + backend_version=self._backend_version, + ) + + def boolean(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.boolean(), + version=self._version, + backend_version=self._backend_version, + ) + + def string(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.string(), + version=self._version, + backend_version=self._backend_version, + ) + + def categorical(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.categorical(), + version=self._version, + backend_version=self._backend_version, + ) + + def all(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.all(), + version=self._version, + backend_version=self._backend_version, + ) + + def datetime( + self, + time_unit: TimeUnit | Iterable[TimeUnit] | None, + time_zone: str | timezone | Iterable[str | timezone | None] | None, + ) -> PolarsExpr: + return PolarsExpr( + pl.selectors.datetime(time_unit=time_unit, time_zone=time_zone), # type: ignore[arg-type] + version=self._version, + backend_version=self._backend_version, + ) diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/series.py b/venv/lib/python3.8/site-packages/narwhals/_polars/series.py new file mode 100644 index 0000000..2a4325e --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/series.py @@ -0,0 +1,757 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Iterator, + Mapping, + Sequence, + cast, + overload, +) + +import polars as pl + +from narwhals._polars.utils import ( + catch_polars_exception, + extract_args_kwargs, + extract_native, + narwhals_to_native_dtype, + native_to_narwhals_dtype, +) +from narwhals._utils import Implementation, requires, validate_backend_version +from narwhals.dependencies import is_numpy_array_1d + +if TYPE_CHECKING: + from types import ModuleType + from typing import TypeVar + + import pandas as pd + import pyarrow as pa + from typing_extensions import Self, TypeIs + + from narwhals._polars.dataframe import Method, PolarsDataFrame + from narwhals._polars.expr import PolarsExpr + from narwhals._polars.namespace import PolarsNamespace + from narwhals._utils import Version, _FullContext + from narwhals.dtypes import DType + from narwhals.series import Series + from narwhals.typing import Into1DArray, IntoDType, MultiIndexSelector, _1DArray + + T = TypeVar("T") + + +# Series methods where PolarsSeries just defers to Polars.Series directly. +INHERITED_METHODS = frozenset( + [ + "__add__", + "__and__", + "__floordiv__", + "__invert__", + "__iter__", + "__mod__", + "__mul__", + "__or__", + "__pow__", + "__radd__", + "__rand__", + "__rfloordiv__", + "__rmod__", + "__rmul__", + "__ror__", + "__rsub__", + "__rtruediv__", + "__sub__", + "__truediv__", + "abs", + "all", + "any", + "arg_max", + "arg_min", + "arg_true", + "clip", + "count", + "cum_max", + "cum_min", + "cum_prod", + "cum_sum", + "diff", + "drop_nulls", + "exp", + "fill_null", + "filter", + "gather_every", + "head", + "is_between", + "is_finite", + "is_first_distinct", + "is_in", + "is_last_distinct", + "is_null", + "is_sorted", + "is_unique", + "item", + "len", + "log", + "max", + "mean", + "min", + "mode", + "n_unique", + "null_count", + "quantile", + "rank", + "round", + "sample", + "shift", + "skew", + "std", + "sum", + "tail", + "to_arrow", + "to_frame", + "to_list", + "to_pandas", + "unique", + "var", + "zip_with", + ] +) + + +class PolarsSeries: + def __init__( + self, series: pl.Series, *, backend_version: tuple[int, ...], version: Version + ) -> None: + self._native_series: pl.Series = series + self._backend_version = backend_version + self._implementation = Implementation.POLARS + self._version = version + validate_backend_version(self._implementation, self._backend_version) + + def __repr__(self) -> str: # pragma: no cover + return "PolarsSeries" + + def __narwhals_namespace__(self) -> PolarsNamespace: + from narwhals._polars.namespace import PolarsNamespace + + return PolarsNamespace( + backend_version=self._backend_version, version=self._version + ) + + def __narwhals_series__(self) -> Self: + return self + + def __native_namespace__(self) -> ModuleType: + if self._implementation is Implementation.POLARS: + return self._implementation.to_native_namespace() + + msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover + raise AssertionError(msg) + + def _with_version(self, version: Version) -> Self: + return self.__class__( + self.native, backend_version=self._backend_version, version=version + ) + + @classmethod + def from_iterable( + cls, + data: Iterable[Any], + *, + context: _FullContext, + name: str = "", + dtype: IntoDType | None = None, + ) -> Self: + version = context._version + backend_version = context._backend_version + dtype_pl = ( + narwhals_to_native_dtype(dtype, version, backend_version) if dtype else None + ) + # NOTE: `Iterable` is fine, annotation is overly narrow + # https://github.com/pola-rs/polars/blob/82d57a4ee41f87c11ca1b1af15488459727efdd7/py-polars/polars/series/series.py#L332-L333 + native = pl.Series(name=name, values=cast("Sequence[Any]", data), dtype=dtype_pl) + return cls.from_native(native, context=context) + + @staticmethod + def _is_native(obj: pl.Series | Any) -> TypeIs[pl.Series]: + return isinstance(obj, pl.Series) + + @classmethod + def from_native(cls, data: pl.Series, /, *, context: _FullContext) -> Self: + return cls( + data, backend_version=context._backend_version, version=context._version + ) + + @classmethod + def from_numpy(cls, data: Into1DArray, /, *, context: _FullContext) -> Self: + native = pl.Series(data if is_numpy_array_1d(data) else [data]) + return cls.from_native(native, context=context) + + def to_narwhals(self) -> Series[pl.Series]: + return self._version.series(self, level="full") + + def _with_native(self, series: pl.Series) -> Self: + return self.__class__( + series, backend_version=self._backend_version, version=self._version + ) + + @overload + def _from_native_object(self, series: pl.Series) -> Self: ... + + @overload + def _from_native_object(self, series: pl.DataFrame) -> PolarsDataFrame: ... + + @overload + def _from_native_object(self, series: T) -> T: ... + + def _from_native_object( + self, series: pl.Series | pl.DataFrame | T + ) -> Self | PolarsDataFrame | T: + if self._is_native(series): + return self._with_native(series) + if isinstance(series, pl.DataFrame): + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame.from_native(series, context=self) + # scalar + return series + + def _to_expr(self) -> PolarsExpr: + return self.__narwhals_namespace__()._expr._from_series(self) + + def __getattr__(self, attr: str) -> Any: + if attr not in INHERITED_METHODS: + msg = f"{self.__class__.__name__} has not attribute '{attr}'." + raise AttributeError(msg) + + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._from_native_object(getattr(self.native, attr)(*pos, **kwds)) + + return func + + def __len__(self) -> int: + return len(self.native) + + @property + def name(self) -> str: + return self.native.name + + @property + def dtype(self) -> DType: + return native_to_narwhals_dtype( + self.native.dtype, self._version, self._backend_version + ) + + @property + def native(self) -> pl.Series: + return self._native_series + + def alias(self, name: str) -> Self: + return self._from_native_object(self.native.alias(name)) + + def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self: + if isinstance(item, PolarsSeries): + return self._from_native_object(self.native.__getitem__(item.native)) + return self._from_native_object(self.native.__getitem__(item)) + + def cast(self, dtype: IntoDType) -> Self: + dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) + return self._with_native(self.native.cast(dtype_pl)) + + @requires.backend_version((1,)) + def replace_strict( + self, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any], + *, + return_dtype: IntoDType | None, + ) -> Self: + ser = self.native + dtype = ( + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) + if return_dtype + else None + ) + return self._with_native(ser.replace_strict(old, new, return_dtype=dtype)) + + def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: + return self.__array__(dtype, copy=copy) + + def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: + if self._backend_version < (0, 20, 29): + return self.native.__array__(dtype=dtype) + return self.native.__array__(dtype=dtype, copy=copy) + + def __eq__(self, other: object) -> Self: # type: ignore[override] + return self._with_native(self.native.__eq__(extract_native(other))) + + def __ne__(self, other: object) -> Self: # type: ignore[override] + return self._with_native(self.native.__ne__(extract_native(other))) + + # NOTE: `pyright` is being reasonable here + def __ge__(self, other: Any) -> Self: + return self._with_native(self.native.__ge__(extract_native(other))) # pyright: ignore[reportArgumentType] + + def __gt__(self, other: Any) -> Self: + return self._with_native(self.native.__gt__(extract_native(other))) # pyright: ignore[reportArgumentType] + + def __le__(self, other: Any) -> Self: + return self._with_native(self.native.__le__(extract_native(other))) # pyright: ignore[reportArgumentType] + + def __lt__(self, other: Any) -> Self: + return self._with_native(self.native.__lt__(extract_native(other))) # pyright: ignore[reportArgumentType] + + def __rpow__(self, other: PolarsSeries | Any) -> Self: + result = self.native.__rpow__(extract_native(other)) + if self._backend_version < (1, 16, 1): + # Explicitly set alias to work around https://github.com/pola-rs/polars/issues/20071 + result = result.alias(self.name) + return self._with_native(result) + + def is_nan(self) -> Self: + try: + native_is_nan = self.native.is_nan() + except Exception as e: # noqa: BLE001 + raise catch_polars_exception(e, self._backend_version) from None + if self._backend_version < (1, 18): # pragma: no cover + select = pl.when(self.native.is_not_null()).then(native_is_nan) + return self._with_native(pl.select(select)[self.name]) + return self._with_native(native_is_nan) + + def median(self) -> Any: + from narwhals.exceptions import InvalidOperationError + + if not self.dtype.is_numeric(): + msg = "`median` operation not supported for non-numeric input type." + raise InvalidOperationError(msg) + + return self.native.median() + + def to_dummies(self, *, separator: str, drop_first: bool) -> PolarsDataFrame: + from narwhals._polars.dataframe import PolarsDataFrame + + if self._backend_version < (0, 20, 15): + has_nulls = self.native.is_null().any() + result = self.native.to_dummies(separator=separator) + output_columns = result.columns + if drop_first: + _ = output_columns.pop(int(has_nulls)) + + result = result.select(output_columns) + else: + result = self.native.to_dummies(separator=separator, drop_first=drop_first) + result = result.with_columns(pl.all().cast(pl.Int8)) + return PolarsDataFrame.from_native(result, context=self) + + def ewm_mean( + self, + *, + com: float | None, + span: float | None, + half_life: float | None, + alpha: float | None, + adjust: bool, + min_samples: int, + ignore_nulls: bool, + ) -> Self: + extra_kwargs = ( + {"min_periods": min_samples} + if self._backend_version < (1, 21, 0) + else {"min_samples": min_samples} + ) + + native_result = self.native.ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + ignore_nulls=ignore_nulls, + **extra_kwargs, + ) + if self._backend_version < (1,): # pragma: no cover + return self._with_native( + pl.select( + pl.when(~self.native.is_null()).then(native_result).otherwise(None) + )[self.native.name] + ) + + return self._with_native(native_result) + + @requires.backend_version((1,)) + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + extra_kwargs: dict[str, Any] = ( + {"min_periods": min_samples} + if self._backend_version < (1, 21, 0) + else {"min_samples": min_samples} + ) + return self._with_native( + self.native.rolling_var( + window_size=window_size, center=center, ddof=ddof, **extra_kwargs + ) + ) + + @requires.backend_version((1,)) + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + extra_kwargs: dict[str, Any] = ( + {"min_periods": min_samples} + if self._backend_version < (1, 21, 0) + else {"min_samples": min_samples} + ) + return self._with_native( + self.native.rolling_std( + window_size=window_size, center=center, ddof=ddof, **extra_kwargs + ) + ) + + def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: + extra_kwargs: dict[str, Any] = ( + {"min_periods": min_samples} + if self._backend_version < (1, 21, 0) + else {"min_samples": min_samples} + ) + return self._with_native( + self.native.rolling_sum( + window_size=window_size, center=center, **extra_kwargs + ) + ) + + def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: + extra_kwargs: dict[str, Any] = ( + {"min_periods": min_samples} + if self._backend_version < (1, 21, 0) + else {"min_samples": min_samples} + ) + return self._with_native( + self.native.rolling_mean( + window_size=window_size, center=center, **extra_kwargs + ) + ) + + def sort(self, *, descending: bool, nulls_last: bool) -> Self: + if self._backend_version < (0, 20, 6): + result = self.native.sort(descending=descending) + + if nulls_last: + is_null = result.is_null() + result = pl.concat([result.filter(~is_null), result.filter(is_null)]) + else: + result = self.native.sort(descending=descending, nulls_last=nulls_last) + + return self._with_native(result) + + def scatter(self, indices: int | Sequence[int], values: Any) -> Self: + s = self.native.clone().scatter(indices, extract_native(values)) + return self._with_native(s) + + def value_counts( + self, *, sort: bool, parallel: bool, name: str | None, normalize: bool + ) -> PolarsDataFrame: + from narwhals._polars.dataframe import PolarsDataFrame + + if self._backend_version < (1, 0, 0): + value_name_ = name or ("proportion" if normalize else "count") + + result = self.native.value_counts(sort=sort, parallel=parallel).select( + **{ + (self.native.name): pl.col(self.native.name), + value_name_: pl.col("count") / pl.sum("count") + if normalize + else pl.col("count"), + } + ) + else: + result = self.native.value_counts( + sort=sort, parallel=parallel, name=name, normalize=normalize + ) + return PolarsDataFrame.from_native(result, context=self) + + def cum_count(self, *, reverse: bool) -> Self: + if self._backend_version < (0, 20, 4): + not_null_series = ~self.native.is_null() + result = not_null_series.cum_sum(reverse=reverse) + else: + result = self.native.cum_count(reverse=reverse) + + return self._with_native(result) + + def __contains__(self, other: Any) -> bool: + try: + return self.native.__contains__(other) + except Exception as e: # noqa: BLE001 + raise catch_polars_exception(e, self._backend_version) from None + + def hist( # noqa: C901, PLR0912 + self, + bins: list[float | int] | None, + *, + bin_count: int | None, + include_breakpoint: bool, + ) -> PolarsDataFrame: + from narwhals._polars.dataframe import PolarsDataFrame + + if (bins is not None and len(bins) <= 1) or (bin_count == 0): # pragma: no cover + data: list[pl.Series] = [] + if include_breakpoint: + data.append(pl.Series("breakpoint", [], dtype=pl.Float64)) + data.append(pl.Series("count", [], dtype=pl.UInt32)) + return PolarsDataFrame.from_native(pl.DataFrame(data), context=self) + + if self.native.count() < 1: + data_dict: dict[str, Sequence[Any] | pl.Series] + if bins is not None: + data_dict = { + "breakpoint": bins[1:], + "count": pl.zeros(n=len(bins) - 1, dtype=pl.Int64, eager=True), + } + elif (bin_count is not None) and bin_count == 1: + data_dict = {"breakpoint": [1.0], "count": [0]} + elif (bin_count is not None) and bin_count > 1: + data_dict = { + "breakpoint": pl.int_range(1, bin_count + 1, eager=True) / bin_count, + "count": pl.zeros(n=bin_count, dtype=pl.Int64, eager=True), + } + else: # pragma: no cover + msg = ( + "congratulations, you entered unreachable code - please report a bug" + ) + raise AssertionError(msg) + if not include_breakpoint: + del data_dict["breakpoint"] + return PolarsDataFrame.from_native(pl.DataFrame(data_dict), context=self) + + # polars <1.15 does not adjust the bins when they have equivalent min/max + # polars <1.5 with bin_count=... + # returns bins that range from -inf to +inf and has bin_count + 1 bins. + # for compat: convert `bin_count=` call to `bins=` + if (self._backend_version < (1, 15)) and ( + bin_count is not None + ): # pragma: no cover + lower = cast("float", self.native.min()) + upper = cast("float", self.native.max()) + if lower == upper: + width = 1 / bin_count + lower -= 0.5 + upper += 0.5 + else: + width = (upper - lower) / bin_count + + bins = (pl.int_range(0, bin_count + 1, eager=True) * width + lower).to_list() + bin_count = None + + # Polars inconsistently handles NaN values when computing histograms + # against predefined bins: https://github.com/pola-rs/polars/issues/21082 + series = self.native + if self._backend_version < (1, 15) or bins is not None: + series = series.set(series.is_nan(), None) + + df = series.hist( + bins, + bin_count=bin_count, + include_category=False, + include_breakpoint=include_breakpoint, + ) + + if not include_breakpoint: + df.columns = ["count"] + + if self._backend_version < (1, 0) and include_breakpoint: + df = df.rename({"break_point": "breakpoint"}) + + # polars<1.15 implicitly adds -inf and inf to either end of bins + if self._backend_version < (1, 15) and bins is not None: # pragma: no cover + r = pl.int_range(0, len(df)) + df = df.filter((r > 0) & (r < len(df) - 1)) + + # polars<1.27 makes the lowest bin a left/right closed interval. + if self._backend_version < (1, 27) and bins is not None: + df[0, "count"] += (series == bins[0]).sum() + + return PolarsDataFrame.from_native(df, context=self) + + def to_polars(self) -> pl.Series: + return self.native + + @property + def dt(self) -> PolarsSeriesDateTimeNamespace: + return PolarsSeriesDateTimeNamespace(self) + + @property + def str(self) -> PolarsSeriesStringNamespace: + return PolarsSeriesStringNamespace(self) + + @property + def cat(self) -> PolarsSeriesCatNamespace: + return PolarsSeriesCatNamespace(self) + + @property + def struct(self) -> PolarsSeriesStructNamespace: + return PolarsSeriesStructNamespace(self) + + __add__: Method[Self] + __and__: Method[Self] + __floordiv__: Method[Self] + __invert__: Method[Self] + __iter__: Method[Iterator[Any]] + __mod__: Method[Self] + __mul__: Method[Self] + __or__: Method[Self] + __pow__: Method[Self] + __radd__: Method[Self] + __rand__: Method[Self] + __rfloordiv__: Method[Self] + __rmod__: Method[Self] + __rmul__: Method[Self] + __ror__: Method[Self] + __rsub__: Method[Self] + __rtruediv__: Method[Self] + __sub__: Method[Self] + __truediv__: Method[Self] + abs: Method[Self] + all: Method[bool] + any: Method[bool] + arg_max: Method[int] + arg_min: Method[int] + arg_true: Method[Self] + clip: Method[Self] + count: Method[int] + cum_max: Method[Self] + cum_min: Method[Self] + cum_prod: Method[Self] + cum_sum: Method[Self] + diff: Method[Self] + drop_nulls: Method[Self] + exp: Method[Self] + fill_null: Method[Self] + filter: Method[Self] + gather_every: Method[Self] + head: Method[Self] + is_between: Method[Self] + is_finite: Method[Self] + is_first_distinct: Method[Self] + is_in: Method[Self] + is_last_distinct: Method[Self] + is_null: Method[Self] + is_sorted: Method[bool] + is_unique: Method[Self] + item: Method[Any] + len: Method[int] + log: Method[Self] + max: Method[Any] + mean: Method[float] + min: Method[Any] + mode: Method[Self] + n_unique: Method[int] + null_count: Method[int] + quantile: Method[float] + rank: Method[Self] + round: Method[Self] + sample: Method[Self] + shift: Method[Self] + skew: Method[float | None] + std: Method[float] + sum: Method[float] + tail: Method[Self] + to_arrow: Method[pa.Array[Any]] + to_frame: Method[PolarsDataFrame] + to_list: Method[list[Any]] + to_pandas: Method[pd.Series[Any]] + unique: Method[Self] + var: Method[float] + zip_with: Method[Self] + + @property + def list(self) -> PolarsSeriesListNamespace: + return PolarsSeriesListNamespace(self) + + +class PolarsSeriesDateTimeNamespace: + def __init__(self, series: PolarsSeries) -> None: + self._compliant_series = series + + def __getattr__(self, attr: str) -> Any: + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_series._with_native( + getattr(self._compliant_series.native.dt, attr)(*pos, **kwds) + ) + + return func + + +class PolarsSeriesStringNamespace: + def __init__(self, series: PolarsSeries) -> None: + self._compliant_series = series + + def __getattr__(self, attr: str) -> Any: + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_series._with_native( + getattr(self._compliant_series.native.str, attr)(*pos, **kwds) + ) + + return func + + +class PolarsSeriesCatNamespace: + def __init__(self, series: PolarsSeries) -> None: + self._compliant_series = series + + def __getattr__(self, attr: str) -> Any: + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_series._with_native( + getattr(self._compliant_series.native.cat, attr)(*pos, **kwds) + ) + + return func + + +class PolarsSeriesListNamespace: + def __init__(self, series: PolarsSeries) -> None: + self._series = series + + def len(self) -> PolarsSeries: + native_series = self._series.native + native_result = native_series.list.len() + + if self._series._backend_version < (1, 16): # pragma: no cover + native_result = pl.select( + pl.when(~native_series.is_null()).then(native_result).otherwise(None) + )[native_series.name].cast(pl.UInt32()) + + elif self._series._backend_version < (1, 17): # pragma: no cover + native_result = native_series.cast(pl.UInt32()) + + return self._series._with_native(native_result) + + # TODO(FBruzzesi): Remove `pragma: no cover` once other namespace methods are added + def __getattr__(self, attr: str) -> Any: # pragma: no cover + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._series._with_native( + getattr(self._series.native.list, attr)(*pos, **kwds) + ) + + return func + + +class PolarsSeriesStructNamespace: + def __init__(self, series: PolarsSeries) -> None: + self._compliant_series = series + + def __getattr__(self, attr: str) -> Any: + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_series._with_native( + getattr(self._compliant_series.native.struct, attr)(*pos, **kwds) + ) + + return func diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/typing.py b/venv/lib/python3.8/site-packages/narwhals/_polars/typing.py new file mode 100644 index 0000000..88a6f75 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/typing.py @@ -0,0 +1,22 @@ +from __future__ import annotations # pragma: no cover + +from typing import ( + TYPE_CHECKING, # pragma: no cover + Union, # pragma: no cover +) + +if TYPE_CHECKING: + import sys + from typing import TypeVar + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame + from narwhals._polars.expr import PolarsExpr + from narwhals._polars.series import PolarsSeries + + IntoPolarsExpr: TypeAlias = Union[PolarsExpr, PolarsSeries] + FrameT = TypeVar("FrameT", PolarsDataFrame, PolarsLazyFrame) diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py b/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py new file mode 100644 index 0000000..bb15dfb --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Iterator, + Mapping, + TypeVar, + cast, + overload, +) + +import polars as pl + +from narwhals._utils import Version, _DeferredIterable, isinstance_or_issubclass +from narwhals.exceptions import ( + ColumnNotFoundError, + ComputeError, + DuplicateError, + InvalidOperationError, + NarwhalsError, + ShapeError, +) + +if TYPE_CHECKING: + from typing_extensions import TypeIs + + from narwhals._utils import _StoresNative + from narwhals.dtypes import DType + from narwhals.typing import IntoDType + + T = TypeVar("T") + NativeT = TypeVar( + "NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr" + ) + + +@overload +def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ... +@overload +def extract_native(obj: T) -> T: ... +def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T: + return obj.native if _is_compliant_polars(obj) else obj + + +def _is_compliant_polars( + obj: _StoresNative[NativeT] | Any, +) -> TypeIs[_StoresNative[NativeT]]: + from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame + from narwhals._polars.expr import PolarsExpr + from narwhals._polars.series import PolarsSeries + + return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr)) + + +def extract_args_kwargs( + args: Iterable[Any], kwds: Mapping[str, Any], / +) -> tuple[Iterator[Any], dict[str, Any]]: + it_args = (extract_native(arg) for arg in args) + return it_args, {k: extract_native(v) for k, v in kwds.items()} + + +@lru_cache(maxsize=16) +def native_to_narwhals_dtype( # noqa: C901, PLR0912 + dtype: pl.DataType, version: Version, backend_version: tuple[int, ...] +) -> DType: + dtypes = version.dtypes + if dtype == pl.Float64: + return dtypes.Float64() + if dtype == pl.Float32: + return dtypes.Float32() + if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover + # Not available for Polars pre 1.8.0 + return dtypes.Int128() + if dtype == pl.Int64: + return dtypes.Int64() + if dtype == pl.Int32: + return dtypes.Int32() + if dtype == pl.Int16: + return dtypes.Int16() + if dtype == pl.Int8: + return dtypes.Int8() + if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover + # Not available for Polars pre 1.8.0 + return dtypes.UInt128() + if dtype == pl.UInt64: + return dtypes.UInt64() + if dtype == pl.UInt32: + return dtypes.UInt32() + if dtype == pl.UInt16: + return dtypes.UInt16() + if dtype == pl.UInt8: + return dtypes.UInt8() + if dtype == pl.String: + return dtypes.String() + if dtype == pl.Boolean: + return dtypes.Boolean() + if dtype == pl.Object: + return dtypes.Object() + if dtype == pl.Categorical: + return dtypes.Categorical() + if isinstance_or_issubclass(dtype, pl.Enum): + if version is Version.V1: + return dtypes.Enum() # type: ignore[call-arg] + categories = _DeferredIterable( + dtype.categories.to_list + if backend_version >= (0, 20, 4) + else lambda: cast("list[str]", dtype.categories) + ) + return dtypes.Enum(categories) + if dtype == pl.Date: + return dtypes.Date() + if isinstance_or_issubclass(dtype, pl.Datetime): + return ( + dtypes.Datetime() + if dtype is pl.Datetime + else dtypes.Datetime(dtype.time_unit, dtype.time_zone) + ) + if isinstance_or_issubclass(dtype, pl.Duration): + return ( + dtypes.Duration() + if dtype is pl.Duration + else dtypes.Duration(dtype.time_unit) + ) + if isinstance_or_issubclass(dtype, pl.Struct): + fields = [ + dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version)) + for name, tp in dtype + ] + return dtypes.Struct(fields) + if isinstance_or_issubclass(dtype, pl.List): + return dtypes.List( + native_to_narwhals_dtype(dtype.inner, version, backend_version) + ) + if isinstance_or_issubclass(dtype, pl.Array): + outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size + return dtypes.Array( + native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape + ) + if dtype == pl.Decimal: + return dtypes.Decimal() + if dtype == pl.Time: + return dtypes.Time() + if dtype == pl.Binary: + return dtypes.Binary() + return dtypes.Unknown() + + +def narwhals_to_native_dtype( # noqa: C901, PLR0912 + dtype: IntoDType, version: Version, backend_version: tuple[int, ...] +) -> pl.DataType: + dtypes = version.dtypes + if dtype == dtypes.Float64: + return pl.Float64() + if dtype == dtypes.Float32: + return pl.Float32() + if dtype == dtypes.Int128 and hasattr(pl, "Int128"): + # Not available for Polars pre 1.8.0 + return pl.Int128() + if dtype == dtypes.Int64: + return pl.Int64() + if dtype == dtypes.Int32: + return pl.Int32() + if dtype == dtypes.Int16: + return pl.Int16() + if dtype == dtypes.Int8: + return pl.Int8() + if dtype == dtypes.UInt64: + return pl.UInt64() + if dtype == dtypes.UInt32: + return pl.UInt32() + if dtype == dtypes.UInt16: + return pl.UInt16() + if dtype == dtypes.UInt8: + return pl.UInt8() + if dtype == dtypes.String: + return pl.String() + if dtype == dtypes.Boolean: + return pl.Boolean() + if dtype == dtypes.Object: # pragma: no cover + return pl.Object() + if dtype == dtypes.Categorical: + return pl.Categorical() + if isinstance_or_issubclass(dtype, dtypes.Enum): + if version is Version.V1: + msg = "Converting to Enum is not supported in narwhals.stable.v1" + raise NotImplementedError(msg) + if isinstance(dtype, dtypes.Enum): + return pl.Enum(dtype.categories) + msg = "Can not cast / initialize Enum without categories present" + raise ValueError(msg) + if dtype == dtypes.Date: + return pl.Date() + if dtype == dtypes.Time: + return pl.Time() + if dtype == dtypes.Binary: + return pl.Binary() + if dtype == dtypes.Decimal: + msg = "Casting to Decimal is not supported yet." + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Datetime): + return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type] + if isinstance_or_issubclass(dtype, dtypes.Duration): + return pl.Duration(dtype.time_unit) # type: ignore[arg-type] + if isinstance_or_issubclass(dtype, dtypes.List): + return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) + if isinstance_or_issubclass(dtype, dtypes.Struct): + fields = [ + pl.Field( + field.name, + narwhals_to_native_dtype(field.dtype, version, backend_version), + ) + for field in dtype.fields + ] + return pl.Struct(fields) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + size = dtype.size + kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size} + return pl.Array( + narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs + ) + return pl.Unknown() # pragma: no cover + + +def catch_polars_exception( + exception: Exception, backend_version: tuple[int, ...] +) -> NarwhalsError | Exception: + if isinstance(exception, pl.exceptions.ColumnNotFoundError): + return ColumnNotFoundError(str(exception)) + elif isinstance(exception, pl.exceptions.ShapeError): + return ShapeError(str(exception)) + elif isinstance(exception, pl.exceptions.InvalidOperationError): + return InvalidOperationError(str(exception)) + elif isinstance(exception, pl.exceptions.DuplicateError): + return DuplicateError(str(exception)) + elif isinstance(exception, pl.exceptions.ComputeError): + return ComputeError(str(exception)) + if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError): + # Old versions of Polars didn't have PolarsError. + return NarwhalsError(str(exception)) # pragma: no cover + elif backend_version < (1,) and "polars.exceptions" in str( + type(exception) + ): # pragma: no cover + # Last attempt, for old Polars versions. + return NarwhalsError(str(exception)) + # Just return exception as-is. + return exception |
