Skip to content
33 changes: 33 additions & 0 deletions dataframely/columns/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(
alias: str | None = None,
metadata: dict[str, Any] | None = None,
description: str | None = None,
sqlalchemy_use_enum: bool = False,
sqlalchemy_enum_name: str | None = None,
):
"""
Args:
Expand Down Expand Up @@ -68,6 +70,15 @@ def __init__(
names, the specified alias is the only valid name.
metadata: A dictionary of metadata to attach to the column.
description: A human-readable description of the column.
sqlalchemy_use_enum: When ``True``, map this column to :class:`sqlalchemy.Enum`
in :meth:`~dataframely.Schema.to_sqlalchemy_columns` instead of
``CHAR`` / ``VARCHAR``.
sqlalchemy_enum_name: Optional name for the SQLAlchemy / database enum type
when ``sqlalchemy_use_enum=True``. If omitted and ``categories`` is a
Python :class:`enum.Enum` subclass, the lowercased enum class is used.
Otherwise, the name of the column is used.
The persisted values are the enum members' ``.value`` strings (not
member names), matching :attr:`categories`.
"""
super().__init__(
nullable=nullable,
Expand All @@ -78,8 +89,26 @@ def __init__(
metadata=metadata,
description=description,
)
if sqlalchemy_enum_name and not sqlalchemy_use_enum:
raise ValueError(
"`sqlalchemy_enum_name` has no effect when `sqlalchemy_use_enum=False`."
)

self.sqlalchemy_use_enum = sqlalchemy_use_enum
self.sqlalchemy_enum_name = sqlalchemy_enum_name
Comment thread
AndreasAlbertQC marked this conversation as resolved.
if isclass(categories) and issubclass(categories, enum.Enum):
# If the user passed an Enum type, we want to determine a default name
# based on the Enum class name, which is also what sqlalchemy does.
# One could instead keep a reference to the Enum class around and pass it
# to sqlalchemy later on, but that will interfere with the base-class implementations
# of `matches` and `to_dict` / `from_dict`.
if self.sqlalchemy_use_enum:
self.sqlalchemy_enum_name = (
self.sqlalchemy_enum_name or categories.__name__.lower()
)

categories = (item.value for item in categories)

self.categories = list(categories)

@property
Expand All @@ -92,6 +121,10 @@ def validate_dtype(self, dtype: PolarsDataType) -> bool:
return self.categories == dtype.categories.to_list()

def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
if self.sqlalchemy_use_enum:
return sa.Enum(
*self.categories, name=self.sqlalchemy_enum_name or self._name
Comment thread
jackoberman marked this conversation as resolved.
Comment thread
AndreasAlbertQC marked this conversation as resolved.
)
category_lengths = [len(c) for c in self.categories]
if all(length == category_lengths[0] for length in category_lengths):
return sa.CHAR(category_lengths[0])
Expand Down
35 changes: 35 additions & 0 deletions docs/guides/features/sql-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,41 @@ the maximal length of the string is inferred from the regular expression if poss
maximal lengths can be particularly important for primary key columns. Some database systems, such as Microsoft SQL Server, do not allow `VARCHAR(max)` columns (unbounded strings) to be used as primary keys.
```

## Native SQL enums

By default, {class}`~dataframely.Enum` maps to `sa.CHAR` or `sa.String` columns so stored values remain plain strings. You may set `sqlalchemy_use_enum=True` to instead generate native enums:

```python
from enum import Enum, auto

import sqlalchemy as sa
import dataframely as dy
from sqlalchemy.dialects.postgresql import dialect as pg_dialect
from sqlalchemy.dialects.mssql import dialect as mssql_dialect


class Status(str, Enum):
PENDING = auto()
APPROVED = auto()


class Staged(dy.Schema):
status = dy.Enum(Status, sqlalchemy_use_enum=True)
Comment thread
AndreasAlbertQC marked this conversation as resolved.
```

This will translate the `~dataframely.Enum` to a `~sqlalchemy.Enum`:

```python
>>> Staged.to_sqlalchemy_columns(pg_dialect())
[Column('status', Enum('1', '2', name='status'), table=None, nullable=False)]
```

Depending on the database dialect you use, `sqlalchemy` will render this accordingly.
For example, `postgresql` supports native enums, and `sqlalchemy` will create a native enum column, while in MSSQL, where this is not supported, it will fall back to `VARCHAR`.

When `categories` is a Python `enum.Enum` subclass, `sqlalchemy` uses the enum class name (lowercased) as the database enum type name.
For string category lists, the SQL column name is used by default; override it with `sqlalchemy_enum_name` if needed.
Comment thread
AndreasAlbertQC marked this conversation as resolved.

## Collections of multiple tables

If you have an entire `dy.Collection`, it's also easy to generate one table for each member table of the collection.
Expand Down
66 changes: 66 additions & 0 deletions tests/column_types/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,69 @@ def test_sequences_and_enums(
S = create_schema("test", {"x": dy.Enum(categories1)})
df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(categories2))})
S.validate(df)


def test_matches_sqlalchemy_use_enum() -> None:
expr = pl.element()
assert dy.Enum(["a", "b"]).matches(dy.Enum(["a", "b"]), expr)
assert not dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches(
dy.Enum(["a", "b"]), expr
)
assert dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches(
dy.Enum(["a", "b"], sqlalchemy_use_enum=True), expr
)


def test_matches_sqlalchemy_use_enum_fails_on_internal_name_mismatch() -> None:
class MyEnum(str, Enum):
x = "x"

assert not dy.Enum(MyEnum, sqlalchemy_use_enum=True).matches(
dy.Enum(["x"], sqlalchemy_use_enum=True), pl.element()
)


def test_matches_sqlalchemy_enum_name() -> None:
expr = pl.element()
Comment thread
AndreasAlbertQC marked this conversation as resolved.
assert dy.Enum(
["a", "b"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="one",
).matches(
dy.Enum(
["a", "b"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="one",
),
expr,
)
assert not dy.Enum(
["a", "b"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="one",
).matches(
dy.Enum(
["a", "b"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="two",
),
expr,
)


def test_sqlalchemy_enum_name_without_use_enum_raises() -> None:
with pytest.raises(ValueError, match="`sqlalchemy_enum_name` has no effect"):
dy.Enum(["a", "b"], sqlalchemy_enum_name="my_enum")


def test_as_dict_from_dict_sqlalchemy_enum_flags() -> None:
column = dy.Enum(
["a", "b"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="my_enum",
)
data = column.as_dict(pl.element())
restored = dy.Enum.from_dict(data)
assert restored.sqlalchemy_use_enum is True
assert restored.sqlalchemy_enum_name == "my_enum"
assert restored.categories == ["a", "b"]
64 changes: 63 additions & 1 deletion tests/columns/test_sqlalchemy_columns.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause

from enum import Enum
from typing import cast

import pytest

import dataframely as dy
from dataframely._compat import Dialect, MSDialect_pyodbc, PGDialect_psycopg2
from dataframely._compat import Dialect, MSDialect_pyodbc, PGDialect_psycopg2, sa
from dataframely.columns import Column
from dataframely.testing import COLUMN_TYPES, create_schema

Expand Down Expand Up @@ -171,3 +174,62 @@ def test_raise_for_object_column(dialect: Dialect) -> None:
NotImplementedError, match="SQL column cannot have 'Object' type."
):
dy.Object().sqlalchemy_dtype(dialect)


class _Status(str, Enum):
PENDING = "pending"
APPROVED = "approved"


@pytest.mark.parametrize(
("column", "dialect", "datatype"),
[
(
dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True),
PGDialect_psycopg2(),
"a",
),
(
dy.Enum(
["foo", "bar"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="my_status",
),
PGDialect_psycopg2(),
"my_status",
),
(dy.Enum(_Status, sqlalchemy_use_enum=True), PGDialect_psycopg2(), "_status"),
Comment thread
AndreasAlbertQC marked this conversation as resolved.
(
dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True),
MSDialect_pyodbc(),
"VARCHAR(3)",
),
],
)
def test_enum_sqlalchemy_native(
column: Column, dialect: Dialect, datatype: str
) -> None:
schema = create_schema("test", {"a": column})
columns = schema.to_sqlalchemy_columns(dialect)
assert len(columns) == 1
assert columns[0].type.compile(dialect) == datatype


def test_enum_sqlalchemy_native_python_enum_uses_member_values() -> None:
column = dy.Enum(_Status, sqlalchemy_use_enum=True)
schema = create_schema("test", {"a": column})
sa_type = cast(
sa.sql.sqltypes.Enum, schema.to_sqlalchemy_columns(PGDialect_psycopg2())[0].type
)
assert list(sa_type.enums) == column.categories


def test_enum_sqlalchemy_native_string_categories_use_column_name() -> None:
class TestSchema(dy.Schema):
status = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True)

column = TestSchema.columns()["status"]
assert (
column.sqlalchemy_dtype(PGDialect_psycopg2()).compile(PGDialect_psycopg2())
== "status"
)
Loading