Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 4 additions & 32 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,41 +149,13 @@ Option 2 Configuration

def ready(self):
if 'tos' in settings.INSTALLED_APPS:
cache = caches[getattr(settings, 'TOS_CACHE_NAME', 'default')]
from tos.utils import add_staff_users_to_tos_cache, set_staff_in_cache_for_tos
tos_app = apps.get_app_config('tos')
TermsOfService = tos_app.get_model('TermsOfService')

@receiver(post_save, sender=get_user_model(), dispatch_uid='set_staff_in_cache_for_tos')
def set_staff_in_cache_for_tos(user, instance, **kwargs):
if kwargs.get('raw', False):
return

# Get the cache prefix
key_version = cache.get('django:tos:key_version')

# If the user is staff allow them to skip the TOS agreement check
if instance.is_staff or instance.is_superuser:
cache.set('django:tos:skip_tos_check:{}'.format(instance.id), version=key_version)

# But if they aren't make sure we invalidate them from the cache
elif cache.get('django:tos:skip_tos_check:{}'.format(instance.id), False):
cache.delete('django:tos:skip_tos_check:{}'.format(instance.id), version=key_version)

@receiver(post_save, sender=TermsOfService, dispatch_uid='add_staff_users_to_tos_cache')
def add_staff_users_to_tos_cache(*args, **kwargs):
if kwargs.get('raw', False):
return

# Get the cache prefix
key_version = cache.get('django:tos:key_version')

# Efficiently cache all of the users who are allowed to skip the TOS
# agreement check
cache.set_many({
'django:tos:skip_tos_check:{}'.format(staff_user.id): True
for staff_user in get_user_model().objects.filter(
Q(is_staff=True) | Q(is_superuser=True))
}, version=key_version)
post_save.connect(set_staff_in_cache_for_tos, sender=get_user_model(), dispatch_uid='set_staff_in_cache_for_tos')

post_save.connect(add_staff_users_to_tos_cache, sender=TermsOfService, dispatch_uid='add_staff_users_to_tos_cache')

===============
django-tos-i18n
Expand Down
6 changes: 3 additions & 3 deletions tos/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db.models.signals import pre_save

from .signal_handlers import invalidate_cached_agreements
from .utils import initialize_cache_version


class TOSConfig(AppConfig):
Expand All @@ -16,9 +17,8 @@ def ready(self):
if 'tos.middleware.UserAgreementMiddleware' in MIDDLEWARES: # pragma: no cover
TermsOfService = self.get_model('TermsOfService')

initialize_cache_version()

pre_save.connect(invalidate_cached_agreements,
sender=TermsOfService,
dispatch_uid='invalidate_cached_agreements')

# Create the TOS key version immediately
invalidate_cached_agreements(TermsOfService)
Empty file added tos/management/__init__.py
Empty file.
Empty file.
11 changes: 11 additions & 0 deletions tos/management/commands/add_staff_users_to_tos_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from django.core.management.base import BaseCommand

from tos.utils import add_staff_users_to_tos_cache


class Command(BaseCommand):
def handle(self, *args, **options):
add_staff_users_to_tos_cache()
self.stdout.write(
self.style.SUCCESS("Successfully added staff users to TOS staff")
)
3 changes: 2 additions & 1 deletion tos/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from django.utils.cache import add_never_cache_headers

from .models import UserAgreement
from .utils import get_tos_cache


cache = caches[getattr(settings, 'TOS_CACHE_NAME', 'default')]
cache = get_tos_cache()
tos_check_url = reverse_lazy('tos_check_tos')


Expand Down
12 changes: 2 additions & 10 deletions tos/signal_handlers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
from django.conf import settings
from django.core.cache import caches


# Force the user to create a separate cache
cache = caches[getattr(settings, 'TOS_CACHE_NAME', 'default')]
from tos.utils import invalidate_cached_agreements as invalidate_cached_agreements_func


def invalidate_cached_agreements(sender, **kwargs):
if kwargs.get('raw', False):
return

# Set the key version to 0 if it doesn't exist and leave it
# alone if it does
cache.add('django:tos:key_version', 0)

# This key will be used to version the rest of the TOS keys
# Incrementing it will effectively invalidate all previous keys
cache.incr('django:tos:key_version')
invalidate_cached_agreements_func(sender)
111 changes: 111 additions & 0 deletions tos/tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from io import StringIO

from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.management import call_command
from django.test import TestCase
from django.urls import reverse

from tos.models import TermsOfService, UserAgreement, has_user_agreed_latest_tos
from tos.utils import (
add_staff_users_to_tos_cache,
get_tos_cache,
initialize_cache_version,
set_staff_in_cache_for_tos,
)


class CacheTestCase(TestCase):
def setUp(self):
self.cache = get_tos_cache()
self.cache.clear()

User = get_user_model()
User.objects.bulk_create([
User(
username=f"user{i}",
email=f"user{i}@example.com",
is_staff=True if i < 3 else False,
is_superuser=True if i % 2 == 0 else False,
)
for i in range(1, 10)
])

def get_skip_tos_check(self, i: int):
return self.cache.get(
f"django:tos:skip_tos_check:{i}",
None,
version=self.cache.get('django:tos:key_version'),
)

def call_command(self, cmd, *args, **kwargs):
out = StringIO()
err = StringIO()
call_command(
cmd,
*args,
stdout=out,
stderr=err,
**kwargs,
)
return out.getvalue(), err.getvalue()

def test_command(self):
for i in range(1, 10):
self.assertIsNone(self.get_skip_tos_check(i))

out, _ = self.call_command("add_staff_users_to_tos_cache")

self.assertIn("Successfully added staff users to TOS staff", out)

for i in range(1, 3):
self.assertIsNotNone(self.get_skip_tos_check(i))
for i in range(2, 10, 2):
self.assertIsNotNone(self.get_skip_tos_check(i))
for i in range(3, 10, 2):
self.assertIsNone(self.get_skip_tos_check(i))

def test_initialize_cache_version(self):
self.assertIsNone(self.cache.get('django:tos:key_version'))

initialize_cache_version()

self.assertIsNotNone(self.cache.get('django:tos:key_version', None))
self.assertEqual(self.cache.get('django:tos:key_version'), 1)

initialize_cache_version()

self.assertIsNotNone(self.cache.get('django:tos:key_version', None))
self.assertEqual(self.cache.get('django:tos:key_version'), 1)

def test_add_staff_users_to_tos_cache(self):
self.assertIsNone(add_staff_users_to_tos_cache(raw=True))

add_staff_users_to_tos_cache()

for i in range(1, 3):
self.assertIsNotNone(self.get_skip_tos_check(i))
for i in range(2, 10, 2):
self.assertIsNotNone(self.get_skip_tos_check(i))
for i in range(3, 10, 2):
self.assertIsNone(self.get_skip_tos_check(i))

def test_set_staff_in_cache_for_tos(self):
self.assertIsNone(set_staff_in_cache_for_tos(instance=None, raw=True))

User = get_user_model()
for i in range(1, 10):
set_staff_in_cache_for_tos(instance=User.objects.get(id=i))
if i < 3:
self.assertTrue(self.get_skip_tos_check(i))
if i % 2 == 0:
self.assertTrue(self.get_skip_tos_check(i))
if not (i < 3 or i % 2 == 0):
self.assertIsNone(self.get_skip_tos_check(i))

# Set it manually again, then run set again to ensure it
# removes the user from the skip cache when they are removed as
# a staff and superuser
self.cache.set(f"django:tos:skip_tos_check:{i}", True, version=self.cache.get("django:tos:key_version"))
set_staff_in_cache_for_tos(instance=User.objects.get(id=i))
self.assertIsNone(self.get_skip_tos_check(i))
7 changes: 4 additions & 3 deletions tos/tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tos.middleware import UserAgreementMiddleware
from tos.models import TermsOfService, UserAgreement
from tos.signal_handlers import invalidate_cached_agreements
from tos.utils import get_tos_cache


@modify_settings(
Expand All @@ -20,7 +21,7 @@ class TestMiddleware(TestCase):

def setUp(self):
# Clear cache between tests
cache = caches[getattr(settings, 'TOS_CACHE_NAME', 'default')]
cache = get_tos_cache()
cache.clear()

# User that has agreed to TOS
Expand Down Expand Up @@ -145,7 +146,7 @@ def test_ajax_request(self):
self.assertEqual(response.status_code, 200)

def test_skip_for_user(self):
cache = caches[getattr(settings, 'TOS_CACHE_NAME', 'default')]
cache = get_tos_cache()

key_version = cache.get('django:tos:key_version')

Expand Down Expand Up @@ -177,7 +178,7 @@ def test_use_cache(self):
self.assertEqual(response.request['PATH_INFO'], '/')

def test_invalidate_cached_agreements(self):
cache = caches[getattr(settings, 'TOS_CACHE_NAME', 'default')]
cache = get_tos_cache()

invalidate_cached_agreements(TermsOfService)

Expand Down
69 changes: 69 additions & 0 deletions tos/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import TYPE_CHECKING

from django.apps import AppConfig, apps
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.cache import caches
from django.core.management.base import BaseCommand
from django.db.models import Q
from django.db.models.signals import post_save, pre_save
from django.dispatch import receiver

if TYPE_CHECKING:
from django.contrib.auth.models import AbstractUser


def get_tos_cache():
return caches[getattr(settings, 'TOS_CACHE_NAME', 'default')]


cache = get_tos_cache()


def initialize_cache_version():
if not cache.get('django:tos:key_version', False):
# The function is a signal handler, so it needs a sender argument, but
# it doesn't actually use it at all, so it can be None
invalidate_cached_agreements(sender=None)


def invalidate_cached_agreements(sender, **kwargs):
# Set the key version to 0 if it doesn't exist and leave it
# alone if it does
cache.add('django:tos:key_version', 0)

# This key will be used to version the rest of the TOS keys
# Incrementing it will effectively invalidate all previous keys
cache.incr('django:tos:key_version')


def add_staff_users_to_tos_cache(*args, **kwargs):
if kwargs.get('raw', False):
return

# Get the cache prefix
key_version = cache.get('django:tos:key_version')

# Efficiently cache all of the users who are allowed to skip the TOS
# agreement check
cache.set_many({
f'django:tos:skip_tos_check:{staff_user.id}': True
for staff_user in get_user_model().objects.filter(
Q(is_staff=True) | Q(is_superuser=True))
}, version=key_version)


def set_staff_in_cache_for_tos(*, instance: 'AbstractUser', **kwargs):
if kwargs.get('raw', False):
return

# Get the cache prefix
key_version = cache.get('django:tos:key_version')

# If the user is staff allow them to skip the TOS agreement check
if instance.is_staff or instance.is_superuser:
cache.set(f'django:tos:skip_tos_check:{instance.id}', True, version=key_version)

# But if they aren't make sure we invalidate them from the cache
elif cache.get(f'django:tos:skip_tos_check:{instance.id}', False):
cache.delete(f'django:tos:skip_tos_check:{instance.id}', version=key_version)
3 changes: 2 additions & 1 deletion tos/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from django.views.generic import TemplateView

from tos.models import has_user_agreed_latest_tos, TermsOfService, UserAgreement
from .utils import get_tos_cache


cache = caches[getattr(settings, 'TOS_CACHE_NAME', 'default')]
cache = get_tos_cache()


class TosView(TemplateView):
Expand Down