from __future__ import annotations from functools import lru_cache from typing import TYPE_CHECKING, Any import duckdb from narwhals._utils import Version, isinstance_or_issubclass if TYPE_CHECKING: from duckdb import DuckDBPyRelation, Expression from duckdb.typing import DuckDBPyType from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.expr import DuckDBExpr from narwhals.dtypes import DType from narwhals.typing import IntoDType UNITS_DICT = { "y": "year", "q": "quarter", "mo": "month", "d": "day", "h": "hour", "m": "minute", "s": "second", "ms": "millisecond", "us": "microsecond", "ns": "nanosecond", } col = duckdb.ColumnExpression """Alias for `duckdb.ColumnExpression`.""" lit = duckdb.ConstantExpression """Alias for `duckdb.ConstantExpression`.""" when = duckdb.CaseExpression """Alias for `duckdb.CaseExpression`.""" def concat_str(*exprs: Expression, separator: str = "") -> Expression: """Concatenate many strings, NULL inputs are skipped. Wraps [concat] and [concat_ws] `FunctionExpression`(s). Arguments: exprs: Native columns. separator: String that will be used to separate the values of each column. Returns: A new native expression. [concat]: https://duckdb.org/docs/stable/sql/functions/char.html#concatstring- [concat_ws]: https://duckdb.org/docs/stable/sql/functions/char.html#concat_wsseparator-string- """ return ( duckdb.FunctionExpression("concat_ws", lit(separator), *exprs) if separator else duckdb.FunctionExpression("concat", *exprs) ) def evaluate_exprs( df: DuckDBLazyFrame, /, *exprs: DuckDBExpr ) -> list[tuple[str, Expression]]: native_results: list[tuple[str, Expression]] = [] for expr in exprs: native_series_list = expr._call(df) output_names = expr._evaluate_output_names(df) if expr._alias_output_names is not None: output_names = expr._alias_output_names(output_names) if len(output_names) != len(native_series_list): # pragma: no cover msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" raise AssertionError(msg) native_results.extend(zip(output_names, native_series_list)) return native_results class DeferredTimeZone: """Object which gets passed between `native_to_narwhals_dtype` calls. DuckDB stores the time zone in the connection, rather than in the dtypes, so this ensures that when calculating the schema of a dataframe with multiple timezone-aware columns, that the connection's time zone is only fetched once. Note: we cannot make the time zone a cached `DuckDBLazyFrame` property because the time zone can be modified after `DuckDBLazyFrame` creation: ```python df = nw.from_native(rel) print(df.collect_schema()) rel.query("set timezone = 'Asia/Kolkata'") print(df.collect_schema()) # should change to reflect new time zone ``` """ _cached_time_zone: str | None = None def __init__(self, rel: DuckDBPyRelation) -> None: self._rel = rel @property def time_zone(self) -> str: """Fetch relation time zone (if it wasn't calculated already).""" if self._cached_time_zone is None: self._cached_time_zone = fetch_rel_time_zone(self._rel) return self._cached_time_zone def native_to_narwhals_dtype( duckdb_dtype: DuckDBPyType, version: Version, deferred_time_zone: DeferredTimeZone ) -> DType: duckdb_dtype_id = duckdb_dtype.id dtypes = version.dtypes # Handle nested data types first if duckdb_dtype_id == "list": return dtypes.List( native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone) ) if duckdb_dtype_id == "struct": children = duckdb_dtype.children return dtypes.Struct( [ dtypes.Field( name=child[0], dtype=native_to_narwhals_dtype(child[1], version, deferred_time_zone), ) for child in children ] ) if duckdb_dtype_id == "array": child, size = duckdb_dtype.children shape: list[int] = [size[1]] while child[1].id == "array": child, size = child[1].children shape.insert(0, size[1]) inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone) return dtypes.Array(inner=inner, shape=tuple(shape)) if duckdb_dtype_id == "enum": if version is Version.V1: return dtypes.Enum() # type: ignore[call-arg] categories = duckdb_dtype.children[0][1] return dtypes.Enum(categories=categories) if duckdb_dtype_id == "timestamp with time zone": return dtypes.Datetime(time_zone=deferred_time_zone.time_zone) return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version) def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str: result = rel.query( "duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'" ).fetchone() assert result is not None # noqa: S101 return result[0] # type: ignore[no-any-return] @lru_cache(maxsize=16) def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType: dtypes = version.dtypes return { "hugeint": dtypes.Int128(), "bigint": dtypes.Int64(), "integer": dtypes.Int32(), "smallint": dtypes.Int16(), "tinyint": dtypes.Int8(), "uhugeint": dtypes.UInt128(), "ubigint": dtypes.UInt64(), "uinteger": dtypes.UInt32(), "usmallint": dtypes.UInt16(), "utinyint": dtypes.UInt8(), "double": dtypes.Float64(), "float": dtypes.Float32(), "varchar": dtypes.String(), "date": dtypes.Date(), "timestamp": dtypes.Datetime(), "boolean": dtypes.Boolean(), "interval": dtypes.Duration(), "decimal": dtypes.Decimal(), "time": dtypes.Time(), "blob": dtypes.Binary(), }.get(duckdb_dtype_id, dtypes.Unknown()) def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> str: # noqa: C901, PLR0912, PLR0915 dtypes = version.dtypes if isinstance_or_issubclass(dtype, dtypes.Decimal): msg = "Casting to Decimal is not supported yet." raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Float64): return "DOUBLE" if isinstance_or_issubclass(dtype, dtypes.Float32): return "FLOAT" if isinstance_or_issubclass(dtype, dtypes.Int128): return "INT128" if isinstance_or_issubclass(dtype, dtypes.Int64): return "BIGINT" if isinstance_or_issubclass(dtype, dtypes.Int32): return "INTEGER" if isinstance_or_issubclass(dtype, dtypes.Int16): return "SMALLINT" if isinstance_or_issubclass(dtype, dtypes.Int8): return "TINYINT" if isinstance_or_issubclass(dtype, dtypes.UInt128): return "UINT128" if isinstance_or_issubclass(dtype, dtypes.UInt64): return "UBIGINT" if isinstance_or_issubclass(dtype, dtypes.UInt32): return "UINTEGER" if isinstance_or_issubclass(dtype, dtypes.UInt16): # pragma: no cover return "USMALLINT" if isinstance_or_issubclass(dtype, dtypes.UInt8): # pragma: no cover return "UTINYINT" if isinstance_or_issubclass(dtype, dtypes.String): return "VARCHAR" if isinstance_or_issubclass(dtype, dtypes.Boolean): # pragma: no cover return "BOOLEAN" if isinstance_or_issubclass(dtype, dtypes.Time): return "TIME" if isinstance_or_issubclass(dtype, dtypes.Binary): return "BLOB" if isinstance_or_issubclass(dtype, dtypes.Categorical): msg = "Categorical not supported by DuckDB" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Enum): if version is Version.V1: msg = "Converting to Enum is not supported in narwhals.stable.v1" raise NotImplementedError(msg) if isinstance(dtype, dtypes.Enum): categories = "'" + "', '".join(dtype.categories) + "'" return f"ENUM ({categories})" msg = "Can not cast / initialize Enum without categories present" raise ValueError(msg) if isinstance_or_issubclass(dtype, dtypes.Datetime): _time_unit = dtype.time_unit _time_zone = dtype.time_zone msg = "todo" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Duration): # pragma: no cover _time_unit = dtype.time_unit msg = "todo" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Date): # pragma: no cover return "DATE" if isinstance_or_issubclass(dtype, dtypes.List): inner = narwhals_to_native_dtype(dtype.inner, version) return f"{inner}[]" if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover inner = ", ".join( f'"{field.name}" {narwhals_to_native_dtype(field.dtype, version)}' for field in dtype.fields ) return f"STRUCT({inner})" if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover shape = dtype.shape duckdb_shape_fmt = "".join(f"[{item}]" for item in shape) inner_dtype: Any = dtype for _ in shape: inner_dtype = inner_dtype.inner duckdb_inner = narwhals_to_native_dtype(inner_dtype, version) return f"{duckdb_inner}{duckdb_shape_fmt}" msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) def generate_partition_by_sql(*partition_by: str | Expression) -> str: if not partition_by: return "" by_sql = ", ".join([f"{col(x) if isinstance(x, str) else x}" for x in partition_by]) return f"partition by {by_sql}" def generate_order_by_sql(*order_by: str, ascending: bool) -> str: if ascending: by_sql = ", ".join([f"{col(x)} asc nulls first" for x in order_by]) else: by_sql = ", ".join([f"{col(x)} desc nulls last" for x in order_by]) return f"order by {by_sql}"