From 5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e Mon Sep 17 00:00:00 2001 From: sotech117 Date: Thu, 31 Jul 2025 17:27:24 -0400 Subject: add code for analysis of data --- .../site-packages/narwhals/_spark_like/__init__.py | 0 .../narwhals/_spark_like/dataframe.py | 531 ++++++++++++ .../site-packages/narwhals/_spark_like/expr.py | 930 +++++++++++++++++++++ .../site-packages/narwhals/_spark_like/expr_dt.py | 193 +++++ .../narwhals/_spark_like/expr_list.py | 14 + .../site-packages/narwhals/_spark_like/expr_str.py | 115 +++ .../narwhals/_spark_like/expr_struct.py | 19 + .../site-packages/narwhals/_spark_like/group_by.py | 35 + .../narwhals/_spark_like/namespace.py | 290 +++++++ .../narwhals/_spark_like/selectors.py | 29 + .../site-packages/narwhals/_spark_like/utils.py | 285 +++++++ 11 files changed, 2441 insertions(+) create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/__init__.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/dataframe.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/expr.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_dt.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_list.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_str.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_struct.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/group_by.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/namespace.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/selectors.py create mode 100644 venv/lib/python3.8/site-packages/narwhals/_spark_like/utils.py (limited to 'venv/lib/python3.8/site-packages/narwhals/_spark_like') diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/__init__.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/dataframe.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/dataframe.py new file mode 100644 index 0000000..c4ea73f --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/dataframe.py @@ -0,0 +1,531 @@ +from __future__ import annotations + +import warnings +from functools import reduce +from operator import and_ +from typing import TYPE_CHECKING, Any, Iterator, Mapping, Sequence + +from narwhals._namespace import is_native_spark_like +from narwhals._spark_like.utils import ( + evaluate_exprs, + import_functions, + import_native_dtypes, + import_window, + native_to_narwhals_dtype, +) +from narwhals._utils import ( + Implementation, + find_stacklevel, + generate_temporary_column_name, + not_implemented, + parse_columns_to_drop, + parse_version, + validate_backend_version, +) +from narwhals.exceptions import InvalidOperationError +from narwhals.typing import CompliantLazyFrame + +if TYPE_CHECKING: + from types import ModuleType + + import pyarrow as pa + from sqlframe.base.column import Column + from sqlframe.base.dataframe import BaseDataFrame + from sqlframe.base.window import Window + from typing_extensions import Self, TypeAlias, TypeIs + + from narwhals._compliant.typing import CompliantDataFrameAny + from narwhals._spark_like.expr import SparkLikeExpr + from narwhals._spark_like.group_by import SparkLikeLazyGroupBy + from narwhals._spark_like.namespace import SparkLikeNamespace + from narwhals._utils import Version, _FullContext + from narwhals.dataframe import LazyFrame + from narwhals.dtypes import DType + from narwhals.typing import JoinStrategy, LazyUniqueKeepStrategy + + SQLFrameDataFrame = BaseDataFrame[Any, Any, Any, Any, Any] + +Incomplete: TypeAlias = Any # pragma: no cover +"""Marker for working code that fails type checking.""" + + +class SparkLikeLazyFrame( + CompliantLazyFrame[ + "SparkLikeExpr", "SQLFrameDataFrame", "LazyFrame[SQLFrameDataFrame]" + ] +): + def __init__( + self, + native_dataframe: SQLFrameDataFrame, + *, + backend_version: tuple[int, ...], + version: Version, + implementation: Implementation, + ) -> None: + self._native_frame: SQLFrameDataFrame = native_dataframe + self._backend_version = backend_version + self._implementation = implementation + self._version = version + self._cached_schema: dict[str, DType] | None = None + self._cached_columns: list[str] | None = None + validate_backend_version(self._implementation, self._backend_version) + + @property + def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202, N802 + if TYPE_CHECKING: + from sqlframe.base import functions + + return functions + else: + return import_functions(self._implementation) + + @property + def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202 + if TYPE_CHECKING: + from sqlframe.base import types + + return types + else: + return import_native_dtypes(self._implementation) + + @property + def _Window(self) -> type[Window]: # noqa: N802 + if TYPE_CHECKING: + from sqlframe.base.window import Window + + return Window + else: + return import_window(self._implementation) + + @staticmethod + def _is_native(obj: SQLFrameDataFrame | Any) -> TypeIs[SQLFrameDataFrame]: + return is_native_spark_like(obj) + + @classmethod + def from_native(cls, data: SQLFrameDataFrame, /, *, context: _FullContext) -> Self: + return cls( + data, + backend_version=context._backend_version, + version=context._version, + implementation=context._implementation, + ) + + def to_narwhals(self) -> LazyFrame[SQLFrameDataFrame]: + return self._version.lazyframe(self, level="lazy") + + def __native_namespace__(self) -> ModuleType: # pragma: no cover + return self._implementation.to_native_namespace() + + def __narwhals_namespace__(self) -> SparkLikeNamespace: + from narwhals._spark_like.namespace import SparkLikeNamespace + + return SparkLikeNamespace( + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def __narwhals_lazyframe__(self) -> Self: + return self + + def _with_version(self, version: Version) -> Self: + return self.__class__( + self.native, + backend_version=self._backend_version, + version=version, + implementation=self._implementation, + ) + + def _with_native(self, df: SQLFrameDataFrame) -> Self: + return self.__class__( + df, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def _to_arrow_schema(self) -> pa.Schema: # pragma: no cover + import pyarrow as pa # ignore-banned-import + + from narwhals._arrow.utils import narwhals_to_native_dtype + + schema: list[tuple[str, pa.DataType]] = [] + nw_schema = self.collect_schema() + native_schema = self.native.schema + for key, value in nw_schema.items(): + try: + native_dtype = narwhals_to_native_dtype(value, self._version) + except Exception as exc: # noqa: BLE001,PERF203 + native_spark_dtype = native_schema[key].dataType # type: ignore[index] + # If we can't convert the type, just set it to `pa.null`, and warn. + # Avoid the warning if we're starting from PySpark's void type. + # We can avoid the check when we introduce `nw.Null` dtype. + null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue] + if not isinstance(native_spark_dtype, null_type): + warnings.warn( + f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}", + stacklevel=find_stacklevel(), + ) + schema.append((key, pa.null())) + else: + schema.append((key, native_dtype)) + return pa.schema(schema) + + def _collect_to_arrow(self) -> pa.Table: + if self._implementation.is_pyspark() and self._backend_version < (4,): + import pyarrow as pa # ignore-banned-import + + try: + return pa.Table.from_batches(self.native._collect_as_arrow()) + except ValueError as exc: + if "at least one RecordBatch" in str(exc): + # Empty dataframe + + data: dict[str, list[Any]] = {k: [] for k in self.columns} + pa_schema = self._to_arrow_schema() + return pa.Table.from_pydict(data, schema=pa_schema) + else: # pragma: no cover + raise + elif self._implementation.is_pyspark_connect() and self._backend_version < (4,): + import pyarrow as pa # ignore-banned-import + + pa_schema = self._to_arrow_schema() + return pa.Table.from_pandas(self.native.toPandas(), schema=pa_schema) + else: + return self.native.toArrow() + + def _iter_columns(self) -> Iterator[Column]: + for col in self.columns: + yield self._F.col(col) + + @property + def columns(self) -> list[str]: + if self._cached_columns is None: + self._cached_columns = ( + list(self.schema) + if self._cached_schema is not None + else self.native.columns + ) + return self._cached_columns + + def collect( + self, backend: ModuleType | Implementation | str | None, **kwargs: Any + ) -> CompliantDataFrameAny: + if backend is Implementation.PANDAS: + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + self.native.toPandas(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd), + version=self._version, + validate_column_names=True, + ) + + elif backend is None or backend is Implementation.PYARROW: + import pyarrow as pa # ignore-banned-import + + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + self._collect_to_arrow(), + backend_version=parse_version(pa), + version=self._version, + validate_column_names=True, + ) + + elif backend is Implementation.POLARS: + import polars as pl # ignore-banned-import + import pyarrow as pa # ignore-banned-import + + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame( + pl.from_arrow(self._collect_to_arrow()), # type: ignore[arg-type] + backend_version=parse_version(pl), + version=self._version, + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover + + def simple_select(self, *column_names: str) -> Self: + return self._with_native(self.native.select(*column_names)) + + def aggregate(self, *exprs: SparkLikeExpr) -> Self: + new_columns = evaluate_exprs(self, *exprs) + + new_columns_list = [col.alias(col_name) for col_name, col in new_columns] + return self._with_native(self.native.agg(*new_columns_list)) + + def select(self, *exprs: SparkLikeExpr) -> Self: + new_columns = evaluate_exprs(self, *exprs) + new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns] + return self._with_native(self.native.select(*new_columns_list)) + + def with_columns(self, *exprs: SparkLikeExpr) -> Self: + new_columns = evaluate_exprs(self, *exprs) + return self._with_native(self.native.withColumns(dict(new_columns))) + + def filter(self, predicate: SparkLikeExpr) -> Self: + # `[0]` is safe as the predicate's expression only returns a single column + condition = predicate._call(self)[0] + spark_df = self.native.where(condition) + return self._with_native(spark_df) + + @property + def schema(self) -> dict[str, DType]: + if self._cached_schema is None: + self._cached_schema = { + field.name: native_to_narwhals_dtype( + field.dataType, + self._version, + self._native_dtypes, + self.native.sparkSession, + ) + for field in self.native.schema + } + return self._cached_schema + + def collect_schema(self) -> dict[str, DType]: + return self.schema + + def drop(self, columns: Sequence[str], *, strict: bool) -> Self: + columns_to_drop = parse_columns_to_drop(self, columns, strict=strict) + return self._with_native(self.native.drop(*columns_to_drop)) + + def head(self, n: int) -> Self: + return self._with_native(self.native.limit(n)) + + def group_by( + self, keys: Sequence[str] | Sequence[SparkLikeExpr], *, drop_null_keys: bool + ) -> SparkLikeLazyGroupBy: + from narwhals._spark_like.group_by import SparkLikeLazyGroupBy + + return SparkLikeLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) + + def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self: + if isinstance(descending, bool): + descending = [descending] * len(by) + + if nulls_last: + sort_funcs = ( + self._F.desc_nulls_last if d else self._F.asc_nulls_last + for d in descending + ) + else: + sort_funcs = ( + self._F.desc_nulls_first if d else self._F.asc_nulls_first + for d in descending + ) + + sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)] + return self._with_native(self.native.sort(*sort_cols)) + + def drop_nulls(self, subset: Sequence[str] | None) -> Self: + subset = list(subset) if subset else None + return self._with_native(self.native.dropna(subset=subset)) + + def rename(self, mapping: Mapping[str, str]) -> Self: + rename_mapping = { + colname: mapping.get(colname, colname) for colname in self.columns + } + return self._with_native( + self.native.select( + [self._F.col(old).alias(new) for old, new in rename_mapping.items()] + ) + ) + + def unique( + self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy + ) -> Self: + if subset and (error := self._check_columns_exist(subset)): + raise error + subset = list(subset) if subset else None + if keep == "none": + tmp = generate_temporary_column_name(8, self.columns) + window = self._Window.partitionBy(subset or self.columns) + df = ( + self.native.withColumn(tmp, self._F.count("*").over(window)) + .filter(self._F.col(tmp) == self._F.lit(1)) + .drop(self._F.col(tmp)) + ) + return self._with_native(df) + return self._with_native(self.native.dropDuplicates(subset=subset)) + + def join( + self, + other: Self, + how: JoinStrategy, + left_on: Sequence[str] | None, + right_on: Sequence[str] | None, + suffix: str, + ) -> Self: + left_columns = self.columns + right_columns = other.columns + + right_on_: list[str] = list(right_on) if right_on is not None else [] + left_on_: list[str] = list(left_on) if left_on is not None else [] + + # create a mapping for columns on other + # `right_on` columns will be renamed as `left_on` + # the remaining columns will be either added the suffix or left unchanged. + right_cols_to_rename = ( + [c for c in right_columns if c not in right_on_] + if how != "full" + else right_columns + ) + + rename_mapping = { + **dict(zip(right_on_, left_on_)), + **{ + colname: f"{colname}{suffix}" if colname in left_columns else colname + for colname in right_cols_to_rename + }, + } + other_native = other.native.select( + [self._F.col(old).alias(new) for old, new in rename_mapping.items()] + ) + + # If how in {"semi", "anti"}, then resulting columns are same as left columns + # Otherwise, we add the right columns with the new mapping, while keeping the + # original order of right_columns. + col_order = left_columns.copy() + + if how in {"inner", "left", "cross"}: + col_order.extend( + rename_mapping[colname] + for colname in right_columns + if colname not in right_on_ + ) + elif how == "full": + col_order.extend(rename_mapping.values()) + + right_on_remapped = [rename_mapping[c] for c in right_on_] + on_ = ( + reduce( + and_, + ( + getattr(self.native, left_key) == getattr(other_native, right_key) + for left_key, right_key in zip(left_on_, right_on_remapped) + ), + ) + if how == "full" + else None + if how == "cross" + else left_on_ + ) + how_native = "full_outer" if how == "full" else how + return self._with_native( + self.native.join(other_native, on=on_, how=how_native).select(col_order) + ) + + def explode(self, columns: Sequence[str]) -> Self: + dtypes = self._version.dtypes + + schema = self.collect_schema() + for col_to_explode in columns: + dtype = schema[col_to_explode] + + if dtype != dtypes.List: + msg = ( + f"`explode` operation not supported for dtype `{dtype}`, " + "expected List type" + ) + raise InvalidOperationError(msg) + + column_names = self.columns + + if len(columns) != 1: + msg = ( + "Exploding on multiple columns is not supported with SparkLike backend since " + "we cannot guarantee that the exploded columns have matching element counts." + ) + raise NotImplementedError(msg) + + if self._implementation.is_pyspark() or self._implementation.is_pyspark_connect(): + return self._with_native( + self.native.select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != columns[0] + else self._F.explode_outer(col_name).alias(col_name) + for col_name in column_names + ] + ) + ) + elif self._implementation.is_sqlframe(): + # Not every sqlframe dialect supports `explode_outer` function + # (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289) + # therefore we simply explode the array column which will ignore nulls and + # zero sized arrays, and append these specific condition with nulls (to + # match polars behavior). + + def null_condition(col_name: str) -> Column: + return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0) + + return self._with_native( + self.native.select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != columns[0] + else self._F.explode(col_name).alias(col_name) + for col_name in column_names + ] + ).union( + self.native.filter(null_condition(columns[0])).select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != columns[0] + else self._F.lit(None).alias(col_name) + for col_name in column_names + ] + ) + ) + ) + else: # pragma: no cover + msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + + def unpivot( + self, + on: Sequence[str] | None, + index: Sequence[str] | None, + variable_name: str, + value_name: str, + ) -> Self: + if self._implementation.is_sqlframe(): + if variable_name == "": + msg = "`variable_name` cannot be empty string for sqlframe backend." + raise NotImplementedError(msg) + + if value_name == "": + msg = "`value_name` cannot be empty string for sqlframe backend." + raise NotImplementedError(msg) + else: # pragma: no cover + pass + + ids = tuple(index) if index else () + values = ( + tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on) + ) + unpivoted_native_frame = self.native.unpivot( + ids=ids, + values=values, + variableColumnName=variable_name, + valueColumnName=value_name, + ) + if index is None: + unpivoted_native_frame = unpivoted_native_frame.drop(*ids) + return self._with_native(unpivoted_native_frame) + + gather_every = not_implemented.deprecated( + "`LazyFrame.gather_every` is deprecated and will be removed in a future version." + ) + join_asof = not_implemented() + tail = not_implemented.deprecated( + "`LazyFrame.tail` is deprecated and will be removed in a future version." + ) + with_row_index = not_implemented() diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr.py new file mode 100644 index 0000000..5c42dbb --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr.py @@ -0,0 +1,930 @@ +from __future__ import annotations + +import operator +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Iterator, + Literal, + Mapping, + Sequence, + cast, +) + +from narwhals._compliant import LazyExpr +from narwhals._compliant.window import WindowInputs +from narwhals._expression_parsing import ExprKind +from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace +from narwhals._spark_like.expr_list import SparkLikeExprListNamespace +from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace +from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace +from narwhals._spark_like.utils import ( + import_functions, + import_native_dtypes, + import_window, + narwhals_to_native_dtype, +) +from narwhals._utils import Implementation, not_implemented, parse_version +from narwhals.dependencies import get_pyspark + +if TYPE_CHECKING: + from sqlframe.base.column import Column + from sqlframe.base.window import Window, WindowSpec + from typing_extensions import Self, TypeAlias + + from narwhals._compliant.typing import ( + AliasNames, + EvalNames, + EvalSeries, + WindowFunction, + ) + from narwhals._expression_parsing import ExprMetadata + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + from narwhals._spark_like.namespace import SparkLikeNamespace + from narwhals._utils import Version, _FullContext + from narwhals.typing import ( + FillNullStrategy, + IntoDType, + NonNestedLiteral, + NumericLiteral, + RankMethod, + TemporalLiteral, + ) + + NativeRankMethod: TypeAlias = Literal["rank", "dense_rank", "row_number"] + SparkWindowFunction = WindowFunction[SparkLikeLazyFrame, Column] + SparkWindowInputs = WindowInputs[Column] + + +class SparkLikeExpr(LazyExpr["SparkLikeLazyFrame", "Column"]): + _REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = { + "min": "rank", + "max": "rank", + "average": "rank", + "dense": "dense_rank", + "ordinal": "row_number", + } + + def __init__( + self, + call: EvalSeries[SparkLikeLazyFrame, Column], + window_function: SparkWindowFunction | None = None, + *, + evaluate_output_names: EvalNames[SparkLikeLazyFrame], + alias_output_names: AliasNames | None, + backend_version: tuple[int, ...], + version: Version, + implementation: Implementation, + ) -> None: + self._call = call + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names + self._backend_version = backend_version + self._version = version + self._implementation = implementation + self._metadata: ExprMetadata | None = None + self._window_function: SparkWindowFunction | None = window_function + + @property + def window_function(self) -> SparkWindowFunction: + def default_window_func( + df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs + ) -> list[Column]: + assert not window_inputs.order_by # noqa: S101 + return [ + expr.over(self.partition_by(*window_inputs.partition_by)) + for expr in self(df) + ] + + return self._window_function or default_window_func + + def __call__(self, df: SparkLikeLazyFrame) -> Sequence[Column]: + return self._call(df) + + def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + if kind is ExprKind.LITERAL: + return self + return self.over([self._F.lit(1)], []) + + @property + def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202, N802 + if TYPE_CHECKING: + from sqlframe.base import functions + + return functions + else: + return import_functions(self._implementation) + + @property + def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202 + if TYPE_CHECKING: + from sqlframe.base import types + + return types + else: + return import_native_dtypes(self._implementation) + + @property + def _Window(self) -> type[Window]: # noqa: N802 + if TYPE_CHECKING: + from sqlframe.base.window import Window + + return Window + else: + return import_window(self._implementation) + + def _sort( + self, *cols: Column | str, descending: bool = False, nulls_last: bool = False + ) -> Iterator[Column]: + F = self._F # noqa: N806 + mapping = { + (False, False): F.asc_nulls_first, + (False, True): F.asc_nulls_last, + (True, False): F.desc_nulls_first, + (True, True): F.desc_nulls_last, + } + sort = mapping[(descending, nulls_last)] + yield from (sort(col) for col in cols) + + def partition_by(self, *cols: Column | str) -> WindowSpec: + """Wraps `Window().paritionBy`, with default and `WindowInputs` handling.""" + return self._Window.partitionBy(*cols or [self._F.lit(1)]) + + def __narwhals_expr__(self) -> None: ... + + def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover + # Unused, just for compatibility with PandasLikeExpr + from narwhals._spark_like.namespace import SparkLikeNamespace + + return SparkLikeNamespace( + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def _with_window_function(self, window_function: SparkWindowFunction) -> Self: + return self.__class__( + self._call, + window_function, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + @classmethod + def _alias_native(cls, expr: Column, name: str) -> Column: + return expr.alias(name) + + def _cum_window_func( + self, + *, + reverse: bool, + func_name: Literal["sum", "max", "min", "count", "product"], + ) -> SparkWindowFunction: + def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]: + window = ( + self.partition_by(*inputs.partition_by) + .orderBy( + *self._sort(*inputs.order_by, descending=reverse, nulls_last=reverse) + ) + .rowsBetween(self._Window.unboundedPreceding, 0) + ) + return [ + getattr(self._F, func_name)(expr).over(window) for expr in self._call(df) + ] + + return func + + def _rolling_window_func( + self, + *, + func_name: Literal["sum", "mean", "std", "var"], + center: bool, + window_size: int, + min_samples: int, + ddof: int | None = None, + ) -> SparkWindowFunction: + supported_funcs = ["sum", "mean", "std", "var"] + if center: + half = (window_size - 1) // 2 + remainder = (window_size - 1) % 2 + start = self._Window.currentRow - half - remainder + end = self._Window.currentRow + half + else: + start = self._Window.currentRow - window_size + 1 + end = self._Window.currentRow + + def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]: + window = ( + self.partition_by(*inputs.partition_by) + .orderBy(*self._sort(*inputs.order_by)) + .rowsBetween(start, end) + ) + if func_name in {"sum", "mean"}: + func_: str = func_name + elif func_name == "var" and ddof == 0: + func_ = "var_pop" + elif func_name in "var" and ddof == 1: + func_ = "var_samp" + elif func_name == "std" and ddof == 0: + func_ = "stddev_pop" + elif func_name == "std" and ddof == 1: + func_ = "stddev_samp" + elif func_name in {"var", "std"}: # pragma: no cover + msg = f"Only ddof=0 and ddof=1 are currently supported for rolling_{func_name}." + raise ValueError(msg) + else: # pragma: no cover + msg = f"Only the following functions are supported: {supported_funcs}.\nGot: {func_name}." + raise ValueError(msg) + return [ + self._F.when( + self._F.count(expr).over(window) >= min_samples, + getattr(self._F, func_)(expr).over(window), + ) + for expr in self._call(df) + ] + + return func + + @classmethod + def from_column_names( + cls: type[Self], + evaluate_column_names: EvalNames[SparkLikeLazyFrame], + /, + *, + context: _FullContext, + ) -> Self: + def func(df: SparkLikeLazyFrame) -> list[Column]: + return [df._F.col(col_name) for col_name in evaluate_column_names(df)] + + return cls( + func, + evaluate_output_names=evaluate_column_names, + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + implementation=context._implementation, + ) + + @classmethod + def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: + def func(df: SparkLikeLazyFrame) -> list[Column]: + columns = df.columns + return [df._F.col(columns[i]) for i in column_indices] + + return cls( + func, + evaluate_output_names=cls._eval_names_indices(column_indices), + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + implementation=context._implementation, + ) + + def _callable_to_eval_series( + self, call: Callable[..., Column], /, **expressifiable_args: Self | Any + ) -> EvalSeries[SparkLikeLazyFrame, Column]: + def func(df: SparkLikeLazyFrame) -> list[Column]: + native_series_list = self(df) + other_native_series = { + key: df._evaluate_expr(value) + if self._is_expr(value) + else self._F.lit(value) + for key, value in expressifiable_args.items() + } + return [ + call(native_series, **other_native_series) + for native_series in native_series_list + ] + + return func + + def _push_down_window_function( + self, call: Callable[..., Column], /, **expressifiable_args: Self | Any + ) -> SparkWindowFunction: + def window_f( + df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs + ) -> Sequence[Column]: + # If a function `f` is elementwise, and `g` is another function, then + # - `f(g) over (window)` + # - `f(g over (window)) + # are equivalent. + # Make sure to only use with if `call` is elementwise! + native_series_list = self.window_function(df, window_inputs) + other_native_series = { + key: df._evaluate_window_expr(value, window_inputs) + if self._is_expr(value) + else self._F.lit(value) + for key, value in expressifiable_args.items() + } + return [ + call(native_series, **other_native_series) + for native_series in native_series_list + ] + + return window_f + + def _with_callable( + self, call: Callable[..., Column], /, **expressifiable_args: Self | Any + ) -> Self: + return self.__class__( + self._callable_to_eval_series(call, **expressifiable_args), + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def _with_elementwise( + self, call: Callable[..., Column], /, **expressifiable_args: Self | Any + ) -> Self: + return self.__class__( + self._callable_to_eval_series(call, **expressifiable_args), + self._push_down_window_function(call, **expressifiable_args), + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def _with_binary(self, op: Callable[..., Column], other: Self | Any) -> Self: + return self.__class__( + self._callable_to_eval_series(op, other=other), + self._push_down_window_function(op, other=other), + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: + return type(self)( + self._call, + self._window_function, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=func, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override] + return self._with_binary(lambda expr, other: expr.__eq__(other), other) + + def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override] + return self._with_binary(lambda expr, other: expr.__ne__(other), other) + + def __add__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__add__(other), other) + + def __sub__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__sub__(other), other) + + def __rsub__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: other.__sub__(expr), other).alias( + "literal" + ) + + def __mul__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__mul__(other), other) + + def __truediv__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__truediv__(other), other) + + def __rtruediv__(self, other: SparkLikeExpr) -> Self: + return self._with_binary( + lambda expr, other: other.__truediv__(expr), other + ).alias("literal") + + def __floordiv__(self, other: SparkLikeExpr) -> Self: + def _floordiv(expr: Column, other: Column) -> Column: + return self._F.floor(expr / other) + + return self._with_binary(_floordiv, other) + + def __rfloordiv__(self, other: SparkLikeExpr) -> Self: + def _rfloordiv(expr: Column, other: Column) -> Column: + return self._F.floor(other / expr) + + return self._with_binary(_rfloordiv, other).alias("literal") + + def __pow__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__pow__(other), other) + + def __rpow__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: other.__pow__(expr), other).alias( + "literal" + ) + + def __mod__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__mod__(other), other) + + def __rmod__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: other.__mod__(expr), other).alias( + "literal" + ) + + def __ge__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__ge__(other), other) + + def __gt__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr > other, other) + + def __le__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__le__(other), other) + + def __lt__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__lt__(other), other) + + def __and__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__and__(other), other) + + def __or__(self, other: SparkLikeExpr) -> Self: + return self._with_binary(lambda expr, other: expr.__or__(other), other) + + def __invert__(self) -> Self: + invert = cast("Callable[..., Column]", operator.invert) + return self._with_elementwise(invert) + + def abs(self) -> Self: + return self._with_elementwise(self._F.abs) + + def all(self) -> Self: + def f(expr: Column) -> Column: + return self._F.coalesce(self._F.bool_and(expr), self._F.lit(True)) # noqa: FBT003 + + def window_f( + df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs + ) -> Sequence[Column]: + return [ + self._F.coalesce( + self._F.bool_and(expr).over( + self.partition_by(*window_inputs.partition_by) + ), + self._F.lit(True), # noqa: FBT003 + ) + for expr in self(df) + ] + + return self._with_callable(f)._with_window_function(window_f) + + def any(self) -> Self: + def f(expr: Column) -> Column: + return self._F.coalesce(self._F.bool_or(expr), self._F.lit(False)) # noqa: FBT003 + + def window_f( + df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs + ) -> Sequence[Column]: + return [ + self._F.coalesce( + self._F.bool_or(expr).over( + self.partition_by(*window_inputs.partition_by) + ), + self._F.lit(False), # noqa: FBT003 + ) + for expr in self(df) + ] + + return self._with_callable(f)._with_window_function(window_f) + + def cast(self, dtype: IntoDType) -> Self: + def _cast(expr: Column) -> Column: + spark_dtype = narwhals_to_native_dtype( + dtype, self._version, self._native_dtypes + ) + return expr.cast(spark_dtype) + + return self._with_elementwise(_cast) + + def count(self) -> Self: + return self._with_callable(self._F.count) + + def max(self) -> Self: + return self._with_callable(self._F.max) + + def mean(self) -> Self: + return self._with_callable(self._F.mean) + + def median(self) -> Self: + def _median(expr: Column) -> Column: + if ( + self._implementation + in {Implementation.PYSPARK, Implementation.PYSPARK_CONNECT} + and (pyspark := get_pyspark()) is not None + and parse_version(pyspark) < (3, 4) + ): # pragma: no cover + # Use percentile_approx with default accuracy parameter (10000) + return self._F.percentile_approx(expr.cast("double"), 0.5) + + return self._F.median(expr) + + return self._with_callable(_median) + + def min(self) -> Self: + return self._with_callable(self._F.min) + + def null_count(self) -> Self: + def _null_count(expr: Column) -> Column: + return self._F.count_if(self._F.isnull(expr)) + + return self._with_callable(_null_count) + + def sum(self) -> Self: + def f(expr: Column) -> Column: + return self._F.coalesce(self._F.sum(expr), self._F.lit(0)) + + def window_f( + df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs + ) -> Sequence[Column]: + return [ + self._F.coalesce( + self._F.sum(expr).over( + self.partition_by(*window_inputs.partition_by) + ), + self._F.lit(0), + ) + for expr in self(df) + ] + + return self._with_callable(f)._with_window_function(window_f) + + def std(self, ddof: int) -> Self: + F = self._F # noqa: N806 + if ddof == 0: + return self._with_callable(F.stddev_pop) + if ddof == 1: + return self._with_callable(F.stddev_samp) + + def func(expr: Column) -> Column: + n_rows = F.count(expr) + return F.stddev_samp(expr) * F.sqrt((n_rows - 1) / (n_rows - ddof)) + + return self._with_callable(func) + + def var(self, ddof: int) -> Self: + F = self._F # noqa: N806 + if ddof == 0: + return self._with_callable(F.var_pop) + if ddof == 1: + return self._with_callable(F.var_samp) + + def func(expr: Column) -> Column: + n_rows = F.count(expr) + return F.var_samp(expr) * (n_rows - 1) / (n_rows - ddof) + + return self._with_callable(func) + + def clip( + self, + lower_bound: Self | NumericLiteral | TemporalLiteral | None = None, + upper_bound: Self | NumericLiteral | TemporalLiteral | None = None, + ) -> Self: + def _clip_lower(expr: Column, lower_bound: Column) -> Column: + result = expr + return self._F.when(result < lower_bound, lower_bound).otherwise(result) + + def _clip_upper(expr: Column, upper_bound: Column) -> Column: + result = expr + return self._F.when(result > upper_bound, upper_bound).otherwise(result) + + def _clip_both(expr: Column, lower_bound: Column, upper_bound: Column) -> Column: + return ( + self._F.when(expr < lower_bound, lower_bound) + .when(expr > upper_bound, upper_bound) + .otherwise(expr) + ) + + if lower_bound is None: + return self._with_elementwise(_clip_upper, upper_bound=upper_bound) + if upper_bound is None: + return self._with_elementwise(_clip_lower, lower_bound=lower_bound) + return self._with_elementwise( + _clip_both, lower_bound=lower_bound, upper_bound=upper_bound + ) + + def is_finite(self) -> Self: + def _is_finite(expr: Column) -> Column: + # A value is finite if it's not NaN, and not infinite, while NULLs should be + # preserved + is_finite_condition = ( + ~self._F.isnan(expr) + & (expr != self._F.lit(float("inf"))) + & (expr != self._F.lit(float("-inf"))) + ) + return self._F.when(~self._F.isnull(expr), is_finite_condition).otherwise( + None + ) + + return self._with_elementwise(_is_finite) + + def is_in(self, values: Sequence[Any]) -> Self: + def _is_in(expr: Column) -> Column: + return expr.isin(values) if values else self._F.lit(False) # noqa: FBT003 + + return self._with_elementwise(_is_in) + + def is_unique(self) -> Self: + def _is_unique(expr: Column, *partition_by: str | Column) -> Column: + return self._F.count("*").over(self.partition_by(expr, *partition_by)) == 1 + + def _unpartitioned_is_unique(expr: Column) -> Column: + return _is_unique(expr) + + def _partitioned_is_unique( + df: SparkLikeLazyFrame, inputs: SparkWindowInputs + ) -> Sequence[Column]: + assert not inputs.order_by # noqa: S101 + return [_is_unique(expr, *inputs.partition_by) for expr in self(df)] + + return self._with_callable(_unpartitioned_is_unique)._with_window_function( + _partitioned_is_unique + ) + + def len(self) -> Self: + def _len(_expr: Column) -> Column: + # Use count(*) to count all rows including nulls + return self._F.count("*") + + return self._with_callable(_len) + + def round(self, decimals: int) -> Self: + def _round(expr: Column) -> Column: + return self._F.round(expr, decimals) + + return self._with_elementwise(_round) + + def skew(self) -> Self: + return self._with_callable(self._F.skewness) + + def n_unique(self) -> Self: + def _n_unique(expr: Column) -> Column: + return self._F.count_distinct(expr) + self._F.max( + self._F.isnull(expr).cast(self._native_dtypes.IntegerType()) + ) + + return self._with_callable(_n_unique) + + def over(self, partition_by: Sequence[str | Column], order_by: Sequence[str]) -> Self: + def func(df: SparkLikeLazyFrame) -> Sequence[Column]: + return self.window_function(df, WindowInputs(partition_by, order_by)) + + return self.__class__( + func, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def is_null(self) -> Self: + return self._with_elementwise(self._F.isnull) + + def is_nan(self) -> Self: + def _is_nan(expr: Column) -> Column: + return self._F.when(self._F.isnull(expr), None).otherwise(self._F.isnan(expr)) + + return self._with_elementwise(_is_nan) + + def shift(self, n: int) -> Self: + def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]: + window = self.partition_by(*inputs.partition_by).orderBy( + *self._sort(*inputs.order_by) + ) + return [self._F.lag(expr, n).over(window) for expr in self(df)] + + return self._with_window_function(func) + + def is_first_distinct(self) -> Self: + def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]: + return [ + self._F.row_number().over( + self.partition_by(*inputs.partition_by, expr).orderBy( + *self._sort(*inputs.order_by) + ) + ) + == 1 + for expr in self(df) + ] + + return self._with_window_function(func) + + def is_last_distinct(self) -> Self: + def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]: + return [ + self._F.row_number().over( + self.partition_by(*inputs.partition_by, expr).orderBy( + *self._sort(*inputs.order_by, descending=True, nulls_last=True) + ) + ) + == 1 + for expr in self(df) + ] + + return self._with_window_function(func) + + def diff(self) -> Self: + def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]: + window = self.partition_by(*inputs.partition_by).orderBy( + *self._sort(*inputs.order_by) + ) + return [expr - self._F.lag(expr).over(window) for expr in self(df)] + + return self._with_window_function(func) + + def cum_sum(self, *, reverse: bool) -> Self: + return self._with_window_function( + self._cum_window_func(reverse=reverse, func_name="sum") + ) + + def cum_max(self, *, reverse: bool) -> Self: + return self._with_window_function( + self._cum_window_func(reverse=reverse, func_name="max") + ) + + def cum_min(self, *, reverse: bool) -> Self: + return self._with_window_function( + self._cum_window_func(reverse=reverse, func_name="min") + ) + + def cum_count(self, *, reverse: bool) -> Self: + return self._with_window_function( + self._cum_window_func(reverse=reverse, func_name="count") + ) + + def cum_prod(self, *, reverse: bool) -> Self: + return self._with_window_function( + self._cum_window_func(reverse=reverse, func_name="product") + ) + + def fill_null( + self, + value: Self | NonNestedLiteral, + strategy: FillNullStrategy | None, + limit: int | None, + ) -> Self: + if strategy is not None: + + def _fill_with_strategy( + df: SparkLikeLazyFrame, inputs: SparkWindowInputs + ) -> Sequence[Column]: + fn = self._F.last_value if strategy == "forward" else self._F.first_value + if strategy == "forward": + start = self._Window.unboundedPreceding if limit is None else -limit + end = self._Window.currentRow + else: + start = self._Window.currentRow + end = self._Window.unboundedFollowing if limit is None else limit + return [ + fn(expr, ignoreNulls=True).over( + self.partition_by(*inputs.partition_by) + .orderBy(*self._sort(*inputs.order_by)) + .rowsBetween(start, end) + ) + for expr in self(df) + ] + + return self._with_window_function(_fill_with_strategy) + + def _fill_constant(expr: Column, value: Column) -> Column: + return self._F.ifnull(expr, value) + + return self._with_elementwise(_fill_constant, value=value) + + def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: + return self._with_window_function( + self._rolling_window_func( + func_name="sum", + center=center, + window_size=window_size, + min_samples=min_samples, + ) + ) + + def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: + return self._with_window_function( + self._rolling_window_func( + func_name="mean", + center=center, + window_size=window_size, + min_samples=min_samples, + ) + ) + + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + return self._with_window_function( + self._rolling_window_func( + func_name="var", + center=center, + window_size=window_size, + min_samples=min_samples, + ddof=ddof, + ) + ) + + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + return self._with_window_function( + self._rolling_window_func( + func_name="std", + center=center, + window_size=window_size, + min_samples=min_samples, + ddof=ddof, + ) + ) + + def rank(self, method: RankMethod, *, descending: bool) -> Self: + func_name = self._REMAP_RANK_METHOD[method] + + def _rank( + expr: Column, + *, + descending: bool, + partition_by: Sequence[str | Column] | None = None, + ) -> Column: + order_by = self._sort(expr, descending=descending, nulls_last=True) + if partition_by is not None: + window = self.partition_by(*partition_by).orderBy(*order_by) + count_window = self.partition_by(*partition_by, expr) + else: + window = self.partition_by().orderBy(*order_by) + count_window = self.partition_by(expr) + if method == "max": + rank_expr = ( + getattr(self._F, func_name)().over(window) + + self._F.count(expr).over(count_window) + - self._F.lit(1) + ) + + elif method == "average": + rank_expr = getattr(self._F, func_name)().over(window) + ( + self._F.count(expr).over(count_window) - self._F.lit(1) + ) / self._F.lit(2) + + else: + rank_expr = getattr(self._F, func_name)().over(window) + + return self._F.when(expr.isNotNull(), rank_expr) + + def _unpartitioned_rank(expr: Column) -> Column: + return _rank(expr, descending=descending) + + def _partitioned_rank( + df: SparkLikeLazyFrame, inputs: SparkWindowInputs + ) -> Sequence[Column]: + assert not inputs.order_by # noqa: S101 + return [ + _rank(expr, descending=descending, partition_by=inputs.partition_by) + for expr in self(df) + ] + + return self._with_callable(_unpartitioned_rank)._with_window_function( + _partitioned_rank + ) + + def log(self, base: float) -> Self: + def _log(expr: Column) -> Column: + return ( + self._F.when(expr < 0, self._F.lit(float("nan"))) + .when(expr == 0, self._F.lit(float("-inf"))) + .otherwise(self._F.log(float(base), expr)) + ) + + return self._with_elementwise(_log) + + def exp(self) -> Self: + def _exp(expr: Column) -> Column: + return self._F.exp(expr) + + return self._with_elementwise(_exp) + + @property + def str(self) -> SparkLikeExprStringNamespace: + return SparkLikeExprStringNamespace(self) + + @property + def dt(self) -> SparkLikeExprDateTimeNamespace: + return SparkLikeExprDateTimeNamespace(self) + + @property + def list(self) -> SparkLikeExprListNamespace: + return SparkLikeExprListNamespace(self) + + @property + def struct(self) -> SparkLikeExprStructNamespace: + return SparkLikeExprStructNamespace(self) + + drop_nulls = not_implemented() + unique = not_implemented() + quantile = not_implemented() diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_dt.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_dt.py new file mode 100644 index 0000000..c5c76e3 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_dt.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from narwhals._duration import parse_interval_string +from narwhals._spark_like.utils import ( + UNITS_DICT, + fetch_session_time_zone, + strptime_to_pyspark_format, +) + +if TYPE_CHECKING: + from sqlframe.base.column import Column + + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + from narwhals._spark_like.expr import SparkLikeExpr + + +class SparkLikeExprDateTimeNamespace: + def __init__(self, expr: SparkLikeExpr) -> None: + self._compliant_expr = expr + + def to_string(self, format: str) -> SparkLikeExpr: + F = self._compliant_expr._F # noqa: N806 + + def _to_string(_input: Column) -> Column: + # Handle special formats + if format == "%G-W%V": + return self._format_iso_week(_input) + if format == "%G-W%V-%u": + return self._format_iso_week_with_day(_input) + + format_, suffix = self._format_microseconds(_input, format) + + # Convert Python format to PySpark format + pyspark_fmt = strptime_to_pyspark_format(format_) + + result = F.date_format(_input, pyspark_fmt) + if "T" in format_: + # `strptime_to_pyspark_format` replaces "T" with " " since pyspark + # does not support the literal "T" in `date_format`. + # If no other spaces are in the given format, then we can revert this + # operation, otherwise we raise an exception. + if " " not in format_: + result = F.replace(result, F.lit(" "), F.lit("T")) + else: # pragma: no cover + msg = ( + "`dt.to_string` with a format that contains both spaces and " + " the literal 'T' is not supported for spark-like backends." + ) + raise NotImplementedError(msg) + + return F.concat(result, *suffix) + + return self._compliant_expr._with_callable(_to_string) + + def date(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.to_date) + + def year(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.year) + + def month(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.month) + + def day(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.day) + + def hour(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.hour) + + def minute(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.minute) + + def second(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.second) + + def millisecond(self) -> SparkLikeExpr: + def _millisecond(expr: Column) -> Column: + return self._compliant_expr._F.floor( + (self._compliant_expr._F.unix_micros(expr) % 1_000_000) / 1000 + ) + + return self._compliant_expr._with_callable(_millisecond) + + def microsecond(self) -> SparkLikeExpr: + def _microsecond(expr: Column) -> Column: + return self._compliant_expr._F.unix_micros(expr) % 1_000_000 + + return self._compliant_expr._with_callable(_microsecond) + + def nanosecond(self) -> SparkLikeExpr: + def _nanosecond(expr: Column) -> Column: + return (self._compliant_expr._F.unix_micros(expr) % 1_000_000) * 1000 + + return self._compliant_expr._with_callable(_nanosecond) + + def ordinal_day(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.dayofyear) + + def weekday(self) -> SparkLikeExpr: + def _weekday(expr: Column) -> Column: + # PySpark's dayofweek returns 1-7 for Sunday-Saturday + return (self._compliant_expr._F.dayofweek(expr) + 6) % 7 + + return self._compliant_expr._with_callable(_weekday) + + def truncate(self, every: str) -> SparkLikeExpr: + multiple, unit = parse_interval_string(every) + if multiple != 1: + msg = f"Only multiple 1 is currently supported for Spark-like.\nGot {multiple!s}." + raise ValueError(msg) + if unit == "ns": + msg = "Truncating to nanoseconds is not yet supported for Spark-like." + raise NotImplementedError(msg) + format = UNITS_DICT[unit] + + def _truncate(expr: Column) -> Column: + return self._compliant_expr._F.date_trunc(format, expr) + + return self._compliant_expr._with_callable(_truncate) + + def _no_op_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover + def func(df: SparkLikeLazyFrame) -> Sequence[Column]: + native_series_list = self._compliant_expr(df) + conn_time_zone = fetch_session_time_zone(df.native.sparkSession) + if conn_time_zone != time_zone: + msg = ( + "PySpark stores the time zone in the session, rather than in the " + f"data type, so changing the timezone to anything other than {conn_time_zone} " + " (the current session time zone) is not supported." + ) + raise NotImplementedError(msg) + return native_series_list + + return self._compliant_expr.__class__( + func, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=self._compliant_expr._alias_output_names, + backend_version=self._compliant_expr._backend_version, + version=self._compliant_expr._version, + implementation=self._compliant_expr._implementation, + ) + + def convert_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover + return self._no_op_time_zone(time_zone) + + def replace_time_zone( + self, time_zone: str | None + ) -> SparkLikeExpr: # pragma: no cover + if time_zone is None: + return self._compliant_expr._with_callable( + lambda _input: _input.cast("timestamp_ntz") + ) + else: + return self._no_op_time_zone(time_zone) + + def _format_iso_week_with_day(self, _input: Column) -> Column: + """Format datetime as ISO week string with day.""" + F = self._compliant_expr._F # noqa: N806 + + year = F.date_format(_input, "yyyy") + week = F.lpad(F.weekofyear(_input).cast("string"), 2, "0") + day = F.dayofweek(_input) + # Adjust Sunday from 1 to 7 + day = F.when(day == 1, 7).otherwise(day - 1) + return F.concat(year, F.lit("-W"), week, F.lit("-"), day.cast("string")) + + def _format_iso_week(self, _input: Column) -> Column: + """Format datetime as ISO week string.""" + F = self._compliant_expr._F # noqa: N806 + + year = F.date_format(_input, "yyyy") + week = F.lpad(F.weekofyear(_input).cast("string"), 2, "0") + return F.concat(year, F.lit("-W"), week) + + def _format_microseconds( + self, _input: Column, format: str + ) -> tuple[str, tuple[Column, ...]]: + """Format microseconds if present in format, else it's a no-op.""" + F = self._compliant_expr._F # noqa: N806 + + suffix: tuple[Column, ...] + if format.endswith((".%f", "%.f")): + import re + + micros = F.unix_micros(_input) % 1_000_000 + micros_str = F.lpad(micros.cast("string"), 6, "0") + suffix = (F.lit("."), micros_str) + format_ = re.sub(r"(.%|%.)f$", "", format) + return format_, suffix + + return format, () diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_list.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_list.py new file mode 100644 index 0000000..b59eb83 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_list.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from narwhals._spark_like.expr import SparkLikeExpr + + +class SparkLikeExprListNamespace: + def __init__(self, expr: SparkLikeExpr) -> None: + self._compliant_expr = expr + + def len(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.array_size) diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_str.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_str.py new file mode 100644 index 0000000..7c65952 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_str.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING + +from narwhals._spark_like.utils import strptime_to_pyspark_format +from narwhals._utils import _is_naive_format + +if TYPE_CHECKING: + from sqlframe.base.column import Column + + from narwhals._spark_like.expr import SparkLikeExpr + + +class SparkLikeExprStringNamespace: + def __init__(self, expr: SparkLikeExpr) -> None: + self._compliant_expr = expr + + def len_chars(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.char_length) + + def replace_all(self, pattern: str, value: str, *, literal: bool) -> SparkLikeExpr: + def func(expr: Column) -> Column: + replace_all_func = ( + self._compliant_expr._F.replace + if literal + else self._compliant_expr._F.regexp_replace + ) + return replace_all_func( + expr, + self._compliant_expr._F.lit(pattern), # pyright: ignore[reportArgumentType] + self._compliant_expr._F.lit(value), # pyright: ignore[reportArgumentType] + ) + + return self._compliant_expr._with_callable(func) + + def strip_chars(self, characters: str | None) -> SparkLikeExpr: + import string + + def func(expr: Column) -> Column: + to_remove = characters if characters is not None else string.whitespace + return self._compliant_expr._F.btrim( + expr, self._compliant_expr._F.lit(to_remove) + ) + + return self._compliant_expr._with_callable(func) + + def starts_with(self, prefix: str) -> SparkLikeExpr: + return self._compliant_expr._with_callable( + lambda expr: self._compliant_expr._F.startswith( + expr, self._compliant_expr._F.lit(prefix) + ) + ) + + def ends_with(self, suffix: str) -> SparkLikeExpr: + return self._compliant_expr._with_callable( + lambda expr: self._compliant_expr._F.endswith( + expr, self._compliant_expr._F.lit(suffix) + ) + ) + + def contains(self, pattern: str, *, literal: bool) -> SparkLikeExpr: + def func(expr: Column) -> Column: + contains_func = ( + self._compliant_expr._F.contains + if literal + else self._compliant_expr._F.regexp + ) + return contains_func(expr, self._compliant_expr._F.lit(pattern)) + + return self._compliant_expr._with_callable(func) + + def slice(self, offset: int, length: int | None) -> SparkLikeExpr: + # From the docs: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.substring.html + # The position is not zero based, but 1 based index. + def func(expr: Column) -> Column: + col_length = self._compliant_expr._F.char_length(expr) + + _offset = ( + col_length + self._compliant_expr._F.lit(offset + 1) + if offset < 0 + else self._compliant_expr._F.lit(offset + 1) + ) + _length = ( + self._compliant_expr._F.lit(length) if length is not None else col_length + ) + return expr.substr(_offset, _length) + + return self._compliant_expr._with_callable(func) + + def split(self, by: str) -> SparkLikeExpr: + return self._compliant_expr._with_callable( + lambda expr: self._compliant_expr._F.split(expr, by) + ) + + def to_uppercase(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.upper) + + def to_lowercase(self) -> SparkLikeExpr: + return self._compliant_expr._with_callable(self._compliant_expr._F.lower) + + def to_datetime(self, format: str | None) -> SparkLikeExpr: + F = self._compliant_expr._F # noqa: N806 + if not format: + function = F.to_timestamp + elif _is_naive_format(format): + function = partial( + F.to_timestamp_ntz, format=F.lit(strptime_to_pyspark_format(format)) + ) + else: + format = strptime_to_pyspark_format(format) + function = partial(F.to_timestamp, format=format) + return self._compliant_expr._with_callable( + lambda expr: function(F.replace(expr, F.lit("T"), F.lit(" "))) + ) diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_struct.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_struct.py new file mode 100644 index 0000000..03e6d71 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/expr_struct.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sqlframe.base.column import Column + + from narwhals._spark_like.expr import SparkLikeExpr + + +class SparkLikeExprStructNamespace: + def __init__(self, expr: SparkLikeExpr) -> None: + self._compliant_expr = expr + + def field(self, name: str) -> SparkLikeExpr: + def func(expr: Column) -> Column: + return expr.getField(name) + + return self._compliant_expr._with_callable(func).alias(name) diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/group_by.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/group_by.py new file mode 100644 index 0000000..4c63d77 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/group_by.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from narwhals._compliant import LazyGroupBy + +if TYPE_CHECKING: + from sqlframe.base.column import Column # noqa: F401 + + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + from narwhals._spark_like.expr import SparkLikeExpr + + +class SparkLikeLazyGroupBy(LazyGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "Column"]): + def __init__( + self, + df: SparkLikeLazyFrame, + keys: Sequence[SparkLikeExpr] | Sequence[str], + /, + *, + drop_null_keys: bool, + ) -> None: + frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys) + self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame + + def agg(self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: + result = ( + self.compliant.native.groupBy(*self._keys).agg(*agg_columns) + if (agg_columns := list(self._evaluate_exprs(exprs))) + else self.compliant.native.select(*self._keys).dropDuplicates() + ) + + return self.compliant._with_native(result).rename( + dict(zip(self._keys, self._output_key_names)) + ) diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/namespace.py new file mode 100644 index 0000000..7ad42ff --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/namespace.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import operator +from functools import reduce +from typing import TYPE_CHECKING, Callable, Iterable, Sequence + +from narwhals._compliant import LazyNamespace, LazyThen, LazyWhen +from narwhals._expression_parsing import ( + combine_alias_output_names, + combine_evaluate_output_names, +) +from narwhals._spark_like.dataframe import SparkLikeLazyFrame +from narwhals._spark_like.expr import SparkLikeExpr +from narwhals._spark_like.selectors import SparkLikeSelectorNamespace +from narwhals._spark_like.utils import ( + import_functions, + import_native_dtypes, + narwhals_to_native_dtype, +) + +if TYPE_CHECKING: + from sqlframe.base.column import Column + + from narwhals._spark_like.dataframe import SQLFrameDataFrame # noqa: F401 + from narwhals._spark_like.expr import SparkWindowInputs + from narwhals._utils import Implementation, Version + from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral + + +class SparkLikeNamespace( + LazyNamespace[SparkLikeLazyFrame, SparkLikeExpr, "SQLFrameDataFrame"] +): + def __init__( + self, + *, + backend_version: tuple[int, ...], + version: Version, + implementation: Implementation, + ) -> None: + self._backend_version = backend_version + self._version = version + self._implementation = implementation + + @property + def selectors(self) -> SparkLikeSelectorNamespace: + return SparkLikeSelectorNamespace.from_namespace(self) + + @property + def _expr(self) -> type[SparkLikeExpr]: + return SparkLikeExpr + + @property + def _lazyframe(self) -> type[SparkLikeLazyFrame]: + return SparkLikeLazyFrame + + @property + def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202, N802 + if TYPE_CHECKING: + from sqlframe.base import functions + + return functions + else: + return import_functions(self._implementation) + + @property + def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202 + if TYPE_CHECKING: + from sqlframe.base import types + + return types + else: + return import_native_dtypes(self._implementation) + + def _with_elementwise( + self, func: Callable[[Iterable[Column]], Column], *exprs: SparkLikeExpr + ) -> SparkLikeExpr: + def call(df: SparkLikeLazyFrame) -> list[Column]: + cols = (col for _expr in exprs for col in _expr(df)) + return [func(cols)] + + def window_function( + df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs + ) -> list[Column]: + cols = ( + col for _expr in exprs for col in _expr.window_function(df, window_inputs) + ) + return [func(cols)] + + return self._expr( + call=call, + window_function=window_function, + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> SparkLikeExpr: + def _lit(df: SparkLikeLazyFrame) -> list[Column]: + column = df._F.lit(value) + if dtype: + native_dtype = narwhals_to_native_dtype( + dtype, version=self._version, spark_types=df._native_dtypes + ) + column = column.cast(native_dtype) + + return [column] + + return self._expr( + call=_lit, + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def len(self) -> SparkLikeExpr: + def func(df: SparkLikeLazyFrame) -> list[Column]: + return [df._F.count("*")] + + return self._expr( + func, + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def all_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: + def func(cols: Iterable[Column]) -> Column: + return reduce(operator.and_, cols) + + return self._with_elementwise(func, *exprs) + + def any_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: + def func(cols: Iterable[Column]) -> Column: + return reduce(operator.or_, cols) + + return self._with_elementwise(func, *exprs) + + def max_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: + def func(cols: Iterable[Column]) -> Column: + return self._F.greatest(*cols) + + return self._with_elementwise(func, *exprs) + + def min_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: + def func(cols: Iterable[Column]) -> Column: + return self._F.least(*cols) + + return self._with_elementwise(func, *exprs) + + def sum_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: + def func(cols: Iterable[Column]) -> Column: + return reduce( + operator.add, (self._F.coalesce(col, self._F.lit(0)) for col in cols) + ) + + return self._with_elementwise(func, *exprs) + + def mean_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: + def func(cols: Iterable[Column]) -> Column: + cols = list(cols) + F = exprs[0]._F # noqa: N806 + # PySpark before 3.5 doesn't have `try_divide`, SQLFrame doesn't have it. + divide = getattr(F, "try_divide", operator.truediv) + return divide( + reduce( + operator.add, (self._F.coalesce(col, self._F.lit(0)) for col in cols) + ), + reduce( + operator.add, + ( + col.isNotNull().cast(self._native_dtypes.IntegerType()) + for col in cols + ), + ), + ) + + return self._with_elementwise(func, *exprs) + + def concat( + self, items: Iterable[SparkLikeLazyFrame], *, how: ConcatMethod + ) -> SparkLikeLazyFrame: + dfs = [item._native_frame for item in items] + if how == "vertical": + cols_0 = dfs[0].columns + for i, df in enumerate(dfs[1:], start=1): + cols_current = df.columns + if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)): + msg = ( + "unable to vstack, column names don't match:\n" + f" - dataframe 0: {cols_0}\n" + f" - dataframe {i}: {cols_current}\n" + ) + raise TypeError(msg) + + return SparkLikeLazyFrame( + native_dataframe=reduce(lambda x, y: x.union(y), dfs), + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + if how == "diagonal": + return SparkLikeLazyFrame( + native_dataframe=reduce( + lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs + ), + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + raise NotImplementedError + + def concat_str( + self, *exprs: SparkLikeExpr, separator: str, ignore_nulls: bool + ) -> SparkLikeExpr: + def func(df: SparkLikeLazyFrame) -> list[Column]: + cols = [s for _expr in exprs for s in _expr(df)] + cols_casted = [s.cast(df._native_dtypes.StringType()) for s in cols] + null_mask = [df._F.isnull(s) for s in cols] + + if not ignore_nulls: + null_mask_result = reduce(operator.or_, null_mask) + result = df._F.when( + ~null_mask_result, + reduce( + lambda x, y: df._F.format_string(f"%s{separator}%s", x, y), + cols_casted, + ), + ).otherwise(df._F.lit(None)) + else: + init_value, *values = [ + df._F.when(~nm, col).otherwise(df._F.lit("")) + for col, nm in zip(cols_casted, null_mask) + ] + + separators = ( + df._F.when(nm, df._F.lit("")).otherwise(df._F.lit(separator)) + for nm in null_mask[:-1] + ) + result = reduce( + lambda x, y: df._F.format_string("%s%s", x, y), + ( + df._F.format_string("%s%s", s, v) + for s, v in zip(separators, values) + ), + init_value, + ) + + return [result] + + return self._expr( + call=func, + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + + def when(self, predicate: SparkLikeExpr) -> SparkLikeWhen: + return SparkLikeWhen.from_expr(predicate, context=self) + + +class SparkLikeWhen(LazyWhen[SparkLikeLazyFrame, "Column", SparkLikeExpr]): + @property + def _then(self) -> type[SparkLikeThen]: + return SparkLikeThen + + def __call__(self, df: SparkLikeLazyFrame) -> Sequence[Column]: + self.when = df._F.when + self.lit = df._F.lit + return super().__call__(df) + + def _window_function( + self, df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs + ) -> Sequence[Column]: + self.when = df._F.when + self.lit = df._F.lit + return super()._window_function(df, window_inputs) + + +class SparkLikeThen( + LazyThen[SparkLikeLazyFrame, "Column", SparkLikeExpr], SparkLikeExpr +): ... diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/selectors.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/selectors.py new file mode 100644 index 0000000..013bb9d --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/selectors.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._compliant import CompliantSelector, LazySelectorNamespace +from narwhals._spark_like.expr import SparkLikeExpr + +if TYPE_CHECKING: + from sqlframe.base.column import Column # noqa: F401 + + from narwhals._spark_like.dataframe import SparkLikeLazyFrame # noqa: F401 + + +class SparkLikeSelectorNamespace(LazySelectorNamespace["SparkLikeLazyFrame", "Column"]): + @property + def _selector(self) -> type[SparkLikeSelector]: + return SparkLikeSelector + + +class SparkLikeSelector(CompliantSelector["SparkLikeLazyFrame", "Column"], SparkLikeExpr): # type: ignore[misc] + def _to_expr(self) -> SparkLikeExpr: + return SparkLikeExpr( + self._call, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) diff --git a/venv/lib/python3.8/site-packages/narwhals/_spark_like/utils.py b/venv/lib/python3.8/site-packages/narwhals/_spark_like/utils.py new file mode 100644 index 0000000..95fdc96 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_spark_like/utils.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from functools import lru_cache +from importlib import import_module +from typing import TYPE_CHECKING, Any, overload + +from narwhals._utils import Implementation, isinstance_or_issubclass +from narwhals.exceptions import UnsupportedDTypeError + +if TYPE_CHECKING: + from types import ModuleType + + import sqlframe.base.types as sqlframe_types + from sqlframe.base.column import Column + from sqlframe.base.session import _BaseSession as Session + from typing_extensions import TypeAlias + + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + from narwhals._spark_like.expr import SparkLikeExpr + from narwhals._utils import Version + from narwhals.dtypes import DType + from narwhals.typing import IntoDType + + _NativeDType: TypeAlias = sqlframe_types.DataType + SparkSession = Session[Any, Any, Any, Any, Any, Any, Any] + +UNITS_DICT = { + "y": "year", + "q": "quarter", + "mo": "month", + "d": "day", + "h": "hour", + "m": "minute", + "s": "second", + "ms": "millisecond", + "us": "microsecond", + "ns": "nanosecond", +} + +# see https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +# and https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior +DATETIME_PATTERNS_MAPPING = { + "%Y": "yyyy", # Year with century (4 digits) + "%y": "yy", # Year without century (2 digits) + "%m": "MM", # Month (01-12) + "%d": "dd", # Day of the month (01-31) + "%H": "HH", # Hour (24-hour clock) (00-23) + "%I": "hh", # Hour (12-hour clock) (01-12) + "%M": "mm", # Minute (00-59) + "%S": "ss", # Second (00-59) + "%f": "S", # Microseconds -> Milliseconds + "%p": "a", # AM/PM + "%a": "E", # Abbreviated weekday name + "%A": "E", # Full weekday name + "%j": "D", # Day of the year + "%z": "Z", # Timezone offset + "%s": "X", # Unix timestamp +} + + +# NOTE: don't lru_cache this as `ModuleType` isn't hashable +def native_to_narwhals_dtype( # noqa: C901, PLR0912 + dtype: _NativeDType, version: Version, spark_types: ModuleType, session: SparkSession +) -> DType: + dtypes = version.dtypes + if TYPE_CHECKING: + native = sqlframe_types + else: + native = spark_types + + if isinstance(dtype, native.DoubleType): + return dtypes.Float64() + if isinstance(dtype, native.FloatType): + return dtypes.Float32() + if isinstance(dtype, native.LongType): + return dtypes.Int64() + if isinstance(dtype, native.IntegerType): + return dtypes.Int32() + if isinstance(dtype, native.ShortType): + return dtypes.Int16() + if isinstance(dtype, native.ByteType): + return dtypes.Int8() + if isinstance(dtype, (native.StringType, native.VarcharType, native.CharType)): + return dtypes.String() + if isinstance(dtype, native.BooleanType): + return dtypes.Boolean() + if isinstance(dtype, native.DateType): + return dtypes.Date() + if isinstance(dtype, native.TimestampNTZType): + # TODO(marco): cover this + return dtypes.Datetime() # pragma: no cover + if isinstance(dtype, native.TimestampType): + return dtypes.Datetime(time_zone=fetch_session_time_zone(session)) + if isinstance(dtype, native.DecimalType): + # TODO(marco): cover this + return dtypes.Decimal() # pragma: no cover + if isinstance(dtype, native.ArrayType): + return dtypes.List( + inner=native_to_narwhals_dtype( + dtype.elementType, version, spark_types, session + ) + ) + if isinstance(dtype, native.StructType): + return dtypes.Struct( + fields=[ + dtypes.Field( + name=field.name, + dtype=native_to_narwhals_dtype( + field.dataType, version, spark_types, session + ), + ) + for field in dtype + ] + ) + if isinstance(dtype, native.BinaryType): + return dtypes.Binary() + return dtypes.Unknown() # pragma: no cover + + +@lru_cache(maxsize=4) +def fetch_session_time_zone(session: SparkSession) -> str: + # Timezone can't be changed in PySpark session, so this can be cached. + try: + return session.conf.get("spark.sql.session.timeZone") # type: ignore[attr-defined] + except Exception: # noqa: BLE001 + # https://github.com/eakmanrq/sqlframe/issues/406 + return "" + + +def narwhals_to_native_dtype( # noqa: C901, PLR0912 + dtype: IntoDType, version: Version, spark_types: ModuleType +) -> _NativeDType: + dtypes = version.dtypes + if TYPE_CHECKING: + native = sqlframe_types + else: + native = spark_types + + if isinstance_or_issubclass(dtype, dtypes.Float64): + return native.DoubleType() + if isinstance_or_issubclass(dtype, dtypes.Float32): + return native.FloatType() + if isinstance_or_issubclass(dtype, dtypes.Int64): + return native.LongType() + if isinstance_or_issubclass(dtype, dtypes.Int32): + return native.IntegerType() + if isinstance_or_issubclass(dtype, dtypes.Int16): + return native.ShortType() + if isinstance_or_issubclass(dtype, dtypes.Int8): + return native.ByteType() + if isinstance_or_issubclass(dtype, dtypes.String): + return native.StringType() + if isinstance_or_issubclass(dtype, dtypes.Boolean): + return native.BooleanType() + if isinstance_or_issubclass(dtype, dtypes.Date): + return native.DateType() + if isinstance_or_issubclass(dtype, dtypes.Datetime): + dt_time_zone = dtype.time_zone + if dt_time_zone is None: + return native.TimestampNTZType() + if dt_time_zone != "UTC": # pragma: no cover + msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}" + raise ValueError(msg) + return native.TimestampType() + if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)): + return native.ArrayType( + elementType=narwhals_to_native_dtype( + dtype.inner, version=version, spark_types=native + ) + ) + if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover + return native.StructType( + fields=[ + native.StructField( + name=field.name, + dataType=narwhals_to_native_dtype( + field.dtype, version=version, spark_types=native + ), + ) + for field in dtype.fields + ] + ) + if isinstance_or_issubclass(dtype, dtypes.Binary): + return native.BinaryType() + + if isinstance_or_issubclass( + dtype, + ( + dtypes.UInt64, + dtypes.UInt32, + dtypes.UInt16, + dtypes.UInt8, + dtypes.Enum, + dtypes.Categorical, + dtypes.Time, + ), + ): # pragma: no cover + msg = "Unsigned integer, Enum, Categorical and Time types are not supported by spark-like backend" + raise UnsupportedDTypeError(msg) + + msg = f"Unknown dtype: {dtype}" # pragma: no cover + raise AssertionError(msg) + + +def evaluate_exprs( + df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr +) -> list[tuple[str, Column]]: + native_results: list[tuple[str, Column]] = [] + + for expr in exprs: + native_series_list = expr._call(df) + output_names = expr._evaluate_output_names(df) + if expr._alias_output_names is not None: + output_names = expr._alias_output_names(output_names) + if len(output_names) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + native_results.extend(zip(output_names, native_series_list)) + + return native_results + + +def import_functions(implementation: Implementation, /) -> ModuleType: + if implementation is Implementation.PYSPARK: + from pyspark.sql import functions + + return functions + if implementation is Implementation.PYSPARK_CONNECT: + from pyspark.sql.connect import functions + + return functions + from sqlframe.base.session import _BaseSession + + return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.functions") + + +def import_native_dtypes(implementation: Implementation, /) -> ModuleType: + if implementation is Implementation.PYSPARK: + from pyspark.sql import types + + return types + if implementation is Implementation.PYSPARK_CONNECT: + from pyspark.sql.connect import types + + return types + from sqlframe.base.session import _BaseSession + + return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.types") + + +def import_window(implementation: Implementation, /) -> type[Any]: + if implementation is Implementation.PYSPARK: + from pyspark.sql import Window + + return Window + + if implementation is Implementation.PYSPARK_CONNECT: + from pyspark.sql.connect.window import Window + + return Window + from sqlframe.base.session import _BaseSession + + return import_module( + f"sqlframe.{_BaseSession().execution_dialect_name}.window" + ).Window + + +@overload +def strptime_to_pyspark_format(format: None) -> None: ... + + +@overload +def strptime_to_pyspark_format(format: str) -> str: ... + + +def strptime_to_pyspark_format(format: str | None) -> str | None: + """Converts a Python strptime datetime format string to a PySpark datetime format string.""" + if format is None: # pragma: no cover + return None + + # Replace Python format specifiers with PySpark specifiers + pyspark_format = format + for py_format, spark_format in DATETIME_PATTERNS_MAPPING.items(): + pyspark_format = pyspark_format.replace(py_format, spark_format) + return pyspark_format.replace("T", " ") -- cgit v1.2.3-70-g09d2