From d67178e44b3214fade080bc0799f30a706d68826 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Fri, 21 Mar 2025 17:04:23 -0800 Subject: [PATCH] test: add example test for double creation of enums --- .../tests/snapshots/snap_test_datatypes.py | 18 +++++++++++++ duckdb_engine/tests/test_datatypes.py | 26 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 duckdb_engine/tests/snapshots/snap_test_datatypes.py diff --git a/duckdb_engine/tests/snapshots/snap_test_datatypes.py b/duckdb_engine/tests/snapshots/snap_test_datatypes.py new file mode 100644 index 000000000..ced8c6379 --- /dev/null +++ b/duckdb_engine/tests/snapshots/snap_test_datatypes.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import Snapshot + +snapshots = Snapshot() + +snapshots["test_enum 1"] = """ +CREATE TYPE severity AS ENUM ( 'LOW', 'MEDIUM', 'HIGH' ); + + +CREATE TABLE bugs(severity ENUM('LOW', 'MEDIUM', 'HIGH'), PRIMARY KEY(severity)); + + + + +""" diff --git a/duckdb_engine/tests/test_datatypes.py b/duckdb_engine/tests/test_datatypes.py index 623ebb5fe..aa1b66d16 100644 --- a/duckdb_engine/tests/test_datatypes.py +++ b/duckdb_engine/tests/test_datatypes.py @@ -1,9 +1,12 @@ import decimal +import enum import json import warnings +from pathlib import Path from typing import Any, Dict, Type from uuid import uuid4 +import sqlalchemy as sa from packaging.version import Version from pytest import importorskip, mark from snapshottest.module import SnapshotTest @@ -248,3 +251,26 @@ def test_div_is_floordiv(engine: Engine) -> None: stmt = test_table.c.value / test_table.c.eur2usd_rate assert str(stmt.compile(engine)) == "test_table.value / test_table.eur2usd_rate" + + +def test_enum( + engine: Engine, session: Session, tmp_path: Path, snapshot: SnapshotTest +) -> None: + base = declarative_base() + + class Severity(enum.Enum): + LOW = enum.auto() + MEDIUM = enum.auto() + HIGH = enum.auto() + + class Bug(base): + __tablename__ = "bugs" + + severity = sa.Column(sa.Enum(Severity), primary_key=True) + + base.metadata.create_all(bind=engine) + session.execute(sa.text(f"EXPORT DATABASE '{tmp_path}'")) + ddl_sql = (tmp_path / "schema.sql").read_text() + # n_enum_defs = ddl_sql.count("'LOW', 'MEDIUM', 'HIGH'") + # assert n_enum_defs == 1, ddl_sql + snapshot.assert_match(ddl_sql)