diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index 6c63f2c..e6f7336 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -33,6 +33,8 @@ def __init__( alias: str | None = None, metadata: dict[str, Any] | None = None, description: str | None = None, + sqlalchemy_use_enum: bool = False, + sqlalchemy_enum_name: str | None = None, ): """ Args: @@ -68,6 +70,15 @@ def __init__( names, the specified alias is the only valid name. metadata: A dictionary of metadata to attach to the column. description: A human-readable description of the column. + sqlalchemy_use_enum: When ``True``, map this column to :class:`sqlalchemy.Enum` + in :meth:`~dataframely.Schema.to_sqlalchemy_columns` instead of + ``CHAR`` / ``VARCHAR``. + sqlalchemy_enum_name: Optional name for the SQLAlchemy / database enum type + when ``sqlalchemy_use_enum=True``. If omitted and ``categories`` is a + Python :class:`enum.Enum` subclass, the lowercased enum class is used. + Otherwise, the name of the column is used. + The persisted values are the enum members' ``.value`` strings (not + member names), matching :attr:`categories`. """ super().__init__( nullable=nullable, @@ -78,8 +89,26 @@ def __init__( metadata=metadata, description=description, ) + if sqlalchemy_enum_name and not sqlalchemy_use_enum: + raise ValueError( + "`sqlalchemy_enum_name` has no effect when `sqlalchemy_use_enum=False`." + ) + + self.sqlalchemy_use_enum = sqlalchemy_use_enum + self.sqlalchemy_enum_name = sqlalchemy_enum_name if isclass(categories) and issubclass(categories, enum.Enum): + # If the user passed an Enum type, we want to determine a default name + # based on the Enum class name, which is also what sqlalchemy does. + # One could instead keep a reference to the Enum class around and pass it + # to sqlalchemy later on, but that will interfere with the base-class implementations + # of `matches` and `to_dict` / `from_dict`. + if self.sqlalchemy_use_enum: + self.sqlalchemy_enum_name = ( + self.sqlalchemy_enum_name or categories.__name__.lower() + ) + categories = (item.value for item in categories) + self.categories = list(categories) @property @@ -92,6 +121,10 @@ def validate_dtype(self, dtype: PolarsDataType) -> bool: return self.categories == dtype.categories.to_list() def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: + if self.sqlalchemy_use_enum: + return sa.Enum( + *self.categories, name=self.sqlalchemy_enum_name or self._name + ) category_lengths = [len(c) for c in self.categories] if all(length == category_lengths[0] for length in category_lengths): return sa.CHAR(category_lengths[0]) diff --git a/docs/guides/features/sql-generation.md b/docs/guides/features/sql-generation.md index e84d2fd..d186eb6 100644 --- a/docs/guides/features/sql-generation.md +++ b/docs/guides/features/sql-generation.md @@ -81,6 +81,41 @@ the maximal length of the string is inferred from the regular expression if poss maximal lengths can be particularly important for primary key columns. Some database systems, such as Microsoft SQL Server, do not allow `VARCHAR(max)` columns (unbounded strings) to be used as primary keys. ``` +## Native SQL enums + +By default, {class}`~dataframely.Enum` maps to `sa.CHAR` or `sa.String` columns so stored values remain plain strings. You may set `sqlalchemy_use_enum=True` to instead generate native enums: + +```python +from enum import Enum, auto + +import sqlalchemy as sa +import dataframely as dy +from sqlalchemy.dialects.postgresql import dialect as pg_dialect +from sqlalchemy.dialects.mssql import dialect as mssql_dialect + + +class Status(str, Enum): + PENDING = auto() + APPROVED = auto() + + +class Staged(dy.Schema): + status = dy.Enum(Status, sqlalchemy_use_enum=True) +``` + +This will translate the `~dataframely.Enum` to a `~sqlalchemy.Enum`: + +```python +>>> Staged.to_sqlalchemy_columns(pg_dialect()) +[Column('status', Enum('1', '2', name='status'), table=None, nullable=False)] +``` + +Depending on the database dialect you use, `sqlalchemy` will render this accordingly. +For example, `postgresql` supports native enums, and `sqlalchemy` will create a native enum column, while in MSSQL, where this is not supported, it will fall back to `VARCHAR`. + +When `categories` is a Python `enum.Enum` subclass, `sqlalchemy` uses the enum class name (lowercased) as the database enum type name. +For string category lists, the SQL column name is used by default; override it with `sqlalchemy_enum_name` if needed. + ## Collections of multiple tables If you have an entire `dy.Collection`, it's also easy to generate one table for each member table of the collection. diff --git a/tests/column_types/test_enum.py b/tests/column_types/test_enum.py index 85078a3..0fd9e6e 100644 --- a/tests/column_types/test_enum.py +++ b/tests/column_types/test_enum.py @@ -108,3 +108,69 @@ def test_sequences_and_enums( S = create_schema("test", {"x": dy.Enum(categories1)}) df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(categories2))}) S.validate(df) + + +def test_matches_sqlalchemy_use_enum() -> None: + expr = pl.element() + assert dy.Enum(["a", "b"]).matches(dy.Enum(["a", "b"]), expr) + assert not dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches( + dy.Enum(["a", "b"]), expr + ) + assert dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches( + dy.Enum(["a", "b"], sqlalchemy_use_enum=True), expr + ) + + +def test_matches_sqlalchemy_use_enum_fails_on_internal_name_mismatch() -> None: + class MyEnum(str, Enum): + x = "x" + + assert not dy.Enum(MyEnum, sqlalchemy_use_enum=True).matches( + dy.Enum(["x"], sqlalchemy_use_enum=True), pl.element() + ) + + +def test_matches_sqlalchemy_enum_name() -> None: + expr = pl.element() + assert dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="one", + ).matches( + dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="one", + ), + expr, + ) + assert not dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="one", + ).matches( + dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="two", + ), + expr, + ) + + +def test_sqlalchemy_enum_name_without_use_enum_raises() -> None: + with pytest.raises(ValueError, match="`sqlalchemy_enum_name` has no effect"): + dy.Enum(["a", "b"], sqlalchemy_enum_name="my_enum") + + +def test_as_dict_from_dict_sqlalchemy_enum_flags() -> None: + column = dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="my_enum", + ) + data = column.as_dict(pl.element()) + restored = dy.Enum.from_dict(data) + assert restored.sqlalchemy_use_enum is True + assert restored.sqlalchemy_enum_name == "my_enum" + assert restored.categories == ["a", "b"] diff --git a/tests/columns/test_sqlalchemy_columns.py b/tests/columns/test_sqlalchemy_columns.py index 6731202..3a53ce8 100644 --- a/tests/columns/test_sqlalchemy_columns.py +++ b/tests/columns/test_sqlalchemy_columns.py @@ -1,10 +1,13 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause +from enum import Enum +from typing import cast + import pytest import dataframely as dy -from dataframely._compat import Dialect, MSDialect_pyodbc, PGDialect_psycopg2 +from dataframely._compat import Dialect, MSDialect_pyodbc, PGDialect_psycopg2, sa from dataframely.columns import Column from dataframely.testing import COLUMN_TYPES, create_schema @@ -171,3 +174,62 @@ def test_raise_for_object_column(dialect: Dialect) -> None: NotImplementedError, match="SQL column cannot have 'Object' type." ): dy.Object().sqlalchemy_dtype(dialect) + + +class _Status(str, Enum): + PENDING = "pending" + APPROVED = "approved" + + +@pytest.mark.parametrize( + ("column", "dialect", "datatype"), + [ + ( + dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True), + PGDialect_psycopg2(), + "a", + ), + ( + dy.Enum( + ["foo", "bar"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="my_status", + ), + PGDialect_psycopg2(), + "my_status", + ), + (dy.Enum(_Status, sqlalchemy_use_enum=True), PGDialect_psycopg2(), "_status"), + ( + dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True), + MSDialect_pyodbc(), + "VARCHAR(3)", + ), + ], +) +def test_enum_sqlalchemy_native( + column: Column, dialect: Dialect, datatype: str +) -> None: + schema = create_schema("test", {"a": column}) + columns = schema.to_sqlalchemy_columns(dialect) + assert len(columns) == 1 + assert columns[0].type.compile(dialect) == datatype + + +def test_enum_sqlalchemy_native_python_enum_uses_member_values() -> None: + column = dy.Enum(_Status, sqlalchemy_use_enum=True) + schema = create_schema("test", {"a": column}) + sa_type = cast( + sa.sql.sqltypes.Enum, schema.to_sqlalchemy_columns(PGDialect_psycopg2())[0].type + ) + assert list(sa_type.enums) == column.categories + + +def test_enum_sqlalchemy_native_string_categories_use_column_name() -> None: + class TestSchema(dy.Schema): + status = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True) + + column = TestSchema.columns()["status"] + assert ( + column.sqlalchemy_dtype(PGDialect_psycopg2()).compile(PGDialect_psycopg2()) + == "status" + )