diff --git a/docs/user/management_commands.rst b/docs/user/management_commands.rst index 575c3511..68686620 100644 --- a/docs/user/management_commands.rst +++ b/docs/user/management_commands.rst @@ -129,6 +129,10 @@ Following is an example: ./manage.py delete_unverified_users --older-than-days 1 --exclude-methods mobile_phone,email +If a user has multiple ``RegisteredUser`` rows across organizations, the +command keeps that user when **any** related row uses one of the excluded +methods. + ``upgrade_from_django_freeradius`` ---------------------------------- diff --git a/docs/user/rest-api.rst b/docs/user/rest-api.rst index 0efc1882..f4664342 100644 --- a/docs/user/rest-api.rst +++ b/docs/user/rest-api.rst @@ -803,6 +803,48 @@ Param Description phone_number string ============ =========== +Update user registration method ++++++++++++++++++++++++++++++++ + +**Requires the user auth token (Bearer Token)**. + +Allows users to update their registered user method for an organization. +The method can only be updated when it is currently set to +``pending_verification``. Once updated, it cannot be changed again via +this endpoint. + +This endpoint is used during cross-organization login when a user +authenticates to a new organization. The user must complete verification +for that organization before they can create account with the new +organization. + +.. code-block:: text + + /api/v1/radius/organization//account/registration-method/ + +Responds only to **POST**. + +Parameters: + +====== =========== +Param Description +====== =========== +method string (\*) +====== =========== + +(\*) ``method`` must be one of the available +:ref:`registration/verification methods +`, excluding +``pending_verification``. + +**Success Response (200 OK)**: + +.. code-block:: json + + { + "method": "mobile_phone" + } + .. _radius_batch_user_creation: Batch user creation diff --git a/docs/user/settings.rst b/docs/user/settings.rst index 4df3d1c2..ebbe8d64 100644 --- a/docs/user/settings.rst +++ b/docs/user/settings.rst @@ -696,6 +696,9 @@ verification method. The following choices are available by default: - ``mobile_phone``: Mobile phone number :ref:`verification via SMS ` - ``social_login``: :doc:`social login feature ` +- ``pending_verification``: Transitional state used when a user + authenticates to a new organization but has not yet completed + verification for that organization. .. note:: @@ -714,6 +717,33 @@ verification method. The following choices are available by default: **Disclaimer:** these are just suggestions on possible configurations of OpenWISP RADIUS and must not be considered as legal advice. +``OPENWISP_RADIUS_USER_SETTABLE_REGISTRATION_METHODS`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Default**: ``["", "email", "mobile_phone"]`` + +Defines which ``RegisteredUser.method`` values can be written by users +through the public registration APIs. + +Methods not included in this setting cannot be selected by users through +those APIs, even if they are present in the full list returned by +``get_registration_choices()``. + +This is especially useful to keep server-assigned provenance methods such +as ``saml``, ``social_login`` or ``manual`` out of user-controlled API +input. These methods may still be assigned internally by server-side +authentication or integration flows when appropriate. + +Example: + +.. code-block:: python + + OPENWISP_RADIUS_USER_SETTABLE_REGISTRATION_METHODS = [ + "", + "email", + "mobile_phone", + ] + .. _openwisp_radius_register_registration_method: Adding support for more registration/verification methods diff --git a/openwisp_radius/admin.py b/openwisp_radius/admin.py index ac095a54..e505d7dd 100644 --- a/openwisp_radius/admin.py +++ b/openwisp_radius/admin.py @@ -7,6 +7,8 @@ from django.contrib.admin.utils import model_ngettext from django.contrib.auth import get_user_model from django.core.exceptions import PermissionDenied +from django.db.models import Prefetch +from django.forms.models import BaseInlineFormSet from django.http import HttpResponseRedirect from django.templatetags.static import static from django.urls import reverse @@ -534,11 +536,31 @@ def has_change_permission(self, request, obj=None): return False +class RegisteredUserFormset(BaseInlineFormSet): + def get_unique_error_message(self, unique_check): + # Django inline formsets perform their own uniqueness validation + # (BaseModelFormSet.validate_unique) *before* model-level validation runs. + # Because of this, the custom `violation_error_message` defined on + # `UniqueConstraint` is never surfaced in the admin UI. + # + # Overriding this method allows us to replace Django’s generic + # "Please correct the duplicate data for ." message with a + # domain-specific, user-friendly error that matches our constraint. + if unique_check == ("user", "organization"): + return _( + "A user cannot have more than one registration record in the" + " same organization." + ) + + class RegisteredUserInline(StackedInline): model = RegisteredUser form = AlwaysHasChangedForm + formset = RegisteredUserFormset extra = 0 readonly_fields = ("modified",) + fields = ("organization", "method", "is_verified", "modified") + autocomplete_fields = ("organization",) def has_delete_permission(self, request, obj=None): return False @@ -549,22 +571,50 @@ def has_delete_permission(self, request, obj=None): RadiusUserGroupInline, PhoneTokenInline, ] -UserAdmin.list_filter += (RegisteredUserFilter, "registered_user__method") +UserAdmin.list_filter += (RegisteredUserFilter, "registered_users__method") +user_admin_get_queryset = UserAdmin.get_queryset + + +def get_queryset(self, request): + queryset = user_admin_get_queryset(self, request) + registered_users = RegisteredUser.objects.only( + "user_id", "organization_id", "is_verified" + ) + if not request.user.is_superuser: + registered_users = registered_users.filter( + organization__in=request.user.organizations_managed + ) + return queryset.prefetch_related( + Prefetch( + "registered_users", + queryset=registered_users, + to_attr="prefetched_registered_users", + ) + ) def get_is_verified(self, obj): - try: - value = "yes" if obj.registered_user.is_verified else "no" - except Exception: + prefetched_registered_users = getattr(obj, "prefetched_registered_users", None) + if prefetched_registered_users is not None: + is_verifieds = [ + reg_user.is_verified for reg_user in prefetched_registered_users + ] + else: + is_verifieds = [] + if not is_verifieds: value = "unknown" + elif any(is_verifieds): + value = "yes" + else: + value = "no" icon_url = static(f"admin/img/icon-{value}.svg") return mark_safe(f'{value}') +UserAdmin.get_queryset = get_queryset UserAdmin.get_is_verified = get_is_verified UserAdmin.get_is_verified.short_description = _("Verified") UserAdmin.list_display.insert(3, "get_is_verified") -UserAdmin.list_select_related = ("registered_user",) class OrganizationRadiusSettingsInline(admin.StackedInline): diff --git a/openwisp_radius/api/freeradius_views.py b/openwisp_radius/api/freeradius_views.py index b69232e5..acf10621 100644 --- a/openwisp_radius/api/freeradius_views.py +++ b/openwisp_radius/api/freeradius_views.py @@ -57,6 +57,7 @@ RadiusToken = load_model("RadiusToken") RadiusAccounting = load_model("RadiusAccounting") +RegisteredUser = load_model("RegisteredUser") OrganizationRadiusSettings = load_model("OrganizationRadiusSettings") OrganizationUser = swapper.load_model("openwisp_users", "OrganizationUser") Organization = swapper.load_model("openwisp_users", "Organization") @@ -290,7 +291,7 @@ def get_user(self, request, username, password): """ conditions = self._get_user_query_conditions(request) try: - user = auth_backend.get_users(username).filter(conditions)[0] + user = auth_backend.get_users(username).filter(conditions).distinct()[0] except IndexError: return None # ensure user is member of the authenticated org @@ -409,19 +410,21 @@ def _get_user_query_conditions(self, request): # just ensure user is active if not needs_verification: return is_active - # if identity verification is enabled - is_verified = Q(registered_user__is_verified=True) + organization_id = request._auth + registered_user = Q(registered_users__organization_id=organization_id) + is_verified = Q(registered_users__is_verified=True) AUTHORIZE_UNVERIFIED = registration.AUTHORIZE_UNVERIFIED - # and no method should authorize unverified users - # ensure user is active AND verified if not AUTHORIZE_UNVERIFIED: - return is_active & is_verified + return is_active & registered_user & is_verified # in case some methods are allowed to authorize unverified users # ensure user is active AND # (user is verified OR user uses one of these methods) else: - authorize_unverified = Q(registered_user__method__in=AUTHORIZE_UNVERIFIED) - return is_active & (is_verified | authorize_unverified) + return ( + is_active + & registered_user + & (is_verified | Q(registered_users__method__in=AUTHORIZE_UNVERIFIED)) + ) def authenticate_user(self, request, user, password): """ diff --git a/openwisp_radius/api/serializers.py b/openwisp_radius/api/serializers.py index b9b01165..4da53580 100644 --- a/openwisp_radius/api/serializers.py +++ b/openwisp_radius/api/serializers.py @@ -36,7 +36,6 @@ from .. import settings as app_settings from ..base.forms import PasswordResetForm from ..counters.exceptions import SkipCheck -from ..registration import REGISTRATION_METHOD_CHOICES from ..utils import ( get_group_checks, get_organization_radius_settings, @@ -571,9 +570,13 @@ class RegisterSerializer( 'verification in its "Organization RADIUS Settings."' ), default="", - choices=REGISTRATION_METHOD_CHOICES, + choices=(), ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields["method"].choices = app_settings.USER_SETTABLE_REGISTRATION_METHODS + def validate_phone_number(self, phone_number): org = self.context["view"].organization if get_organization_radius_settings(org, "sms_verification"): @@ -688,9 +691,11 @@ def save(self, request): # the custom_signup method contains the openwisp specific logic self.custom_signup(request, user) # create a RegisteredUser object for every user that registers through API - RegisteredUser.objects.create( + org = self.context["view"].organization + RegisteredUser.get_or_create_for_user_and_org( user=user, - method=self.validated_data["method"], + organization=org, + defaults={"method": self.validated_data["method"]}, ) setup_user_email(request, user, []) return user @@ -753,8 +758,55 @@ def save(self): # yet, tha will be done by the phone token validation view # once the phone number has been validated # at this point we flag the user as unverified again - self.user.registered_user.is_verified = False - self.user.registered_user.save() + org = self.context["view"].organization + reg_user, _ = RegisteredUser.get_or_create_for_user_and_org( + user=self.user, + organization=org, + defaults={"is_verified": False, "method": ""}, + ) + reg_user.is_verified = False + reg_user.save() + + +class UpdateRegisteredUserMethodSerializer(ValidatedModelSerializer): + method = serializers.ChoiceField( + choices=app_settings.USER_SETTABLE_REGISTRATION_METHODS, + help_text=_( + "The registration method to set for the user. " + "Cannot be 'pending_verification'." + ), + ) + + class Meta: + model = RegisteredUser + fields = ["method"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields["method"].choices = app_settings.USER_SETTABLE_REGISTRATION_METHODS + + def validate_method(self, value): + if value == "pending_verification": + raise serializers.ValidationError( + _("'pending_verification' cannot be set as a registration method.") + ) + return value + + def validate(self, attrs): + if self.instance.method != "pending_verification": + raise serializers.ValidationError( + { + "method": _( + "Method can only be updated from pending verification state." + ) + } + ) + return attrs + + def update(self, instance, validated_data): + instance.method = validated_data["method"] + instance.save() + return instance class RadiusUserSerializer(serializers.ModelSerializer): @@ -762,11 +814,8 @@ class RadiusUserSerializer(serializers.ModelSerializer): Used to return information about the logged in user """ - is_verified = serializers.BooleanField(source="registered_user.is_verified") - method = serializers.CharField( - source="registered_user.method", - allow_null=True, - ) + is_verified = serializers.SerializerMethodField() + method = serializers.SerializerMethodField() password_expired = serializers.BooleanField(source="has_password_expired") radius_user_token = serializers.CharField(source="radius_token.key", default=None) @@ -786,3 +835,30 @@ class Meta: "password_expired", "radius_user_token", ] + + def _get_registered_user(self, obj): + if not hasattr(self, "_registered_user_cache"): + self._registered_user_cache = {} + if obj.pk not in self._registered_user_cache: + view = self.context.get("view") + organization = getattr(view, "organization", None) + reg_user = None + # We iterate over .all() instead of using .filter() because callers + # of this serializer (e.g. validate_auth_token) prefetch + # "registered_users" via prefetch_related. Using .all() hits the + # in-memory prefetch cache (0 DB queries), whereas .filter() would + # bypass the cache and issue a new query every time. + for ru in obj.registered_users.all(): + if organization and ru.organization_id == organization.pk: + reg_user = ru + break + self._registered_user_cache[obj.pk] = reg_user + return self._registered_user_cache[obj.pk] + + def get_is_verified(self, obj): + reg_user = self._get_registered_user(obj) + return reg_user.is_verified if reg_user else None + + def get_method(self, obj): + reg_user = self._get_registered_user(obj) + return reg_user.method if reg_user else None diff --git a/openwisp_radius/api/urls.py b/openwisp_radius/api/urls.py index 88d02572..3ca6407d 100644 --- a/openwisp_radius/api/urls.py +++ b/openwisp_radius/api/urls.py @@ -77,6 +77,11 @@ def get_api_urls(api_views=None): api_views.change_phone_number, name="phone_number_change", ), + path( + "radius/organization//account/registration-method/", + api_views.update_registered_user_registration_method, + name="update_registered_user_registration_method", + ), path("radius/batch/", api_views.batch, name="batch"), path( "radius/organization//batch//pdf/", diff --git a/openwisp_radius/api/utils.py b/openwisp_radius/api/utils.py index 6d742c57..0e81d63a 100644 --- a/openwisp_radius/api/utils.py +++ b/openwisp_radius/api/utils.py @@ -9,6 +9,7 @@ Organization = load_model("openwisp_users", "Organization") OrganizationRadiusSettings = load_model("openwisp_radius", "OrganizationRadiusSettings") +RegisteredUser = load_model("openwisp_radius", "RegisteredUser") class ErrorDictMixin(object): @@ -30,8 +31,15 @@ def _needs_identity_verification(self, organization_filter_kwargs={}, org=None): except ObjectDoesNotExist: return app_settings.NEEDS_IDENTITY_VERIFICATION - def is_identity_verified_strong(self, user): - try: - return user.registered_user.is_identity_verified_strong - except ObjectDoesNotExist: + def is_identity_verified_strong(self, user, organization): + reg_user = None + # We use all() to utilize the prefetch cache, otherwise + # it would cause an additional query to fetch the organization-specific + # registration record + for ru in user.registered_users.all(): + if organization and ru.organization_id == organization.pk: + reg_user = ru + break + if reg_user is None: return False + return reg_user.is_identity_verified_strong diff --git a/openwisp_radius/api/views.py b/openwisp_radius/api/views.py index 07c4bd37..b87505cd 100644 --- a/openwisp_radius/api/views.py +++ b/openwisp_radius/api/views.py @@ -11,8 +11,8 @@ from django.contrib.sites.shortcuts import get_current_site from django.core.cache import cache from django.core.exceptions import ValidationError +from django.db import IntegrityError, transaction from django.db.models import Q -from django.db.utils import IntegrityError from django.http import Http404, HttpResponse from django.utils import timezone from django.utils.decorators import method_decorator @@ -35,6 +35,7 @@ ListCreateAPIView, RetrieveAPIView, RetrieveUpdateDestroyAPIView, + get_object_or_404, ) from rest_framework.pagination import PageNumberPagination from rest_framework.permissions import ( @@ -74,6 +75,7 @@ RadiusBatchSerializer, RadiusGroupSerializer, RadiusUserGroupSerializer, + UpdateRegisteredUserMethodSerializer, UserRadiusUsageSerializer, ValidatePhoneTokenSerializer, ) @@ -92,6 +94,7 @@ Organization = swapper.load_model("openwisp_users", "Organization") OrganizationUser = swapper.load_model("openwisp_users", "OrganizationUser") PhoneToken = load_model("PhoneToken") +RegisteredUser = load_model("RegisteredUser") RadiusAccounting = load_model("RadiusAccounting") RadiusToken = load_model("RadiusToken") RadiusBatch = load_model("RadiusBatch") @@ -315,13 +318,13 @@ def post(self, request, *args, **kwargs): self.update_user_details(user) context = {"view": self, "request": request} serializer = self.serializer_class(instance=token, context=context) - response = RadiusUserSerializer(user).data + response = RadiusUserSerializer(user, context=context).data response.update(serializer.data) status_code = 200 if user.is_active else 401 # If identity verification is required, check if user is verified if self._needs_identity_verification( {"slug": kwargs["slug"]} - ) and not self.is_identity_verified_strong(user): + ) and not self.is_identity_verified_strong(user, self.organization): status_code = 401 return Response(response, status=status_code) @@ -335,24 +338,23 @@ def validate_membership(self, user): if get_organization_radius_settings( self.organization, "registration_enabled" ): - if self._needs_identity_verification( - org=self.organization - ) and not self.is_identity_verified_strong(user): - raise PermissionDenied try: - org_user = OrganizationUser( - user=user, organization=self.organization - ) - org_user.full_clean() - org_user.save() + with transaction.atomic(): + OrganizationUser.objects.get_or_create( + user=user, organization=self.organization + ) + RegisteredUser.get_or_create_for_user_and_org( + user=user, + organization=self.organization, + defaults={"method": "pending_verification"}, + ) except ValidationError as error: raise serializers.ValidationError( {"non_field_errors": error.message_dict.pop("__all__")} ) else: message = _( - "{organization} does not allow self registration " - "of new accounts." + "{organization} does not allow self registration of new accounts." ).format(organization=self.organization.name) raise PermissionDenied(message) @@ -383,9 +385,15 @@ def post(self, request, *args, **kwargs): response = {"response_code": "BLANK_OR_INVALID_TOKEN"} if request_token: try: - token = UserToken.objects.select_related( - "user", "user__registered_user" - ).get(key=request_token) + token = ( + UserToken.objects.select_related( + "user", + ) + .prefetch_related( + "user__registered_users", + ) + .get(key=request_token) + ) except UserToken.DoesNotExist: pass else: @@ -395,7 +403,7 @@ def post(self, request, *args, **kwargs): ) # user may be in the process of changing the phone number # in that case show the new phone number (which is not verified yet) - if not self.is_identity_verified_strong(user): + if not self.is_identity_verified_strong(user, self.organization): phone_token = ( PhoneToken.objects.filter(user=user) .order_by("-created") @@ -404,8 +412,8 @@ def post(self, request, *args, **kwargs): user.phone_number = ( phone_token.phone_number if phone_token else user.phone_number ) - response = RadiusUserSerializer(user).data context = {"view": self, "request": request} + response = RadiusUserSerializer(user, context=context).data token_data = rest_auth_settings.api_settings.TOKEN_SERIALIZER( token, context=context ).data @@ -632,13 +640,14 @@ def create(self, *args, **kwargs): phone_number = request.data.get("phone_number", request.user.phone_number) phone_token = PhoneToken( user=request.user, + organization=self.organization, ip=self.get_ident(request), phone_number=phone_number, ) try: phone_token.full_clean() if kwargs.get("enforce_unverified", True): - phone_token._validate_already_verified() + phone_token._validate_already_verified(organization=self.organization) except ValidationError as e: error_dict = self._get_error_dict(e) raise serializers.ValidationError(error_dict) @@ -663,6 +672,7 @@ def enforce_sms_request_cooldown(self, cooldown, phone_number): last_phone_token = ( PhoneToken.objects.filter( user=self.request.user, + organization=self.organization, phone_number=phone_number, created__gt=datetime_now - timezone.timedelta(seconds=cooldown), ) @@ -705,6 +715,7 @@ def get(self, request, *args, **kwargs): self.validate_membership(user) is_active = PhoneToken.objects.filter( user=request.user, + organization=self.organization, phone_number=user.phone_number, valid_until__gte=timezone.now(), verified=False, @@ -741,21 +752,45 @@ def post(self, request, *args, **kwargs): self.validate_membership(user) serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - phone_token = PhoneToken.objects.filter(user=user).order_by("-created").first() + phone_token = ( + PhoneToken.objects.filter(user=user, organization=self.organization) + .order_by("-created") + .first() + ) if not phone_token: return self._error_response( _("No verification code found in the system for this user.") ) try: - is_valid = phone_token.is_valid(serializer.data["code"]) + is_valid = phone_token.is_valid( + serializer.data["code"], organization=self.organization + ) except PhoneTokenException as e: return self._error_response(str(e)) if not is_valid: return self._error_response(_("Invalid code.")) else: - user.registered_user.is_verified = True - user.registered_user.method = "mobile_phone" - user.is_active = True + old_phone_number = str(user.phone_number) if user.phone_number else None + phone_number_changed = old_phone_number and old_phone_number != str( + phone_token.phone_number + ) + if phone_number_changed: + # The shipped registration methods only tie identity verification + # to the stored phone number for mobile_phone entries. + RegisteredUser.objects.filter( + user=user, + method="mobile_phone", + ).exclude(organization=self.organization,).update(is_verified=False) + reg_user, __ = RegisteredUser.get_or_create_for_user_and_org( + user=user, + organization=self.organization, + defaults={ + "is_verified": True, + "method": "mobile_phone", + }, + ) + reg_user.is_verified = True + reg_user.method = "mobile_phone" # Update username if phone_number is used as username if user.username == user.phone_number: user.username = phone_token.phone_number @@ -763,9 +798,14 @@ def post(self, request, *args, **kwargs): # we can write it to the user field user.phone_number = phone_token.phone_number user.save() - user.registered_user.save() - # delete any radius token cache key if present - cache.delete(f"rt-{phone_token.phone_number}") + reg_user.save() + # Delete any cached radius token for either the previous or current + # phone number so callers cannot keep using stale cached entries. + cache_keys = {f"rt-{phone_token.phone_number}"} + if old_phone_number: + cache_keys.add(f"rt-{old_phone_number}") + for cache_key in cache_keys: + cache.delete(cache_key) return Response(None, status=200) @@ -813,6 +853,51 @@ def create_phone_token(self, *args, **kwargs): change_phone_number = ChangePhoneNumberView.as_view() +class UpdateRegisteredUserMethodView(DispatchOrgMixin, GenericAPIView): + authentication_classes = (BearerAuthentication, SessionAuthentication) + permission_classes = (IsAuthenticated,) + serializer_class = UpdateRegisteredUserMethodSerializer + + @swagger_auto_schema( + operation_description=(""" + **Requires the user auth token (Bearer Token).** + Allows users to update their organization-specific registration + method. + The method can only be updated when it is currently + set to 'pending_verification'. + Once updated, it cannot be changed again via this endpoint. + """), + responses={ + 200: "Method updated successfully", + 400: ( + "Invalid request (method is not 'pending_verification' " + "or invalid method value)" + ), + 401: "Authentication required", + 404: "RegisteredUser not found for this user and organization", + }, + ) + def post(self, request, slug): + user = request.user + self.validate_membership(user) + reg_user = get_object_or_404( + RegisteredUser, + user_id=user.pk, + organization=self.organization, + ) + serializer = self.get_serializer( + instance=reg_user, data=request.data, partial=True + ) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response( + {"method": serializer.instance.method}, status=status.HTTP_200_OK + ) + + +update_registered_user_registration_method = UpdateRegisteredUserMethodView.as_view() + + class RadiusAccountingFilter(AccountingFilter): called_station_id = CharFilter( field_name="called_station_id", method="filter_mac_address" diff --git a/openwisp_radius/base/admin_filters.py b/openwisp_radius/base/admin_filters.py index 5fd73991..e0e1b87a 100644 --- a/openwisp_radius/base/admin_filters.py +++ b/openwisp_radius/base/admin_filters.py @@ -1,6 +1,11 @@ from django.contrib.admin import SimpleListFilter +from django.db.models import Exists, OuterRef, Q from django.utils.translation import gettext_lazy as _ +from ..utils import load_model + +RegisteredUser = load_model("RegisteredUser") + class RegisteredUserFilter(SimpleListFilter): title = _("Verified") @@ -14,8 +19,28 @@ def lookups(self, request, model_admin): ) def queryset(self, request, queryset): + if self.value() is None: + return queryset + where = Q() + if not request.user.is_superuser: + where &= Q( + registered_users__organization__in=request.user.organizations_managed + ) if self.value() == "unknown": - return queryset.filter(registered_user__isnull=True) + if not request.user.is_superuser: + # Restrict the "unknown" check to organizations managed by the + # current admin. A plain `registered_users__isnull=True` filter + # would treat users registered in other organizations as known + # and incorrectly exclude them from the results. + registered_users = RegisteredUser.objects.filter( + user=OuterRef("pk"), + organization__in=request.user.organizations_managed, + ) + return queryset.annotate( + has_managed_registered_user=Exists(registered_users) + ).filter(has_managed_registered_user=False) + + where &= Q(registered_users__isnull=True) elif self.value(): - return queryset.filter(registered_user__is_verified=self.value() == "true") - return queryset + where &= Q(registered_users__is_verified=self.value() == "true") + return queryset.filter(where).distinct() diff --git a/openwisp_radius/base/models.py b/openwisp_radius/base/models.py index c78cbb08..9eb70b0a 100644 --- a/openwisp_radius/base/models.py +++ b/openwisp_radius/base/models.py @@ -14,7 +14,7 @@ from django.conf import settings from django.contrib.auth import get_user_model from django.core.cache import cache -from django.core.exceptions import ObjectDoesNotExist, ValidationError +from django.core.exceptions import ValidationError from django.core.mail import send_mail from django.core.serializers.json import DjangoJSONEncoder from django.db import models, transaction @@ -174,6 +174,9 @@ _LOGIN_URL_HELP_TEXT = _("Enter the URL where users can log in to the wifi service") _STATUS_URL_HELP_TEXT = _("Enter the URL where users can log out from the wifi service") _PASSWORD_RESET_URL_HELP_TEXT = _("Enter the URL where users can reset their password") +_REGISTRATION_UNIQUE_VALIDATION_ERROR = _( + "A user cannot have more than one registration record in the same organization." +) OPTIONAL_SETTINGS = app_settings.OPTIONAL_REGISTRATION_FIELDS @@ -1058,10 +1061,22 @@ def save_user(self, user): OrganizationUser = swapper.load_model("openwisp_users", "OrganizationUser") RegisteredUser = swapper.load_model("openwisp_radius", "RegisteredUser") user.save() - registered_user = RegisteredUser(user=user, method="manual") - if self.organization.radius_settings.needs_identity_verification: + radius_settings = self.organization.radius_settings + registered_user, created = RegisteredUser.get_or_create_for_user_and_org( + user=user, + organization=self.organization, + defaults={ + "method": "manual", + "is_verified": radius_settings.needs_identity_verification, + }, + ) + if ( + not created + and self.organization.radius_settings.needs_identity_verification + ): + registered_user.method = "manual" registered_user.is_verified = True - registered_user.save() + registered_user.save() self.users.add(user) if OrganizationUser.objects.filter( user=user, organization=self.organization @@ -1452,7 +1467,7 @@ def delete_cache(self, *args, **kwargs): cache.delete(f"ip-{self.organization.pk}") -class AbstractPhoneToken(TimeStampedEditableModel): +class AbstractPhoneToken(OrgMixin, TimeStampedEditableModel): """ Phone Verification Token (sent via SMS) """ @@ -1541,15 +1556,13 @@ def save(self, *args, **kwargs): return result def send_token(self): - OrganizationUser = swapper.load_model("openwisp_users", "OrganizationUser") - org_user = OrganizationUser.objects.filter(user=self.user).first() - if not org_user: + if self.organization is None: raise exceptions.NoOrgException( _("The user {user} is not member of any organization").format( user=self.user ) ) - org_radius_settings = org_user.organization.radius_settings + org_radius_settings = self.organization.radius_settings message = _(org_radius_settings.sms_message).format( organization=org_radius_settings.organization.name, code=self.token ) @@ -1560,28 +1573,33 @@ def send_token(self): ) sms_message.send(meta_data=org_radius_settings.sms_meta_data) - def is_valid(self, token): + def is_valid(self, token, organization=None): self.attempts += 1 try: - self.verified = self.__check(token) + self.verified = self.__check(token, organization=organization) except exceptions.PhoneTokenException as phone_error: self.save() raise phone_error self.save() return self.verified - def _validate_already_verified(self): - try: - if self.user.registered_user.is_verified: - logger.warning(f"User {self.user.pk} is already verified") - raise exceptions.UserAlreadyVerified( - _("This user has been already verified.") - ) - except ObjectDoesNotExist: - pass + def _validate_already_verified(self, organization=None): + RegisteredUser = swapper.load_model("openwisp_radius", "RegisteredUser") + if organization is not None: + reg_user = RegisteredUser.get_for_user_and_org(self.user, organization) + is_verified = reg_user is not None and reg_user.is_verified + else: + is_verified = RegisteredUser.objects.filter( + user=self.user, is_verified=True + ).exists() + if is_verified: + logger.warning(f"User {self.user.pk} is already verified") + raise exceptions.UserAlreadyVerified( + _("This user has been already verified.") + ) - def __check(self, token): - self._validate_already_verified() + def __check(self, token, organization=None): + self._validate_already_verified(organization=organization) if self.attempts > app_settings.SMS_TOKEN_MAX_ATTEMPTS: logger.warning( f"User {self.user} has reached the max " @@ -1603,12 +1621,11 @@ def __check(self, token): return token == self.token -class AbstractRegisteredUser(models.Model): - user = models.OneToOneField( +class AbstractRegisteredUser(UUIDModel, OrgMixin): + user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name="registered_user", - primary_key=True, + related_name="registered_users", ) method = models.CharField( _("registration method"), @@ -1640,7 +1657,7 @@ class AbstractRegisteredUser(models.Model): default=False, ) modified = AutoLastModifiedField(_("Last verification change"), editable=True) - _weak_verification_methods = {"", "email"} + _weak_verification_methods = {"", "email", "pending_verification"} @property def is_identity_verified_strong(self): @@ -1650,6 +1667,39 @@ class Meta: abstract = True verbose_name = _("Registration Information") verbose_name_plural = verbose_name + constraints = [ + models.UniqueConstraint( + fields=["user", "organization"], + name="unique_registered_user_per_org", + violation_error_message=_REGISTRATION_UNIQUE_VALIDATION_ERROR, + ), + ] + + @classmethod + def get_or_create_for_user_and_org(cls, user, organization, defaults=None): + defaults = defaults or {} + return cls.objects.get_or_create( + user=user, organization=organization, defaults=defaults + ) + + @classmethod + def get_for_user_and_org(cls, user, organization): + prefetched_registered_users = getattr(user, "prefetched_registered_users", None) + if prefetched_registered_users is None: + prefetched_registered_users = getattr( + user, + "_prefetched_objects_cache", + {}, + ).get("registered_users") + if prefetched_registered_users is not None: + for registered_user in prefetched_registered_users: + if registered_user.organization_id == organization.pk: + return registered_user + return None + try: + return cls.objects.get(user=user, organization=organization) + except cls.DoesNotExist: + return None @classmethod def unverify_inactive_users(cls): diff --git a/openwisp_radius/checks.py b/openwisp_radius/checks.py index 58e55048..2bf1d325 100644 --- a/openwisp_radius/checks.py +++ b/openwisp_radius/checks.py @@ -1,6 +1,8 @@ from django.core import checks +from django.core.exceptions import ImproperlyConfigured from . import settings as app_settings +from .registration import validate_user_settable_registration_methods @checks.register @@ -49,3 +51,21 @@ def check_social_registration_enabled(app_configs, **kwargs): ) ) return errors + + +@checks.register +def check_user_settable_registration_methods(app_configs, **kwargs): + errors = [] + try: + validate_user_settable_registration_methods( + app_settings.USER_SETTABLE_REGISTRATION_METHODS + ) + except ImproperlyConfigured as error: + errors.append( + checks.Error( + msg="Improperly Configured", + hint=str(error), + obj="Settings", + ) + ) + return errors diff --git a/openwisp_radius/integrations/monitoring/tasks.py b/openwisp_radius/integrations/monitoring/tasks.py index f251edc3..d986afde 100644 --- a/openwisp_radius/integrations/monitoring/tasks.py +++ b/openwisp_radius/integrations/monitoring/tasks.py @@ -3,7 +3,7 @@ from celery import shared_task from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType -from django.db.models import Count, Q +from django.db.models import Count, F, Q from django.utils import timezone from swapper import load_model @@ -75,9 +75,9 @@ def _write_user_signup_metric_for_all(metric_key): ) ) # Some manually created users, like superuser may not have a - # RegisteredUser object. We would could them with "unspecified" method + # RegisteredUser object. We would count them with "unspecified" method users_without_registereduser_query = User.objects.filter( - registered_user__isnull=True + registered_users__isnull=True ) if metric_key == "user_signups": users_without_registereduser_query = users_without_registereduser_query.filter( @@ -97,6 +97,8 @@ def _write_user_signup_metric_for_all(metric_key): for method, count in total_registered_users.items(): method = clean_registration_method(method) + if method is None: + continue metric = get_metric_func(organization_id="__all__", registration_method=method) metric_data.append((metric, {"value": count})) Metric.batch_write(metric_data) @@ -111,27 +113,33 @@ def _write_user_signup_metrics_for_orgs(metric_key): else: get_metric_func = _get_total_user_signup_metric - # Get the registration data for the past hour. - # The query returns a tuple of organization_id, registration_method and - # count of users who registered with that organization and method. + # Get registration data grouped by organization and registration method. + # Scope OrganizationUser joins to the same organization as the + # RegisteredUser to avoid memberships from other organizations affecting + # this organization's signup metrics. registered_users_query = RegisteredUser.objects.exclude( + method="pending_verification" + ).exclude( + user__openwisp_users_organizationuser__organization_id=F("organization_id"), user__openwisp_users_organizationuser__created__gt=end_time, ) if metric_key == "user_signups": registered_users_query = registered_users_query.filter( + user__openwisp_users_organizationuser__organization_id=F("organization_id"), user__openwisp_users_organizationuser__created__gt=start_time, user__openwisp_users_organizationuser__created__lte=end_time, ) registered_users = registered_users_query.values_list( - "user__openwisp_users_organizationuser__organization_id", "method" + "organization_id", + "method", ).annotate(count=Count("user_id", distinct=True)) - # There could be users which were manually created (e.g. superuser) - # which do not have related RegisteredUser object. Add the count - # of such users with the "unspecified" method. - users_without_registereduser_query = OrganizationUser.objects.filter( - user__registered_user__isnull=True + # Count users without a RegisteredUser for this organization. + # A simple ``registered_users__isnull=True`` check would incorrectly + # exclude users having RegisteredUser rows only in other organizations. + users_without_registereduser_query = OrganizationUser.objects.exclude( + user__registered_users__organization_id=F("organization_id") ) if metric_key == "user_signups": users_without_registereduser_query = users_without_registereduser_query.filter( @@ -146,11 +154,22 @@ def _write_user_signup_metrics_for_orgs(metric_key): for org_id, registration_method, count in registered_users: registration_method = clean_registration_method(registration_method) if registration_method == "unspecified": - count += users_without_registereduser.get(org_id, 0) + count += users_without_registereduser.pop(org_id, 0) metric = get_metric_func( organization_id=org_id, registration_method=registration_method ) metric_data.append((metric, {"value": count})) + + # Write metrics for organizations having only users without a + # RegisteredUser for that organization. These organizations are not + # present in ``registered_users`` because they have no matching + # RegisteredUser rows. + for org_id, count in users_without_registereduser.items(): + metric = get_metric_func( + organization_id=org_id, registration_method="unspecified" + ) + metric_data.append((metric, {"value": count})) + Metric.batch_write(metric_data) @@ -182,18 +201,22 @@ def post_save_radiusaccounting( called_station_id, time=None, ): - try: - registration_method = ( - RegisteredUser.objects.only("method").get(user__username=username).method - ) - except RegisteredUser.DoesNotExist: + registration_method = ( + RegisteredUser.objects.only("method") + .filter(user__username=username, organization_id=organization_id) + .first() + ) + if registration_method is None: logger.info( f'RegisteredUser object not found for "{username}".' ' The metric will be written with "unspecified" registration method!' ) registration_method = "unspecified" else: - registration_method = clean_registration_method(registration_method) + registration_method = registration_method.method + registration_method = clean_registration_method(registration_method) + if registration_method is None: + registration_method = "unspecified" device_lookup = Q(mac_address__iexact=called_station_id.replace("-", ":")) extra_tags = { "method": registration_method, diff --git a/openwisp_radius/integrations/monitoring/tests/test_metrics.py b/openwisp_radius/integrations/monitoring/tests/test_metrics.py index 8a3f6dd7..4c4450ca 100644 --- a/openwisp_radius/integrations/monitoring/tests/test_metrics.py +++ b/openwisp_radius/integrations/monitoring/tests/test_metrics.py @@ -4,6 +4,7 @@ from django.contrib.contenttypes.models import ContentType from django.core.cache import cache from django.test import tag +from django.utils import timezone from swapper import load_model from openwisp_radius.tests import _RADACCT @@ -16,13 +17,50 @@ TASK_PATH = "openwisp_radius.integrations.monitoring.tasks" RegisteredUser = load_model("openwisp_radius", "RegisteredUser") +OrganizationUser = load_model("openwisp_users", "OrganizationUser") User = get_user_model() @tag("radius_monitoring") class TestMetrics(CreateDeviceMonitoringMixin, BaseTransactionTestCase): + def _read_chart(self, chart, **kwargs): + return chart.read( + additional_query_kwargs={"additional_params": kwargs}, + ) + + def _get_metric_traces(self, metric_key, organization_id): + chart = self.metric_model.objects.get(key=metric_key).chart_set.first() + points = self._read_chart( + chart, + organization_id=[str(organization_id)], + ) + return {trace_name: values[-1] for trace_name, values in points["traces"]} + + def _assert_pending_verification_excluded(self, points): + """ + Ensure that pending_verification users do not contribute + to metric outputs. + + This validates both: + - trace-level values (time series data) + - summary-level aggregation + """ + self.assertEqual(points["traces"][0][1][-1], 0) + summary = points.get("summary", {}) + # Summary should not contain any positive counts + for key, value in summary.items(): + self.assertEqual( + value, + 0, + f"pending_verification leaked into summary for key={key}", + ) + def _create_registered_user(self, **kwargs): - options = {"is_verified": False, "method": "mobile_phone"} + options = { + "is_verified": False, + "method": "mobile_phone", + "organization": self.default_org, + } options.update(**kwargs) if "user" not in options: options["user"] = self._create_user() @@ -238,6 +276,7 @@ def test_post_save_radius_accounting_device_not_found(self, mocked_logger): convert_called_station_id feature, but it is not configured properly leaving all called_station_id unconverted. """ + cache.clear() user = self._create_user() reg_user = self._create_registered_user(user=user) options = _RADACCT.copy() @@ -254,7 +293,6 @@ def test_post_save_radius_accounting_device_not_found(self, mocked_logger): options["stop_time"] = options["start_time"] # Remove calls for user registration from mocked logger mocked_logger.reset_mock() - self._create_radius_accounting(**options) self.assertEqual( self.metric_model.objects.filter( @@ -368,14 +406,102 @@ def test_post_save_radius_accounting_registereduser_not_found(self, mocked_logge ' The metric will be written with "unspecified" registration method!' ) + def test_post_save_radiusaccounting_pending_verification(self): + """ + Test that when a user has a RegisteredUser with method="pending_verification", + the metric is written with "unspecified" instead of None. + """ + user = self._create_user() + self._create_registered_user(user=user, method="pending_verification") + device = self._create_device() + device_loc = self._create_device_location( + content_object=device, + location=self._create_location(organization=device.organization), + ) + options = _RADACCT.copy() + options.update( + { + "unique_id": "pending_001", + "username": user.username, + "called_station_id": device.mac_address.replace("-", ":").upper(), + "calling_station_id": "00:00:00:00:00:00", + "input_octets": "8000000000", + "output_octets": "9000000000", + } + ) + options["stop_time"] = options["start_time"] + self._create_radius_accounting(**options) + self.assertEqual( + self.metric_model.objects.filter( + configuration="radius_acc", + name="RADIUS Accounting", + key="radius_acc", + object_id=str(device.id), + content_type=ContentType.objects.get_for_model(self.device_model), + extra_tags={ + "called_station_id": device.mac_address, + "calling_station_id": sha1_hash("00:00:00:00:00:00"), + "location_id": str(device_loc.location.id), + "method": "unspecified", + "organization_id": str(self.default_org.id), + }, + ).count(), + 1, + ) + + def test_post_save_radiusaccounting_does_not_fallback_to_other_org( + self, + ): + """ + Test that a RegisteredUser from another organization is not used + when accounting is written for the current organization. + """ + user = self._create_user() + self._create_registered_user( + user=user, organization=self.default_org, method="mobile_phone" + ) + org2 = self._create_org(name="metrics-org-2", slug="metrics-org-2") + self._create_org_user(user=user, organization=org2) + self._create_registered_user(user=user, organization=org2, method="email") + device = self._create_device() + device_loc = self._create_device_location( + content_object=device, + location=self._create_location(organization=device.organization), + ) + options = _RADACCT.copy() + options.update( + { + "unique_id": "org_spec_001", + "username": user.username, + "called_station_id": device.mac_address.replace("-", ":").upper(), + "calling_station_id": "00:00:00:00:00:00", + "input_octets": "8000000000", + "output_octets": "9000000000", + } + ) + options["stop_time"] = options["start_time"] + self._create_radius_accounting(**options) + self.assertEqual( + self.metric_model.objects.filter( + configuration="radius_acc", + name="RADIUS Accounting", + key="radius_acc", + object_id=str(device.id), + content_type=ContentType.objects.get_for_model(self.device_model), + extra_tags={ + "called_station_id": device.mac_address, + "calling_station_id": sha1_hash("00:00:00:00:00:00"), + "location_id": str(device_loc.location.id), + "method": "mobile_phone", + "organization_id": str(self.default_org.id), + }, + ).count(), + 1, + ) + def test_write_user_registration_metrics(self): from ..tasks import write_user_registration_metrics - def _read_chart(chart, **kwargs): - return chart.read( - additional_query_kwargs={"additional_params": kwargs}, - ) - # The TransactionTestCase truncates all the data after each test. # The general metrics and charts which are created by migrations # get deleted after each test. Therefore, we create them again here. @@ -393,21 +519,25 @@ def _read_chart(chart, **kwargs): write_user_registration_metrics.delay() user_signup_chart = user_signup_metric.chart_set.first() - all_points = _read_chart(user_signup_chart, organization_id=["__all__"]) + all_points = self._read_chart( + user_signup_chart, organization_id=["__all__"] + ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) - org_points = _read_chart(user_signup_chart, organization_id=[str(org.id)]) + org_points = self._read_chart( + user_signup_chart, organization_id=[str(org.id)] + ) self.assertEqual(len(org_points["traces"]), 0) total_user_signup_chart = total_user_signup_metric.chart_set.first() - all_points = _read_chart( + all_points = self._read_chart( total_user_signup_chart, organization_id=["__all__"] ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) - org_points = _read_chart( + org_points = self._read_chart( total_user_signup_chart, organization_id=[str(org.id)] ) self.assertEqual(len(org_points["traces"]), 0) @@ -421,23 +551,27 @@ def _read_chart(chart, **kwargs): write_user_registration_metrics.delay() user_signup_chart = user_signup_metric.chart_set.first() - all_points = _read_chart(user_signup_chart, organization_id=["__all__"]) + all_points = self._read_chart( + user_signup_chart, organization_id=["__all__"] + ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) - org_points = _read_chart(user_signup_chart, organization_id=[str(org.id)]) + org_points = self._read_chart( + user_signup_chart, organization_id=[str(org.id)] + ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) total_user_signup_chart = total_user_signup_metric.chart_set.first() - all_points = _read_chart( + all_points = self._read_chart( total_user_signup_chart, organization_id=["__all__"] ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) - org_points = _read_chart( + org_points = self._read_chart( total_user_signup_chart, organization_id=[str(org.id)] ) self.assertEqual(all_points["traces"][0][0], "unspecified") @@ -454,13 +588,17 @@ def _read_chart(chart, **kwargs): write_user_registration_metrics.delay() user_signup_chart = user_signup_metric.chart_set.first() - all_points = _read_chart(user_signup_chart, organization_id=["__all__"]) + all_points = self._read_chart( + user_signup_chart, organization_id=["__all__"] + ) self.assertEqual(all_points["traces"][0][0], "mobile_phone") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual( all_points["summary"], {"mobile_phone": 1, "unspecified": 0} ) - org_points = _read_chart(user_signup_chart, organization_id=[str(org.id)]) + org_points = self._read_chart( + user_signup_chart, organization_id=[str(org.id)] + ) self.assertEqual(all_points["traces"][0][0], "mobile_phone") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual( @@ -468,7 +606,7 @@ def _read_chart(chart, **kwargs): ) total_user_signup_chart = total_user_signup_metric.chart_set.first() - org_points = _read_chart( + org_points = self._read_chart( total_user_signup_chart, organization_id=["__all__"] ) self.assertEqual(org_points["traces"][0][0], "mobile_phone") @@ -476,7 +614,7 @@ def _read_chart(chart, **kwargs): self.assertEqual( org_points["summary"], {"mobile_phone": 1, "unspecified": 0} ) - org_points = _read_chart( + org_points = self._read_chart( total_user_signup_chart, organization_id=[str(org.id)] ) self.assertEqual(all_points["traces"][0][0], "mobile_phone") @@ -484,3 +622,147 @@ def _read_chart(chart, **kwargs): self.assertEqual( all_points["summary"], {"mobile_phone": 1, "unspecified": 0} ) + + def test_pending_verification_excluded_from_metrics(self): + from ..tasks import write_user_registration_metrics + + cache.clear() + create_general_metrics(None, None) + org = self._create_org(name="pending_verification_test_org") + user_signup_metric = self.metric_model.objects.get(key="user_signups") + total_user_signup_metric = self.metric_model.objects.get(key="tot_user_signups") + user = self._create_org_user(organization=org).user + self._create_registered_user( + user=user, organization=org, method="pending_verification" + ) + write_user_registration_metrics.delay() + + user_signup_chart = user_signup_metric.chart_set.first() + org_points = self._read_chart(user_signup_chart, organization_id=[str(org.pk)]) + all_points = self._read_chart(user_signup_chart, organization_id=["__all__"]) + self.assertEqual(len(org_points["traces"]), 0) + self._assert_pending_verification_excluded(all_points) + + total_user_signup_chart = total_user_signup_metric.chart_set.first() + org_points = self._read_chart( + total_user_signup_chart, organization_id=[str(org.pk)] + ) + all_points = self._read_chart( + total_user_signup_chart, organization_id=["__all__"] + ) + self.assertEqual(len(org_points["traces"]), 0) + self._assert_pending_verification_excluded(all_points) + + def test_write_user_registration_metrics_uses_org_specific_methods(self): + """ + Ensure organization metrics use the registration method associated + with that specific organization membership. + + Scenario: + - One user belongs to two organizations. + - The user has one RegisteredUser row per organization. + - Each RegisteredUser uses a different registration method. + + Expected behavior: + - Global metrics aggregate both methods. + - Each organization only counts its own method. + """ + from ..tasks import write_user_registration_metrics + + def _get_metric_traces(metric_key, organization_id): + chart = self.metric_model.objects.get(key=metric_key).chart_set.first() + points = self._read_chart( + chart, + organization_id=[str(organization_id)], + ) + return {trace_name: values[-1] for trace_name, values in points["traces"]} + + cache.clear() + create_general_metrics(None, None) + org1 = self._get_org() + org2 = self._create_org(name="org2", slug="org2") + user = self._create_user() + self._create_org_user(user=user, organization=org1) + self._create_org_user(user=user, organization=org2) + self._create_registered_user( + user=user, + organization=org1, + method="mobile_phone", + ) + self._create_registered_user( + user=user, + organization=org2, + method="email", + ) + write_user_registration_metrics.delay() + for metric_key in ["user_signups", "tot_user_signups"]: + all_points = _get_metric_traces(metric_key, "__all__") + org1_points = _get_metric_traces(metric_key, org1.pk) + org2_points = _get_metric_traces(metric_key, org2.pk) + + # Global metrics aggregate registrations from all organizations. + self.assertEqual(all_points.get("mobile_phone", 0), 1) + self.assertEqual(all_points.get("email", 0), 1) + + # org1 only counts its own registration method. + self.assertEqual(org1_points.get("mobile_phone", 0), 1) + self.assertEqual(org1_points.get("email", 0), 0) + + # org2 only counts its own registration method. + self.assertEqual(org2_points.get("email", 0), 1) + self.assertEqual(org2_points.get("mobile_phone", 0), 0) + + def test_write_user_registration_metrics_scopes_membership_window_per_org( + self, + ): + """ + Ensure signup metrics scope organization membership windows per organization. + + Scenario: + - One user belongs to two organizations. + - The membership in org1 was created before the metric window. + - The membership in org2 was created within the metric window. + - The user has a RegisteredUser only for org1. + + Expected behavior: + - org1 does not count the user in ``user_signups`` because the + membership is outside the current window. + - org2 counts the user as ``unspecified`` in ``user_signups`` because + the membership is within the current window and no RegisteredUser + exists for org2. + - ``tot_user_signups`` still counts org1 with its registration method. + - org2 must not inherit org1's registration method. + """ + from ..tasks import write_user_registration_metrics + + cache.clear() + create_general_metrics(None, None) + org1 = self._get_org() + org2 = self._create_org(name="org2-window-scope", slug="org2-window-scope") + old_time = timezone.now() - timezone.timedelta(hours=2) + user = self._create_user(date_joined=old_time) + org1_membership = self._create_org_user( + user=user, + organization=org1, + ) + OrganizationUser.objects.filter(pk=org1_membership.pk).update(created=old_time) + self._create_registered_user( + user=user, + organization=org1, + method="mobile_phone", + ) + self._create_org_user( + user=user, + organization=org2, + ) + + write_user_registration_metrics.delay() + + org1_user_signups = self._get_metric_traces("user_signups", org1.pk) + org2_user_signups = self._get_metric_traces("user_signups", org2.pk) + org1_total_signups = self._get_metric_traces("tot_user_signups", org1.pk) + org2_total_signups = self._get_metric_traces("tot_user_signups", org2.pk) + self.assertEqual(org1_user_signups.get("mobile_phone", 0), 0) + self.assertEqual(org2_user_signups.get("unspecified", 0), 1) + self.assertEqual(org1_total_signups.get("mobile_phone", 0), 1) + self.assertEqual(org2_total_signups.get("unspecified", 0), 1) diff --git a/openwisp_radius/integrations/monitoring/utils.py b/openwisp_radius/integrations/monitoring/utils.py index 6fdb8eee..2528f479 100644 --- a/openwisp_radius/integrations/monitoring/utils.py +++ b/openwisp_radius/integrations/monitoring/utils.py @@ -51,4 +51,6 @@ def sha1_hash(input_string): def clean_registration_method(method): if method == "": method = "unspecified" + elif method == "pending_verification": + return None return method diff --git a/openwisp_radius/management/commands/base/delete_unverified_users.py b/openwisp_radius/management/commands/base/delete_unverified_users.py index ebefc038..eceb2ce7 100644 --- a/openwisp_radius/management/commands/base/delete_unverified_users.py +++ b/openwisp_radius/management/commands/base/delete_unverified_users.py @@ -2,6 +2,7 @@ from django.contrib.auth import get_user_model from django.core.management import BaseCommand +from django.db.models import Count, Q from django.utils.timezone import now from openwisp_radius.utils import load_model @@ -33,14 +34,23 @@ def handle(self, *args, **options): if exclude_methods: exclude_methods = exclude_methods.split(",") - qs = User.objects.filter( - date_joined__lt=days, - registered_user__isnull=False, - registered_user__is_verified=False, - is_staff=False, + qs = ( + User.objects.filter( + date_joined__lt=days, + registered_users__isnull=False, + is_staff=False, + ) + .annotate( + num_verified=Count( + "registered_users", + filter=Q(registered_users__is_verified=True), + ) + ) + .filter(num_verified=0) + .distinct() ) if exclude_methods: - qs = qs.exclude(registered_user__method__in=exclude_methods) + qs = qs.exclude(registered_users__method__in=exclude_methods) for user in qs.iterator(): if not RadiusAccounting.objects.filter(username=user.username).exists(): diff --git a/openwisp_radius/migrations/0045_registereduser_add_uuid.py b/openwisp_radius/migrations/0045_registereduser_add_uuid.py new file mode 100644 index 00000000..fb780d41 --- /dev/null +++ b/openwisp_radius/migrations/0045_registereduser_add_uuid.py @@ -0,0 +1,190 @@ +import uuid + +import django.db.models.deletion +import django.utils.timezone +import model_utils.fields +import swapper +from django.conf import settings +from django.db import migrations, models + +from . import copy_registered_users_ctcr_forward, copy_registered_users_ctcr_reverse + + +def copy_registered_users_forward(apps, schema_editor): + copy_registered_users_ctcr_forward(apps, schema_editor, app_label="openwisp_radius") + + +def copy_registered_users_reverse(apps, schema_editor): + copy_registered_users_ctcr_reverse(apps, schema_editor, app_label="openwisp_radius") + + +class Migration(migrations.Migration): + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + swapper.dependency("openwisp_users", "Organization"), + ("openwisp_radius", "0044_convert_user_credentials_data"), + ] + + operations = [ + migrations.AddField( + model_name="phonetoken", + name="organization", + field=models.ForeignKey( + blank=True, + help_text="Organization associated with this phone token.", + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="phone_tokens", + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + migrations.SeparateDatabaseAndState( + state_operations=[ + migrations.AddField( + model_name="registereduser", + name="id", + field=models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + migrations.AlterField( + model_name="registereduser", + name="user", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="registered_users", + to=settings.AUTH_USER_MODEL, + ), + ), + migrations.AddField( + model_name="registereduser", + name="organization", + field=models.ForeignKey( + blank=True, + help_text=( + "Organization associated with this registered user entry." + ), + null=True, + related_name="registered_users", + on_delete=django.db.models.deletion.CASCADE, + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + migrations.AddConstraint( + model_name="registereduser", + constraint=models.UniqueConstraint( + fields=["user", "organization"], + name="unique_registered_user_per_org", + violation_error_message=( + "A user cannot have more than one registration record " + "in the same organization." + ), + ), + ), + ], + database_operations=[ + migrations.CreateModel( + name="RegisteredUserNew", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "method", + models.CharField( + blank=True, + default="", + help_text=( + "users can sign up in different ways, some " + "methods are valid as indirect identity " + "verification (eg: mobile phone SIM card in " + "most countries)" + ), + max_length=64, + verbose_name="registration method", + ), + ), + ( + "is_verified", + models.BooleanField( + default=False, + help_text=( + "whether the user has completed any identity " + "verification process sucessfully" + ), + verbose_name="verified", + ), + ), + ( + "modified", + model_utils.fields.AutoLastModifiedField( + default=django.utils.timezone.now, + editable=False, + verbose_name="Last verification change", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "organization", + models.ForeignKey( + blank=True, + help_text=( + "Organization associated with this registered user" + " entry." + ), + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=swapper.get_model_name( + "openwisp_users", "Organization" + ), + verbose_name="organization", + ), + ), + ], + options={ + "verbose_name": "Registration Information", + "verbose_name_plural": "Registration Information", + "constraints": [ + models.UniqueConstraint( + fields=["user", "organization"], + name="unique_registered_user_per_org", + violation_error_message=( + "A user cannot have more than one " + "registration record in the same " + "organization." + ), + ) + ], + }, + ), + migrations.RunPython( + copy_registered_users_forward, + copy_registered_users_reverse, + ), + migrations.DeleteModel(name="RegisteredUser"), + migrations.RenameModel( + old_name="RegisteredUserNew", + new_name="RegisteredUser", + ), + ], + ), + ] diff --git a/openwisp_radius/migrations/0046_registered_user_multitenant_data.py b/openwisp_radius/migrations/0046_registered_user_multitenant_data.py new file mode 100644 index 00000000..7c453bd3 --- /dev/null +++ b/openwisp_radius/migrations/0046_registered_user_multitenant_data.py @@ -0,0 +1,40 @@ +import swapper +from django.conf import settings +from django.db import migrations + +from . import ( + migrate_registered_users_multitenant_forward, + migrate_registered_users_multitenant_reverse, + populate_phonetoken_organization, +) + + +def migrate_registered_users_forward(apps, schema_editor): + migrate_registered_users_multitenant_forward( + apps, schema_editor, app_label="openwisp_radius" + ) + + +def migrate_registered_users_reverse(apps, schema_editor): + migrate_registered_users_multitenant_reverse( + apps, schema_editor, app_label="openwisp_radius" + ) + + +class Migration(migrations.Migration): + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + swapper.dependency("openwisp_users", "Organization"), + ("openwisp_radius", "0045_registereduser_add_uuid"), + ] + + operations = [ + migrations.RunPython( + populate_phonetoken_organization, + migrations.RunPython.noop, + ), + migrations.RunPython( + migrate_registered_users_forward, + migrate_registered_users_reverse, + ), + ] diff --git a/openwisp_radius/migrations/0047_registered_user_multitenant_constraints.py b/openwisp_radius/migrations/0047_registered_user_multitenant_constraints.py new file mode 100644 index 00000000..b1090cde --- /dev/null +++ b/openwisp_radius/migrations/0047_registered_user_multitenant_constraints.py @@ -0,0 +1,54 @@ +import swapper +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + swapper.dependency("openwisp_users", "Organization"), + ("openwisp_radius", "0046_registered_user_multitenant_data"), + ] + + operations = [ + migrations.RemoveIndex( + model_name="phonetoken", + name="openwisp_ra_user_id_9fe207_idx", + ), + migrations.RemoveIndex( + model_name="phonetoken", + name="openwisp_ra_user_id_d4dd52_idx", + ), + migrations.AlterField( + model_name="phonetoken", + name="organization", + field=models.ForeignKey( + on_delete=models.deletion.CASCADE, + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + migrations.AddIndex( + model_name="phonetoken", + index=models.Index( + fields=["user", "created"], + name="openwisp_ra_user_id_9fe207_idx", + ), + ), + migrations.AddIndex( + model_name="phonetoken", + index=models.Index( + fields=["user", "created", "ip"], + name="openwisp_ra_user_id_d4dd52_idx", + ), + ), + migrations.AlterField( + model_name="registereduser", + name="organization", + field=models.ForeignKey( + on_delete=models.deletion.CASCADE, + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + ] diff --git a/openwisp_radius/migrations/__init__.py b/openwisp_radius/migrations/__init__.py index 45c9abf2..08810ba9 100644 --- a/openwisp_radius/migrations/__init__.py +++ b/openwisp_radius/migrations/__init__.py @@ -4,9 +4,12 @@ from django.conf import settings from django.contrib.auth.management import create_permissions from django.contrib.auth.models import Permission +from django.db.models import Case, IntegerField, Prefetch, Value, When from ..utils import create_default_groups +BATCH_SIZE = 1000 + def get_swapped_model(apps, app_name, model_name): model_path = swapper.get_model_name(app_name, model_name) @@ -14,6 +17,285 @@ def get_swapped_model(apps, app_name, model_name): return apps.get_model(app, model) +def _batched_iterator(iterator, batch_size=BATCH_SIZE): + batch = [] + for item in iterator: + batch.append(item) + if len(batch) >= batch_size: + yield batch + batch = [] + if batch: + yield batch + + +def _flush_bulk_create(model, objects, batch_size=BATCH_SIZE): + if objects: + model.objects.bulk_create(objects, batch_size=batch_size) + objects.clear() + + +def _registered_user_extra_kwargs(registered_user, extra_fields=()): + return { + field_name: getattr(registered_user, field_name) for field_name in extra_fields + } + + +def _registered_user_method_priority_case(): + # Strong methods (anything that is not '' or 'email') must rank above the + # weak fallbacks so rollback restores the strongest verification state. + return Case( + When(method="pending_verification", then=Value(-1)), + When(method="", then=Value(0)), + When(method="email", then=Value(1)), + default=Value(2), + output_field=IntegerField(), + ) + + +def copy_registered_users_ctcr_forward( + apps, + schema_editor, + app_label, + new_model_name="RegisteredUserNew", + extra_fields=(), +): + RegisteredUser = apps.get_model(app_label, "RegisteredUser") + RegisteredUserNew = apps.get_model(app_label, new_model_name) + if RegisteredUser._meta.swapped: + return + + new_objects = [] + queryset = RegisteredUser.objects.order_by("user_id") + for registered_user in queryset.iterator(chunk_size=BATCH_SIZE): + copied = RegisteredUserNew( + id=uuid.uuid4(), + user_id=registered_user.user_id, + organization=None, + method=registered_user.method, + is_verified=registered_user.is_verified, + **_registered_user_extra_kwargs(registered_user, extra_fields), + ) + copied.modified = registered_user.modified + new_objects.append(copied) + if len(new_objects) >= BATCH_SIZE: + _flush_bulk_create(RegisteredUserNew, new_objects) + _flush_bulk_create(RegisteredUserNew, new_objects) + + +def copy_registered_users_ctcr_reverse( + apps, + schema_editor, + app_label, + new_model_name="RegisteredUserNew", + extra_fields=(), +): + RegisteredUser = apps.get_model(app_label, "RegisteredUser") + RegisteredUserNew = apps.get_model(app_label, new_model_name) + if RegisteredUser._meta.swapped: + return + + restored_objects = [] + previous_user_id = None + # Annotate each row with an explicit verification priority so that stronger + # methods (anything that is not '' or 'email') sort before weaker ones. + # Lexical ordering of 'method' would place '' first, picking the weakest. + method_priority = _registered_user_method_priority_case() + queryset = RegisteredUserNew.objects.annotate( + method_priority=method_priority + ).order_by("user_id", "-is_verified", "-method_priority", "-modified") + for registered_user in queryset.iterator(chunk_size=BATCH_SIZE): + if registered_user.user_id == previous_user_id: + continue + previous_user_id = registered_user.user_id + restored = RegisteredUser( + user_id=registered_user.user_id, + method=registered_user.method, + is_verified=registered_user.is_verified, + **_registered_user_extra_kwargs(registered_user, extra_fields), + ) + restored.modified = registered_user.modified + restored_objects.append(restored) + if len(restored_objects) >= BATCH_SIZE: + _flush_bulk_create(RegisteredUser, restored_objects) + _flush_bulk_create(RegisteredUser, restored_objects) + + +def migrate_registered_users_multitenant_forward( + apps, schema_editor, app_label, extra_fields=() +): + """ + Expand legacy org-less RegisteredUser rows into organization-specific rows. + + Before this migration, RegisteredUser is effectively single-tenant and users + are expected to have at most one row where organization IS NULL. That row is + treated as the template for all organization-specific rows created during the + migration. + + For each user, the migration: + 1. Finds the org-less RegisteredUser row. + 2. Creates one RegisteredUser per OrganizationUser membership. + 3. Deletes the original org-less row. + + Implementation notes: + - Assumes each user has at most one org-less RegisteredUser row. + - Prioritizes readability and explicit control flow over aggressive SQL/JOIN + optimization. + - Avoids JOIN-based filtering to keep migration assumptions visible in Python + and reduce duplicate-row/DISTINCT complexity. + - Uses iterator(), prefetch_related(), bulk_create(), and batched deletes to + remain memory bounded while processing large datasets. + """ + User = apps.get_model(settings.AUTH_USER_MODEL) + RegisteredUser = apps.get_model( + app_label, + "RegisteredUser", + ) + OrganizationUser = get_swapped_model( + apps, + "openwisp_users", + "OrganizationUser", + ) + + queryset = User.objects.prefetch_related( + Prefetch( + "registered_users", + queryset=RegisteredUser.objects.only( + "id", + "user_id", + "organization_id", + "method", + "is_verified", + "modified", + *extra_fields, + ), + # Store prefetched objects directly as a Python list to avoid + # additional queryset evaluation during iteration. + to_attr="prefetched_registered_users", + ), + Prefetch( + "openwisp_users_organizationuser", + queryset=OrganizationUser.objects.only( + "user_id", + "organization_id", + ), + to_attr="organization_memberships", + ), + ).order_by("id") + + to_create = [] + for user in queryset.iterator(chunk_size=BATCH_SIZE): + # Locate the legacy org-less RegisteredUser row that acts as the source + # template for new organization-specific rows. + # + # We intentionally do this in Python instead of SQL because: + # + # - the prefetched list is expected to be extremely small + # (ideally, it will contain at most one item due to the migration invariant) + # - it keeps migration assumptions explicit, + # - and avoids introducing JOIN + DISTINCT complexity. + base_registered_user = next( + ( + registered_user + for registered_user in user.prefetched_registered_users + if registered_user.organization_id is None + ), + None, + ) + # Users without a legacy org-less RegisteredUser row require no work. + if not base_registered_user: + continue + + # Create one RegisteredUser row per organization membership. + for membership in user.organization_memberships: + copied = RegisteredUser( + id=uuid.uuid4(), + user_id=user.id, + organization_id=membership.organization_id, + method=base_registered_user.method, + is_verified=base_registered_user.is_verified, + **_registered_user_extra_kwargs( + base_registered_user, + extra_fields, + ), + ) + # Preserve the original modification timestamp because this migration + # reshapes existing data rather than creating a logically new + # verification state. + copied.modified = base_registered_user.modified + to_create.append(copied) + + # Flush inserts in batches to avoid holding too many unsaved model + # instances in memory. + if len(to_create) >= BATCH_SIZE: + _flush_bulk_create( + RegisteredUser, + to_create, + ) + + _flush_bulk_create( + RegisteredUser, + to_create, + ) + + # Delete all remaining legacy org-less RegisteredUser rows. + # + # This covers: + # 1. Users whose org-less row was expanded into org-specific rows above. + # 2. Users with an org-less row but zero organization memberships. + # These users have no org-specific rows to migrate to, and keeping + # an org-less row would violate the new (user, organization) unique + # constraint, so the row is intentionally cleaned up here. + RegisteredUser.objects.filter( + organization__isnull=True, + ).delete() + + +def migrate_registered_users_multitenant_reverse( + apps, schema_editor, app_label, extra_fields=() +): + # Keep the strongest RegisteredUser per user and delete the weaker duplicates. + # Ranking is by: verified over unverified, stronger method over weaker method, + # then newer modified timestamps over older ones. + RegisteredUser = apps.get_model(app_label, "RegisteredUser") + # Process users in batches so the migration scales to large tables without + # issuing one query per user. + user_ids_qs = ( + RegisteredUser.objects.order_by().values_list("user_id", flat=True).distinct() + ) + for user_id_batch in _batched_iterator( + user_ids_qs.iterator(chunk_size=BATCH_SIZE), BATCH_SIZE + ): + # Annotate each row with an explicit verification priority so that stronger + # methods (anything that is not '' or 'email') sort before weaker ones. + method_priority = _registered_user_method_priority_case() + ranked_registered_users = ( + RegisteredUser.objects.filter( + user_id__in=user_id_batch, + ) + .annotate(method_priority=method_priority) + .order_by("user_id", "-is_verified", "-method_priority", "-modified") + ) + to_delete_pks = [] + current_user_id = None + for registered_user in ranked_registered_users.iterator(chunk_size=BATCH_SIZE): + # Rows for the same user are consecutive because of the ordering + # above, and the first row in each group is the strongest one. + # Every later row for that user is therefore a weaker duplicate. + is_duplicate_for_user = registered_user.user_id == current_user_id + if is_duplicate_for_user: + to_delete_pks.append(registered_user.pk) + else: + current_user_id = registered_user.user_id + if len(to_delete_pks) >= BATCH_SIZE: + RegisteredUser.objects.filter(pk__in=to_delete_pks).delete() + to_delete_pks.clear() + + # Delete all weaker rows for the batch at once rather than issuing a + # separate delete for each user. + if to_delete_pks: + RegisteredUser.objects.filter(pk__in=to_delete_pks).delete() + + def delete_old_radius_token(apps, schema_editor): RadiusToken = get_swapped_model(apps, "openwisp_radius", "RadiusToken") RadiusToken.objects.all().delete() @@ -145,3 +427,69 @@ def populate_phonetoken_phone_number(apps, schema_editor): for phone_token in PhoneToken.objects.all(): phone_token.phone_number = phone_token.user.phone_number phone_token.save(update_fields=["phone_number"]) + + +def _get_first_membership_organization_id( + user_id, + OrganizationUser, +): + return ( + OrganizationUser.objects.filter( + user_id=user_id, + ) + .order_by("created", "pk") + .values_list("organization_id", flat=True) + .first() + ) + + +def populate_phonetoken_organization( + apps, + schema_editor, + app_label="openwisp_radius", +): + """Populate PhoneToken.organization_id from the user's first organization. + + For each user that has PhoneToken rows with a null organization_id, + find the user's first OrganizationUser membership (ordered by created, pk) + and set that organization_id on all their PhoneToken records that are + still null. + + Any rows that cannot be resolved to an organization are + discarded before the later NOT NULL migration step. + + Operates using the provided apps registry (for migrations). + + Args: + apps: Django apps registry passed to migrations functions. + schema_editor: Schema editor passed to migrations functions (unused). + app_label: App label to load the PhoneToken model from. + """ + PhoneToken = apps.get_model(app_label, "PhoneToken") + OrganizationUser = get_swapped_model( + apps, + "openwisp_users", + "OrganizationUser", + ) + user_ids = ( + PhoneToken.objects.filter( + organization_id__isnull=True, + ) + .order_by() + .values_list("user_id", flat=True) + .distinct() + ) + for user_id in user_ids.iterator(chunk_size=BATCH_SIZE): + organization_id = _get_first_membership_organization_id( + user_id, + OrganizationUser, + ) + if organization_id is None: + continue + PhoneToken.objects.filter( + user_id=user_id, + organization_id__isnull=True, + ).update( + organization_id=organization_id, + ) + PhoneToken.objects.filter(organization_id__isnull=True).delete() diff --git a/openwisp_radius/registration.py b/openwisp_radius/registration.py index e376232d..4e75b333 100644 --- a/openwisp_radius/registration.py +++ b/openwisp_radius/registration.py @@ -10,6 +10,7 @@ ("manual", _("Manually created")), ("email", _("Email")), ("mobile_phone", _("Mobile phone")), + ("pending_verification", _("Pending Verification")), ] AUTHORIZE_UNVERIFIED = [] @@ -58,3 +59,31 @@ def unregister_registration_method(name, fail_loud=True): def get_registration_choices(): return REGISTRATION_METHOD_CHOICES + + +def validate_user_settable_registration_methods(methods): + if not isinstance(methods, (list, tuple)): + raise ImproperlyConfigured( + "OPENWISP_RADIUS_USER_SETTABLE_REGISTRATION_METHODS must be a list or tuple" + ) + methods = list(methods) + duplicates = [] + seen = set() + for method in methods: + if method in seen and method not in duplicates: + duplicates.append(method) + seen.add(method) + if duplicates: + raise ImproperlyConfigured( + "OPENWISP_RADIUS_USER_SETTABLE_REGISTRATION_METHODS contains duplicate " + f"values: {', '.join(repr(method) for method in duplicates)}" + ) + available_choices = dict(get_registration_choices()) + invalid_methods = [method for method in methods if method not in available_choices] + if invalid_methods: + raise ImproperlyConfigured( + "OPENWISP_RADIUS_USER_SETTABLE_REGISTRATION_METHODS contains unknown " + f"values: {', '.join(repr(method) for method in invalid_methods)}" + ) + + return [(method, available_choices[method]) for method in methods] diff --git a/openwisp_radius/saml/backends.py b/openwisp_radius/saml/backends.py index f61d5d55..3c55a657 100644 --- a/openwisp_radius/saml/backends.py +++ b/openwisp_radius/saml/backends.py @@ -1,4 +1,3 @@ -from django.core.exceptions import ObjectDoesNotExist from djangosaml2.backends import Saml2Backend from .. import settings as app_settings @@ -12,20 +11,27 @@ def _update_user(self, user, attributes, attribute_mapping, force_save=False): ): # Skip updating user's username if the user didn't signed up # with SAML registration method. - try: - attribute_mapping = attribute_mapping.copy() - if user.registered_user.method != "saml": - for key, value in attribute_mapping.items(): - if "username" in value: - break - if len(value) == 1: - attribute_mapping.pop(key, None) - else: - attribute_mapping[key] = [] - for attr in value: - if attr != "username": - attribute_mapping[key].append(attr) - - except ObjectDoesNotExist: - pass + attribute_mapping = attribute_mapping.copy() + # Check if any of the user's registered_users records + # were NOT created via SAML. + # NOTE: This uses a global check (any org) rather than org-specific. + # This is intentionally conservative: if a user has ever signed up + # via a non-SAML method in any org, their username won't be updated + # during SAML login in any org. This prevents the SAML identity + # provider from overwriting a username set or preferred by the user + # elsewhere. Since the User model is shared across organizations, + # updating the username based solely on one org's SAML flow could + # unexpectedly change the user's identity in other orgs. + has_non_saml = user.registered_users.exclude(method="saml").exists() + if has_non_saml: + for key, value in attribute_mapping.items(): + if "username" in value: + break + if len(value) == 1: + attribute_mapping.pop(key, None) + else: + attribute_mapping[key] = [] + for attr in value: + if attr != "username": + attribute_mapping[key].append(attr) return super()._update_user(user, attributes, attribute_mapping, force_save) diff --git a/openwisp_radius/saml/views.py b/openwisp_radius/saml/views.py index 95bf5a25..74e6beb4 100644 --- a/openwisp_radius/saml/views.py +++ b/openwisp_radius/saml/views.py @@ -9,6 +9,7 @@ from django.contrib.auth import get_user_model, logout from django.contrib.auth.mixins import LoginRequiredMixin from django.core.exceptions import ObjectDoesNotExist, PermissionDenied, ValidationError +from django.db import transaction from django.shortcuts import get_object_or_404, redirect, render from django.urls import reverse from django.views.generic import UpdateView @@ -67,30 +68,64 @@ def post_login_hook(self, request, user, session_info): org = self.get_organization_from_relay_state() is_member = user.is_member(org) # add user to organization - if not is_member: - orgUser = OrganizationUser(organization=org, user=user) - orgUser.full_clean() - orgUser.save() - try: - user.registered_user - except ObjectDoesNotExist: - registered_user = RegisteredUser( - user=user, method="saml", is_verified=app_settings.SAML_IS_VERIFIED + with transaction.atomic(): + if not is_member: + orgUser = OrganizationUser(organization=org, user=user) + orgUser.full_clean() + orgUser.save() + registered_user, created = RegisteredUser.get_or_create_for_user_and_org( + user=user, + organization=org, + defaults={ + "method": "saml", + "is_verified": app_settings.SAML_IS_VERIFIED, + }, ) - registered_user.full_clean() - registered_user.save() - # The user is just created, it will not have an email address + if ( + not created + and registered_user.method == "pending_verification" + and not registered_user.is_verified + ): + registered_user.method = "saml" + registered_user.is_verified = app_settings.SAML_IS_VERIFIED + registered_user.full_clean() + registered_user.save() if user.email: try: - email_address = EmailAddress( - user=user, email=user.email, primary=True, verified=True + user_has_primary_email = EmailAddress.objects.filter( + user=user, primary=True ) - email_address.full_clean() - email_address.save() + try: + email_address = EmailAddress.objects.get( + user=user, email=user.email + ) + except EmailAddress.DoesNotExist: + email_address = EmailAddress( + user=user, + email=user.email, + verified=True, + primary=not user_has_primary_email.exists(), + ) + email_address.full_clean() + email_address.save() + else: + changed_fields = [] + if not email_address.verified: + email_address.verified = True + changed_fields.append("verified") + if ( + not email_address.primary + and not user_has_primary_email.exists() + ): + email_address.primary = True + changed_fields.append("primary") + if changed_fields: + email_address.full_clean() + email_address.save(update_fields=changed_fields) except ValidationError: logger.exception( - f'Failed email validation for "{user}"' - " during SAML user creation" + f'Failed email validation for "{user}" during' + " SAML user creation" ) def customize_relay_state(self, relay_state): diff --git a/openwisp_radius/settings.py b/openwisp_radius/settings.py index e7f908dd..6d67f7e9 100644 --- a/openwisp_radius/settings.py +++ b/openwisp_radius/settings.py @@ -95,6 +95,9 @@ def get_default_password_reset_url(urls): ALLOW_FIXED_LINE_OR_MOBILE = get_settings_value("ALLOW_FIXED_LINE_OR_MOBILE", False) REGISTRATION_API_ENABLED = get_settings_value("REGISTRATION_API_ENABLED", True) NEEDS_IDENTITY_VERIFICATION = get_settings_value("NEEDS_IDENTITY_VERIFICATION", False) +USER_SETTABLE_REGISTRATION_METHODS = get_settings_value( + "USER_SETTABLE_REGISTRATION_METHODS", ["", "email", "mobile_phone"] +) SMS_MESSAGE_TEMPLATE = get_settings_value( "SMS_MESSAGE_TEMPLATE", _("{organization} verification code: {code}") ) @@ -232,10 +235,13 @@ def get_default_password_reset_url(urls): if not hasattr(settings, "OPENWISP_USERS_EXPORT_USERS_COMMAND_CONFIG"): from openwisp_users import settings as ow_users_settings - ow_users_settings.EXPORT_USERS_COMMAND_CONFIG["fields"].extend( - ["registered_user.method", "registered_user.is_verified"] + ow_users_settings.EXPORT_USERS_COMMAND_CONFIG["fields"].append( + { + "name": "registered_users", + "fields": ("organization_id", "method", "is_verified"), + } ) - ow_users_settings.EXPORT_USERS_COMMAND_CONFIG["select_related"].extend( - ["registered_user"] + ow_users_settings.EXPORT_USERS_COMMAND_CONFIG["prefetch_related"].extend( + ["registered_users"] ) BATCH_ASYNC_THRESHOLD = get_settings_value("BATCH_ASYNC_THRESHOLD", 15) diff --git a/openwisp_radius/social/views.py b/openwisp_radius/social/views.py index cc50a3f8..ba9f84ed 100644 --- a/openwisp_radius/social/views.py +++ b/openwisp_radius/social/views.py @@ -1,5 +1,6 @@ import swapper -from django.core.exceptions import ObjectDoesNotExist, PermissionDenied +from django.core.exceptions import PermissionDenied +from django.db import transaction from django.http import HttpResponse, HttpResponseRedirect from django.shortcuts import get_object_or_404 from django.utils.translation import gettext_lazy as _ @@ -42,18 +43,22 @@ def authorize(self, request, org, *args, **kwargs): user = request.user is_member = user.is_member(org) # add user to organization - if not is_member: - orgUser = OrganizationUser(organization=org, user=user) - orgUser.full_clean() - orgUser.save() - try: - user.registered_user - except ObjectDoesNotExist: - registered_user = RegisteredUser( - user=user, method="social_login", is_verified=False + with transaction.atomic(): + if not is_member: + orgUser = OrganizationUser(organization=org, user=user) + orgUser.full_clean() + orgUser.save() + registered_user, created = RegisteredUser.get_or_create_for_user_and_org( + user=user, + organization=org, + defaults={"method": "social_login", "is_verified": False}, ) - registered_user.full_clean() - registered_user.save() + if not created: + if registered_user.method == "pending_verification": + registered_user.method = "social_login" + registered_user.is_verified = False + registered_user.full_clean() + registered_user.save() def get_redirect_url(self, request, organization): """ diff --git a/openwisp_radius/tests/mixins.py b/openwisp_radius/tests/mixins.py index 1852116d..01e39c19 100644 --- a/openwisp_radius/tests/mixins.py +++ b/openwisp_radius/tests/mixins.py @@ -97,10 +97,10 @@ def _get_user_edit_form_inline_params(self, user, organization): "phonetoken_set-MIN_NUM_FORMS": 0, "phonetoken_set-MAX_NUM_FORMS": 0, # registered user inline - "registered_user-TOTAL_FORMS": 0, - "registered_user-INITIAL_FORMS": 0, - "registered_user-MIN_NUM_FORMS": 0, - "registered_user-MAX_NUM_FORMS": 0, + "registered_users-TOTAL_FORMS": 0, + "registered_users-INITIAL_FORMS": 0, + "registered_users-MIN_NUM_FORMS": 0, + "registered_users-MAX_NUM_FORMS": 0, # radius token inline "radius_token-TOTAL_FORMS": "0", "radius_token-INITIAL_FORMS": "0", diff --git a/openwisp_radius/tests/test_admin.py b/openwisp_radius/tests/test_admin.py index bc829810..4bc8d1da 100644 --- a/openwisp_radius/tests/test_admin.py +++ b/openwisp_radius/tests/test_admin.py @@ -2,10 +2,12 @@ import lxml.html as lxml_html import swapper +from django.contrib import admin from django.contrib.auth import get_user_model from django.contrib.auth.models import Permission from django.core.cache import cache from django.core.exceptions import ImproperlyConfigured +from django.test import RequestFactory from django.urls import reverse from django.utils.translation import gettext_lazy as _ @@ -671,16 +673,19 @@ def test_backward_compatible_default_password_reset_url(self): f"admin:{self.app_label_users}_organization_add", ) PASSWORD_RESET_URLS = {"default": default_password_reset_url} - with mock.patch.object( - app_settings, - "DEFAULT_PASSWORD_RESET_URL", - app_settings.get_default_password_reset_url(PASSWORD_RESET_URLS), - ), mock.patch.object( - # The default value is set on project startup, hence - # it also requires mocking. - OrganizationRadiusSettings._meta.get_field("password_reset_url"), - "fallback", - app_settings.DEFAULT_PASSWORD_RESET_URL, + with ( + mock.patch.object( + app_settings, + "DEFAULT_PASSWORD_RESET_URL", + app_settings.get_default_password_reset_url(PASSWORD_RESET_URLS), + ), + mock.patch.object( + # The default value is set on project startup, hence + # it also requires mocking. + OrganizationRadiusSettings._meta.get_field("password_reset_url"), + "fallback", + app_settings.DEFAULT_PASSWORD_RESET_URL, + ), ): response = self.client.get(url) self.assertContains(response, default_password_reset_url) @@ -1359,7 +1364,7 @@ def test_inline_registered_user(self): with self.subTest("Inline exists"): response = self.client.get(url) - self.assertContains(response, "id_registered_user-TOTAL_FORMS") + self.assertContains(response, "id_registered_users-TOTAL_FORMS") with self.subTest("Register new choice"): register_registration_method("national_id", "National ID") @@ -1407,6 +1412,66 @@ def test_inline_registered_user(self): register_registration_method("github", "GitHub", strong_identity=False) self.assertIn("github", RegisteredUser._weak_verification_methods) + def test_admin_prevents_duplicate_registered_user_same_org(self): + user = self._create_user(username="dup_test_user", email="dup@test.org") + reg_user = RegisteredUser.objects.create( + user=user, organization=self.default_org, is_verified=True + ) + user_change_url = reverse( + f"admin:{User._meta.app_label}_user_change", args=[user.pk] + ) + response = self.client.get(user_change_url) + self.assertEqual(response.status_code, 200) + data = { + "username": "dup_test_user", + "email": "dup@test.org", + "registered_users-TOTAL_FORMS": "2", + "registered_users-INITIAL_FORMS": "1", + "registered_users-MIN_NUM_FORMS": "0", + "registered_users-MAX_NUM_FORMS": "1000", + "registered_users-0-id": str(reg_user.pk), + "registered_users-0-user": str(user.pk), + "registered_users-0-organization": str(self.default_org.pk), + "registered_users-0-method": "", + "registered_users-0-is_verified": "on", + "registered_users-1-id": "", + "registered_users-1-user": str(user.pk), + "registered_users-1-organization": str(self.default_org.pk), + "registered_users-1-method": "", + "registered_users-1-is_verified": "on", + } + response = self.client.post(user_change_url, data) + self.assertContains(response, "errors") + self.assertContains( + response, + "A user cannot have more than one registration record in the" + " same organization.", + ) + self.assertEqual( + RegisteredUser.objects.filter( + user=user, organization=self.default_org + ).count(), + 1, + ) + + def test_user_admin_shows_multiple_registered_user_records(self): + user = self._create_user(username="multiuser", email="multi@test.org") + org2 = self._create_org(name="org2", slug="org2") + RegisteredUser.objects.create( + user=user, organization=self.default_org, is_verified=True + ) + RegisteredUser.objects.create(user=user, organization=org2, is_verified=False) + user_url = reverse(f"admin:{User._meta.app_label}_user_change", args=[user.pk]) + response = self.client.get(user_url) + self.assertEqual(response.status_code, 200) + self.assertContains( + response, + ( + '' + ), + ) + def test_get_is_verified_user_admin_list(self): unknown = User.objects.first() self.assertIsNotNone(unknown) @@ -1416,7 +1481,10 @@ def test_get_is_verified_user_admin_list(self): verified.full_clean() verified.save() RegisteredUser.objects.create( - user=verified, method="mobile_phone", is_verified=True + user=verified, + organization=self.default_org, + method="mobile_phone", + is_verified=True, ) unverified = User.objects.create( username="unverified", password="unverified", email="unverified@test.com" @@ -1424,7 +1492,10 @@ def test_get_is_verified_user_admin_list(self): unverified.full_clean() unverified.save() RegisteredUser.objects.create( - user=unverified, method="mobile_phone", is_verified=False + user=unverified, + organization=self.default_org, + method="mobile_phone", + is_verified=False, ) app_label = User._meta.app_label url = reverse(f"admin:{app_label}_user_changelist") @@ -1440,6 +1511,22 @@ def get_expected_html(value): self.assertContains(response, get_expected_html("no")) self.assertContains(response, get_expected_html("unknown")) + def test_get_is_verified_user_admin_list_avoids_nplus1_queries(self): + app_label = User._meta.app_label + path = reverse(f"admin:{app_label}_user_changelist") + # Create users + for i in range(5): + user = self._create_user(username=f"user-{i}", email=f"user-{i}@test.com") + RegisteredUser.objects.create( + user=user, + organization=self.default_org, + method="mobile_phone", + is_verified=(i % 2 == 0), + ) + with self.assertNumQueries(8): + response = self.client.get(path) + self.assertEqual(response.status_code, 200) + def test_registered_user_filter(self): unknown = User.objects.first() self.assertIsNotNone(unknown) @@ -1449,7 +1536,10 @@ def test_registered_user_filter(self): verified.full_clean() verified.save() RegisteredUser.objects.create( - user=verified, method="mobile_phone", is_verified=True + user=verified, + organization=self.default_org, + method="mobile_phone", + is_verified=True, ) unverified = User.objects.create( username="unverified", password="unverified", email="unverified@test.com" @@ -1457,7 +1547,10 @@ def test_registered_user_filter(self): unverified.full_clean() unverified.save() RegisteredUser.objects.create( - user=unverified, method="mobile_phone", is_verified=False + user=unverified, + organization=self.default_org, + method="mobile_phone", + is_verified=False, ) app_label = User._meta.app_label url = reverse(f"admin:{app_label}_user_changelist") @@ -1486,6 +1579,130 @@ def get_expected_html(value): self.assertNotContains(response, get_expected_html("no")) self.assertContains(response, get_expected_html("unknown")) + def test_get_is_verified_scoped_to_managed_organizations(self): + org1 = self._create_org(name="org-1", slug="org-1") + org2 = self._create_org(name="org-2", slug="org-2") + manager = self._create_administrator([org1]) + scoped_user = self._create_user( + username="scoped-user", + email="scoped-user@test.com", + ) + other_org_user = self._create_user( + username="other-org-user", + email="other-org-user@test.com", + ) + self._create_org_user(user=scoped_user, organization=org1) + self._create_org_user(user=other_org_user, organization=org1) + RegisteredUser.objects.create( + user=scoped_user, + organization=org1, + method="mobile_phone", + is_verified=False, + ) + RegisteredUser.objects.create( + user=scoped_user, + organization=org2, + method="mobile_phone", + is_verified=True, + ) + RegisteredUser.objects.create( + user=other_org_user, + organization=org2, + method="mobile_phone", + is_verified=True, + ) + request = RequestFactory().get( + reverse(f"admin:{User._meta.app_label}_user_changelist") + ) + request.user = manager + user_admin = admin.site._registry[User] + queryset = user_admin.get_queryset(request) + scoped_user = queryset.get(pk=scoped_user.pk) + other_org_user = queryset.get(pk=other_org_user.pk) + # The scoped user should show as unverified since the user admin + # should only consider the registration record from the managed + # organization (org1), while the other org user should show as + # unknown since their registration record in the managed + # organization is missing + self.assertIn("icon-no.svg", user_admin.get_is_verified(scoped_user)) + self.assertIn("icon-unknown.svg", user_admin.get_is_verified(other_org_user)) + + def test_registered_user_filter_scoped_to_managed_organizations(self): + org1 = self._create_org(name="org-1", slug="org-1") + org2 = self._create_org(name="org-2", slug="org-2") + manager = self._create_administrator([org1]) + org1_verified = self._create_user( + username="org1-verified", + email="org1-verified@test.com", + ) + common_user_unverified = self._create_user( + username="common-user-unverified", + email="common-user-unverified@test.com", + ) + org2_registered = self._create_user( + username="org2-only", + email="org2-only@test.com", + ) + self._create_org_user(user=org1_verified, organization=org1) + self._create_org_user(user=common_user_unverified, organization=org1) + self._create_org_user(user=org2_registered, organization=org1) + RegisteredUser.objects.create( + user=org1_verified, + organization=org1, + method="mobile_phone", + is_verified=True, + ) + RegisteredUser.objects.create( + user=common_user_unverified, + organization=org1, + method="mobile_phone", + is_verified=False, + ) + RegisteredUser.objects.create( + user=common_user_unverified, + organization=org2, + method="mobile_phone", + is_verified=True, + ) + RegisteredUser.objects.create( + user=org2_registered, + organization=org2, + method="mobile_phone", + is_verified=True, + ) + self.client.force_login(manager) + app_label = User._meta.app_label + url = reverse(f"admin:{app_label}_user_changelist") + + response = self.client.get(url, {"is_verified": "true"}) + self.assertContains(response, org1_verified.username) + self.assertNotContains(response, common_user_unverified.username) + self.assertNotContains(response, org2_registered.username) + + response = self.client.get(url, {"is_verified": "false"}) + self.assertContains(response, common_user_unverified.username) + self.assertNotContains(response, org1_verified.username) + self.assertNotContains(response, org2_registered.username) + + response = self.client.get(url, {"is_verified": "unknown"}) + self.assertContains(response, org2_registered.username) + self.assertNotContains(response, org1_verified.username) + self.assertNotContains(response, common_user_unverified.username) + + def test_registered_user_filter_does_not_limit_default_changelist(self): + org = self._create_org(name="org-filter-default", slug="org-filter-default") + manager = self._create_administrator([org]) + user = self._create_user( + username="no-registered-user", + email="no-registered-user@test.com", + ) + self._create_org_user(user=user, organization=org) + self.client.force_login(manager) + app_label = User._meta.app_label + url = reverse(f"admin:{app_label}_user_changelist") + response = self.client.get(url) + self.assertContains(response, user.username) + def test_admin_menu_groups(self): # Test menu group (openwisp-utils menu group) for RadiusAccounting, RadiusBatch, # RadiusCheck, RadiusGroup, Nas, RadiusPostAuth, RadiusToken, and RadiusReply diff --git a/openwisp_radius/tests/test_api/test_api.py b/openwisp_radius/tests/test_api/test_api.py index d0a6f3d5..81810e4d 100644 --- a/openwisp_radius/tests/test_api/test_api.py +++ b/openwisp_radius/tests/test_api/test_api.py @@ -28,6 +28,8 @@ from openwisp_radius.api.serializers import ( RadiusUserGroupSerializer, RadiusUserSerializer, + RegisterSerializer, + UpdateRegisteredUserMethodSerializer, UserGroupCheckSerializer, ) from openwisp_utils.tests import capture_any_output, capture_stderr @@ -41,6 +43,7 @@ RadiusBatch = load_model("RadiusBatch") RadiusUserGroup = load_model("RadiusUserGroup") RadiusGroup = load_model("RadiusGroup") +RegisteredUser = load_model("RegisteredUser") OrganizationRadiusSettings = load_model("OrganizationRadiusSettings") Organization = swapper.load_model("openwisp_users", "Organization") OrganizationUser = swapper.load_model("openwisp_users", "OrganizationUser") @@ -60,10 +63,34 @@ def _radius_batch_post_request(self, data, username="admin", password="tester"): login_payload = {"username": username, "password": password} login_url = reverse("radius:user_auth_token", args=[self.default_org.slug]) login_response = self.client.post(login_url, data=login_payload) - header = f'Bearer {login_response.json()["key"]}' + header = f"Bearer {login_response.json()['key']}" url = reverse("radius:batch") return self.client.post(url, data, HTTP_AUTHORIZATION=header) + def _get_update_method_url(self, org=None): + if org is None: + org = self.default_org + return reverse( + "radius:update_registered_user_registration_method", args=[org.slug] + ) + + def _create_pending_verification_user(self, username_suffix=""): + user = self._create_user( + username=f"pendinguser{username_suffix}", + password="tester", + email=f"pendinguser{username_suffix}@test.com", + ) + org2 = self._create_org(name="org2") + OrganizationUser.objects.create(user=user, organization=org2) + RegisteredUser.objects.create( + user=user, + organization=org2, + method="pending_verification", + is_verified=False, + ) + user_token = Token.objects.create(user=user) + return user, org2, user_token + def test_batch_bad_request_400(self): self.assertEqual(RadiusBatch.objects.count(), 0) data = self._radius_batch_prefix_data(number_of_users=-1) @@ -159,7 +186,10 @@ def test_register_201(self): user = User.objects.get(email=self._test_email) self.assertTrue(user.is_member(self.default_org)) self.assertTrue(user.is_active) - self.assertFalse(user.registered_user.is_verified) + self.assertEqual( + user.registered_users.get(organization=self.default_org).is_verified, + False, + ) def test_register_400_password(self): response = self._register_user( @@ -319,19 +349,27 @@ def test_register_duplicate_different_org(self): def test_radius_user_serializer(self): self._register_user() try: - user = User.objects.select_related("radius_token", "registered_user").get( - email=self._test_email + user = ( + User.objects.select_related("radius_token") + .prefetch_related("registered_users") + .get(email=self._test_email) ) - admin = User.objects.select_related("radius_token", "registered_user").get( - username="admin" + admin = ( + User.objects.select_related("radius_token") + .prefetch_related("registered_users") + .get(username="admin") ) except User.DoesNotExist as e: self.fail(f"user not found: {e}") with self.assertNumQueries(0): - data = RadiusUserSerializer(user).data + # Organization is required to get the RegisteredUser object + view = mock.MagicMock() + view.organization = self.default_org + data = RadiusUserSerializer(user, context={"view": view}).data with self.subTest("test full data"): + registered_user = user.registered_users.get(organization=self.default_org) self.assertEqual( data, { @@ -343,9 +381,9 @@ def test_radius_user_serializer(self): "birth_date": user.birth_date, "location": user.location, "is_active": user.is_active, - "is_verified": user.registered_user.is_verified, "password_expired": user.has_password_expired(), - "method": user.registered_user.method, + "is_verified": registered_user.is_verified, + "method": registered_user.method, "radius_user_token": user.radius_token.key, }, ) @@ -370,6 +408,44 @@ def test_radius_user_serializer(self): }, ) + with self.subTest("org-specific record is returned for the current org"): + user2 = self._create_user(username="user2", email="user2@test.com") + self._create_org_user(user=user2, organization=self.default_org) + RegisteredUser.objects.create( + user=user2, + organization=self.default_org, + is_verified=True, + method="mobile_phone", + ) + url = reverse("radius:user_auth_token", args=[self.default_org.slug]) + r = self.client.post(url, {"username": "user2", "password": "tester"}) + self.assertEqual(r.status_code, 200) + self.assertEqual(r.data["is_verified"], True) + self.assertEqual(r.data["method"], "mobile_phone") + + with self.subTest("other-organization record is not used as fallback"): + user3 = self._create_user(username="user3", email="user3@test.com") + self._create_org_user(user=user3, organization=self.default_org) + org2 = self._create_org(name="serializer-org2", slug="serializer-org2") + self._create_org_user(user=user3, organization=org2) + RegisteredUser.objects.create( + user=user3, organization=org2, is_verified=True, method="email" + ) + url = reverse("radius:user_auth_token", args=[self.default_org.slug]) + r = self.client.post(url, {"username": "user3", "password": "tester"}) + self.assertEqual(r.status_code, 200) + self.assertIsNone(r.data["is_verified"]) + self.assertIsNone(r.data["method"]) + + with self.subTest("returns None when no RegisteredUser records exist"): + user4 = self._create_user(username="user4", email="user4@test.com") + self._create_org_user(user=user4, organization=self.default_org) + url = reverse("radius:user_auth_token", args=[self.default_org.slug]) + r = self.client.post(url, {"username": "user4", "password": "tester"}) + self.assertEqual(r.status_code, 200) + self.assertIsNone(r.data["is_verified"]) + self.assertIsNone(r.data["method"]) + # The fallback value is set on project startup, hence it also requires mocking. @mock.patch.object( OrganizationRadiusSettings._meta.get_field("first_name"), @@ -490,6 +566,54 @@ def test_register_verification_field(self): self.assertEqual(r.status_code, 201) self.assertEqual(User.objects.count(), 2) + def test_register_serializer_user_settable_methods(self): + url = reverse("radius:rest_register", args=[self.default_org.slug]) + for method in ["saml", "social_login"]: + with self.subTest(f"RegisterSerializer rejects {method}"): + response = self.client.post( + url, + { + "username": f"{method}@example.com", + "email": f"{method}@example.com", + "password1": "password", + "password2": "password", + "method": method, + }, + ) + self.assertEqual(response.status_code, 400) + self.assertIn( + '"{input}" is not a valid choice.'.format(input=method), + response.data["method"], + ) + + with self.subTest("custom configured method is accepted"): + with mock.patch.object( + app_settings, + "USER_SETTABLE_REGISTRATION_METHODS", + ["", "email", "manual"], + ): + serializer = RegisterSerializer(context={"view": mock.MagicMock()}) + self.assertEqual( + list(serializer.fields["method"].choices.keys()), + ["", "email", "manual"], + ) + response = self.client.post( + url, + { + "username": "manual@example.com", + "email": "manual@example.com", + "password1": "password", + "password2": "password", + "method": "manual", + }, + ) + self.assertEqual(response.status_code, 201) + registered_user = RegisteredUser.objects.get( + user__username="manual@example.com", + organization=self.default_org, + ) + self.assertEqual(registered_user.method, "manual") + @override_settings( ACCOUNT_EMAIL_VERIFICATION="mandatory", ACCOUNT_EMAIL_REQUIRED=True ) @@ -916,7 +1040,7 @@ def test_user_accounting_list_200(self): response = self.client.post( auth_url, {"username": "tester", "password": "tester"} ) - authorization = f'Bearer {response.data["key"]}' + authorization = f"Bearer {response.data['key']}" stop_time = "2018-03-02T11:43:24.020460+01:00" data1 = self.acct_post_data data1.update( @@ -1556,6 +1680,201 @@ def test_radius_user_group_serializer_without_view_context(self): self.assertEqual(serializer._user, None) self.assertEqual(serializer.fields["group"].queryset.count(), 0) + def test_update_registered_user_method_success(self): + user, org2, user_token = self._create_pending_verification_user( + username_suffix="_success" + ) + url = self._get_update_method_url(org2) + response = self.client.post( + url, + {"method": "mobile_phone"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["method"], "mobile_phone") + registered_user = RegisteredUser.objects.get(user=user, organization=org2) + self.assertEqual(registered_user.method, "mobile_phone") + self.assertEqual(registered_user.is_verified, False) + + def test_update_registered_user_method_with_valid_methods(self): + user, org2, user_token = self._create_pending_verification_user( + username_suffix="_valid" + ) + url = self._get_update_method_url(org2) + for method in ["", "email", "mobile_phone"]: + with self.subTest(method=method): + registered_user = RegisteredUser.objects.get( + user=user, organization=org2 + ) + registered_user.method = "pending_verification" + registered_user.save() + response = self.client.post( + url, + {"method": method}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["method"], method) + + @mock.patch.object( + app_settings, + "USER_SETTABLE_REGISTRATION_METHODS", + ["", "email", "mobile_phone"], + ) + def test_update_registered_user_method_user_settable_methods(self): + _, org2, user_token = self._create_pending_verification_user( + username_suffix="_choices" + ) + url = self._get_update_method_url(org2) + + with self.subTest("default field choices"): + serializer = UpdateRegisteredUserMethodSerializer() + self.assertEqual( + list(serializer.fields["method"].choices.keys()), + ["", "email", "mobile_phone"], + ) + + for method in ["saml", "social_login"]: + with self.subTest(f"UpdateRegisteredUserMethodSerializer rejects {method}"): + response = self.client.post( + url, + {"method": method}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 400) + self.assertIn( + '"{input}" is not a valid choice.'.format(input=method), + response.data["method"], + ) + + with self.subTest("custom configured method is accepted"): + with mock.patch.object( + app_settings, + "USER_SETTABLE_REGISTRATION_METHODS", + ["", "email", "manual"], + ): + serializer = UpdateRegisteredUserMethodSerializer() + self.assertEqual( + list(serializer.fields["method"].choices.keys()), + ["", "email", "manual"], + ) + response = self.client.post( + url, + {"method": "manual"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["method"], "manual") + + def test_update_registered_user_method_validation_errors(self): + user, org2, user_token = self._create_pending_verification_user() + url = self._get_update_method_url(org2) + with self.subTest("reject_pending_verification_as_input"): + response = self.client.post( + url, + {"method": "pending_verification"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 400) + + with self.subTest("reject_invalid_method"): + response = self.client.post( + url, + {"method": "invalid_method"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 400) + + with self.subTest("reject_non_pending_state"): + registered_user = RegisteredUser.objects.get(user=user, organization=org2) + registered_user.method = "mobile_phone" + registered_user.save() + response = self.client.post( + url, + {"method": "email"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 400) + self.assertIn("pending verification", response.data["method"][0]) + + def test_update_registered_user_method_404_cases(self): + with self.subTest("non member without registered user"): + user = self._create_user(username="noreguser", password="tester") + user_token = Token.objects.create(user=user) + url = self._get_update_method_url() + response = self.client.post( + url, + {"method": "mobile_phone"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 400) + self.assertIn("non_field_errors", response.data) + self.assertIn("is not member", str(response.data["non_field_errors"])) + + with self.subTest("non member cannot update other users record"): + user, org2, user_token = self._create_pending_verification_user( + username_suffix="_owner" + ) + other_user = self._create_user( + username="otheruser", password="tester", email="otheruser@test.com" + ) + other_user_token = Token.objects.create(user=other_user) + url = self._get_update_method_url(org2) + response = self.client.post( + url, + {"method": "mobile_phone"}, + HTTP_AUTHORIZATION=f"Bearer {other_user_token.key}", + ) + self.assertEqual(response.status_code, 400) + self.assertIn("non_field_errors", response.data) + self.assertIn("is not member", str(response.data["non_field_errors"])) + + with self.subTest("invalid_org"): + user, _, user_token = self._create_pending_verification_user( + username_suffix="_invalid_org" + ) + url = reverse( + "radius:update_registered_user_registration_method", + args=["nonexistent-org-slug"], + ) + response = self.client.post( + url, + {"method": "mobile_phone"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 404) + + def test_update_registered_user_method_rejects_non_member_with_registered_user( + self, + ): + user = self._create_user( + username="nonmember-update", + password="tester", + email="nonmember-update@test.com", + ) + org = self._create_org(name="org-update", slug="org-update") + RegisteredUser.objects.create( + user=user, + organization=org, + method="pending_verification", + is_verified=False, + ) + user_token = Token.objects.create(user=user) + url = self._get_update_method_url(org) + response = self.client.post( + url, + {"method": "mobile_phone"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 400) + self.assertIn("non_field_errors", response.data) + self.assertIn("is not member", str(response.data["non_field_errors"])) + + def test_update_registered_user_method_requires_authentication(self): + url = self._get_update_method_url() + response = self.client.post(url, {"method": "mobile_phone"}) + self.assertEqual(response.status_code, 401) + class TestTransactionApi(AcctMixin, ApiTokenMixin, BaseTransactionTestCase): def test_user_radius_usage_view(self): @@ -1565,7 +1884,7 @@ def test_user_radius_usage_view(self): response = self.client.post( auth_url, {"username": "tester", "password": "tester"} ) - authorization = f'Bearer {response.data["key"]}' + authorization = f"Bearer {response.data['key']}" self.assertEqual(response.status_code, 200) with self.subTest("Test user has not used any data"): response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization) diff --git a/openwisp_radius/tests/test_api/test_freeradius_api.py b/openwisp_radius/tests/test_api/test_freeradius_api.py index 64968edd..55ea5e49 100644 --- a/openwisp_radius/tests/test_api/test_freeradius_api.py +++ b/openwisp_radius/tests/test_api/test_freeradius_api.py @@ -172,7 +172,7 @@ def test_authorize_fail_auth_details_incomplete(self): f"?uuid={str(self.default_org.pk)}", ]: with self.subTest(querystring): - post_url = f'{reverse("radius:authorize")}{querystring}' + post_url = f"{reverse('radius:authorize')}{querystring}" response = self.client.post( post_url, {"username": "tester", "password": "tester"} ) @@ -206,6 +206,134 @@ def test_authorize_unverified_user(self): self.assertEqual(response.status_code, 200) self.assertIsNone(response.data) + def test_authorize_verified_user(self): + org_user = self._get_org_user() + user = org_user.user + org_settings = OrganizationRadiusSettings.objects.get( + organization=self._get_org() + ) + org_settings.needs_identity_verification = True + org_settings.save() + + with self.subTest("org-specific verified record passes authorization"): + RegisteredUser.objects.create( + user=user, organization=self._get_org(), is_verified=True + ) + response = self._authorize_user(auth_header=self.auth_header) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, {"control:Auth-Type": "Accept"}) + + with self.subTest("other-organization record does not pass authorization"): + RegisteredUser.objects.filter(user=user).delete() + org2 = self._create_org(name="verified-org-2", slug="verified-org-2") + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create( + user=user, organization=org2, is_verified=True + ) + response = self._authorize_user(auth_header=self.auth_header) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, None) + + def test_multi_org_user_different_verification_states(self): + org1 = self._get_org() + org_settings = OrganizationRadiusSettings.objects.get(organization=org1) + org_settings.needs_identity_verification = True + org_settings.save() + org2 = self._create_org(name="org2", slug="org2") + org2_settings = OrganizationRadiusSettings.objects.get_or_create( + organization=org2 + )[0] + org2_settings.needs_identity_verification = True + org2_settings.full_clean() + org2_settings.save() + user = self._get_user_with_org() + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create(user=user, organization=org1, is_verified=True) + auth_header_org1 = f"Bearer {org1.pk} {org1.radius_settings.token}" + response = self._authorize_user( + username=user.username, auth_header=auth_header_org1 + ) + self.assertEqual(response.data["control:Auth-Type"], "Accept") + + auth_header_org2 = f"Bearer {org2.pk} {org2.radius_settings.token}" + response = self._authorize_user( + username=user.username, auth_header=auth_header_org2 + ) + self.assertIsNone(response.data) + + def test_other_org_record_is_not_used_as_fallback(self): + org1 = self._get_org() + org2 = self._create_org(name="org2", slug="org2") + org2_settings = OrganizationRadiusSettings.objects.get_or_create( + organization=org2 + )[0] + org2_settings.needs_identity_verification = True + org2_settings.full_clean() + org2_settings.save() + user = self._get_user_with_org() + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create(user=user, organization=org2, is_verified=True) + org_settings = OrganizationRadiusSettings.objects.get(organization=org1) + org_settings.needs_identity_verification = True + org_settings.save() + + auth_header_org1 = f"Bearer {org1.pk} {org1.radius_settings.token}" + response = self._authorize_user( + username=user.username, auth_header=auth_header_org1 + ) + self.assertEqual(response.data, None) + + auth_header_org2 = f"Bearer {org2.pk} {org2.radius_settings.token}" + response = self._authorize_user( + username=user.username, auth_header=auth_header_org2 + ) + self.assertEqual(response.data["control:Auth-Type"], "Accept") + + def test_other_org_verified_with_org_unverified(self): + """ + A user with a verified record in another org should not be + authorized for an org where they have an org-specific unverified record. + """ + org = self._get_org() + org_settings = OrganizationRadiusSettings.objects.get(organization=org) + org_settings.needs_identity_verification = True + org_settings.save() + user = self._get_user_with_org() + org2 = self._create_org(name="org2-priority", slug="org2-priority") + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create(user=user, organization=org, is_verified=False) + RegisteredUser.objects.create(user=user, organization=org2, is_verified=True) + auth_header = f"Bearer {org.pk} {org.radius_settings.token}" + response = self._authorize_user(username=user.username, auth_header=auth_header) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, None) + + @mock.patch.object(registration, "AUTHORIZE_UNVERIFIED", ["mobile_phone"]) + def test_other_org_special_method_with_org_unverified_not_authorized(self): + """ + When AUTHORIZE_UNVERIFIED is set, the org-specific + record still takes precedence. A user with org-specific unverified record + using a non-special method should NOT be authorized even if they have a + verified record in another organization with a special method. + """ + org = self._get_org() + org_settings = OrganizationRadiusSettings.objects.get(organization=org) + org_settings.needs_identity_verification = True + org_settings.save() + user = self._get_user_with_org() + org2 = self._create_org(name="org2-special", slug="org2-special") + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create( + user=user, organization=org, method="email", is_verified=False + ) + RegisteredUser.objects.create( + user=user, organization=org2, method="mobile_phone", is_verified=True + ) + auth_header = f"Bearer {org.pk} {org.radius_settings.token}" + response = self._authorize_user(username=user.username, auth_header=auth_header) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, None) + def test_authorize_radius_token_unverified_user(self): user = self._get_org_user() org_settings = OrganizationRadiusSettings.objects.get( @@ -258,7 +386,7 @@ def test_postauth_radius_token_accept_201(self): def test_postauth_accept_201_querystring(self): self.assertEqual(RadiusPostAuth.objects.all().count(), 0) params = self._get_postauth_params() - post_url = f'{reverse("radius:postauth")}{self.token_querystring}' + post_url = f"{reverse('radius:postauth')}{self.token_querystring}" response = self.client.post(post_url, params) params["password"] = "" self.assertEqual(RadiusPostAuth.objects.filter(**params).count(), 1) @@ -1227,7 +1355,7 @@ def test_accounting_when_nas_using_pfsense_started(self): self.assertIsNone(response.data) def test_get_authorize_view(self): - url = f'{reverse("radius:authorize")}{self.token_querystring}' + url = f"{reverse('radius:authorize')}{self.token_querystring}" r = self.client.get(url, HTTP_ACCEPT="text/html") self.assertEqual(r.status_code, 405) expected = f'
email > empty. + """ + user = self._create_user(username="method-priority-user") + org1 = self._create_org(name="method-org-1", slug="method-org-1") + org2 = self._create_org(name="method-org-2", slug="method-org-2") + org3 = self._create_org(name="method-org-3", slug="method-org-3") + modified_base = timezone.now() + # All unverified, same timestamp - method should decide + with freeze_time(modified_base): + RegisteredUser.objects.create( + user=user, + organization=org1, + is_verified=False, + method="", + ) + RegisteredUser.objects.create( + user=user, + organization=org2, + is_verified=False, + method="email", + ) + RegisteredUser.objects.create( + user=user, + organization=org3, + is_verified=False, + method="mobile_phone", + ) + # Rollback: mobile_phone should win (highest method priority) + migrate_registered_users_multitenant_reverse( + apps, None, app_label=self._get_app_label() + ) + surviving_record = RegisteredUser.objects.get(user=user) + self.assertEqual(surviving_record.organization, org3) + self.assertEqual(surviving_record.method, "mobile_phone") + self.assertEqual(RegisteredUser.objects.filter(user=user).count(), 1) + + def test_multitenant_reverse_pending_verification_method_ignored( + self, + ): + user = self._create_user( + username="pending-vs-strong", + email="pending-vs-strong@example.com", + ) + org1 = self._create_org( + name="pending-org-1", + slug="pending-org-1", + ) + org2 = self._create_org( + name="pending-org-2", + slug="pending-org-2", + ) + modified_base = timezone.now() + with freeze_time(modified_base): + RegisteredUser.objects.create( + user=user, + organization=org1, + is_verified=False, + method="pending_verification", + ) + strong_record = RegisteredUser.objects.create( + user=user, + organization=org2, + is_verified=False, + method="mobile_phone", + ) + migrate_registered_users_multitenant_reverse( + apps, None, app_label=self._get_app_label() + ) + surviving_record = RegisteredUser.objects.get(user=user) + self.assertEqual(surviving_record.pk, strong_record.pk) + self.assertEqual(surviving_record.method, "mobile_phone") + self.assertEqual(RegisteredUser.objects.filter(user=user).count(), 1) + + def test_multitenant_reverse_full_cleanup(self): + """ + Test that duplicate org-scoped records are reduced to one per user. + """ + user1 = self._create_user( + username="cleanup-user-1", email="cleanup1@example.com" + ) + user2 = self._create_user( + username="cleanup-user-2", email="cleanup2@example.com" + ) + org1 = self._create_org(name="cleanup-org-1", slug="cleanup-org-1") + org2 = self._create_org(name="cleanup-org-2", slug="cleanup-org-2") + # Create multiple org-scoped records for multiple users + for user, org in [(user1, org1), (user1, org2), (user2, org1)]: + RegisteredUser.objects.create( + user=user, + organization=org, + is_verified=False, + method="email", + ) + self.assertEqual( + RegisteredUser.objects.filter(user=user1).count(), + 2, + ) + migrate_registered_users_multitenant_reverse( + apps, None, app_label=self._get_app_label() + ) + self.assertEqual( + RegisteredUser.objects.filter(user=user1).count(), + 1, + ) + self.assertEqual( + RegisteredUser.objects.filter(user=user2).count(), + 1, + ) + + +class TestPhoneTokenOrganizationPopulateResolution(BaseTestCase): + def _set_org_user_created(self, org_user, created): + OrganizationUser.objects.filter(pk=org_user.pk).update(created=created) + org_user.refresh_from_db(fields=["created"]) + return org_user + + def test_get_first_membership_returns_earliest_membership(self): + user = self._create_user(username="phone-membership-user") + org1 = self.default_org + org2 = self._create_org( + name="phone-membership-org2", slug="phone-membership-org2" + ) + org1_user = OrganizationUser.objects.create(user=user, organization=org1) + org2_user = OrganizationUser.objects.create(user=user, organization=org2) + base_time = timezone.now() + self._set_org_user_created(org1_user, base_time + timedelta(days=1)) + self._set_org_user_created(org2_user, base_time) + organization_id = _get_first_membership_organization_id( + user.pk, + OrganizationUser, + ) + self.assertEqual(organization_id, org2.pk) + + def test_get_first_membership_returns_none_without_membership(self): + user = self._create_user(username="phone-unresolved-user") + organization_id = _get_first_membership_organization_id( + user.pk, + OrganizationUser, + ) + self.assertEqual(organization_id, None) class TestMigrationRadiusBatchJsonField(TestOrganizationMixin, TestCase): - app_label = "openwisp_radius" migration_path = "openwisp_radius.migrations.0044_convert_user_credentials_data" - radius_batch_model = load_model("RadiusBatch") + + def _get_app_label(self): + return RadiusBatch._meta.app_label def _get_convert_user_credentials_data(self): migration_module = importlib.import_module(self.migration_path) return migration_module.convert_user_credentials_data def _get_model(self, app_label, model_name): - self.assertEqual(app_label, self.app_label) + self.assertEqual(app_label, self._get_app_label()) self.assertEqual(model_name, "RadiusBatch") - return self.radius_batch_model + return RadiusBatch def _get_apps(self): apps = MagicMock() @@ -40,28 +425,29 @@ def _convert_user_credentials_data(self): def test_convert_user_credentials_data(self): org = self._get_org() - batch = self.radius_batch_model.objects.create( + batch = RadiusBatch.objects.create( name="test_batch_migration", strategy="prefix", prefix="test", organization=org, ) - self.radius_batch_model.objects.filter(pk=batch.pk).update( + RadiusBatch.objects.filter(pk=batch.pk).update( user_credentials=json.dumps({"user1": "pass1"}) ) self._convert_user_credentials_data() batch.refresh_from_db() self.assertEqual(batch.user_credentials, {"user1": "pass1"}) + @capture_any_output() def test_convert_user_credentials_data_invalid_json(self): org = self._get_org() - batch = self.radius_batch_model.objects.create( + batch = RadiusBatch.objects.create( name="test_batch_invalid", strategy="prefix", prefix="test2", organization=org, ) - self.radius_batch_model.objects.filter(pk=batch.pk).update( + RadiusBatch.objects.filter(pk=batch.pk).update( user_credentials="invalid_json_string" ) self._convert_user_credentials_data() diff --git a/openwisp_radius/tests/test_models.py b/openwisp_radius/tests/test_models.py index c618a29e..3388b151 100644 --- a/openwisp_radius/tests/test_models.py +++ b/openwisp_radius/tests/test_models.py @@ -42,6 +42,7 @@ RadiusBatch = load_model("RadiusBatch") OrganizationRadiusSettings = load_model("OrganizationRadiusSettings") Organization = swapper.load_model("openwisp_users", "Organization") +RegisteredUser = load_model("RegisteredUser") class TestNas(BaseTestCase): @@ -1218,5 +1219,62 @@ def test_sessions_with_multiple_orgs(self, mocked_radclient): self.assertEqual(org2_session.groupname, f"{org2.slug}-users") +class TestRegisteredUser(BaseTestCase): + def test_get_for_user_and_org(self): + user = self._create_user() + org1 = self._create_org(name="ru-test-org-1", slug="ru-test-org-1") + org2 = self._create_org(name="ru-test-org-2", slug="ru-test-org-2") + + with self.subTest("returns None when no records exist"): + result = RegisteredUser.get_for_user_and_org(user, org1) + self.assertEqual(result, None) + + with self.subTest("returns only the requested organization record"): + org2_ru = RegisteredUser.objects.create( + user=user, organization=org2, is_verified=True + ) + result = RegisteredUser.get_for_user_and_org(user, org1) + self.assertEqual(result, None) + result = RegisteredUser.get_for_user_and_org(user, org2) + self.assertEqual(result, org2_ru) + self.assertEqual(result.is_verified, True) + + with self.subTest("uses prefetched registered_users without extra queries"): + org1_ru = RegisteredUser.objects.create( + user=user, + organization=org1, + is_verified=False, + ) + prefetched_user = ( + get_user_model() + .objects.prefetch_related("registered_users") + .get(pk=user.pk) + ) + with self.assertNumQueries(0): + result = RegisteredUser.get_for_user_and_org(prefetched_user, org1) + self.assertEqual(result, org1_ru) + + def test_clean_requires_unique_org_specific_registered_user(self): + user = self._create_user() + org = self._create_org(name="dup-test-org", slug="dup-test-org") + other_org = self._create_org(name="dup-test-org-2", slug="dup-test-org-2") + + with self.subTest("duplicate org-specific raises ValidationError"): + RegisteredUser.objects.create(user=user, organization=org) + duplicate = RegisteredUser(user=user, organization=org) + with self.assertRaises(ValidationError): + duplicate.full_clean() + + with self.subTest("different organizations are allowed"): + record = RegisteredUser(user=user, organization=other_org) + record.full_clean() + + def test_clean_requires_organization(self): + user = self._create_user() + + with self.assertRaises(ValidationError): + RegisteredUser(user=user).full_clean() + + del BaseTestCase del BaseTransactionTestCase diff --git a/openwisp_radius/tests/test_saml/test_views.py b/openwisp_radius/tests/test_saml/test_views.py index 0c662970..d35aae32 100644 --- a/openwisp_radius/tests/test_saml/test_views.py +++ b/openwisp_radius/tests/test_saml/test_views.py @@ -3,6 +3,7 @@ from urllib.parse import parse_qs, urlparse import swapper +from allauth.account.models import EmailAddress from django.conf import settings from django.contrib.auth import SESSION_KEY, get_user_model from django.core import mail @@ -12,6 +13,7 @@ from djangosaml2.utils import get_session_id_from_saml2, saml2_from_httpredirect_request from rest_framework.authtoken.models import Token +from openwisp_radius import settings as app_settings from openwisp_radius.saml.utils import get_url_or_path from openwisp_users.tests.utils import TestOrganizationMixin from openwisp_utils.tests import capture_any_output @@ -150,10 +152,39 @@ def test_relay_state_relative_path(self): query_params = parse_qs(urlparse(response.url).query) self._post_successful_auth_assertions(query_params, org_slug) + @capture_any_output() + def test_pending_verification_registered_user_updated_for_org(self): + org = Organization.objects.get(slug="default") + user = self._create_user(username="test-user", email="org_user@example.com") + registered_user = RegisteredUser.objects.create( + user=user, + organization=org, + method="pending_verification", + is_verified=False, + ) + relay_state = self._get_relay_state( + redirect_url="https://captive-portal.example.com", org_slug="default" + ) + saml_response, relay_state = self._get_saml_response_for_acs_view(relay_state) + response = self.client.post( + reverse("radius:saml2_acs"), + { + "SAMLResponse": self.b64_for_post(saml_response), + "RelayState": relay_state, + }, + ) + self.assertEqual(response.status_code, 302) + registered_users = RegisteredUser.objects.filter(user=user, organization=org) + self.assertEqual(registered_users.count(), 1) + registered_user.refresh_from_db() + self.assertEqual(registered_user.method, "saml") + self.assertEqual(registered_user.is_verified, app_settings.SAML_IS_VERIFIED) + @capture_any_output() def test_user_registered_with_non_saml_method(self): + org = Organization.objects.get(slug="default") user = self._create_user(username="test-user", email="org_user@example.com") - RegisteredUser.objects.create(user=user, method="manual") + RegisteredUser.objects.create(user=user, method="manual", organization=org) relay_state = self._get_relay_state( redirect_url="https://captive-portal.example.com", org_slug="default" ) @@ -194,6 +225,135 @@ def test_user_registered_with_non_saml_method(self): user.refresh_from_db() self.assertEqual(user.username, "org_user@example.com") + @capture_any_output() + def test_saml_login_marks_existing_email_verified(self): + org = Organization.objects.get(slug="default") + user = self._create_user(username="test-user", email="org_user@example.com") + user.emailaddress_set.all().delete() + email_address = EmailAddress.objects.create( + user=user, + email="org_user@example.com", + primary=True, + verified=False, + ) + registered_user = RegisteredUser.objects.create( + user=user, + organization=org, + method="pending_verification", + is_verified=False, + ) + relay_state = self._get_relay_state( + redirect_url="https://captive-portal.example.com", org_slug="default" + ) + saml_response, relay_state = self._get_saml_response_for_acs_view(relay_state) + response = self.client.post( + reverse("radius:saml2_acs"), + { + "SAMLResponse": self.b64_for_post(saml_response), + "RelayState": relay_state, + }, + ) + self.assertEqual(response.status_code, 302) + email_address.refresh_from_db() + registered_user.refresh_from_db() + self.assertTrue(email_address.verified) + self.assertTrue(email_address.primary) + self.assertEqual(EmailAddress.objects.filter(user=user).count(), 1) + self.assertEqual(registered_user.method, "saml") + self.assertEqual(registered_user.is_verified, app_settings.SAML_IS_VERIFIED) + self.assertEqual( + RegisteredUser.objects.filter(user=user, organization=org).count(), 1 + ) + + @capture_any_output() + def test_saml_login_existing_email_already_verified(self): + org = Organization.objects.get(slug="default") + user = self._create_user(username="test-user", email="org_user@example.com") + user.emailaddress_set.all().delete() + email_address = EmailAddress.objects.create( + user=user, + email="org_user@example.com", + primary=True, + verified=True, + ) + registered_user = RegisteredUser.objects.create( + user=user, + organization=org, + method="pending_verification", + is_verified=False, + ) + relay_state = self._get_relay_state( + redirect_url="https://captive-portal.example.com", org_slug="default" + ) + saml_response, relay_state = self._get_saml_response_for_acs_view(relay_state) + response = self.client.post( + reverse("radius:saml2_acs"), + { + "SAMLResponse": self.b64_for_post(saml_response), + "RelayState": relay_state, + }, + ) + self.assertEqual(response.status_code, 302) + email_address.refresh_from_db() + registered_user.refresh_from_db() + self.assertEqual(email_address.verified, True) + self.assertEqual(email_address.primary, True) + self.assertEqual(EmailAddress.objects.filter(user=user).count(), 1) + self.assertEqual(registered_user.method, "saml") + self.assertEqual(registered_user.is_verified, app_settings.SAML_IS_VERIFIED) + self.assertEqual( + RegisteredUser.objects.filter(user=user, organization=org).count(), 1 + ) + + @override_settings(SAML_DJANGO_USER_MAIN_ATTRIBUTE="username") + @capture_any_output() + def test_saml_login_preserves_existing_primary_email_different_uid(self): + org = Organization.objects.get(slug="default") + user = self._create_user( + username="saml-user@example.com", + email="existing-primary@example.com", + ) + user.emailaddress_set.all().delete() + existing_primary = EmailAddress.objects.create( + user=user, + email="existing-primary@example.com", + primary=True, + verified=True, + ) + registered_user = RegisteredUser.objects.create( + user=user, + organization=org, + method="pending_verification", + is_verified=False, + ) + relay_state = self._get_relay_state( + redirect_url="https://captive-portal.example.com", org_slug="default" + ) + saml_response, relay_state = self._get_saml_response_for_acs_view( + relay_state, uid="saml-user@example.com" + ) + response = self.client.post( + reverse("radius:saml2_acs"), + { + "SAMLResponse": self.b64_for_post(saml_response), + "RelayState": relay_state, + }, + ) + self.assertEqual(response.status_code, 302) + existing_primary.refresh_from_db() + self.assertEqual(existing_primary.primary, True) + self.assertEqual(existing_primary.verified, True) + new_email = EmailAddress.objects.get(user=user, email="saml-user@example.com") + self.assertEqual(new_email.primary, False) + self.assertEqual(new_email.verified, True) + self.assertEqual(EmailAddress.objects.filter(user=user).count(), 2) + registered_user.refresh_from_db() + self.assertEqual(registered_user.method, "saml") + self.assertEqual(registered_user.is_verified, app_settings.SAML_IS_VERIFIED) + self.assertEqual( + RegisteredUser.objects.filter(user=user, organization=org).count(), 1 + ) + @override_settings(SAML_ALLOWED_HOSTS=["captive-portal.example.com"]) class TestAdditionInfoView(TestSamlMixin, TestCase): @@ -294,12 +454,15 @@ def test_saml_login_disabled(self): org.radius_settings.save() redirect_url = "https://captive-portal.example.com" with self.subTest("SAML authentication is disabled site-wide"): - with patch( - "openwisp_radius.settings.SAML_REGISTRATION_ENABLED", False - ), patch.object( - OrganizationRadiusSettings._meta.get_field("saml_registration_enabled"), - "fallback", - False, + with ( + patch("openwisp_radius.settings.SAML_REGISTRATION_ENABLED", False), + patch.object( + OrganizationRadiusSettings._meta.get_field( + "saml_registration_enabled" + ), + "fallback", + False, + ), ): response = self.client.get( self.login_url, diff --git a/openwisp_radius/tests/test_selenium.py b/openwisp_radius/tests/test_selenium.py index 7b059345..8291f46d 100644 --- a/openwisp_radius/tests/test_selenium.py +++ b/openwisp_radius/tests/test_selenium.py @@ -21,6 +21,7 @@ @tag("selenium_tests") +@tag("no_parallel") class BasicTest( SeleniumTestMixin, FileMixin, StaticLiveServerTestCase, TestOrganizationMixin ): diff --git a/openwisp_radius/tests/test_social.py b/openwisp_radius/tests/test_social.py index 19ceafdb..da663faf 100644 --- a/openwisp_radius/tests/test_social.py +++ b/openwisp_radius/tests/test_social.py @@ -2,7 +2,6 @@ from allauth.socialaccount.models import SocialAccount from django.contrib.auth import get_user_model -from django.core.exceptions import ObjectDoesNotExist from django.urls import reverse from rest_framework.authtoken.models import Token from swapper import load_model @@ -14,6 +13,7 @@ from .mixins import ApiTokenMixin, BaseTestCase RadiusToken = load_model("openwisp_radius", "RadiusToken") +RegisteredUser = load_model("openwisp_radius", "RegisteredUser") OrganizationRadiusSettings = load_model("openwisp_radius", "OrganizationRadiusSettings") Organization = load_model("openwisp_users", "Organization") User = get_user_model() @@ -102,13 +102,33 @@ def test_redirect_cp_301(self): user = User.objects.filter(username="socialuser").first() self.assertTrue(user.is_member(self.default_org)) try: - reg_user = user.registered_user - except ObjectDoesNotExist: + reg_user = user.registered_users.get(organization=self.default_org) + except RegisteredUser.DoesNotExist: self.fail("RegisteredUser instance not found") self.assertEqual(reg_user.method, "social_login") # social login is not a legally valid identity verification method # so this should be always False when users sign up with this method - self.assertFalse(reg_user.is_verified) + self.assertEqual(reg_user.is_verified, False) + + def test_pending_verification_registered_user_updated_for_org(self): + user = self._create_social_user() + registered_user = RegisteredUser.objects.create( + user=user, + organization=self.default_org, + method="pending_verification", + is_verified=False, + ) + self.client.force_login(user) + url = self.get_url() + response = self.client.get(url, {"cp": "http://wifi.openwisp.org/cp"}) + self.assertEqual(response.status_code, 302) + registered_users = RegisteredUser.objects.filter( + user=user, organization=self.default_org + ) + self.assertEqual(registered_users.count(), 1) + registered_user.refresh_from_db() + self.assertEqual(registered_user.method, "social_login") + self.assertEqual(registered_user.is_verified, False) def test_authorize_using_radius_user_token_200(self): self.test_redirect_cp_301() diff --git a/openwisp_radius/tests/test_tasks.py b/openwisp_radius/tests/test_tasks.py index 8aadb051..230d2335 100644 --- a/openwisp_radius/tests/test_tasks.py +++ b/openwisp_radius/tests/test_tasks.py @@ -139,9 +139,7 @@ def test_delete_unverified_users(self): management.call_command("batch_add_users", **options) User.objects.update(date_joined=now() - timedelta(days=3)) for user in User.objects.all(): - user.registered_user.is_verified = False - user.registered_user.method = "email" - user.registered_user.save(update_fields=["is_verified", "method"]) + user.registered_users.update(is_verified=False, method="email") self.assertEqual(User.objects.count(), 3) tasks.delete_unverified_users.delay(older_than_days=2) self.assertEqual(User.objects.count(), 0) @@ -320,19 +318,35 @@ def test_unverify_inactive_users(self, *args): User.objects.exclude(id=active_user.id).update( last_login=today - timedelta(days=60) ) - RegisteredUser.objects.create(user=admin, is_verified=True) - RegisteredUser.objects.create(user=active_user, is_verified=True) RegisteredUser.objects.create( - user=unspecified_user, method="", is_verified=True + user=admin, organization=self.default_org, is_verified=True ) RegisteredUser.objects.create( - user=manually_registered_user, method="manual", is_verified=True + user=active_user, organization=self.default_org, is_verified=True ) RegisteredUser.objects.create( - user=email_registered_user, method="email", is_verified=True + user=unspecified_user, + organization=self.default_org, + method="", + is_verified=True, ) RegisteredUser.objects.create( - user=mobile_registered_user, method="mobile_phone", is_verified=True + user=manually_registered_user, + organization=self.default_org, + method="manual", + is_verified=True, + ) + RegisteredUser.objects.create( + user=email_registered_user, + organization=self.default_org, + method="email", + is_verified=True, + ) + RegisteredUser.objects.create( + user=mobile_registered_user, + organization=self.default_org, + method="mobile_phone", + is_verified=True, ) tasks.unverify_inactive_users.delay() @@ -342,12 +356,38 @@ def test_unverify_inactive_users(self, *args): manually_registered_user.refresh_from_db() email_registered_user.refresh_from_db() mobile_registered_user.refresh_from_db() - self.assertEqual(admin.registered_user.is_verified, True) - self.assertEqual(active_user.registered_user.is_verified, True) - self.assertEqual(unspecified_user.registered_user.is_verified, True) - self.assertEqual(manually_registered_user.registered_user.is_verified, True) - self.assertEqual(email_registered_user.registered_user.is_verified, True) - self.assertEqual(mobile_registered_user.registered_user.is_verified, False) + self.assertEqual( + admin.registered_users.get(organization=self.default_org).is_verified, + True, + ) + self.assertEqual( + active_user.registered_users.get(organization=self.default_org).is_verified, + True, + ) + self.assertEqual( + unspecified_user.registered_users.get( + organization=self.default_org + ).is_verified, + True, + ) + self.assertEqual( + manually_registered_user.registered_users.get( + organization=self.default_org + ).is_verified, + True, + ) + self.assertEqual( + email_registered_user.registered_users.get( + organization=self.default_org + ).is_verified, + True, + ) + self.assertEqual( + mobile_registered_user.registered_users.get( + organization=self.default_org + ).is_verified, + False, + ) @mock.patch.object(app_settings, "DELETE_INACTIVE_USERS", 30) def test_delete_inactive_users(self, *args): diff --git a/openwisp_radius/tests/test_token.py b/openwisp_radius/tests/test_token.py index 3a03115b..57366fb0 100644 --- a/openwisp_radius/tests/test_token.py +++ b/openwisp_radius/tests/test_token.py @@ -41,7 +41,12 @@ def setUp(self): radius_settings.save() def _create_token( - self, user=None, ip="127.0.0.1", phone_number="+393664351808", created=None + self, + user=None, + organization=None, + ip="127.0.0.1", + phone_number="+393664351808", + created=None, ): if not user: opts = { @@ -53,7 +58,14 @@ def _create_token( } user = self._create_user(**opts) self._create_org_user(**{"user": user}) - token = PhoneToken(user=user, ip=ip, phone_number=phone_number) + if organization is None: + organization = self.default_org + token = PhoneToken( + user=user, + organization=organization, + ip=ip, + phone_number=phone_number, + ) if created: token.created = created token.modified = created @@ -65,7 +77,10 @@ def _create_token( def test_is_already_verified(self): token = self._create_token() RegisteredUser.objects.create( - user=token.user, method="mobile_phone", is_verified=True + user=token.user, + organization=self.default_org, + method="mobile_phone", + is_verified=True, ) token.refresh_from_db() diff --git a/openwisp_radius/tests/test_users_integration.py b/openwisp_radius/tests/test_users_integration.py index 30b02f6d..2a0010ce 100644 --- a/openwisp_radius/tests/test_users_integration.py +++ b/openwisp_radius/tests/test_users_integration.py @@ -96,11 +96,23 @@ def test_radiustoken_inline(self): @capture_stdout() def test_export_users_command(self): temp_file = NamedTemporaryFile(delete=False) - user = self._create_org_user().user - RegisteredUser.objects.create( - user=user, method="mobile_phone", is_verified=False + org_user = self._create_org_user() + user = org_user.user + org2 = self._create_org(name="Test Organization 2") + self._create_org_user(organization=org2, user=user) + org1_reg_user = RegisteredUser.objects.create( + user=user, + organization=org_user.organization, + method="mobile_phone", + is_verified=False, ) - with self.assertNumQueries(2): + org2_reg_user = RegisteredUser.objects.create( + user=user, + organization=org2, + method="mobile_phone", + is_verified=True, + ) + with self.assertNumQueries(3): call_command("export_users", filename=temp_file.name) with open(temp_file.name, "r") as file: @@ -108,10 +120,19 @@ def test_export_users_command(self): csv_data = list(csv_reader) self.assertEqual(len(csv_data), 2) - self.assertIn("registered_user.method", csv_data[0]) - self.assertIn("registered_user.is_verified", csv_data[0]) - self.assertEqual(csv_data[1][-2], "mobile_phone") - self.assertEqual(csv_data[1][-1], "False") + self.assertIn( + "registered_users (organization_id, method, is_verified)", csv_data[0] + ) + self.assertEqual( + csv_data[1][-1], + ( + f"({org1_reg_user.organization_id},{org1_reg_user.method}," + f"{org1_reg_user.is_verified})" + "\n" + f"({org2_reg_user.organization_id},{org2_reg_user.method}," + f"{org2_reg_user.is_verified})" + ), + ) def test_radiususergroup_inline(self): """ diff --git a/runtests b/runtests index 60761e1d..188b58c4 100755 --- a/runtests +++ b/runtests @@ -3,7 +3,7 @@ set -e # Standard tests coverage run runtests.py --parallel \ - --exclude-tag=no_parallel >/dev/null 2>&1 \ + --exclude-tag=no_parallel 2>&1 \ || ./runtests.py --exclude-tag=no_parallel # Test extensibility diff --git a/tests/openwisp2/sample_radius/api/views.py b/tests/openwisp2/sample_radius/api/views.py index 6bb68b99..d5b468ef 100644 --- a/tests/openwisp2/sample_radius/api/views.py +++ b/tests/openwisp2/sample_radius/api/views.py @@ -24,6 +24,9 @@ RadiusUserGroupListCreateView, ) from openwisp_radius.api.views import RegisterView as BaseRegisterView +from openwisp_radius.api.views import ( + UpdateRegisteredUserMethodView as BaseUpdateRegisteredUserMethodView, +) from openwisp_radius.api.views import UserAccountingView as BaseUserAccountingView from openwisp_radius.api.views import UserRadiusUsageView as BaseUserRadiusUsageView from openwisp_radius.api.views import ValidateAuthTokenView as BaseValidateAuthTokenView @@ -104,6 +107,10 @@ class RadiusAccountingView(BaseRadiusAccountingView): pass +class UpdateRegisteredUserMethodView(BaseUpdateRegisteredUserMethodView): + pass + + authorize = AuthorizeView.as_view() postauth = PostAuthView.as_view() accounting = AccountingView.as_view() @@ -126,3 +133,4 @@ class RadiusAccountingView(BaseRadiusAccountingView): radius_group_detail = RadiusGroupDetailView.as_view() radius_user_group_list = RadiusUserGroupListCreateView.as_view() radius_user_group_detail = RadiusUserGroupDetailView.as_view() +update_registered_user_registration_method = UpdateRegisteredUserMethodView.as_view() diff --git a/tests/openwisp2/sample_radius/migrations/0033_registered_user_multitenant.py b/tests/openwisp2/sample_radius/migrations/0033_registered_user_multitenant.py new file mode 100644 index 00000000..ab916b26 --- /dev/null +++ b/tests/openwisp2/sample_radius/migrations/0033_registered_user_multitenant.py @@ -0,0 +1,276 @@ +import uuid + +import django +import django.db.models.deletion +import django.utils.timezone +import model_utils.fields +import swapper +from django.conf import settings +from django.db import migrations, models + +from openwisp_radius.migrations import ( + copy_registered_users_ctcr_forward, + copy_registered_users_ctcr_reverse, + migrate_registered_users_multitenant_forward, + migrate_registered_users_multitenant_reverse, + populate_phonetoken_organization, +) + + +def copy_registered_users_forward(apps, schema_editor): + copy_registered_users_ctcr_forward( + apps, + schema_editor, + app_label="sample_radius", + extra_fields=("details",), + ) + + +def copy_registered_users_reverse(apps, schema_editor): + copy_registered_users_ctcr_reverse( + apps, + schema_editor, + app_label="sample_radius", + extra_fields=("details",), + ) + + +def migrate_registered_users_forward(apps, schema_editor): + migrate_registered_users_multitenant_forward( + apps, + schema_editor, + app_label="sample_radius", + extra_fields=("details",), + ) + + +def migrate_registered_users_reverse(apps, schema_editor): + migrate_registered_users_multitenant_reverse( + apps, + schema_editor, + app_label="sample_radius", + extra_fields=("details",), + ) + + +def populate_sample_phonetoken_organization(apps, schema_editor): + populate_phonetoken_organization(apps, schema_editor, app_label="sample_radius") + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + swapper.dependency("openwisp_users", "Organization"), + ( + "sample_radius", + "0032_alter_organizationradiussettings_sms_meta_data_and_more", + ), + ] + + operations = [ + migrations.AddField( + model_name="phonetoken", + name="organization", + field=models.ForeignKey( + blank=True, + help_text="Organization associated with this phone token.", + null=True, + on_delete=models.deletion.CASCADE, + related_name="phone_tokens", + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + migrations.RunPython( + populate_sample_phonetoken_organization, + migrations.RunPython.noop, + ), + migrations.RemoveIndex( + model_name="phonetoken", + name="sample_radi_user_id_b748c7_idx", + ), + migrations.RemoveIndex( + model_name="phonetoken", + name="sample_radi_user_id_044fca_idx", + ), + migrations.AlterField( + model_name="phonetoken", + name="organization", + field=models.ForeignKey( + on_delete=models.deletion.CASCADE, + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + migrations.AddIndex( + model_name="phonetoken", + index=models.Index( + fields=["user", "created"], + name="sample_radi_user_id_b748c7_idx", + ), + ), + migrations.AddIndex( + model_name="phonetoken", + index=models.Index( + fields=["user", "created", "ip"], + name="sample_radi_user_id_044fca_idx", + ), + ), + migrations.SeparateDatabaseAndState( + database_operations=[ + migrations.CreateModel( + name="RegisteredUserNew", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "details", + models.CharField( + blank=True, + max_length=64, + null=True, + ), + ), + ( + "method", + models.CharField( + blank=True, + default="", + help_text=( + "users can sign up in different ways, some " + "methods are valid as indirect identity " + "verification (eg: mobile phone SIM card in " + "most countries)" + ), + max_length=64, + verbose_name="registration method", + ), + ), + ( + "is_verified", + models.BooleanField( + default=False, + help_text=( + "whether the user has completed any identity " + "verification process sucessfully" + ), + verbose_name="verified", + ), + ), + ( + "modified", + model_utils.fields.AutoLastModifiedField( + default=django.utils.timezone.now, + editable=False, + verbose_name="Last verification change", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "organization", + models.ForeignKey( + blank=True, + help_text=( + "The organization this registration info belongs" + " to." + ), + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=swapper.get_model_name( + "openwisp_users", "Organization" + ), + verbose_name="organization", + ), + ), + ], + options={ + "verbose_name": "Registration Information", + "verbose_name_plural": "Registration Information", + }, + ), + migrations.RunPython( + copy_registered_users_forward, + copy_registered_users_reverse, + ), + migrations.DeleteModel(name="RegisteredUser"), + migrations.RenameModel( + old_name="RegisteredUserNew", + new_name="RegisteredUser", + ), + ], + state_operations=[ + migrations.AddField( + model_name="registereduser", + name="id", + field=models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + migrations.AlterField( + model_name="registereduser", + name="user", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="registered_users", + to=settings.AUTH_USER_MODEL, + ), + ), + migrations.AddField( + model_name="registereduser", + name="organization", + field=models.ForeignKey( + help_text=( + "Organization associated with this registered user entry." + ), + related_name="registered_users", + on_delete=django.db.models.deletion.CASCADE, + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + ], + ), + migrations.RunPython( + migrate_registered_users_forward, + migrate_registered_users_reverse, + ), + migrations.AddConstraint( + model_name="registereduser", + constraint=models.UniqueConstraint( + fields=["user", "organization"], + name="unique_registered_user_per_org", + violation_error_message=( + "A user cannot have more than one registration" + " record in the same organization." + ), + ), + ), + migrations.AlterField( + model_name="registereduser", + name="organization", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + ] diff --git a/tests/openwisp2/sample_radius/tests.py b/tests/openwisp2/sample_radius/tests.py index 6878dcd0..7e509c87 100644 --- a/tests/openwisp2/sample_radius/tests.py +++ b/tests/openwisp2/sample_radius/tests.py @@ -1,3 +1,4 @@ +from openwisp_radius.tests import test_migrations as base_migration_tests from openwisp_radius.tests.test_admin import TestAdmin as BaseTestAdmin from openwisp_radius.tests.test_api.test_api import TestApi as BaseTestApi from openwisp_radius.tests.test_api.test_freeradius_api import ( @@ -185,6 +186,18 @@ class TestLoginView(BaseTestLoginView): pass +class TestMigrationRegisteredUserMultitenancy( + base_migration_tests.TestMigrationRegisteredUserMultitenancy, +): + pass + + +class TestPhoneTokenOrganizationPopulateResolution( + base_migration_tests.TestPhoneTokenOrganizationPopulateResolution, +): + pass + + del BaseTestAdmin del BaseTestApi del BaseTestFreeradiusApi @@ -214,3 +227,4 @@ class TestLoginView(BaseTestLoginView): del BaseTestUpgradeFromDjangoFreeradius del BaseTestAssertionConsumerServiceView del BaseTestLoginView +del base_migration_tests