Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions piccolo/columns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,9 @@ def _validate_default(
elif (
default is None
and None in allowed_types
or type(default) in allowed_types
or isinstance(
default, tuple(t for t in allowed_types if isinstance(t, type))
)
):
self._validated = True
return True
Expand All @@ -539,8 +541,9 @@ def _validate_default(
):
self._validated = True
return True
elif (
isinstance(default, Enum) and type(default.value) in allowed_types
elif isinstance(default, Enum) and isinstance(
default.value,
tuple(t for t in allowed_types if isinstance(t, type)),
):
self._validated = True
return True
Expand Down
11 changes: 11 additions & 0 deletions tests/columns/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
from piccolo.table import Table


def get_custom_default(base):
class CustomDefault(base):
pass

return CustomDefault()


class TestDefaults(TestCase):
"""
Columns check the type of the default argument.
Expand Down Expand Up @@ -66,27 +73,31 @@ def test_uuid(self):
UUID(default=None, null=True)
UUID(default=UUID4())
UUID(default=uuid.uuid4())
UUID(default=get_custom_default(UUID4))
with self.assertRaises(ValueError):
UUID(default="hello world")

def test_time(self):
Time(default=None, null=True)
Time(default=TimeNow())
Time(default=datetime.datetime.now().time())
Time(default=get_custom_default(TimeNow))
with self.assertRaises(ValueError):
Time(default="hello world") # type: ignore

def test_date(self):
Date(default=None, null=True)
Date(default=DateNow())
Date(default=datetime.datetime.now().date())
Date(default=get_custom_default(DateNow))
with self.assertRaises(ValueError):
Date(default="hello world") # type: ignore

def test_timestamp(self):
Timestamp(default=None, null=True)
Timestamp(default=TimestampNow())
Timestamp(default=datetime.datetime.now())
Timestamp(default=get_custom_default(TimestampNow))
with self.assertRaises(ValueError):
Timestamp(default="hello world") # type: ignore

Expand Down