from __future__ import annotations from typing import ( TYPE_CHECKING, Any, Iterator, Literal, Mapping, Sequence, Sized, cast, overload, ) import polars as pl from narwhals._polars.namespace import PolarsNamespace from narwhals._polars.series import PolarsSeries from narwhals._polars.utils import ( catch_polars_exception, extract_args_kwargs, native_to_narwhals_dtype, ) from narwhals._utils import ( Implementation, _into_arrow_table, check_columns_exist, convert_str_slice_to_int_slice, is_compliant_series, is_index_selector, is_range, is_sequence_like, is_slice_index, is_slice_none, parse_columns_to_drop, parse_version, requires, validate_backend_version, ) from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ColumnNotFoundError if TYPE_CHECKING: from types import ModuleType from typing import Callable, TypeVar import pandas as pd import pyarrow as pa from typing_extensions import Self, TypeAlias, TypeIs from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny from narwhals._polars.expr import PolarsExpr from narwhals._polars.group_by import PolarsGroupBy, PolarsLazyGroupBy from narwhals._translate import IntoArrowTable from narwhals._utils import Version, _FullContext from narwhals.dataframe import DataFrame, LazyFrame from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.typing import ( JoinStrategy, MultiColSelector, MultiIndexSelector, PivotAgg, SingleIndexSelector, _2DArray, ) T = TypeVar("T") R = TypeVar("R") Method: TypeAlias = "Callable[..., R]" """Generic alias representing all methods implemented via `__getattr__`. Where `R` is the return type. """ # DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly. INHERITED_METHODS = frozenset( [ "clone", "drop_nulls", "estimated_size", "explode", "filter", "gather_every", "head", "is_unique", "item", "iter_rows", "join_asof", "rename", "row", "rows", "sample", "select", "sort", "tail", "to_arrow", "to_pandas", "unique", "with_columns", "write_csv", "write_parquet", ] ) class PolarsDataFrame: clone: Method[Self] collect: Method[CompliantDataFrameAny] drop_nulls: Method[Self] estimated_size: Method[int | float] explode: Method[Self] filter: Method[Self] gather_every: Method[Self] item: Method[Any] iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]] is_unique: Method[PolarsSeries] join_asof: Method[Self] rename: Method[Self] row: Method[tuple[Any, ...]] rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]] sample: Method[Self] select: Method[Self] sort: Method[Self] to_arrow: Method[pa.Table] to_pandas: Method[pd.DataFrame] unique: Method[Self] with_columns: Method[Self] # NOTE: `write_csv` requires an `@overload` for `str | None` # Can't do that here 😟 write_csv: Method[Any] write_parquet: Method[None] # CompliantDataFrame _evaluate_aliases: Any def __init__( self, df: pl.DataFrame, *, backend_version: tuple[int, ...], version: Version ) -> None: self._native_frame = df self._backend_version = backend_version self._implementation = Implementation.POLARS self._version = version validate_backend_version(self._implementation, self._backend_version) @classmethod def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: if context._backend_version >= (1, 3): native = pl.DataFrame(data) else: native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context))) return cls.from_native(native, context=context) @classmethod def from_dict( cls, data: Mapping[str, Any], /, *, context: _FullContext, schema: Mapping[str, DType] | Schema | None, ) -> Self: from narwhals.schema import Schema pl_schema = Schema(schema).to_polars() if schema is not None else schema return cls.from_native(pl.from_dict(data, pl_schema), context=context) @staticmethod def _is_native(obj: pl.DataFrame | Any) -> TypeIs[pl.DataFrame]: return isinstance(obj, pl.DataFrame) @classmethod def from_native(cls, data: pl.DataFrame, /, *, context: _FullContext) -> Self: return cls( data, backend_version=context._backend_version, version=context._version ) @classmethod def from_numpy( cls, data: _2DArray, /, *, context: _FullContext, # NOTE: Maybe only `Implementation`? schema: Mapping[str, DType] | Schema | Sequence[str] | None, ) -> Self: from narwhals.schema import Schema pl_schema = ( Schema(schema).to_polars() if isinstance(schema, (Mapping, Schema)) else schema ) return cls.from_native(pl.from_numpy(data, pl_schema), context=context) def to_narwhals(self) -> DataFrame[pl.DataFrame]: return self._version.dataframe(self, level="full") @property def native(self) -> pl.DataFrame: return self._native_frame def __repr__(self) -> str: # pragma: no cover return "PolarsDataFrame" def __narwhals_dataframe__(self) -> Self: return self def __narwhals_namespace__(self) -> PolarsNamespace: return PolarsNamespace( backend_version=self._backend_version, version=self._version ) def __native_namespace__(self) -> ModuleType: if self._implementation is Implementation.POLARS: return self._implementation.to_native_namespace() msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) def _with_version(self, version: Version) -> Self: return self.__class__( self.native, backend_version=self._backend_version, version=version ) def _with_native(self, df: pl.DataFrame) -> Self: return self.__class__( df, backend_version=self._backend_version, version=self._version ) @overload def _from_native_object(self, obj: pl.Series) -> PolarsSeries: ... @overload def _from_native_object(self, obj: pl.DataFrame) -> Self: ... @overload def _from_native_object(self, obj: T) -> T: ... def _from_native_object( self, obj: pl.Series | pl.DataFrame | T ) -> Self | PolarsSeries | T: if isinstance(obj, pl.Series): return PolarsSeries.from_native(obj, context=self) if self._is_native(obj): return self._with_native(obj) # scalar return obj def __len__(self) -> int: return len(self.native) def head(self, n: int) -> Self: return self._with_native(self.native.head(n)) def tail(self, n: int) -> Self: return self._with_native(self.native.tail(n)) def __getattr__(self, attr: str) -> Any: if attr not in INHERITED_METHODS: # pragma: no cover msg = f"{self.__class__.__name__} has not attribute '{attr}'." raise AttributeError(msg) def func(*args: Any, **kwargs: Any) -> Any: pos, kwds = extract_args_kwargs(args, kwargs) try: return self._from_native_object(getattr(self.native, attr)(*pos, **kwds)) except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?" raise ColumnNotFoundError(msg) from e except Exception as e: # noqa: BLE001 raise catch_polars_exception(e, self._backend_version) from None return func def __array__( self, dtype: Any | None = None, *, copy: bool | None = None ) -> _2DArray: if self._backend_version < (0, 20, 28) and copy is not None: msg = "`copy` in `__array__` is only supported for 'polars>=0.20.28'" raise NotImplementedError(msg) if self._backend_version < (0, 20, 28): return self.native.__array__(dtype) return self.native.__array__(dtype) def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.native.to_numpy() def collect_schema(self) -> dict[str, DType]: if self._backend_version < (1,): return { name: native_to_narwhals_dtype( dtype, self._version, self._backend_version ) for name, dtype in self.native.schema.items() } else: collected_schema = self.native.collect_schema() return { name: native_to_narwhals_dtype( dtype, self._version, self._backend_version ) for name, dtype in collected_schema.items() } @property def shape(self) -> tuple[int, int]: return self.native.shape def __getitem__( # noqa: C901, PLR0912 self, item: tuple[ SingleIndexSelector | MultiIndexSelector[PolarsSeries], MultiColSelector[PolarsSeries], ], ) -> Any: rows, columns = item if self._backend_version > (0, 20, 30): rows_native = rows.native if is_compliant_series(rows) else rows columns_native = columns.native if is_compliant_series(columns) else columns selector = rows_native, columns_native selected = self.native.__getitem__(selector) # type: ignore[index] return self._from_native_object(selected) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum # Polars version we support # This mostly mirrors the logic in `EagerDataFrame.__getitem__`. rows = list(rows) if isinstance(rows, tuple) else rows columns = list(columns) if isinstance(columns, tuple) else columns if is_numpy_array_1d(columns): columns = columns.tolist() native = self.native if not is_slice_none(columns): if isinstance(columns, Sized) and len(columns) == 0: return self.select() if is_index_selector(columns): if is_slice_index(columns) or is_range(columns): native = native.select( self.columns[slice(columns.start, columns.stop, columns.step)] ) elif is_compliant_series(columns): native = native[:, columns.native.to_list()] else: native = native[:, columns] elif isinstance(columns, slice): native = native.select( self.columns[ slice(*convert_str_slice_to_int_slice(columns, self.columns)) ] ) elif is_compliant_series(columns): native = native.select(columns.native.to_list()) elif is_sequence_like(columns): native = native.select(columns) else: msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) if not is_slice_none(rows): if isinstance(rows, int): native = native[[rows], :] elif isinstance(rows, (slice, range)): native = native[rows, :] elif is_compliant_series(rows): native = native[rows.native, :] elif is_sequence_like(rows): native = native[rows, :] else: msg = f"Unreachable code, got unexpected type: {type(rows)}" raise AssertionError(msg) return self._with_native(native) def simple_select(self, *column_names: str) -> Self: return self._with_native(self.native.select(*column_names)) def aggregate(self, *exprs: Any) -> Self: return self.select(*exprs) def get_column(self, name: str) -> PolarsSeries: return PolarsSeries.from_native(self.native.get_column(name), context=self) def iter_columns(self) -> Iterator[PolarsSeries]: for series in self.native.iter_columns(): yield PolarsSeries.from_native(series, context=self) @property def columns(self) -> list[str]: return self.native.columns @property def schema(self) -> dict[str, DType]: return { name: native_to_narwhals_dtype(dtype, self._version, self._backend_version) for name, dtype in self.native.schema.items() } def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny: if backend is None or backend is Implementation.POLARS: return PolarsLazyFrame.from_native(self.native.lazy(), context=self) elif backend is Implementation.DUCKDB: import duckdb # ignore-banned-import from narwhals._duckdb.dataframe import DuckDBLazyFrame # NOTE: (F841) is a false positive df = self.native # noqa: F841 return DuckDBLazyFrame( duckdb.table("df"), backend_version=parse_version(duckdb), version=self._version, ) elif backend is Implementation.DASK: import dask # ignore-banned-import import dask.dataframe as dd # ignore-banned-import from narwhals._dask.dataframe import DaskLazyFrame return DaskLazyFrame( dd.from_pandas(self.native.to_pandas()), backend_version=parse_version(dask), version=self._version, ) raise AssertionError # pragma: no cover @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ... @overload def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... def to_dict( self, *, as_series: bool ) -> dict[str, PolarsSeries] | dict[str, list[Any]]: if as_series: return { name: PolarsSeries.from_native(col, context=self) for name, col in self.native.to_dict().items() } else: return self.native.to_dict(as_series=False) def group_by( self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool ) -> PolarsGroupBy: from narwhals._polars.group_by import PolarsGroupBy return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys) def with_row_index(self, name: str) -> Self: if self._backend_version < (0, 20, 4): return self._with_native(self.native.with_row_count(name)) return self._with_native(self.native.with_row_index(name)) def drop(self, columns: Sequence[str], *, strict: bool) -> Self: to_drop = parse_columns_to_drop(self, columns, strict=strict) return self._with_native(self.native.drop(to_drop)) def unpivot( self, on: Sequence[str] | None, index: Sequence[str] | None, variable_name: str, value_name: str, ) -> Self: if self._backend_version < (1, 0, 0): return self._with_native( self.native.melt( id_vars=index, value_vars=on, variable_name=variable_name, value_name=value_name, ) ) return self._with_native( self.native.unpivot( on=on, index=index, variable_name=variable_name, value_name=value_name ) ) @requires.backend_version((1,)) def pivot( self, on: Sequence[str], *, index: Sequence[str] | None, values: Sequence[str] | None, aggregate_function: PivotAgg | None, sort_columns: bool, separator: str, ) -> Self: try: result = self.native.pivot( on, index=index, values=values, aggregate_function=aggregate_function, sort_columns=sort_columns, separator=separator, ) except Exception as e: # noqa: BLE001 raise catch_polars_exception(e, self._backend_version) from None return self._from_native_object(result) def to_polars(self) -> pl.DataFrame: return self.native def join( self, other: Self, *, how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, ) -> Self: how_native = ( "outer" if (self._backend_version < (0, 20, 29) and how == "full") else how ) try: return self._with_native( self.native.join( other=other.native, how=how_native, # type: ignore[arg-type] left_on=left_on, right_on=right_on, suffix=suffix, ) ) except Exception as e: # noqa: BLE001 raise catch_polars_exception(e, self._backend_version) from None def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: return check_columns_exist(subset, available=self.columns) class PolarsLazyFrame: drop_nulls: Method[Self] explode: Method[Self] filter: Method[Self] gather_every: Method[Self] head: Method[Self] join_asof: Method[Self] rename: Method[Self] select: Method[Self] sort: Method[Self] tail: Method[Self] unique: Method[Self] with_columns: Method[Self] # CompliantLazyFrame _evaluate_expr: Any _evaluate_window_expr: Any _evaluate_aliases: Any def __init__( self, df: pl.LazyFrame, *, backend_version: tuple[int, ...], version: Version ) -> None: self._native_frame = df self._backend_version = backend_version self._implementation = Implementation.POLARS self._version = version validate_backend_version(self._implementation, self._backend_version) @staticmethod def _is_native(obj: pl.LazyFrame | Any) -> TypeIs[pl.LazyFrame]: return isinstance(obj, pl.LazyFrame) @classmethod def from_native(cls, data: pl.LazyFrame, /, *, context: _FullContext) -> Self: return cls( data, backend_version=context._backend_version, version=context._version ) def to_narwhals(self) -> LazyFrame[pl.LazyFrame]: return self._version.lazyframe(self, level="lazy") def __repr__(self) -> str: # pragma: no cover return "PolarsLazyFrame" def __narwhals_lazyframe__(self) -> Self: return self def __narwhals_namespace__(self) -> PolarsNamespace: return PolarsNamespace( backend_version=self._backend_version, version=self._version ) def __native_namespace__(self) -> ModuleType: if self._implementation is Implementation.POLARS: return self._implementation.to_native_namespace() msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) def _with_native(self, df: pl.LazyFrame) -> Self: return self.__class__( df, backend_version=self._backend_version, version=self._version ) def _with_version(self, version: Version) -> Self: return self.__class__( self.native, backend_version=self._backend_version, version=version ) def __getattr__(self, attr: str) -> Any: if attr not in INHERITED_METHODS: # pragma: no cover msg = f"{self.__class__.__name__} has not attribute '{attr}'." raise AttributeError(msg) def func(*args: Any, **kwargs: Any) -> Any: pos, kwds = extract_args_kwargs(args, kwargs) try: return self._with_native(getattr(self.native, attr)(*pos, **kwds)) except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover raise ColumnNotFoundError(str(e)) from e return func def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover yield from self.collect(self._implementation).iter_columns() @property def native(self) -> pl.LazyFrame: return self._native_frame @property def columns(self) -> list[str]: return self.native.columns @property def schema(self) -> dict[str, DType]: schema = self.native.schema return { name: native_to_narwhals_dtype(dtype, self._version, self._backend_version) for name, dtype in schema.items() } def collect_schema(self) -> dict[str, DType]: if self._backend_version < (1,): return { name: native_to_narwhals_dtype( dtype, self._version, self._backend_version ) for name, dtype in self.native.schema.items() } else: try: collected_schema = self.native.collect_schema() except Exception as e: # noqa: BLE001 raise catch_polars_exception(e, self._backend_version) from None return { name: native_to_narwhals_dtype( dtype, self._version, self._backend_version ) for name, dtype in collected_schema.items() } def collect( self, backend: Implementation | None, **kwargs: Any ) -> CompliantDataFrameAny: try: result = self.native.collect(**kwargs) except Exception as e: # noqa: BLE001 raise catch_polars_exception(e, self._backend_version) from None if backend is None or backend is Implementation.POLARS: return PolarsDataFrame.from_native(result, context=self) if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import from narwhals._pandas_like.dataframe import PandasLikeDataFrame return PandasLikeDataFrame( result.to_pandas(), implementation=Implementation.PANDAS, backend_version=parse_version(pd), version=self._version, validate_column_names=False, ) if backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame return ArrowDataFrame( result.to_arrow(), backend_version=parse_version(pa), version=self._version, validate_column_names=False, ) msg = f"Unsupported `backend` value: {backend}" # pragma: no cover raise ValueError(msg) # pragma: no cover def group_by( self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool ) -> PolarsLazyGroupBy: from narwhals._polars.group_by import PolarsLazyGroupBy return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) def with_row_index(self, name: str) -> Self: if self._backend_version < (0, 20, 4): return self._with_native(self.native.with_row_count(name)) return self._with_native(self.native.with_row_index(name)) def drop(self, columns: Sequence[str], *, strict: bool) -> Self: if self._backend_version < (1, 0, 0): return self._with_native(self.native.drop(columns)) return self._with_native(self.native.drop(columns, strict=strict)) def unpivot( self, on: Sequence[str] | None, index: Sequence[str] | None, variable_name: str, value_name: str, ) -> Self: if self._backend_version < (1, 0, 0): return self._with_native( self.native.melt( id_vars=index, value_vars=on, variable_name=variable_name, value_name=value_name, ) ) return self._with_native( self.native.unpivot( on=on, index=index, variable_name=variable_name, value_name=value_name ) ) def simple_select(self, *column_names: str) -> Self: return self._with_native(self.native.select(*column_names)) def aggregate(self, *exprs: Any) -> Self: return self.select(*exprs) def join( self, other: Self, *, how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, ) -> Self: how_native = ( "outer" if (self._backend_version < (0, 20, 29) and how == "full") else how ) return self._with_native( self.native.join( other=other.native, how=how_native, # type: ignore[arg-type] left_on=left_on, right_on=right_on, suffix=suffix, ) ) def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: return check_columns_exist( # pragma: no cover subset, available=self.columns )