From 5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e Mon Sep 17 00:00:00 2001 From: sotech117 Date: Thu, 31 Jul 2025 17:27:24 -0400 Subject: add code for analysis of data --- .../site-packages/narwhals/_arrow/__init__.py | 0 .../site-packages/narwhals/_arrow/dataframe.py | 771 +++++++++++++ .../site-packages/narwhals/_arrow/expr.py | 205 ++++ .../site-packages/narwhals/_arrow/group_by.py | 159 +++ .../site-packages/narwhals/_arrow/namespace.py | 283 +++++ .../site-packages/narwhals/_arrow/selectors.py | 29 + .../site-packages/narwhals/_arrow/series.py | 1183 ++++++++++++++++++++ .../site-packages/narwhals/_arrow/series_cat.py | 18 + .../site-packages/narwhals/_arrow/series_dt.py | 194 ++++ .../site-packages/narwhals/_arrow/series_list.py | 16 + .../site-packages/narwhals/_arrow/series_str.py | 62 + .../site-packages/narwhals/_arrow/series_struct.py | 15 + .../site-packages/narwhals/_arrow/typing.py | 72 ++ .../site-packages/narwhals/_arrow/utils.py | 470 ++++++++ 14 files changed, 3477 insertions(+) create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/__init__.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/dataframe.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/group_by.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/selectors.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/series.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/series_cat.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/series_dt.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/series_list.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/series_str.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/series_struct.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/typing.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py (limited to 'venv/lib/python3.8/site-packages/narwhals/_arrow') diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/__init__.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/dataframe.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/dataframe.py new file mode 100644 index 0000000..19763b9 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/dataframe.py @@ -0,0 +1,771 @@ +from __future__ import annotations + +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Iterator, + Literal, + Mapping, + Sequence, + cast, + overload, +) + +import pyarrow as pa +import pyarrow.compute as pc + +from narwhals._arrow.series import ArrowSeries +from narwhals._arrow.utils import align_series_full_broadcast, native_to_narwhals_dtype +from narwhals._compliant import EagerDataFrame +from narwhals._expression_parsing import ExprKind +from narwhals._utils import ( + Implementation, + Version, + check_column_names_are_unique, + convert_str_slice_to_int_slice, + generate_temporary_column_name, + not_implemented, + parse_columns_to_drop, + parse_version, + scale_bytes, + supports_arrow_c_stream, + validate_backend_version, +) +from narwhals.dependencies import is_numpy_array_1d +from narwhals.exceptions import ShapeError + +if TYPE_CHECKING: + from io import BytesIO + from pathlib import Path + from types import ModuleType + + import pandas as pd + import polars as pl + from typing_extensions import Self, TypeAlias, TypeIs + + from narwhals._arrow.expr import ArrowExpr + from narwhals._arrow.group_by import ArrowGroupBy + from narwhals._arrow.namespace import ArrowNamespace + from narwhals._arrow.typing import ( # type: ignore[attr-defined] + ChunkedArrayAny, + Mask, + Order, + ) + from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny + from narwhals._translate import IntoArrowTable + from narwhals._utils import Version, _FullContext + from narwhals.dtypes import DType + from narwhals.schema import Schema + from narwhals.typing import ( + JoinStrategy, + SizedMultiIndexSelector, + SizedMultiNameSelector, + SizeUnit, + UniqueKeepStrategy, + _1DArray, + _2DArray, + _SliceIndex, + _SliceName, + ) + + JoinType: TypeAlias = Literal[ + "left semi", + "right semi", + "left anti", + "right anti", + "inner", + "left outer", + "right outer", + "full outer", + ] + PromoteOptions: TypeAlias = Literal["none", "default", "permissive"] + + +class ArrowDataFrame(EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table"]): + def __init__( + self, + native_dataframe: pa.Table, + *, + backend_version: tuple[int, ...], + version: Version, + validate_column_names: bool, + ) -> None: + if validate_column_names: + check_column_names_are_unique(native_dataframe.column_names) + self._native_frame = native_dataframe + self._implementation = Implementation.PYARROW + self._backend_version = backend_version + self._version = version + validate_backend_version(self._implementation, self._backend_version) + + @classmethod + def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: + backend_version = context._backend_version + if cls._is_native(data): + native = data + elif backend_version >= (14,) or isinstance(data, Collection): + native = pa.table(data) + elif supports_arrow_c_stream(data): # pragma: no cover + msg = f"'pyarrow>=14.0.0' is required for `from_arrow` for object of type {type(data).__name__!r}." + raise ModuleNotFoundError(msg) + else: # pragma: no cover + msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}." + raise TypeError(msg) + return cls.from_native(native, context=context) + + @classmethod + def from_dict( + cls, + data: Mapping[str, Any], + /, + *, + context: _FullContext, + schema: Mapping[str, DType] | Schema | None, + ) -> Self: + from narwhals.schema import Schema + + pa_schema = Schema(schema).to_arrow() if schema is not None else schema + native = pa.Table.from_pydict(data, schema=pa_schema) + return cls.from_native(native, context=context) + + @staticmethod + def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]: + return isinstance(obj, pa.Table) + + @classmethod + def from_native(cls, data: pa.Table, /, *, context: _FullContext) -> Self: + return cls( + data, + backend_version=context._backend_version, + version=context._version, + validate_column_names=True, + ) + + @classmethod + def from_numpy( + cls, + data: _2DArray, + /, + *, + context: _FullContext, + schema: Mapping[str, DType] | Schema | Sequence[str] | None, + ) -> Self: + from narwhals.schema import Schema + + arrays = [pa.array(val) for val in data.T] + if isinstance(schema, (Mapping, Schema)): + native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow()) + else: + native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema)) + return cls.from_native(native, context=context) + + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._arrow.namespace import ArrowNamespace + + return ArrowNamespace( + backend_version=self._backend_version, version=self._version + ) + + def __native_namespace__(self) -> ModuleType: + if self._implementation is Implementation.PYARROW: + return self._implementation.to_native_namespace() + + msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover + raise AssertionError(msg) + + def __narwhals_dataframe__(self) -> Self: + return self + + def __narwhals_lazyframe__(self) -> Self: + return self + + def _with_version(self, version: Version) -> Self: + return self.__class__( + self.native, + backend_version=self._backend_version, + version=version, + validate_column_names=False, + ) + + def _with_native(self, df: pa.Table, *, validate_column_names: bool = True) -> Self: + return self.__class__( + df, + backend_version=self._backend_version, + version=self._version, + validate_column_names=validate_column_names, + ) + + @property + def shape(self) -> tuple[int, int]: + return self.native.shape + + def __len__(self) -> int: + return len(self.native) + + def row(self, index: int) -> tuple[Any, ...]: + return tuple(col[index] for col in self.native.itercolumns()) + + @overload + def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: ... + + @overload + def rows(self, *, named: Literal[False]) -> list[tuple[Any, ...]]: ... + + @overload + def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ... + + def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: + if not named: + return list(self.iter_rows(named=False, buffer_size=512)) # type: ignore[return-value] + return self.native.to_pylist() + + def iter_columns(self) -> Iterator[ArrowSeries]: + for name, series in zip(self.columns, self.native.itercolumns()): + yield ArrowSeries.from_native(series, context=self, name=name) + + _iter_columns = iter_columns + + def iter_rows( + self, *, named: bool, buffer_size: int + ) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]: + df = self.native + num_rows = df.num_rows + + if not named: + for i in range(0, num_rows, buffer_size): + rows = df[i : i + buffer_size].to_pydict().values() + yield from zip(*rows) + else: + for i in range(0, num_rows, buffer_size): + yield from df[i : i + buffer_size].to_pylist() + + def get_column(self, name: str) -> ArrowSeries: + if not isinstance(name, str): + msg = f"Expected str, got: {type(name)}" + raise TypeError(msg) + return ArrowSeries.from_native(self.native[name], context=self, name=name) + + def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: + return self.native.__array__(dtype, copy=copy) + + def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self: + if len(rows) == 0: + return self._with_native(self.native.slice(0, 0)) + if self._backend_version < (18,) and isinstance(rows, tuple): + rows = list(rows) + return self._with_native(self.native.take(rows)) + + def _gather_slice(self, rows: _SliceIndex | range) -> Self: + start = rows.start or 0 + stop = rows.stop if rows.stop is not None else len(self.native) + if start < 0: + start = len(self.native) + start + if stop < 0: + stop = len(self.native) + stop + if rows.step is not None and rows.step != 1: + msg = "Slicing with step is not supported on PyArrow tables" + raise NotImplementedError(msg) + return self._with_native(self.native.slice(start, stop - start)) + + def _select_slice_name(self, columns: _SliceName) -> Self: + start, stop, step = convert_str_slice_to_int_slice(columns, self.columns) + return self._with_native(self.native.select(self.columns[start:stop:step])) + + def _select_slice_index(self, columns: _SliceIndex | range) -> Self: + return self._with_native( + self.native.select(self.columns[columns.start : columns.stop : columns.step]) + ) + + def _select_multi_index( + self, columns: SizedMultiIndexSelector[ChunkedArrayAny] + ) -> Self: + selector: Sequence[int] + if isinstance(columns, pa.ChunkedArray): + # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` + selector = cast("Sequence[int]", columns.to_pylist()) + # TODO @dangotbanned: Fix upstream, it is actually much narrower + # **Doesn't accept `ndarray`** + elif is_numpy_array_1d(columns): + selector = columns.tolist() + else: + selector = columns + return self._with_native(self.native.select(selector)) + + def _select_multi_name( + self, columns: SizedMultiNameSelector[ChunkedArrayAny] + ) -> Self: + selector: Sequence[str] | _1DArray + if isinstance(columns, pa.ChunkedArray): + # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` + selector = cast("Sequence[str]", columns.to_pylist()) + else: + selector = columns + # NOTE: Fixed in https://github.com/zen-xu/pyarrow-stubs/pull/221 + return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType] + + @property + def schema(self) -> dict[str, DType]: + schema = self.native.schema + return { + name: native_to_narwhals_dtype(dtype, self._version) + for name, dtype in zip(schema.names, schema.types) + } + + def collect_schema(self) -> dict[str, DType]: + return self.schema + + def estimated_size(self, unit: SizeUnit) -> int | float: + sz = self.native.nbytes + return scale_bytes(sz, unit) + + explode = not_implemented() + + @property + def columns(self) -> list[str]: + return self.native.column_names + + def simple_select(self, *column_names: str) -> Self: + return self._with_native( + self.native.select(list(column_names)), validate_column_names=False + ) + + def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: + new_series = self._evaluate_into_exprs(*exprs) + if not new_series: + # return empty dataframe, like Polars does + return self._with_native( + self.native.__class__.from_arrays([]), validate_column_names=False + ) + names = [s.name for s in new_series] + reshaped = align_series_full_broadcast(*new_series) + df = pa.Table.from_arrays([s.native for s in reshaped], names=names) + return self._with_native(df, validate_column_names=True) + + def _extract_comparand(self, other: ArrowSeries) -> ChunkedArrayAny: + length = len(self) + if not other._broadcast: + if (len_other := len(other)) != length: + msg = f"Expected object of length {length}, got: {len_other}." + raise ShapeError(msg) + return other.native + + value = other.native[0] + return pa.chunked_array([pa.repeat(value, length)]) + + def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: + # NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame) + # All `pyarrow` data is immutable, so this is fine + native_frame = self.native + new_columns = self._evaluate_into_exprs(*exprs) + columns = self.columns + + for col_value in new_columns: + col_name = col_value.name + column = self._extract_comparand(col_value) + native_frame = ( + native_frame.set_column(columns.index(col_name), col_name, column=column) + if col_name in columns + else native_frame.append_column(col_name, column=column) + ) + + return self._with_native(native_frame, validate_column_names=False) + + def group_by( + self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool + ) -> ArrowGroupBy: + from narwhals._arrow.group_by import ArrowGroupBy + + return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys) + + def join( + self, + other: Self, + *, + how: JoinStrategy, + left_on: Sequence[str] | None, + right_on: Sequence[str] | None, + suffix: str, + ) -> Self: + how_to_join_map: dict[str, JoinType] = { + "anti": "left anti", + "semi": "left semi", + "inner": "inner", + "left": "left outer", + "full": "full outer", + } + + if how == "cross": + plx = self.__narwhals_namespace__() + key_token = generate_temporary_column_name( + n_bytes=8, columns=[*self.columns, *other.columns] + ) + + return self._with_native( + self.with_columns( + plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) + ) + .native.join( + other.with_columns( + plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) + ).native, + keys=key_token, + right_keys=key_token, + join_type="inner", + right_suffix=suffix, + ) + .drop([key_token]) + ) + + coalesce_keys = how != "full" # polars full join does not coalesce keys + return self._with_native( + self.native.join( + other.native, + keys=left_on or [], # type: ignore[arg-type] + right_keys=right_on, # type: ignore[arg-type] + join_type=how_to_join_map[how], + right_suffix=suffix, + coalesce_keys=coalesce_keys, + ) + ) + + join_asof = not_implemented() + + def drop(self, columns: Sequence[str], *, strict: bool) -> Self: + to_drop = parse_columns_to_drop(self, columns, strict=strict) + return self._with_native(self.native.drop(to_drop), validate_column_names=False) + + def drop_nulls(self: ArrowDataFrame, subset: Sequence[str] | None) -> ArrowDataFrame: + if subset is None: + return self._with_native(self.native.drop_null(), validate_column_names=False) + plx = self.__narwhals_namespace__() + return self.filter(~plx.any_horizontal(plx.col(*subset).is_null())) + + def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self: + if isinstance(descending, bool): + order: Order = "descending" if descending else "ascending" + sorting: list[tuple[str, Order]] = [(key, order) for key in by] + else: + sorting = [ + (key, "descending" if is_descending else "ascending") + for key, is_descending in zip(by, descending) + ] + + null_placement = "at_end" if nulls_last else "at_start" + + return self._with_native( + self.native.sort_by(sorting, null_placement=null_placement), + validate_column_names=False, + ) + + def to_pandas(self) -> pd.DataFrame: + return self.native.to_pandas() + + def to_polars(self) -> pl.DataFrame: + import polars as pl # ignore-banned-import + + return pl.from_arrow(self.native) # type: ignore[return-value] + + def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: + import numpy as np # ignore-banned-import + + arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns]) + return arr + + @overload + def to_dict(self, *, as_series: Literal[True]) -> dict[str, ArrowSeries]: ... + + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + + def to_dict( + self, *, as_series: bool + ) -> dict[str, ArrowSeries] | dict[str, list[Any]]: + it = self.iter_columns() + if as_series: + return {ser.name: ser for ser in it} + return {ser.name: ser.to_list() for ser in it} + + def with_row_index(self, name: str) -> Self: + df = self.native + cols = self.columns + + row_indices = pa.array(range(df.num_rows)) + return self._with_native( + df.append_column(name, row_indices).select([name, *cols]) + ) + + def filter( + self: ArrowDataFrame, predicate: ArrowExpr | list[bool | None] + ) -> ArrowDataFrame: + if isinstance(predicate, list): + mask_native: Mask | ChunkedArrayAny = predicate + else: + # `[0]` is safe as the predicate's expression only returns a single column + mask_native = self._evaluate_into_exprs(predicate)[0].native + return self._with_native( + self.native.filter(mask_native), validate_column_names=False + ) + + def head(self, n: int) -> Self: + df = self.native + if n >= 0: + return self._with_native(df.slice(0, n), validate_column_names=False) + else: + num_rows = df.num_rows + return self._with_native( + df.slice(0, max(0, num_rows + n)), validate_column_names=False + ) + + def tail(self, n: int) -> Self: + df = self.native + if n >= 0: + num_rows = df.num_rows + return self._with_native( + df.slice(max(0, num_rows - n)), validate_column_names=False + ) + else: + return self._with_native(df.slice(abs(n)), validate_column_names=False) + + def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny: + if backend is None: + return self + elif backend is Implementation.DUCKDB: + import duckdb # ignore-banned-import + + from narwhals._duckdb.dataframe import DuckDBLazyFrame + + df = self.native # noqa: F841 + return DuckDBLazyFrame( + duckdb.table("df"), + backend_version=parse_version(duckdb), + version=self._version, + ) + elif backend is Implementation.POLARS: + import polars as pl # ignore-banned-import + + from narwhals._polars.dataframe import PolarsLazyFrame + + return PolarsLazyFrame( + cast("pl.DataFrame", pl.from_arrow(self.native)).lazy(), + backend_version=parse_version(pl), + version=self._version, + ) + elif backend is Implementation.DASK: + import dask # ignore-banned-import + import dask.dataframe as dd # ignore-banned-import + + from narwhals._dask.dataframe import DaskLazyFrame + + return DaskLazyFrame( + dd.from_pandas(self.native.to_pandas()), + backend_version=parse_version(dask), + version=self._version, + ) + raise AssertionError # pragma: no cover + + def collect( + self, backend: Implementation | None, **kwargs: Any + ) -> CompliantDataFrameAny: + if backend is Implementation.PYARROW or backend is None: + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + self.native, + backend_version=self._backend_version, + version=self._version, + validate_column_names=False, + ) + + if backend is Implementation.PANDAS: + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + self.native.to_pandas(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd), + version=self._version, + validate_column_names=False, + ) + + if backend is Implementation.POLARS: + import polars as pl # ignore-banned-import + + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame( + cast("pl.DataFrame", pl.from_arrow(self.native)), + backend_version=parse_version(pl), + version=self._version, + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise AssertionError(msg) # pragma: no cover + + def clone(self) -> Self: + return self._with_native(self.native, validate_column_names=False) + + def item(self, row: int | None, column: int | str | None) -> Any: + from narwhals._arrow.series import maybe_extract_py_scalar + + if row is None and column is None: + if self.shape != (1, 1): + msg = ( + "can only call `.item()` if the dataframe is of shape (1, 1)," + " or if explicit row/col values are provided;" + f" frame has shape {self.shape!r}" + ) + raise ValueError(msg) + return maybe_extract_py_scalar(self.native[0][0], return_py_scalar=True) + + elif row is None or column is None: + msg = "cannot call `.item()` with only one of `row` or `column`" + raise ValueError(msg) + + _col = self.columns.index(column) if isinstance(column, str) else column + return maybe_extract_py_scalar(self.native[_col][row], return_py_scalar=True) + + def rename(self, mapping: Mapping[str, str]) -> Self: + names: dict[str, str] | list[str] + if self._backend_version >= (17,): + names = cast("dict[str, str]", mapping) + else: # pragma: no cover + names = [mapping.get(c, c) for c in self.columns] + return self._with_native(self.native.rename_columns(names)) + + def write_parquet(self, file: str | Path | BytesIO) -> None: + import pyarrow.parquet as pp + + pp.write_table(self.native, file) + + @overload + def write_csv(self, file: None) -> str: ... + + @overload + def write_csv(self, file: str | Path | BytesIO) -> None: ... + + def write_csv(self, file: str | Path | BytesIO | None) -> str | None: + import pyarrow.csv as pa_csv + + if file is None: + csv_buffer = pa.BufferOutputStream() + pa_csv.write_csv(self.native, csv_buffer) + return csv_buffer.getvalue().to_pybytes().decode() + pa_csv.write_csv(self.native, file) + return None + + def is_unique(self) -> ArrowSeries: + col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) + row_index = pa.array(range(len(self))) + keep_idx = ( + self.native.append_column(col_token, row_index) + .group_by(self.columns) + .aggregate([(col_token, "min"), (col_token, "max")]) + ) + native = pa.chunked_array( + pc.and_( + pc.is_in(row_index, keep_idx[f"{col_token}_min"]), + pc.is_in(row_index, keep_idx[f"{col_token}_max"]), + ) + ) + return ArrowSeries.from_native(native, context=self) + + def unique( + self: ArrowDataFrame, + subset: Sequence[str] | None, + *, + keep: UniqueKeepStrategy, + maintain_order: bool | None = None, + ) -> ArrowDataFrame: + # The param `maintain_order` is only here for compatibility with the Polars API + # and has no effect on the output. + import numpy as np # ignore-banned-import + + if subset and (error := self._check_columns_exist(subset)): + raise error + subset = list(subset or self.columns) + + if keep in {"any", "first", "last"}: + from narwhals._arrow.group_by import ArrowGroupBy + + agg_func = ArrowGroupBy._REMAP_UNIQUE[keep] + col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) + keep_idx_native = ( + self.native.append_column(col_token, pa.array(np.arange(len(self)))) + .group_by(subset) + .aggregate([(col_token, agg_func)]) + .column(f"{col_token}_{agg_func}") + ) + return self._with_native( + self.native.take(keep_idx_native), validate_column_names=False + ) + + keep_idx = self.simple_select(*subset).is_unique() + plx = self.__narwhals_namespace__() + return self.filter(plx._expr._from_series(keep_idx)) + + def gather_every(self, n: int, offset: int) -> Self: + return self._with_native(self.native[offset::n], validate_column_names=False) + + def to_arrow(self) -> pa.Table: + return self.native + + def sample( + self, + n: int | None, + *, + fraction: float | None, + with_replacement: bool, + seed: int | None, + ) -> Self: + import numpy as np # ignore-banned-import + + num_rows = len(self) + if n is None and fraction is not None: + n = int(num_rows * fraction) + rng = np.random.default_rng(seed=seed) + idx = np.arange(0, num_rows) + mask = rng.choice(idx, size=n, replace=with_replacement) + return self._with_native(self.native.take(mask), validate_column_names=False) + + def unpivot( + self, + on: Sequence[str] | None, + index: Sequence[str] | None, + variable_name: str, + value_name: str, + ) -> Self: + n_rows = len(self) + index_ = [] if index is None else index + on_ = [c for c in self.columns if c not in index_] if on is None else on + concat = ( + partial(pa.concat_tables, promote_options="permissive") + if self._backend_version >= (14, 0, 0) + else pa.concat_tables + ) + names = [*index_, variable_name, value_name] + return self._with_native( + concat( + [ + pa.Table.from_arrays( + [ + *(self.native.column(idx_col) for idx_col in index_), + cast( + "ChunkedArrayAny", + pa.array([on_col] * n_rows, pa.string()), + ), + self.native.column(on_col), + ], + names=names, + ) + for on_col in on_ + ] + ) + ) + # TODO(Unassigned): Even with promote_options="permissive", pyarrow does not + # upcast numeric to non-numeric (e.g. string) datatypes + + pivot = not_implemented() diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py new file mode 100644 index 0000000..af7993c --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +import pyarrow.compute as pc + +from narwhals._arrow.series import ArrowSeries +from narwhals._compliant import EagerExpr +from narwhals._expression_parsing import evaluate_output_names_and_aliases +from narwhals._utils import ( + Implementation, + generate_temporary_column_name, + not_implemented, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._arrow.dataframe import ArrowDataFrame + from narwhals._arrow.namespace import ArrowNamespace + from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs + from narwhals._expression_parsing import ExprMetadata + from narwhals._utils import Version, _FullContext + from narwhals.typing import RankMethod + + +class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]): + _implementation: Implementation = Implementation.PYARROW + + def __init__( + self, + call: EvalSeries[ArrowDataFrame, ArrowSeries], + *, + depth: int, + function_name: str, + evaluate_output_names: EvalNames[ArrowDataFrame], + alias_output_names: AliasNames | None, + backend_version: tuple[int, ...], + version: Version, + scalar_kwargs: ScalarKwargs | None = None, + implementation: Implementation | None = None, + ) -> None: + self._call = call + self._depth = depth + self._function_name = function_name + self._depth = depth + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names + self._backend_version = backend_version + self._version = version + self._scalar_kwargs = scalar_kwargs or {} + self._metadata: ExprMetadata | None = None + + @classmethod + def from_column_names( + cls: type[Self], + evaluate_column_names: EvalNames[ArrowDataFrame], + /, + *, + context: _FullContext, + function_name: str = "", + ) -> Self: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + try: + return [ + ArrowSeries( + df.native[column_name], + name=column_name, + backend_version=df._backend_version, + version=df._version, + ) + for column_name in evaluate_column_names(df) + ] + except KeyError as e: + if error := df._check_columns_exist(evaluate_column_names(df)): + raise error from e + raise + + return cls( + func, + depth=0, + function_name=function_name, + evaluate_output_names=evaluate_column_names, + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + ) + + @classmethod + def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + tbl = df.native + cols = df.columns + return [ + ArrowSeries.from_native(tbl[i], name=cols[i], context=df) + for i in column_indices + ] + + return cls( + func, + depth=0, + function_name="nth", + evaluate_output_names=cls._eval_names_indices(column_indices), + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + ) + + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._arrow.namespace import ArrowNamespace + + return ArrowNamespace( + backend_version=self._backend_version, version=self._version + ) + + def __narwhals_expr__(self) -> None: ... + + def _reuse_series_extra_kwargs( + self, *, returns_scalar: bool = False + ) -> dict[str, Any]: + return {"_return_py_scalar": False} if returns_scalar else {} + + def cum_sum(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_sum", reverse=reverse) + + def shift(self, n: int) -> Self: + return self._reuse_series("shift", n=n) + + def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: + assert self._metadata is not None # noqa: S101 + if partition_by and not self._metadata.is_scalar_like: + msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow." + raise NotImplementedError(msg) + + if not partition_by: + # e.g. `nw.col('a').cum_sum().order_by(key)` + # which we can always easily support, as it doesn't require grouping. + assert order_by # noqa: S101 + + def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: + token = generate_temporary_column_name(8, df.columns) + df = df.with_row_index(token).sort( + *order_by, descending=False, nulls_last=False + ) + result = self(df.drop([token], strict=True)) + # TODO(marco): is there a way to do this efficiently without + # doing 2 sorts? Here we're sorting the dataframe and then + # again calling `sort_indices`. `ArrowSeries.scatter` would also sort. + sorting_indices = pc.sort_indices(df.get_column(token).native) + return [s._with_native(s.native.take(sorting_indices)) for s in result] + else: + + def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + if overlap := set(output_names).intersection(partition_by): + # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined, + # we just don't support it yet. + msg = ( + f"Column names {overlap} appear in both expression output names and in `over` keys.\n" + "This is not yet supported." + ) + raise NotImplementedError(msg) + + tmp = df.group_by(partition_by, drop_null_keys=False).agg(self) + tmp = df.simple_select(*partition_by).join( + tmp, + how="left", + left_on=partition_by, + right_on=partition_by, + suffix="_right", + ) + return [tmp.get_column(alias) for alias in aliases] + + return self.__class__( + func, + depth=self._depth + 1, + function_name=self._function_name + "->over", + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) + + def cum_count(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_count", reverse=reverse) + + def cum_min(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_min", reverse=reverse) + + def cum_max(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_max", reverse=reverse) + + def cum_prod(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_prod", reverse=reverse) + + def rank(self, method: RankMethod, *, descending: bool) -> Self: + return self._reuse_series("rank", method=method, descending=descending) + + def log(self, base: float) -> Self: + return self._reuse_series("log", base=base) + + def exp(self) -> Self: + return self._reuse_series("exp") + + ewm_mean = not_implemented() diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/group_by.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/group_by.py new file mode 100644 index 0000000..d61906a --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/group_by.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import collections +from typing import TYPE_CHECKING, Any, ClassVar, Iterator, Mapping, Sequence + +import pyarrow as pa +import pyarrow.compute as pc + +from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar +from narwhals._compliant import EagerGroupBy +from narwhals._expression_parsing import evaluate_output_names_and_aliases +from narwhals._utils import generate_temporary_column_name + +if TYPE_CHECKING: + from narwhals._arrow.dataframe import ArrowDataFrame + from narwhals._arrow.expr import ArrowExpr + from narwhals._arrow.typing import ( # type: ignore[attr-defined] + AggregateOptions, + Aggregation, + Incomplete, + ) + from narwhals._compliant.group_by import NarwhalsAggregation + from narwhals.typing import UniqueKeepStrategy + + +class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]): + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = { + "sum": "sum", + "mean": "mean", + "median": "approximate_median", + "max": "max", + "min": "min", + "std": "stddev", + "var": "variance", + "len": "count", + "n_unique": "count_distinct", + "count": "count", + } + _REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = { + "any": "min", + "first": "min", + "last": "max", + } + + def __init__( + self, + df: ArrowDataFrame, + keys: Sequence[ArrowExpr] | Sequence[str], + /, + *, + drop_null_keys: bool, + ) -> None: + self._df = df + frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys) + self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame + self._grouped = pa.TableGroupBy(self.compliant.native, self._keys) + self._drop_null_keys = drop_null_keys + + def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: + self._ensure_all_simple(exprs) + aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = [] + expected_pyarrow_column_names: list[str] = self._keys.copy() + new_column_names: list[str] = self._keys.copy() + exclude = (*self._keys, *self._output_key_names) + + for expr in exprs: + output_names, aliases = evaluate_output_names_and_aliases( + expr, self.compliant, exclude + ) + + if expr._depth == 0: + # e.g. `agg(nw.len())` + if expr._function_name != "len": # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + + new_column_names.append(aliases[0]) + expected_pyarrow_column_names.append(f"{self._keys[0]}_count") + aggs.append((self._keys[0], "count", pc.CountOptions(mode="all"))) + continue + + function_name = self._leaf_name(expr) + if function_name in {"std", "var"}: + assert "ddof" in expr._scalar_kwargs # noqa: S101 + option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"]) + elif function_name in {"len", "n_unique"}: + option = pc.CountOptions(mode="all") + elif function_name == "count": + option = pc.CountOptions(mode="only_valid") + else: + option = None + + function_name = self._remap_expr_name(function_name) + new_column_names.extend(aliases) + expected_pyarrow_column_names.extend( + [f"{output_name}_{function_name}" for output_name in output_names] + ) + aggs.extend( + [(output_name, function_name, option) for output_name in output_names] + ) + + result_simple = self._grouped.aggregate(aggs) + + # Rename columns, being very careful + expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list) + for idx, item in enumerate(expected_pyarrow_column_names): + expected_old_names_indices[item].append(idx) + if not ( + set(result_simple.column_names) == set(expected_pyarrow_column_names) + and len(result_simple.column_names) == len(expected_pyarrow_column_names) + ): # pragma: no cover + msg = ( + f"Safety assertion failed, expected {expected_pyarrow_column_names} " + f"got {result_simple.column_names}, " + "please report a bug at https://github.com/narwhals-dev/narwhals/issues" + ) + raise AssertionError(msg) + index_map: list[int] = [ + expected_old_names_indices[item].pop(0) for item in result_simple.column_names + ] + new_column_names = [new_column_names[i] for i in index_map] + result_simple = result_simple.rename_columns(new_column_names) + if self.compliant._backend_version < (12, 0, 0): + columns = result_simple.column_names + result_simple = result_simple.select( + [*self._keys, *[col for col in columns if col not in self._keys]] + ) + + return self.compliant._with_native(result_simple).rename( + dict(zip(self._keys, self._output_key_names)) + ) + + def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: + col_token = generate_temporary_column_name( + n_bytes=8, columns=self.compliant.columns + ) + null_token: str = "__null_token_value__" # noqa: S105 + + table = self.compliant.native + it, separator_scalar = cast_to_comparable_string_types( + *(table[key] for key in self._keys), separator="" + ) + # NOTE: stubs indicate `separator` must also be a `ChunkedArray` + # Reality: `str` is fine + concat_str: Incomplete = pc.binary_join_element_wise + key_values = concat_str( + *it, separator_scalar, null_handling="replace", null_replacement=null_token + ) + table = table.add_column(i=0, field_=col_token, column=key_values) + + for v in pc.unique(key_values): + t = self.compliant._with_native( + table.filter(pc.equal(table[col_token], v)).drop([col_token]) + ) + row = t.simple_select(*self._keys).row(0) + yield ( + tuple(extract_py_scalar(el) for el in row), + t.simple_select(*self._df.columns), + ) diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py new file mode 100644 index 0000000..02d4c69 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import operator +from functools import reduce +from itertools import chain +from typing import TYPE_CHECKING, Literal, Sequence + +import pyarrow as pa +import pyarrow.compute as pc + +from narwhals._arrow.dataframe import ArrowDataFrame +from narwhals._arrow.expr import ArrowExpr +from narwhals._arrow.selectors import ArrowSelectorNamespace +from narwhals._arrow.series import ArrowSeries +from narwhals._arrow.utils import ( + align_series_full_broadcast, + cast_to_comparable_string_types, +) +from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen +from narwhals._expression_parsing import ( + combine_alias_output_names, + combine_evaluate_output_names, +) +from narwhals._utils import Implementation + +if TYPE_CHECKING: + from narwhals._arrow.typing import Incomplete + from narwhals._utils import Version + from narwhals.typing import IntoDType, NonNestedLiteral + + +class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table]): + @property + def _dataframe(self) -> type[ArrowDataFrame]: + return ArrowDataFrame + + @property + def _expr(self) -> type[ArrowExpr]: + return ArrowExpr + + @property + def _series(self) -> type[ArrowSeries]: + return ArrowSeries + + # --- not in spec --- + def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None: + self._backend_version = backend_version + self._implementation = Implementation.PYARROW + self._version = version + + def len(self) -> ArrowExpr: + # coverage bug? this is definitely hit + return self._expr( # pragma: no cover + lambda df: [ + ArrowSeries.from_iterable([len(df.native)], name="len", context=self) + ], + depth=0, + function_name="len", + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + + def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr: + def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: + arrow_series = ArrowSeries.from_iterable( + data=[value], name="literal", context=self + ) + if dtype: + return arrow_series.cast(dtype) + return arrow_series + + return self._expr( + lambda df: [_lit_arrow_series(df)], + depth=0, + function_name="lit", + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + + def all_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + series = chain.from_iterable(expr(df) for expr in exprs) + return [reduce(operator.and_, align_series_full_broadcast(*series))] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="all_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def any_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + series = chain.from_iterable(expr(df) for expr in exprs) + return [reduce(operator.or_, align_series_full_broadcast(*series))] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="any_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + it = chain.from_iterable(expr(df) for expr in exprs) + series = (s.fill_null(0, strategy=None, limit=None) for s in it) + return [reduce(operator.add, align_series_full_broadcast(*series))] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="sum_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + int_64 = self._version.dtypes.Int64() + + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + expr_results = list(chain.from_iterable(expr(df) for expr in exprs)) + series = align_series_full_broadcast( + *(s.fill_null(0, strategy=None, limit=None) for s in expr_results) + ) + non_na = align_series_full_broadcast( + *(1 - s.is_null().cast(int_64) for s in expr_results) + ) + return [reduce(operator.add, series) / reduce(operator.add, non_na)] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="mean_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs)) + init_series, *series = align_series_full_broadcast(init_series, *series) + native_series = reduce( + pc.min_element_wise, [s.native for s in series], init_series.native + ) + return [ + ArrowSeries( + native_series, + name=init_series.name, + backend_version=self._backend_version, + version=self._version, + ) + ] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="min_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs)) + init_series, *series = align_series_full_broadcast(init_series, *series) + native_series = reduce( + pc.max_element_wise, [s.native for s in series], init_series.native + ) + return [ + ArrowSeries( + native_series, + name=init_series.name, + backend_version=self._backend_version, + version=self._version, + ) + ] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="max_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table: + if self._backend_version >= (14,): + return pa.concat_tables(dfs, promote_options="default") + return pa.concat_tables(dfs, promote=True) # pragma: no cover + + def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table: + names = list(chain.from_iterable(df.column_names for df in dfs)) + arrays = list(chain.from_iterable(df.itercolumns() for df in dfs)) + return pa.Table.from_arrays(arrays, names=names) + + def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: + cols_0 = dfs[0].column_names + for i, df in enumerate(dfs[1:], start=1): + cols_current = df.column_names + if cols_current != cols_0: + msg = ( + "unable to vstack, column names don't match:\n" + f" - dataframe 0: {cols_0}\n" + f" - dataframe {i}: {cols_current}\n" + ) + raise TypeError(msg) + return pa.concat_tables(dfs) + + @property + def selectors(self) -> ArrowSelectorNamespace: + return ArrowSelectorNamespace.from_namespace(self) + + def when(self, predicate: ArrowExpr) -> ArrowWhen: + return ArrowWhen.from_expr(predicate, context=self) + + def concat_str( + self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool + ) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + compliant_series_list = align_series_full_broadcast( + *(chain.from_iterable(expr(df) for expr in exprs)) + ) + name = compliant_series_list[0].name + null_handling: Literal["skip", "emit_null"] = ( + "skip" if ignore_nulls else "emit_null" + ) + it, separator_scalar = cast_to_comparable_string_types( + *(s.native for s in compliant_series_list), separator=separator + ) + # NOTE: stubs indicate `separator` must also be a `ChunkedArray` + # Reality: `str` is fine + concat_str: Incomplete = pc.binary_join_element_wise + compliant = self._series( + concat_str(*it, separator_scalar, null_handling=null_handling), + name=name, + backend_version=self._backend_version, + version=self._version, + ) + return [compliant] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="concat_str", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + +class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr]): + @property + def _then(self) -> type[ArrowThen]: + return ArrowThen + + def _if_then_else( + self, when: ArrowSeries, then: ArrowSeries, otherwise: ArrowSeries | None, / + ) -> ArrowSeries: + if otherwise is None: + when, then = align_series_full_broadcast(when, then) + res_native = pc.if_else( + when.native, then.native, pa.nulls(len(when.native), then.native.type) + ) + else: + when, then, otherwise = align_series_full_broadcast(when, then, otherwise) + res_native = pc.if_else(when.native, then.native, otherwise.native) + return then._with_native(res_native) + + +class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ... diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/selectors.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/selectors.py new file mode 100644 index 0000000..d72da05 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/selectors.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._arrow.expr import ArrowExpr +from narwhals._compliant import CompliantSelector, EagerSelectorNamespace + +if TYPE_CHECKING: + from narwhals._arrow.dataframe import ArrowDataFrame # noqa: F401 + from narwhals._arrow.series import ArrowSeries # noqa: F401 + + +class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]): + @property + def _selector(self) -> type[ArrowSelector]: + return ArrowSelector + + +class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc] + def _to_expr(self) -> ArrowExpr: + return ArrowExpr( + self._call, + depth=self._depth, + function_name=self._function_name, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/series.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/series.py new file mode 100644 index 0000000..0259620 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/series.py @@ -0,0 +1,1183 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Iterator, + Mapping, + Sequence, + cast, + overload, +) + +import pyarrow as pa +import pyarrow.compute as pc + +from narwhals._arrow.series_cat import ArrowSeriesCatNamespace +from narwhals._arrow.series_dt import ArrowSeriesDateTimeNamespace +from narwhals._arrow.series_list import ArrowSeriesListNamespace +from narwhals._arrow.series_str import ArrowSeriesStringNamespace +from narwhals._arrow.series_struct import ArrowSeriesStructNamespace +from narwhals._arrow.utils import ( + cast_for_truediv, + chunked_array, + extract_native, + floordiv_compat, + lit, + narwhals_to_native_dtype, + native_to_narwhals_dtype, + nulls_like, + pad_series, +) +from narwhals._compliant import EagerSeries +from narwhals._expression_parsing import ExprKind +from narwhals._utils import ( + Implementation, + generate_temporary_column_name, + is_list_of, + not_implemented, + requires, + validate_backend_version, +) +from narwhals.dependencies import is_numpy_array_1d +from narwhals.exceptions import InvalidOperationError + +if TYPE_CHECKING: + from types import ModuleType + + import pandas as pd + import polars as pl + from typing_extensions import Self, TypeIs + + from narwhals._arrow.dataframe import ArrowDataFrame + from narwhals._arrow.namespace import ArrowNamespace + from narwhals._arrow.typing import ( # type: ignore[attr-defined] + ArrayAny, + ArrayOrChunkedArray, + ArrayOrScalar, + ChunkedArrayAny, + Incomplete, + NullPlacement, + Order, + TieBreaker, + _AsPyType, + _BasicDataType, + ) + from narwhals._utils import Version, _FullContext + from narwhals.dtypes import DType + from narwhals.typing import ( + ClosedInterval, + FillNullStrategy, + Into1DArray, + IntoDType, + NonNestedLiteral, + NumericLiteral, + PythonLiteral, + RankMethod, + RollingInterpolationMethod, + SizedMultiIndexSelector, + TemporalLiteral, + _1DArray, + _2DArray, + _SliceIndex, + ) + + +# TODO @dangotbanned: move into `_arrow.utils` +# Lots of modules are importing inline +@overload +def maybe_extract_py_scalar( + value: pa.Scalar[_BasicDataType[_AsPyType]], + return_py_scalar: bool, # noqa: FBT001 +) -> _AsPyType: ... + + +@overload +def maybe_extract_py_scalar( + value: pa.Scalar[pa.StructType], + return_py_scalar: bool, # noqa: FBT001 +) -> list[dict[str, Any]]: ... + + +@overload +def maybe_extract_py_scalar( + value: pa.Scalar[pa.ListType[_BasicDataType[_AsPyType]]], + return_py_scalar: bool, # noqa: FBT001 +) -> list[_AsPyType]: ... + + +@overload +def maybe_extract_py_scalar( + value: pa.Scalar[Any] | Any, + return_py_scalar: bool, # noqa: FBT001 +) -> Any: ... + + +def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: FBT001 + if TYPE_CHECKING: + return value.as_py() + if return_py_scalar: + return getattr(value, "as_py", lambda: value)() + return value + + +class ArrowSeries(EagerSeries["ChunkedArrayAny"]): + def __init__( + self, + native_series: ChunkedArrayAny, + *, + name: str, + backend_version: tuple[int, ...], + version: Version, + ) -> None: + self._name = name + self._native_series: ChunkedArrayAny = native_series + self._implementation = Implementation.PYARROW + self._backend_version = backend_version + self._version = version + validate_backend_version(self._implementation, self._backend_version) + self._broadcast = False + + @property + def native(self) -> ChunkedArrayAny: + return self._native_series + + def _with_version(self, version: Version) -> Self: + return self.__class__( + self.native, + name=self._name, + backend_version=self._backend_version, + version=version, + ) + + def _with_native( + self, series: ArrayOrScalar, *, preserve_broadcast: bool = False + ) -> Self: + result = self.from_native(chunked_array(series), name=self.name, context=self) + if preserve_broadcast: + result._broadcast = self._broadcast + return result + + @classmethod + def from_iterable( + cls, + data: Iterable[Any], + *, + context: _FullContext, + name: str = "", + dtype: IntoDType | None = None, + ) -> Self: + version = context._version + dtype_pa = narwhals_to_native_dtype(dtype, version) if dtype else None + return cls.from_native( + chunked_array([data], dtype_pa), name=name, context=context + ) + + def _from_scalar(self, value: Any) -> Self: + if self._backend_version < (13,) and hasattr(value, "as_py"): + value = value.as_py() + return super()._from_scalar(value) + + @staticmethod + def _is_native(obj: ChunkedArrayAny | Any) -> TypeIs[ChunkedArrayAny]: + return isinstance(obj, pa.ChunkedArray) + + @classmethod + def from_native( + cls, data: ChunkedArrayAny, /, *, context: _FullContext, name: str = "" + ) -> Self: + return cls( + data, + backend_version=context._backend_version, + version=context._version, + name=name, + ) + + @classmethod + def from_numpy(cls, data: Into1DArray, /, *, context: _FullContext) -> Self: + return cls.from_iterable( + data if is_numpy_array_1d(data) else [data], context=context + ) + + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._arrow.namespace import ArrowNamespace + + return ArrowNamespace( + backend_version=self._backend_version, version=self._version + ) + + def __eq__(self, other: object) -> Self: # type: ignore[override] + other = cast("PythonLiteral | ArrowSeries | None", other) + ser, rhs = extract_native(self, other) + return self._with_native(pc.equal(ser, rhs)) + + def __ne__(self, other: object) -> Self: # type: ignore[override] + other = cast("PythonLiteral | ArrowSeries | None", other) + ser, rhs = extract_native(self, other) + return self._with_native(pc.not_equal(ser, rhs)) + + def __ge__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.greater_equal(ser, other)) + + def __gt__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.greater(ser, other)) + + def __le__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.less_equal(ser, other)) + + def __lt__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.less(ser, other)) + + def __and__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.and_kleene(ser, other)) # type: ignore[arg-type] + + def __rand__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.and_kleene(other, ser)) # type: ignore[arg-type] + + def __or__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.or_kleene(ser, other)) # type: ignore[arg-type] + + def __ror__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.or_kleene(other, ser)) # type: ignore[arg-type] + + def __add__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.add(ser, other)) + + def __radd__(self, other: Any) -> Self: + return self + other + + def __sub__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.subtract(ser, other)) + + def __rsub__(self, other: Any) -> Self: + return (self - other) * (-1) + + def __mul__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.multiply(ser, other)) + + def __rmul__(self, other: Any) -> Self: + return self * other + + def __pow__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.power(ser, other)) + + def __rpow__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.power(other, ser)) + + def __floordiv__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(floordiv_compat(ser, other)) + + def __rfloordiv__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(floordiv_compat(other, ser)) + + def __truediv__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.divide(*cast_for_truediv(ser, other))) # type: ignore[type-var] + + def __rtruediv__(self, other: Any) -> Self: + ser, other = extract_native(self, other) + return self._with_native(pc.divide(*cast_for_truediv(other, ser))) # type: ignore[type-var] + + def __mod__(self, other: Any) -> Self: + floor_div = (self // other).native + ser, other = extract_native(self, other) + res = pc.subtract(ser, pc.multiply(floor_div, other)) + return self._with_native(res) + + def __rmod__(self, other: Any) -> Self: + floor_div = (other // self).native + ser, other = extract_native(self, other) + res = pc.subtract(other, pc.multiply(floor_div, ser)) + return self._with_native(res) + + def __invert__(self) -> Self: + return self._with_native(pc.invert(self.native)) + + @property + def _type(self) -> pa.DataType: + return self.native.type + + def len(self, *, _return_py_scalar: bool = True) -> int: + return maybe_extract_py_scalar(len(self.native), _return_py_scalar) + + def filter(self, predicate: ArrowSeries | list[bool | None]) -> Self: + other_native: Any + if not is_list_of(predicate, bool): + _, other_native = extract_native(self, predicate) + else: + other_native = predicate + return self._with_native(self.native.filter(other_native)) + + def mean(self, *, _return_py_scalar: bool = True) -> float: + return maybe_extract_py_scalar(pc.mean(self.native), _return_py_scalar) + + def median(self, *, _return_py_scalar: bool = True) -> float: + from narwhals.exceptions import InvalidOperationError + + if not self.dtype.is_numeric(): + msg = "`median` operation not supported for non-numeric input type." + raise InvalidOperationError(msg) + + return maybe_extract_py_scalar( + pc.approximate_median(self.native), _return_py_scalar + ) + + def min(self, *, _return_py_scalar: bool = True) -> Any: + return maybe_extract_py_scalar(pc.min(self.native), _return_py_scalar) + + def max(self, *, _return_py_scalar: bool = True) -> Any: + return maybe_extract_py_scalar(pc.max(self.native), _return_py_scalar) + + def arg_min(self, *, _return_py_scalar: bool = True) -> int: + index_min = pc.index(self.native, pc.min(self.native)) + return maybe_extract_py_scalar(index_min, _return_py_scalar) + + def arg_max(self, *, _return_py_scalar: bool = True) -> int: + index_max = pc.index(self.native, pc.max(self.native)) + return maybe_extract_py_scalar(index_max, _return_py_scalar) + + def sum(self, *, _return_py_scalar: bool = True) -> float: + return maybe_extract_py_scalar( + pc.sum(self.native, min_count=0), _return_py_scalar + ) + + def drop_nulls(self) -> Self: + return self._with_native(self.native.drop_null()) + + def shift(self, n: int) -> Self: + if n > 0: + arrays = [nulls_like(n, self), *self.native[:-n].chunks] + elif n < 0: + arrays = [*self.native[-n:].chunks, nulls_like(-n, self)] + else: + return self._with_native(self.native) + return self._with_native(pa.concat_arrays(arrays)) + + def std(self, ddof: int, *, _return_py_scalar: bool = True) -> float: + return maybe_extract_py_scalar( + pc.stddev(self.native, ddof=ddof), _return_py_scalar + ) + + def var(self, ddof: int, *, _return_py_scalar: bool = True) -> float: + return maybe_extract_py_scalar( + pc.variance(self.native, ddof=ddof), _return_py_scalar + ) + + def skew(self, *, _return_py_scalar: bool = True) -> float | None: + ser_not_null = self.native.drop_null() + if len(ser_not_null) == 0: + return None + elif len(ser_not_null) == 1: + return float("nan") + elif len(ser_not_null) == 2: + return 0.0 + else: + m = pc.subtract(ser_not_null, pc.mean(ser_not_null)) + m2 = pc.mean(pc.power(m, lit(2))) + m3 = pc.mean(pc.power(m, lit(3))) + biased_population_skewness = pc.divide(m3, pc.power(m2, lit(1.5))) + return maybe_extract_py_scalar(biased_population_skewness, _return_py_scalar) + + def count(self, *, _return_py_scalar: bool = True) -> int: + return maybe_extract_py_scalar(pc.count(self.native), _return_py_scalar) + + def n_unique(self, *, _return_py_scalar: bool = True) -> int: + return maybe_extract_py_scalar( + pc.count(self.native.unique(), mode="all"), _return_py_scalar + ) + + def __native_namespace__(self) -> ModuleType: + if self._implementation is Implementation.PYARROW: + return self._implementation.to_native_namespace() + + msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover + raise AssertionError(msg) + + @property + def name(self) -> str: + return self._name + + def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self: + if len(rows) == 0: + return self._with_native(self.native.slice(0, 0)) + if self._backend_version < (18,) and isinstance(rows, tuple): + rows = list(rows) + return self._with_native(self.native.take(rows)) + + def _gather_slice(self, rows: _SliceIndex | range) -> Self: + start = rows.start or 0 + stop = rows.stop if rows.stop is not None else len(self.native) + if start < 0: + start = len(self.native) + start + if stop < 0: + stop = len(self.native) + stop + if rows.step is not None and rows.step != 1: + msg = "Slicing with step is not supported on PyArrow tables" + raise NotImplementedError(msg) + return self._with_native(self.native.slice(start, stop - start)) + + def scatter(self, indices: int | Sequence[int], values: Any) -> Self: + import numpy as np # ignore-banned-import + + values_native: ArrayAny + if isinstance(indices, int): + indices_native = pa.array([indices]) + values_native = pa.array([values]) + else: + # TODO(unassigned): we may also want to let `indices` be a Series. + # https://github.com/narwhals-dev/narwhals/issues/2155 + indices_native = pa.array(indices) + if isinstance(values, self.__class__): + values_native = values.native.combine_chunks() + else: + # NOTE: Requires fixes in https://github.com/zen-xu/pyarrow-stubs/pull/209 + pa_array: Incomplete = pa.array + values_native = pa_array(values) + + sorting_indices = pc.sort_indices(indices_native) + indices_native = indices_native.take(sorting_indices) + values_native = values_native.take(sorting_indices) + + mask: _1DArray = np.zeros(self.len(), dtype=bool) + mask[indices_native] = True + # NOTE: Multiple issues + # - Missing `values` type + # - `mask` accepts a `np.ndarray`, but not mentioned in stubs + # - Missing `replacements` type + # - Missing return type + pc_replace_with_mask: Incomplete = pc.replace_with_mask + return self._with_native( + pc_replace_with_mask(self.native, mask, values_native.take(indices_native)) + ) + + def to_list(self) -> list[Any]: + return self.native.to_pylist() + + def __array__(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: + return self.native.__array__(dtype=dtype, copy=copy) + + def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: + return self.native.to_numpy() + + def alias(self, name: str) -> Self: + result = self.__class__( + self.native, + name=name, + backend_version=self._backend_version, + version=self._version, + ) + result._broadcast = self._broadcast + return result + + @property + def dtype(self) -> DType: + return native_to_narwhals_dtype(self.native.type, self._version) + + def abs(self) -> Self: + return self._with_native(pc.abs(self.native)) + + def cum_sum(self, *, reverse: bool) -> Self: + cum_sum = pc.cumulative_sum + result = ( + cum_sum(self.native, skip_nulls=True) + if not reverse + else cum_sum(self.native[::-1], skip_nulls=True)[::-1] + ) + return self._with_native(result) + + def round(self, decimals: int) -> Self: + return self._with_native( + pc.round(self.native, decimals, round_mode="half_towards_infinity") + ) + + def diff(self) -> Self: + return self._with_native(pc.pairwise_diff(self.native.combine_chunks())) + + def any(self, *, _return_py_scalar: bool = True) -> bool: + return maybe_extract_py_scalar( + pc.any(self.native, min_count=0), _return_py_scalar + ) + + def all(self, *, _return_py_scalar: bool = True) -> bool: + return maybe_extract_py_scalar( + pc.all(self.native, min_count=0), _return_py_scalar + ) + + def is_between( + self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval + ) -> Self: + _, lower_bound = extract_native(self, lower_bound) + _, upper_bound = extract_native(self, upper_bound) + if closed == "left": + ge = pc.greater_equal(self.native, lower_bound) + lt = pc.less(self.native, upper_bound) + res = pc.and_kleene(ge, lt) + elif closed == "right": + gt = pc.greater(self.native, lower_bound) + le = pc.less_equal(self.native, upper_bound) + res = pc.and_kleene(gt, le) + elif closed == "none": + gt = pc.greater(self.native, lower_bound) + lt = pc.less(self.native, upper_bound) + res = pc.and_kleene(gt, lt) + elif closed == "both": + ge = pc.greater_equal(self.native, lower_bound) + le = pc.less_equal(self.native, upper_bound) + res = pc.and_kleene(ge, le) + else: # pragma: no cover + raise AssertionError + return self._with_native(res) + + def is_null(self) -> Self: + return self._with_native(self.native.is_null(), preserve_broadcast=True) + + def is_nan(self) -> Self: + return self._with_native(pc.is_nan(self.native), preserve_broadcast=True) + + def cast(self, dtype: IntoDType) -> Self: + data_type = narwhals_to_native_dtype(dtype, self._version) + return self._with_native(pc.cast(self.native, data_type), preserve_broadcast=True) + + def null_count(self, *, _return_py_scalar: bool = True) -> int: + return maybe_extract_py_scalar(self.native.null_count, _return_py_scalar) + + def head(self, n: int) -> Self: + if n >= 0: + return self._with_native(self.native.slice(0, n)) + else: + num_rows = len(self) + return self._with_native(self.native.slice(0, max(0, num_rows + n))) + + def tail(self, n: int) -> Self: + if n >= 0: + num_rows = len(self) + return self._with_native(self.native.slice(max(0, num_rows - n))) + else: + return self._with_native(self.native.slice(abs(n))) + + def is_in(self, other: Any) -> Self: + if self._is_native(other): + value_set: ArrayOrChunkedArray = other + else: + value_set = pa.array(other) + return self._with_native(pc.is_in(self.native, value_set=value_set)) + + def arg_true(self) -> Self: + import numpy as np # ignore-banned-import + + res = np.flatnonzero(self.native) + return self.from_iterable(res, name=self.name, context=self) + + def item(self, index: int | None = None) -> Any: + if index is None: + if len(self) != 1: + msg = ( + "can only call '.item()' if the Series is of length 1," + f" or an explicit index is provided (Series is of length {len(self)})" + ) + raise ValueError(msg) + return maybe_extract_py_scalar(self.native[0], return_py_scalar=True) + return maybe_extract_py_scalar(self.native[index], return_py_scalar=True) + + def value_counts( + self, *, sort: bool, parallel: bool, name: str | None, normalize: bool + ) -> ArrowDataFrame: + """Parallel is unused, exists for compatibility.""" + from narwhals._arrow.dataframe import ArrowDataFrame + + index_name_ = "index" if self._name is None else self._name + value_name_ = name or ("proportion" if normalize else "count") + + val_counts = pc.value_counts(self.native) + values = val_counts.field("values") + counts = cast("ChunkedArrayAny", val_counts.field("counts")) + + if normalize: + arrays = [values, pc.divide(*cast_for_truediv(counts, pc.sum(counts)))] + else: + arrays = [values, counts] + + val_count = pa.Table.from_arrays(arrays, names=[index_name_, value_name_]) + + if sort: + val_count = val_count.sort_by([(value_name_, "descending")]) + + return ArrowDataFrame( + val_count, + backend_version=self._backend_version, + version=self._version, + validate_column_names=True, + ) + + def zip_with(self, mask: Self, other: Self) -> Self: + cond = mask.native.combine_chunks() + return self._with_native(pc.if_else(cond, self.native, other.native)) + + def sample( + self, + n: int | None, + *, + fraction: float | None, + with_replacement: bool, + seed: int | None, + ) -> Self: + import numpy as np # ignore-banned-import + + num_rows = len(self) + if n is None and fraction is not None: + n = int(num_rows * fraction) + + rng = np.random.default_rng(seed=seed) + idx = np.arange(0, num_rows) + mask = rng.choice(idx, size=n, replace=with_replacement) + return self._with_native(self.native.take(mask)) + + def fill_null( + self, + value: Self | NonNestedLiteral, + strategy: FillNullStrategy | None, + limit: int | None, + ) -> Self: + import numpy as np # ignore-banned-import + + def fill_aux( + arr: ChunkedArrayAny, limit: int, direction: FillNullStrategy | None + ) -> ArrayAny: + # this algorithm first finds the indices of the valid values to fill all the null value positions + # then it calculates the distance of each new index and the original index + # if the distance is equal to or less than the limit and the original value is null, it is replaced + valid_mask = pc.is_valid(arr) + indices = pa.array(np.arange(len(arr)), type=pa.int64()) + if direction == "forward": + valid_index = np.maximum.accumulate(np.where(valid_mask, indices, -1)) + distance = indices - valid_index + else: + valid_index = np.minimum.accumulate( + np.where(valid_mask[::-1], indices[::-1], len(arr)) + )[::-1] + distance = valid_index - indices + return pc.if_else( + pc.and_(pc.is_null(arr), pc.less_equal(distance, lit(limit))), # pyright: ignore[reportArgumentType, reportCallIssue] + arr.take(valid_index), + arr, + ) + + if value is not None: + _, native_value = extract_native(self, value) + series: ArrayOrScalar = pc.fill_null(self.native, native_value) + elif limit is None: + fill_func = ( + pc.fill_null_forward if strategy == "forward" else pc.fill_null_backward + ) + series = fill_func(self.native) + else: + series = fill_aux(self.native, limit, strategy) + return self._with_native(series, preserve_broadcast=True) + + def to_frame(self) -> ArrowDataFrame: + from narwhals._arrow.dataframe import ArrowDataFrame + + df = pa.Table.from_arrays([self.native], names=[self.name]) + return ArrowDataFrame( + df, + backend_version=self._backend_version, + version=self._version, + validate_column_names=False, + ) + + def to_pandas(self) -> pd.Series[Any]: + import pandas as pd # ignore-banned-import() + + return pd.Series(self.native, name=self.name) + + def to_polars(self) -> pl.Series: + import polars as pl # ignore-banned-import + + return cast("pl.Series", pl.from_arrow(self.native)) + + def is_unique(self) -> ArrowSeries: + return self.to_frame().is_unique().alias(self.name) + + def is_first_distinct(self) -> Self: + import numpy as np # ignore-banned-import + + row_number = pa.array(np.arange(len(self))) + col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) + first_distinct_index = ( + pa.Table.from_arrays([self.native], names=[self.name]) + .append_column(col_token, row_number) + .group_by(self.name) + .aggregate([(col_token, "min")]) + .column(f"{col_token}_min") + ) + + return self._with_native(pc.is_in(row_number, first_distinct_index)) + + def is_last_distinct(self) -> Self: + import numpy as np # ignore-banned-import + + row_number = pa.array(np.arange(len(self))) + col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) + last_distinct_index = ( + pa.Table.from_arrays([self.native], names=[self.name]) + .append_column(col_token, row_number) + .group_by(self.name) + .aggregate([(col_token, "max")]) + .column(f"{col_token}_max") + ) + + return self._with_native(pc.is_in(row_number, last_distinct_index)) + + def is_sorted(self, *, descending: bool) -> bool: + if not isinstance(descending, bool): + msg = f"argument 'descending' should be boolean, found {type(descending)}" + raise TypeError(msg) + if descending: + result = pc.all(pc.greater_equal(self.native[:-1], self.native[1:])) + else: + result = pc.all(pc.less_equal(self.native[:-1], self.native[1:])) + return maybe_extract_py_scalar(result, return_py_scalar=True) + + def unique(self, *, maintain_order: bool) -> Self: + # TODO(marco): `pc.unique` seems to always maintain order, is that guaranteed? + return self._with_native(self.native.unique()) + + def replace_strict( + self, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any], + *, + return_dtype: IntoDType | None, + ) -> Self: + # https://stackoverflow.com/a/79111029/4451315 + idxs = pc.index_in(self.native, pa.array(old)) + result_native = pc.take(pa.array(new), idxs) + if return_dtype is not None: + result_native.cast(narwhals_to_native_dtype(return_dtype, self._version)) + result = self._with_native(result_native) + if result.is_null().sum() != self.is_null().sum(): + msg = ( + "replace_strict did not replace all non-null values.\n\n" + "The following did not get replaced: " + f"{self.filter(~self.is_null() & result.is_null()).unique(maintain_order=False).to_list()}" + ) + raise ValueError(msg) + return result + + def sort(self, *, descending: bool, nulls_last: bool) -> Self: + order: Order = "descending" if descending else "ascending" + null_placement: NullPlacement = "at_end" if nulls_last else "at_start" + sorted_indices = pc.array_sort_indices( + self.native, order=order, null_placement=null_placement + ) + return self._with_native(self.native.take(sorted_indices)) + + def to_dummies(self, *, separator: str, drop_first: bool) -> ArrowDataFrame: + import numpy as np # ignore-banned-import + + from narwhals._arrow.dataframe import ArrowDataFrame + + name = self._name + # NOTE: stub is missing attributes (https://arrow.apache.org/docs/python/generated/pyarrow.DictionaryArray.html) + da: Incomplete = self.native.combine_chunks().dictionary_encode("encode") + + columns: _2DArray = np.zeros((len(da.dictionary), len(da)), np.int8) + columns[da.indices, np.arange(len(da))] = 1 + null_col_pa, null_col_pl = f"{name}{separator}None", f"{name}{separator}null" + cols = [ + {null_col_pa: null_col_pl}.get( + f"{name}{separator}{v}", f"{name}{separator}{v}" + ) + for v in da.dictionary + ] + + output_order = ( + [ + null_col_pl, + *sorted([c for c in cols if c != null_col_pl])[int(drop_first) :], + ] + if null_col_pl in cols + else sorted(cols)[int(drop_first) :] + ) + return ArrowDataFrame( + pa.Table.from_arrays(columns, names=cols), + backend_version=self._backend_version, + version=self._version, + validate_column_names=True, + ).simple_select(*output_order) + + def quantile( + self, + quantile: float, + interpolation: RollingInterpolationMethod, + *, + _return_py_scalar: bool = True, + ) -> float: + return maybe_extract_py_scalar( + pc.quantile(self.native, q=quantile, interpolation=interpolation)[0], + _return_py_scalar, + ) + + def gather_every(self, n: int, offset: int = 0) -> Self: + return self._with_native(self.native[offset::n]) + + def clip( + self, + lower_bound: Self | NumericLiteral | TemporalLiteral | None, + upper_bound: Self | NumericLiteral | TemporalLiteral | None, + ) -> Self: + _, lower = extract_native(self, lower_bound) if lower_bound else (None, None) + _, upper = extract_native(self, upper_bound) if upper_bound else (None, None) + + if lower is None: + return self._with_native(pc.min_element_wise(self.native, upper)) + if upper is None: + return self._with_native(pc.max_element_wise(self.native, lower)) + return self._with_native( + pc.max_element_wise(pc.min_element_wise(self.native, upper), lower) + ) + + def to_arrow(self) -> ArrayAny: + return self.native.combine_chunks() + + def mode(self) -> ArrowSeries: + plx = self.__narwhals_namespace__() + col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) + counts = self.value_counts( + name=col_token, normalize=False, sort=False, parallel=False + ) + return counts.filter( + plx.col(col_token) + == plx.col(col_token).max().broadcast(kind=ExprKind.AGGREGATION) + ).get_column(self.name) + + def is_finite(self) -> Self: + return self._with_native(pc.is_finite(self.native)) + + def cum_count(self, *, reverse: bool) -> Self: + dtypes = self._version.dtypes + return (~self.is_null()).cast(dtypes.UInt32()).cum_sum(reverse=reverse) + + @requires.backend_version((13,)) + def cum_min(self, *, reverse: bool) -> Self: + result = ( + pc.cumulative_min(self.native, skip_nulls=True) + if not reverse + else pc.cumulative_min(self.native[::-1], skip_nulls=True)[::-1] + ) + return self._with_native(result) + + @requires.backend_version((13,)) + def cum_max(self, *, reverse: bool) -> Self: + result = ( + pc.cumulative_max(self.native, skip_nulls=True) + if not reverse + else pc.cumulative_max(self.native[::-1], skip_nulls=True)[::-1] + ) + return self._with_native(result) + + @requires.backend_version((13,)) + def cum_prod(self, *, reverse: bool) -> Self: + result = ( + pc.cumulative_prod(self.native, skip_nulls=True) + if not reverse + else pc.cumulative_prod(self.native[::-1], skip_nulls=True)[::-1] + ) + return self._with_native(result) + + def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: + min_samples = min_samples if min_samples is not None else window_size + padded_series, offset = pad_series(self, window_size=window_size, center=center) + + cum_sum = padded_series.cum_sum(reverse=False).fill_null( + value=None, strategy="forward", limit=None + ) + rolling_sum = ( + cum_sum + - cum_sum.shift(window_size).fill_null(value=0, strategy=None, limit=None) + if window_size != 0 + else cum_sum + ) + + valid_count = padded_series.cum_count(reverse=False) + count_in_window = valid_count - valid_count.shift(window_size).fill_null( + value=0, strategy=None, limit=None + ) + + result = self._with_native( + pc.if_else((count_in_window >= min_samples).native, rolling_sum.native, None) + ) + return result._gather_slice(slice(offset, None)) + + def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: + min_samples = min_samples if min_samples is not None else window_size + padded_series, offset = pad_series(self, window_size=window_size, center=center) + + cum_sum = padded_series.cum_sum(reverse=False).fill_null( + value=None, strategy="forward", limit=None + ) + rolling_sum = ( + cum_sum + - cum_sum.shift(window_size).fill_null(value=0, strategy=None, limit=None) + if window_size != 0 + else cum_sum + ) + + valid_count = padded_series.cum_count(reverse=False) + count_in_window = valid_count - valid_count.shift(window_size).fill_null( + value=0, strategy=None, limit=None + ) + + result = ( + self._with_native( + pc.if_else( + (count_in_window >= min_samples).native, rolling_sum.native, None + ) + ) + / count_in_window + ) + return result._gather_slice(slice(offset, None)) + + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + min_samples = min_samples if min_samples is not None else window_size + padded_series, offset = pad_series(self, window_size=window_size, center=center) + + cum_sum = padded_series.cum_sum(reverse=False).fill_null( + value=None, strategy="forward", limit=None + ) + rolling_sum = ( + cum_sum + - cum_sum.shift(window_size).fill_null(value=0, strategy=None, limit=None) + if window_size != 0 + else cum_sum + ) + + cum_sum_sq = ( + pow(padded_series, 2) + .cum_sum(reverse=False) + .fill_null(value=None, strategy="forward", limit=None) + ) + rolling_sum_sq = ( + cum_sum_sq + - cum_sum_sq.shift(window_size).fill_null(value=0, strategy=None, limit=None) + if window_size != 0 + else cum_sum_sq + ) + + valid_count = padded_series.cum_count(reverse=False) + count_in_window = valid_count - valid_count.shift(window_size).fill_null( + value=0, strategy=None, limit=None + ) + + result = self._with_native( + pc.if_else( + (count_in_window >= min_samples).native, + (rolling_sum_sq - (rolling_sum**2 / count_in_window)).native, + None, + ) + ) / self._with_native(pc.max_element_wise((count_in_window - ddof).native, 0)) + + return result._gather_slice(slice(offset, None, None)) + + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + return ( + self.rolling_var( + window_size=window_size, min_samples=min_samples, center=center, ddof=ddof + ) + ** 0.5 + ) + + def rank(self, method: RankMethod, *, descending: bool) -> Self: + if method == "average": + msg = ( + "`rank` with `method='average' is not supported for pyarrow backend. " + "The available methods are {'min', 'max', 'dense', 'ordinal'}." + ) + raise ValueError(msg) + + sort_keys: Order = "descending" if descending else "ascending" + tiebreaker: TieBreaker = "first" if method == "ordinal" else method + + native_series: ArrayOrChunkedArray + if self._backend_version < (14, 0, 0): # pragma: no cover + native_series = self.native.combine_chunks() + else: + native_series = self.native + + null_mask = pc.is_null(native_series) + + rank = pc.rank(native_series, sort_keys=sort_keys, tiebreaker=tiebreaker) + + result = pc.if_else(null_mask, lit(None, native_series.type), rank) + return self._with_native(result) + + @requires.backend_version((13,)) + def hist( # noqa: C901, PLR0912, PLR0915 + self, + bins: list[float | int] | None, + *, + bin_count: int | None, + include_breakpoint: bool, + ) -> ArrowDataFrame: + import numpy as np # ignore-banned-import + + from narwhals._arrow.dataframe import ArrowDataFrame + + def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa: ANN202 + d = pc.min_max(self.native) + lower, upper = d["min"].as_py(), d["max"].as_py() + if lower == upper: + lower -= 0.5 + upper += 0.5 + bins = np.linspace(lower, upper, bin_count + 1) + return _hist_from_bins(bins) + + def _hist_from_bins(bins: Sequence[int | float]): # type: ignore[no-untyped-def] # noqa: ANN202 + bin_indices = np.searchsorted(bins, self.native, side="left") + bin_indices = pc.if_else( # lowest bin is inclusive + pc.equal(self.native, lit(bins[0])), 1, bin_indices + ) + + # align unique categories and counts appropriately + obs_cats, obs_counts = np.unique(bin_indices, return_counts=True) + obj_cats = np.arange(1, len(bins)) + counts = np.zeros_like(obj_cats) + counts[np.isin(obj_cats, obs_cats)] = obs_counts[np.isin(obs_cats, obj_cats)] + + bin_right = bins[1:] + return counts, bin_right + + counts: Sequence[int | float | pa.Scalar[Any]] | np.typing.ArrayLike + bin_right: Sequence[int | float | pa.Scalar[Any]] | np.typing.ArrayLike + + data_count = pc.sum( + pc.invert(pc.or_(pc.is_nan(self.native), pc.is_null(self.native))).cast( + pa.uint8() + ), + min_count=0, + ) + if bins is not None: + if len(bins) < 2: + counts, bin_right = [], [] + + elif data_count == pa.scalar(0, type=pa.uint64()): # type:ignore[comparison-overlap] + counts = np.zeros(len(bins) - 1) + bin_right = bins[1:] + + elif len(bins) == 2: + counts = [ + pc.sum( + pc.and_( + pc.greater_equal(self.native, lit(float(bins[0]))), + pc.less_equal(self.native, lit(float(bins[1]))), + ).cast(pa.uint8()) + ) + ] + bin_right = [bins[-1]] + else: + counts, bin_right = _hist_from_bins(bins) + + elif bin_count is not None: + if bin_count == 0: + counts, bin_right = [], [] + elif data_count == pa.scalar(0, type=pa.uint64()): # type:ignore[comparison-overlap] + counts, bin_right = ( + np.zeros(bin_count), + np.linspace(0, 1, bin_count + 1)[1:], + ) + elif bin_count == 1: + d = pc.min_max(self.native) + lower, upper = d["min"], d["max"] + if lower == upper: + counts, bin_right = [data_count], [pc.add(upper, pa.scalar(0.5))] + else: + counts, bin_right = [data_count], [upper] + else: + counts, bin_right = _hist_from_bin_count(bin_count) + + else: # pragma: no cover + # caller guarantees that either bins or bin_count is specified + msg = "must provide one of `bin_count` or `bins`" + raise InvalidOperationError(msg) + + data: dict[str, Any] = {} + if include_breakpoint: + data["breakpoint"] = bin_right + data["count"] = counts + + return ArrowDataFrame( + pa.Table.from_pydict(data), + backend_version=self._backend_version, + version=self._version, + validate_column_names=True, + ) + + def __iter__(self) -> Iterator[Any]: + for x in self.native: + yield maybe_extract_py_scalar(x, return_py_scalar=True) + + def __contains__(self, other: Any) -> bool: + from pyarrow import ( + ArrowInvalid, # ignore-banned-imports + ArrowNotImplementedError, # ignore-banned-imports + ArrowTypeError, # ignore-banned-imports + ) + + try: + other_ = lit(other) if other is not None else lit(None, type=self._type) + return maybe_extract_py_scalar( + pc.is_in(other_, self.native), return_py_scalar=True + ) + except (ArrowInvalid, ArrowNotImplementedError, ArrowTypeError) as exc: + from narwhals.exceptions import InvalidOperationError + + msg = f"Unable to compare other of type {type(other)} with series of type {self.dtype}." + raise InvalidOperationError(msg) from exc + + def log(self, base: float) -> Self: + return self._with_native(pc.logb(self.native, lit(base))) + + def exp(self) -> Self: + return self._with_native(pc.exp(self.native)) + + @property + def dt(self) -> ArrowSeriesDateTimeNamespace: + return ArrowSeriesDateTimeNamespace(self) + + @property + def cat(self) -> ArrowSeriesCatNamespace: + return ArrowSeriesCatNamespace(self) + + @property + def str(self) -> ArrowSeriesStringNamespace: + return ArrowSeriesStringNamespace(self) + + @property + def list(self) -> ArrowSeriesListNamespace: + return ArrowSeriesListNamespace(self) + + @property + def struct(self) -> ArrowSeriesStructNamespace: + return ArrowSeriesStructNamespace(self) + + ewm_mean = not_implemented() diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/series_cat.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_cat.py new file mode 100644 index 0000000..944f339 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_cat.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pyarrow as pa + +from narwhals._arrow.utils import ArrowSeriesNamespace + +if TYPE_CHECKING: + from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import Incomplete + + +class ArrowSeriesCatNamespace(ArrowSeriesNamespace): + def get_categories(self) -> ArrowSeries: + # NOTE: Should be `list[pa.DictionaryArray]`, but `DictionaryArray` has no attributes + chunks: Incomplete = self.native.chunks + return self.with_native(pa.concat_arrays(x.dictionary for x in chunks).unique()) diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/series_dt.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_dt.py new file mode 100644 index 0000000..75aaec5 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_dt.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Mapping, cast + +import pyarrow as pa +import pyarrow.compute as pc + +from narwhals._arrow.utils import UNITS_DICT, ArrowSeriesNamespace, floordiv_compat, lit +from narwhals._duration import parse_interval_string + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny + from narwhals.dtypes import Datetime + from narwhals.typing import TimeUnit + + UnitCurrent: TypeAlias = TimeUnit + UnitTarget: TypeAlias = TimeUnit + BinOpBroadcast: TypeAlias = Callable[[ChunkedArrayAny, ScalarAny], ChunkedArrayAny] + IntoRhs: TypeAlias = int + + +class ArrowSeriesDateTimeNamespace(ArrowSeriesNamespace): + _TIMESTAMP_DATE_FACTOR: ClassVar[Mapping[TimeUnit, int]] = { + "ns": 1_000_000_000, + "us": 1_000_000, + "ms": 1_000, + "s": 1, + } + _TIMESTAMP_DATETIME_OP_FACTOR: ClassVar[ + Mapping[tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs]] + ] = { + ("ns", "us"): (floordiv_compat, 1_000), + ("ns", "ms"): (floordiv_compat, 1_000_000), + ("us", "ns"): (pc.multiply, 1_000), + ("us", "ms"): (floordiv_compat, 1_000), + ("ms", "ns"): (pc.multiply, 1_000_000), + ("ms", "us"): (pc.multiply, 1_000), + ("s", "ns"): (pc.multiply, 1_000_000_000), + ("s", "us"): (pc.multiply, 1_000_000), + ("s", "ms"): (pc.multiply, 1_000), + } + + @property + def unit(self) -> TimeUnit: # NOTE: Unsafe (native). + return cast("pa.TimestampType[TimeUnit, Any]", self.native.type).unit + + @property + def time_zone(self) -> str | None: # NOTE: Unsafe (narwhals). + return cast("Datetime", self.compliant.dtype).time_zone + + def to_string(self, format: str) -> ArrowSeries: + # PyArrow differs from other libraries in that %S also prints out + # the fractional part of the second...:'( + # https://arrow.apache.org/docs/python/generated/pyarrow.compute.strftime.html + format = format.replace("%S.%f", "%S").replace("%S%.f", "%S") + return self.with_native(pc.strftime(self.native, format)) + + def replace_time_zone(self, time_zone: str | None) -> ArrowSeries: + if time_zone is not None: + result = pc.assume_timezone(pc.local_timestamp(self.native), time_zone) + else: + result = pc.local_timestamp(self.native) + return self.with_native(result) + + def convert_time_zone(self, time_zone: str) -> ArrowSeries: + ser = self.replace_time_zone("UTC") if self.time_zone is None else self.compliant + return self.with_native(ser.native.cast(pa.timestamp(self.unit, time_zone))) + + def timestamp(self, time_unit: TimeUnit) -> ArrowSeries: + ser = self.compliant + dtypes = ser._version.dtypes + if isinstance(ser.dtype, dtypes.Datetime): + current = ser.dtype.time_unit + s_cast = self.native.cast(pa.int64()) + if current == time_unit: + result = s_cast + elif item := self._TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)): + fn, factor = item + result = fn(s_cast, lit(factor)) + else: # pragma: no cover + msg = f"unexpected time unit {current}, please report an issue at https://github.com/narwhals-dev/narwhals" + raise AssertionError(msg) + return self.with_native(result) + elif isinstance(ser.dtype, dtypes.Date): + time_s = pc.multiply(self.native.cast(pa.int32()), lit(86_400)) + factor = self._TIMESTAMP_DATE_FACTOR[time_unit] + return self.with_native(pc.multiply(time_s, lit(factor))) + else: + msg = "Input should be either of Date or Datetime type" + raise TypeError(msg) + + def date(self) -> ArrowSeries: + return self.with_native(self.native.cast(pa.date32())) + + def year(self) -> ArrowSeries: + return self.with_native(pc.year(self.native)) + + def month(self) -> ArrowSeries: + return self.with_native(pc.month(self.native)) + + def day(self) -> ArrowSeries: + return self.with_native(pc.day(self.native)) + + def hour(self) -> ArrowSeries: + return self.with_native(pc.hour(self.native)) + + def minute(self) -> ArrowSeries: + return self.with_native(pc.minute(self.native)) + + def second(self) -> ArrowSeries: + return self.with_native(pc.second(self.native)) + + def millisecond(self) -> ArrowSeries: + return self.with_native(pc.millisecond(self.native)) + + def microsecond(self) -> ArrowSeries: + arr = self.native + result = pc.add(pc.multiply(pc.millisecond(arr), lit(1000)), pc.microsecond(arr)) + return self.with_native(result) + + def nanosecond(self) -> ArrowSeries: + result = pc.add( + pc.multiply(self.microsecond().native, lit(1000)), pc.nanosecond(self.native) + ) + return self.with_native(result) + + def ordinal_day(self) -> ArrowSeries: + return self.with_native(pc.day_of_year(self.native)) + + def weekday(self) -> ArrowSeries: + return self.with_native(pc.day_of_week(self.native, count_from_zero=False)) + + def total_minutes(self) -> ArrowSeries: + unit_to_minutes_factor = { + "s": 60, # seconds + "ms": 60 * 1e3, # milli + "us": 60 * 1e6, # micro + "ns": 60 * 1e9, # nano + } + factor = lit(unit_to_minutes_factor[self.unit], type=pa.int64()) + return self.with_native(pc.divide(self.native, factor).cast(pa.int64())) + + def total_seconds(self) -> ArrowSeries: + unit_to_seconds_factor = { + "s": 1, # seconds + "ms": 1e3, # milli + "us": 1e6, # micro + "ns": 1e9, # nano + } + factor = lit(unit_to_seconds_factor[self.unit], type=pa.int64()) + return self.with_native(pc.divide(self.native, factor).cast(pa.int64())) + + def total_milliseconds(self) -> ArrowSeries: + unit_to_milli_factor = { + "s": 1e3, # seconds + "ms": 1, # milli + "us": 1e3, # micro + "ns": 1e6, # nano + } + factor = lit(unit_to_milli_factor[self.unit], type=pa.int64()) + if self.unit == "s": + return self.with_native(pc.multiply(self.native, factor).cast(pa.int64())) + return self.with_native(pc.divide(self.native, factor).cast(pa.int64())) + + def total_microseconds(self) -> ArrowSeries: + unit_to_micro_factor = { + "s": 1e6, # seconds + "ms": 1e3, # milli + "us": 1, # micro + "ns": 1e3, # nano + } + factor = lit(unit_to_micro_factor[self.unit], type=pa.int64()) + if self.unit in {"s", "ms"}: + return self.with_native(pc.multiply(self.native, factor).cast(pa.int64())) + return self.with_native(pc.divide(self.native, factor).cast(pa.int64())) + + def total_nanoseconds(self) -> ArrowSeries: + unit_to_nano_factor = { + "s": 1e9, # seconds + "ms": 1e6, # milli + "us": 1e3, # micro + "ns": 1, # nano + } + factor = lit(unit_to_nano_factor[self.unit], type=pa.int64()) + return self.with_native(pc.multiply(self.native, factor).cast(pa.int64())) + + def truncate(self, every: str) -> ArrowSeries: + multiple, unit = parse_interval_string(every) + return self.with_native( + pc.floor_temporal(self.native, multiple=multiple, unit=UNITS_DICT[unit]) + ) diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/series_list.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_list.py new file mode 100644 index 0000000..aeb4315 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_list.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pyarrow as pa +import pyarrow.compute as pc + +from narwhals._arrow.utils import ArrowSeriesNamespace + +if TYPE_CHECKING: + from narwhals._arrow.series import ArrowSeries + + +class ArrowSeriesListNamespace(ArrowSeriesNamespace): + def len(self) -> ArrowSeries: + return self.with_native(pc.list_value_length(self.native).cast(pa.uint32())) diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/series_str.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_str.py new file mode 100644 index 0000000..64dcce8 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_str.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import string +from typing import TYPE_CHECKING + +import pyarrow.compute as pc + +from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format + +if TYPE_CHECKING: + from narwhals._arrow.series import ArrowSeries + + +class ArrowSeriesStringNamespace(ArrowSeriesNamespace): + def len_chars(self) -> ArrowSeries: + return self.with_native(pc.utf8_length(self.native)) + + def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> ArrowSeries: + fn = pc.replace_substring if literal else pc.replace_substring_regex + arr = fn(self.native, pattern, replacement=value, max_replacements=n) + return self.with_native(arr) + + def replace_all(self, pattern: str, value: str, *, literal: bool) -> ArrowSeries: + return self.replace(pattern, value, literal=literal, n=-1) + + def strip_chars(self, characters: str | None) -> ArrowSeries: + return self.with_native( + pc.utf8_trim(self.native, characters or string.whitespace) + ) + + def starts_with(self, prefix: str) -> ArrowSeries: + return self.with_native(pc.equal(self.slice(0, len(prefix)).native, lit(prefix))) + + def ends_with(self, suffix: str) -> ArrowSeries: + return self.with_native( + pc.equal(self.slice(-len(suffix), None).native, lit(suffix)) + ) + + def contains(self, pattern: str, *, literal: bool) -> ArrowSeries: + check_func = pc.match_substring if literal else pc.match_substring_regex + return self.with_native(check_func(self.native, pattern)) + + def slice(self, offset: int, length: int | None) -> ArrowSeries: + stop = offset + length if length is not None else None + return self.with_native( + pc.utf8_slice_codeunits(self.native, start=offset, stop=stop) + ) + + def split(self, by: str) -> ArrowSeries: + split_series = pc.split_pattern(self.native, by) # type: ignore[call-overload] + return self.with_native(split_series) + + def to_datetime(self, format: str | None) -> ArrowSeries: + format = parse_datetime_format(self.native) if format is None else format + timestamp_array = pc.strptime(self.native, format=format, unit="us") + return self.with_native(timestamp_array) + + def to_uppercase(self) -> ArrowSeries: + return self.with_native(pc.utf8_upper(self.native)) + + def to_lowercase(self) -> ArrowSeries: + return self.with_native(pc.utf8_lower(self.native)) diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/series_struct.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_struct.py new file mode 100644 index 0000000..be5aa4b --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/series_struct.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pyarrow.compute as pc + +from narwhals._arrow.utils import ArrowSeriesNamespace + +if TYPE_CHECKING: + from narwhals._arrow.series import ArrowSeries + + +class ArrowSeriesStructNamespace(ArrowSeriesNamespace): + def field(self, name: str) -> ArrowSeries: + return self.with_native(pc.struct_field(self.native, name)).alias(name) diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/typing.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/typing.py new file mode 100644 index 0000000..3f79fec --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/typing.py @@ -0,0 +1,72 @@ +from __future__ import annotations # pragma: no cover + +from typing import ( + TYPE_CHECKING, # pragma: no cover + Any, # pragma: no cover + TypeVar, # pragma: no cover +) + +if TYPE_CHECKING: + import sys + from typing import Generic, Literal + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + import pyarrow as pa + from pyarrow.__lib_pxi.table import ( + AggregateOptions, # noqa: F401 + Aggregation, # noqa: F401 + ) + from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource] + Indices, # noqa: F401 + Mask, # noqa: F401 + Order, # noqa: F401 + ) + + from narwhals._arrow.expr import ArrowExpr + from narwhals._arrow.series import ArrowSeries + + IntoArrowExpr: TypeAlias = "ArrowExpr | ArrowSeries" + TieBreaker: TypeAlias = Literal["min", "max", "first", "dense"] + NullPlacement: TypeAlias = Literal["at_start", "at_end"] + NativeIntervalUnit: TypeAlias = Literal[ + "year", + "quarter", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond", + "nanosecond", + ] + + ChunkedArrayAny: TypeAlias = pa.ChunkedArray[Any] + ArrayAny: TypeAlias = pa.Array[Any] + ArrayOrChunkedArray: TypeAlias = "ArrayAny | ChunkedArrayAny" + ScalarAny: TypeAlias = pa.Scalar[Any] + ArrayOrScalar: TypeAlias = "ArrayOrChunkedArray | ScalarAny" + ArrayOrScalarT1 = TypeVar("ArrayOrScalarT1", ArrayAny, ChunkedArrayAny, ScalarAny) + ArrayOrScalarT2 = TypeVar("ArrayOrScalarT2", ArrayAny, ChunkedArrayAny, ScalarAny) + _AsPyType = TypeVar("_AsPyType") + + class _BasicDataType(pa.DataType, Generic[_AsPyType]): ... + + +Incomplete: TypeAlias = Any # pragma: no cover +""" +Marker for working code that fails on the stubs. + +Common issues: +- Annotated for `Array`, but not `ChunkedArray` +- Relies on typing information that the stubs don't provide statically +- Missing attributes +- Incorrect return types +- Inconsistent use of generic/concrete types +- `_clone_signature` used on signatures that are not identical +""" diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py new file mode 100644 index 0000000..d100448 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py @@ -0,0 +1,470 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence, cast + +import pyarrow as pa +import pyarrow.compute as pc + +from narwhals._compliant.series import _SeriesNamespace +from narwhals._utils import isinstance_or_issubclass +from narwhals.exceptions import ShapeError + +if TYPE_CHECKING: + from typing_extensions import TypeAlias, TypeIs + + from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import ( + ArrayAny, + ArrayOrScalar, + ArrayOrScalarT1, + ArrayOrScalarT2, + ChunkedArrayAny, + NativeIntervalUnit, + ScalarAny, + ) + from narwhals._duration import IntervalUnit + from narwhals._utils import Version + from narwhals.dtypes import DType + from narwhals.typing import IntoDType, PythonLiteral + + # NOTE: stubs don't allow for `ChunkedArray[StructArray]` + # Intended to represent the `.chunks` property storing `list[pa.StructArray]` + ChunkedArrayStructArray: TypeAlias = ChunkedArrayAny + + def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ... + def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ... + def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ... + def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ... + def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ... + def is_dictionary(t: Any) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ... + def extract_regex( + strings: ChunkedArrayAny, + /, + pattern: str, + *, + options: Any = None, + memory_pool: Any = None, + ) -> ChunkedArrayStructArray: ... +else: + from pyarrow.compute import extract_regex + from pyarrow.types import ( + is_dictionary, # noqa: F401 + is_duration, + is_fixed_size_list, + is_large_list, + is_list, + is_timestamp, + ) + +UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = { + "y": "year", + "q": "quarter", + "mo": "month", + "d": "day", + "h": "hour", + "m": "minute", + "s": "second", + "ms": "millisecond", + "us": "microsecond", + "ns": "nanosecond", +} + +lit = pa.scalar +"""Alias for `pyarrow.scalar`.""" + + +def extract_py_scalar(value: Any, /) -> Any: + from narwhals._arrow.series import maybe_extract_py_scalar + + return maybe_extract_py_scalar(value, return_py_scalar=True) + + +def chunked_array( + arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, / +) -> ChunkedArrayAny: + if isinstance(arr, pa.ChunkedArray): + return arr + if isinstance(arr, list): + return pa.chunked_array(arr, dtype) + else: + return pa.chunked_array([arr], arr.type) + + +def nulls_like(n: int, series: ArrowSeries) -> ArrayAny: + """Create a strongly-typed Array instance with all elements null. + + Uses the type of `series`, without upseting `mypy`. + """ + return pa.nulls(n, series.native.type) + + +@lru_cache(maxsize=16) +def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912 + dtypes = version.dtypes + if pa.types.is_int64(dtype): + return dtypes.Int64() + if pa.types.is_int32(dtype): + return dtypes.Int32() + if pa.types.is_int16(dtype): + return dtypes.Int16() + if pa.types.is_int8(dtype): + return dtypes.Int8() + if pa.types.is_uint64(dtype): + return dtypes.UInt64() + if pa.types.is_uint32(dtype): + return dtypes.UInt32() + if pa.types.is_uint16(dtype): + return dtypes.UInt16() + if pa.types.is_uint8(dtype): + return dtypes.UInt8() + if pa.types.is_boolean(dtype): + return dtypes.Boolean() + if pa.types.is_float64(dtype): + return dtypes.Float64() + if pa.types.is_float32(dtype): + return dtypes.Float32() + # bug in coverage? it shows `31->exit` (where `31` is currently the line number of + # the next line), even though both when the if condition is true and false are covered + if ( # pragma: no cover + pa.types.is_string(dtype) + or pa.types.is_large_string(dtype) + or getattr(pa.types, "is_string_view", lambda _: False)(dtype) + ): + return dtypes.String() + if pa.types.is_date32(dtype): + return dtypes.Date() + if is_timestamp(dtype): + return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz) + if is_duration(dtype): + return dtypes.Duration(time_unit=dtype.unit) + if pa.types.is_dictionary(dtype): + return dtypes.Categorical() + if pa.types.is_struct(dtype): + return dtypes.Struct( + [ + dtypes.Field( + dtype.field(i).name, + native_to_narwhals_dtype(dtype.field(i).type, version), + ) + for i in range(dtype.num_fields) + ] + ) + if is_list(dtype) or is_large_list(dtype): + return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version)) + if is_fixed_size_list(dtype): + return dtypes.Array( + native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size + ) + if pa.types.is_decimal(dtype): + return dtypes.Decimal() + if pa.types.is_time32(dtype) or pa.types.is_time64(dtype): + return dtypes.Time() + if pa.types.is_binary(dtype): + return dtypes.Binary() + return dtypes.Unknown() # pragma: no cover + + +def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType: # noqa: C901, PLR0912 + dtypes = version.dtypes + if isinstance_or_issubclass(dtype, dtypes.Decimal): + msg = "Casting to Decimal is not supported yet." + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Float64): + return pa.float64() + if isinstance_or_issubclass(dtype, dtypes.Float32): + return pa.float32() + if isinstance_or_issubclass(dtype, dtypes.Int64): + return pa.int64() + if isinstance_or_issubclass(dtype, dtypes.Int32): + return pa.int32() + if isinstance_or_issubclass(dtype, dtypes.Int16): + return pa.int16() + if isinstance_or_issubclass(dtype, dtypes.Int8): + return pa.int8() + if isinstance_or_issubclass(dtype, dtypes.UInt64): + return pa.uint64() + if isinstance_or_issubclass(dtype, dtypes.UInt32): + return pa.uint32() + if isinstance_or_issubclass(dtype, dtypes.UInt16): + return pa.uint16() + if isinstance_or_issubclass(dtype, dtypes.UInt8): + return pa.uint8() + if isinstance_or_issubclass(dtype, dtypes.String): + return pa.string() + if isinstance_or_issubclass(dtype, dtypes.Boolean): + return pa.bool_() + if isinstance_or_issubclass(dtype, dtypes.Categorical): + return pa.dictionary(pa.uint32(), pa.string()) + if isinstance_or_issubclass(dtype, dtypes.Datetime): + unit = dtype.time_unit + return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit) + if isinstance_or_issubclass(dtype, dtypes.Duration): + return pa.duration(dtype.time_unit) + if isinstance_or_issubclass(dtype, dtypes.Date): + return pa.date32() + if isinstance_or_issubclass(dtype, dtypes.List): + return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version)) + if isinstance_or_issubclass(dtype, dtypes.Struct): + return pa.struct( + [ + (field.name, narwhals_to_native_dtype(field.dtype, version=version)) + for field in dtype.fields + ] + ) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + inner = narwhals_to_native_dtype(dtype.inner, version=version) + list_size = dtype.size + return pa.list_(inner, list_size=list_size) + if isinstance_or_issubclass(dtype, dtypes.Time): + return pa.time64("ns") + if isinstance_or_issubclass(dtype, dtypes.Binary): + return pa.binary() + + msg = f"Unknown dtype: {dtype}" # pragma: no cover + raise AssertionError(msg) + + +def extract_native( + lhs: ArrowSeries, rhs: ArrowSeries | PythonLiteral | ScalarAny +) -> tuple[ChunkedArrayAny | ScalarAny, ChunkedArrayAny | ScalarAny]: + """Extract native objects in binary operation. + + If the comparison isn't supported, return `NotImplemented` so that the + "right-hand-side" operation (e.g. `__radd__`) can be tried. + + If one of the two sides has a `_broadcast` flag, then extract the scalar + underneath it so that PyArrow can do its own broadcasting. + """ + from narwhals._arrow.dataframe import ArrowDataFrame + from narwhals._arrow.series import ArrowSeries + + if rhs is None: # pragma: no cover + return lhs.native, lit(None, type=lhs._type) + + if isinstance(rhs, ArrowDataFrame): + return NotImplemented + + if isinstance(rhs, ArrowSeries): + if lhs._broadcast and not rhs._broadcast: + return lhs.native[0], rhs.native + if rhs._broadcast: + return lhs.native, rhs.native[0] + return lhs.native, rhs.native + + if isinstance(rhs, list): + msg = "Expected Series or scalar, got list." + raise TypeError(msg) + + return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs) + + +def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]: + # Ensure all of `series` are of the same length. + lengths = [len(s) for s in series] + max_length = max(lengths) + fast_path = all(_len == max_length for _len in lengths) + + if fast_path: + return series + + reshaped = [] + for s in series: + if s._broadcast: + value = s.native[0] + if s._backend_version < (13,) and hasattr(value, "as_py"): + value = value.as_py() + reshaped.append(s._with_native(pa.array([value] * max_length, type=s._type))) + else: + if (actual_len := len(s)) != max_length: + msg = f"Expected object of length {max_length}, got {actual_len}." + raise ShapeError(msg) + reshaped.append(s) + + return reshaped + + +def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any: + # The following lines are adapted from pandas' pyarrow implementation. + # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 + + if pa.types.is_integer(left.type) and pa.types.is_integer(right.type): + divided = pc.divide_checked(left, right) + # TODO @dangotbanned: Use a `TypeVar` in guards + # Narrowing to a `Union` isn't interacting well with the rest of the stubs + # https://github.com/zen-xu/pyarrow-stubs/pull/215 + if pa.types.is_signed_integer(divided.type): + div_type = cast("pa._lib.Int64Type", divided.type) + has_remainder = pc.not_equal(pc.multiply(divided, right), left) + has_one_negative_operand = pc.less( + pc.bit_wise_xor(left, right), lit(0, div_type) + ) + result = pc.if_else( + pc.and_(has_remainder, has_one_negative_operand), + pc.subtract(divided, lit(1, div_type)), + divided, + ) + else: + result = divided # pragma: no cover + result = result.cast(left.type) + else: + divided = pc.divide(left, right) + result = pc.floor(divided) + return result + + +def cast_for_truediv( + arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2 +) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]: + # Lifted from: + # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 + # Ensure int / int -> float mirroring Python/Numpy behavior + # as pc.divide_checked(int, int) -> int + if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type): + # GH: 56645. # noqa: ERA001 + # https://github.com/apache/arrow/issues/35563 + # NOTE: `pyarrow==11.*` doesn't allow keywords in `Array.cast` + return pc.cast(arrow_array, pa.float64(), safe=False), pc.cast( + pa_object, pa.float64(), safe=False + ) + + return arrow_array, pa_object + + +# Regex for date, time, separator and timezone components +DATE_RE = r"(?P\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})" +SEP_RE = r"(?P\s|T)" +TIME_RE = r"(?P