diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_duckdb/namespace.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_duckdb/namespace.py | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_duckdb/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_duckdb/namespace.py new file mode 100644 index 0000000..3616f2e --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_duckdb/namespace.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import operator +from functools import reduce +from itertools import chain +from typing import TYPE_CHECKING, Callable, Iterable, Sequence + +import duckdb +from duckdb import CoalesceOperator, Expression, FunctionExpression +from duckdb.typing import BIGINT, VARCHAR + +from narwhals._compliant import LazyNamespace, LazyThen, LazyWhen +from narwhals._duckdb.dataframe import DuckDBLazyFrame +from narwhals._duckdb.expr import DuckDBExpr +from narwhals._duckdb.selectors import DuckDBSelectorNamespace +from narwhals._duckdb.utils import concat_str, lit, narwhals_to_native_dtype, when +from narwhals._expression_parsing import ( + combine_alias_output_names, + combine_evaluate_output_names, +) +from narwhals._utils import Implementation + +if TYPE_CHECKING: + from narwhals._duckdb.expr import DuckDBWindowInputs + from narwhals._utils import Version + from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral + + +class DuckDBNamespace( + LazyNamespace[DuckDBLazyFrame, DuckDBExpr, duckdb.DuckDBPyRelation] +): + _implementation: Implementation = Implementation.DUCKDB + + def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None: + self._backend_version = backend_version + self._version = version + + @property + def selectors(self) -> DuckDBSelectorNamespace: + return DuckDBSelectorNamespace.from_namespace(self) + + @property + def _expr(self) -> type[DuckDBExpr]: + return DuckDBExpr + + @property + def _lazyframe(self) -> type[DuckDBLazyFrame]: + return DuckDBLazyFrame + + def _with_elementwise( + self, func: Callable[[Iterable[Expression]], Expression], *exprs: DuckDBExpr + ) -> DuckDBExpr: + def call(df: DuckDBLazyFrame) -> list[Expression]: + cols = (col for _expr in exprs for col in _expr(df)) + return [func(cols)] + + def window_function( + df: DuckDBLazyFrame, window_inputs: DuckDBWindowInputs + ) -> list[Expression]: + cols = ( + col for _expr in exprs for col in _expr.window_function(df, window_inputs) + ) + return [func(cols)] + + return self._expr( + call=call, + window_function=window_function, + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + ) + + def concat( + self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod + ) -> DuckDBLazyFrame: + native_items = [item._native_frame for item in items] + items = list(items) + first = items[0] + schema = first.schema + if how == "vertical" and not all(x.schema == schema for x in items[1:]): + msg = "inputs should all have the same schema" + raise TypeError(msg) + res = reduce(lambda x, y: x.union(y), native_items) + return first._with_native(res) + + def concat_str( + self, *exprs: DuckDBExpr, separator: str, ignore_nulls: bool + ) -> DuckDBExpr: + def func(df: DuckDBLazyFrame) -> list[Expression]: + cols = list(chain.from_iterable(expr(df) for expr in exprs)) + if not ignore_nulls: + null_mask_result = reduce(operator.or_, (s.isnull() for s in cols)) + cols_separated = [ + y + for x in [ + (col.cast(VARCHAR),) + if i == len(cols) - 1 + else (col.cast(VARCHAR), lit(separator)) + for i, col in enumerate(cols) + ] + for y in x + ] + return [when(~null_mask_result, concat_str(*cols_separated))] + else: + return [concat_str(*cols, separator=separator)] + + return self._expr( + call=func, + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + ) + + def all_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr: + def func(cols: Iterable[Expression]) -> Expression: + return reduce(operator.and_, cols) + + return self._with_elementwise(func, *exprs) + + def any_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr: + def func(cols: Iterable[Expression]) -> Expression: + return reduce(operator.or_, cols) + + return self._with_elementwise(func, *exprs) + + def max_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr: + def func(cols: Iterable[Expression]) -> Expression: + return FunctionExpression("greatest", *cols) + + return self._with_elementwise(func, *exprs) + + def min_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr: + def func(cols: Iterable[Expression]) -> Expression: + return FunctionExpression("least", *cols) + + return self._with_elementwise(func, *exprs) + + def sum_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr: + def func(cols: Iterable[Expression]) -> Expression: + return reduce(operator.add, (CoalesceOperator(col, lit(0)) for col in cols)) + + return self._with_elementwise(func, *exprs) + + def mean_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr: + def func(cols: Iterable[Expression]) -> Expression: + cols = list(cols) + return reduce( + operator.add, (CoalesceOperator(col, lit(0)) for col in cols) + ) / reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols)) + + return self._with_elementwise(func, *exprs) + + def when(self, predicate: DuckDBExpr) -> DuckDBWhen: + return DuckDBWhen.from_expr(predicate, context=self) + + def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr: + def func(_df: DuckDBLazyFrame) -> list[Expression]: + if dtype is not None: + return [ + lit(value).cast( + narwhals_to_native_dtype(dtype, version=self._version) # type: ignore[arg-type] + ) + ] + return [lit(value)] + + return self._expr( + func, + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + + def len(self) -> DuckDBExpr: + def func(_df: DuckDBLazyFrame) -> list[Expression]: + return [FunctionExpression("count")] + + return self._expr( + call=func, + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + + +class DuckDBWhen(LazyWhen["DuckDBLazyFrame", Expression, DuckDBExpr]): + @property + def _then(self) -> type[DuckDBThen]: + return DuckDBThen + + def __call__(self, df: DuckDBLazyFrame) -> Sequence[Expression]: + self.when = when + self.lit = lit + return super().__call__(df) + + def _window_function( + self, df: DuckDBLazyFrame, window_inputs: DuckDBWindowInputs + ) -> Sequence[Expression]: + self.when = when + self.lit = lit + return super()._window_function(df, window_inputs) + + +class DuckDBThen(LazyThen["DuckDBLazyFrame", Expression, DuckDBExpr], DuckDBExpr): ... |