1515"""Firebase Cloud Messaging module."""
1616
1717from __future__ import annotations
18- from typing import Callable , List , Optional , TypeVar
18+ from typing import Callable , List , Optional
1919import concurrent .futures
2020import json
2121import warnings
22+ import asyncio
2223import requests
2324import httpx
24- import asyncio
2525
26- from google .auth import credentials , transport
26+ from google .auth import credentials
27+ from google .auth .transport import requests as auth_requests
2728from googleapiclient import http
2829from googleapiclient import _auth
2930
3031import firebase_admin
31- from firebase_admin import _http_client
32- from firebase_admin import _messaging_encoder
33- from firebase_admin import _messaging_utils
34- from firebase_admin import _gapic_utils
35- from firebase_admin import _utils
36- from firebase_admin import exceptions
32+ from firebase_admin import (
33+ _http_client ,
34+ _messaging_encoder ,
35+ _messaging_utils ,
36+ _gapic_utils ,
37+ _utils ,
38+ exceptions ,
39+ App
40+ )
3741
3842
3943_MESSAGING_ATTRIBUTE = '_messaging'
6771 'WebpushNotification' ,
6872 'WebpushNotificationAction' ,
6973
70- 'async_send_each'
7174 'send' ,
7275 'send_all' ,
7376 'send_multicast' ,
7477 'send_each' ,
78+ 'send_each_async' ,
7579 'send_each_for_multicast' ,
7680 'subscribe_to_topic' ,
7781 'unsubscribe_from_topic' ,
78- ] # type: ignore
82+ ]
7983
80- TFirebaseError = TypeVar ('TFirebaseError' , bound = exceptions .FirebaseError )
8184
8285AndroidConfig = _messaging_utils .AndroidConfig
8386AndroidFCMOptions = _messaging_utils .AndroidFCMOptions
104107UnregisteredError = _messaging_utils .UnregisteredError
105108
106109
107- def _get_messaging_service (app ) -> _MessagingService :
110+ def _get_messaging_service (app : Optional [ App ] ) -> _MessagingService :
108111 return _utils .get_app_service (app , _MESSAGING_ATTRIBUTE , _MessagingService )
109112
110- def send (message , dry_run = False , app = None ):
113+ def send (message , dry_run = False , app : Optional [ App ] = None ):
111114 """Sends the given message via Firebase Cloud Messaging (FCM).
112115
113116 If the ``dry_run`` mode is enabled, the message will not be actually delivered to the
@@ -147,8 +150,8 @@ def send_each(messages, dry_run=False, app=None):
147150 """
148151 return _get_messaging_service (app ).send_each (messages , dry_run )
149152
150- async def async_send_each (messages , dry_run = True , app : firebase_admin . App | None = None ) -> BatchResponse :
151- return await _get_messaging_service (app ).async_send_each (messages , dry_run )
153+ async def send_each_async (messages , dry_run = True , app : Optional [ App ] = None ) -> BatchResponse :
154+ return await _get_messaging_service (app ).send_each_async (messages , dry_run )
152155
153156def send_each_for_multicast (multicast_message , dry_run = False , app = None ):
154157 """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM).
@@ -374,48 +377,53 @@ def exception(self):
374377 return self ._exception
375378
376379# Auth Flow
380+ # TODO: Remove comments
377381# The aim here is to be able to get auth credentials right before the request is sent.
378382# This is similar to what is done in transport.requests.AuthorizedSession().
379383# We can then pass this in at the client level.
380- class CustomGoogleAuth (httpx .Auth ):
381- def __init__ (self , credentials : credentials .Credentials ):
382- self ._credential = credentials
384+
385+ # Notes:
386+ # - This implementations does not cover timeouts on requests sent to refresh credentials.
387+ # - Uses HTTP/1 and a blocking credential for refreshing.
388+ class GoogleAuthCredentialFlow (httpx .Auth ):
389+ """Google Auth Credential Auth Flow"""
390+ def __init__ (self , credential : credentials .Credentials ):
391+ self ._credential = credential
383392 self ._max_refresh_attempts = 2
384393 self ._refresh_status_codes = (401 ,)
385-
394+
386395 def apply_auth_headers (self , request : httpx .Request ):
387396 # Build request used to refresh credentials if needed
388- auth_request = transport . requests . Request () # type: ignore
389- # This refreshes the credentials if needed and mutates the request headers to contain access token
390- # and any other google auth headers
397+ auth_request = auth_requests . Request ()
398+ # This refreshes the credentials if needed and mutates the request headers to
399+ # contain access token and any other google auth headers
391400 self ._credential .before_request (auth_request , request .method , request .url , request .headers )
392401
393402
394403 def auth_flow (self , request : httpx .Request ):
395404 # Keep original headers since `credentials.before_request` mutates the passed headers and we
396405 # want to keep the original in cause we need an auth retry.
397406 _original_headers = request .headers .copy ()
398-
407+
399408 _credential_refresh_attempt = 0
400- while (
401- _credential_refresh_attempt < self ._max_refresh_attempts
402- ):
409+ while _credential_refresh_attempt <= self ._max_refresh_attempts :
403410 # copy original headers
404411 request .headers = _original_headers .copy ()
405412 # mutates request headers
406413 self .apply_auth_headers (request )
407-
414+
408415 # Continue to perform the request
409416 # yield here dispatches the request and returns with the response
410417 response : httpx .Response = yield request
411-
412- # We can check the result of the response and determine in we need to retry on refreshable status codes.
413- # Current transport.requests.AuthorizedSession() only does this on 401 errors. We should do the same.
418+
419+ # We can check the result of the response and determine in we need to retry
420+ # on refreshable status codes. Current transport.requests.AuthorizedSession()
421+ # only does this on 401 errors. We should do the same.
414422 if response .status_code in self ._refresh_status_codes :
415423 _credential_refresh_attempt += 1
416- print (response .status_code , response .reason_phrase , _credential_refresh_attempt )
417424 else :
418- break ;
425+ break
426+ # Last yielded response is auto returned.
419427
420428
421429
@@ -453,7 +461,7 @@ def __init__(self, app) -> None:
453461 self ._client = _http_client .JsonHttpClient (credential = self ._credential , timeout = timeout )
454462 self ._async_client = httpx .AsyncClient (
455463 http2 = True ,
456- auth = CustomGoogleAuth (self ._credential ),
464+ auth = GoogleAuthCredentialFlow (self ._credential ),
457465 timeout = timeout ,
458466 transport = HttpxRetryTransport ()
459467 )
@@ -509,13 +517,13 @@ def send_data(data):
509517 message = 'Unknown error while making remote service calls: {0}' .format (error ),
510518 cause = error )
511519
512- async def async_send_each (self , messages : List [Message ], dry_run : bool = True ) -> BatchResponse :
520+ async def send_each_async (self , messages : List [Message ], dry_run : bool = True ) -> BatchResponse :
513521 """Sends the given messages to FCM via the FCM v1 API."""
514522 if not isinstance (messages , list ):
515523 raise ValueError ('messages must be a list of messaging.Message instances.' )
516524 if len (messages ) > 1000 :
517525 raise ValueError ('messages must not contain more than 500 elements.' )
518-
526+
519527 async def send_data (data ):
520528 try :
521529 resp = await self ._async_client .request (
@@ -661,7 +669,8 @@ def _handle_batch_error(self, error):
661669 """Handles errors received from the googleapiclient while making batch requests."""
662670 return _gapic_utils .handle_platform_error_from_googleapiclient (
663671 error , _MessagingService ._build_fcm_error_googleapiclient )
664-
672+
673+ # TODO: Remove comments
665674 # We should be careful to clean up the httpx clients.
666675 # Since we are using an async client we must also close in async. However we can sync wrap this.
667676 # The close method is called by the app on shutdown/clean-up of each service. We don't seem to
@@ -677,14 +686,16 @@ def _build_fcm_error_requests(cls, error, message, error_dict):
677686 return exc_type (message , cause = error , http_response = error .response ) if exc_type else None
678687
679688 @classmethod
680- def _build_fcm_error_httpx (cls , error : httpx .HTTPError , message , error_dict ) -> Optional [exceptions .FirebaseError ]:
689+ def _build_fcm_error_httpx (
690+ cls , error : httpx .HTTPError , message , error_dict
691+ ) -> Optional [exceptions .FirebaseError ]:
681692 """Parses a httpx error response from the FCM API and creates a FCM-specific exception if
682693 appropriate."""
683694 exc_type = cls ._build_fcm_error (error_dict )
684695 if isinstance (error , httpx .HTTPStatusError ):
685- return exc_type (message , cause = error , http_response = error . response ) if exc_type else None
686- else :
687- return exc_type (message , cause = error ) if exc_type else None
696+ return exc_type (
697+ message , cause = error , http_response = error . response ) if exc_type else None
698+ return exc_type (message , cause = error ) if exc_type else None
688699
689700
690701 @classmethod
@@ -706,42 +717,43 @@ def _build_fcm_error(cls, error_dict) -> Optional[Callable[..., exceptions.Fireb
706717 return _MessagingService .FCM_ERROR_TYPES .get (fcm_code ) if fcm_code else None
707718
708719
720+ # TODO: Remove comments
721+ # Notes:
722+ # This implementation currently only covers basic retires for pre-defined status errors
709723class HttpxRetryTransport (httpx .AsyncBaseTransport ):
724+ """HTTPX transport with retry logic."""
710725 # We could also support passing kwargs here
711- def __init__ (self ) -> None :
726+ def __init__ (self , ** kwargs ) -> None :
727+ # Hardcoded settings for now
712728 self ._retryable_status_codes = (500 , 503 ,)
713729 self ._max_retry_count = 4
714730
715- # We should use a full AsyncHTTPTransport under the hood since that is
716- # fully implemented. We could consider making this class extend a
717- # AsyncHTTPTransport instead and use the parent class's methods to handle
718- # requests. We sould also ensure that that transport's internal retry is
719- # not enabled.
720- self ._wrapped_transport = httpx .AsyncHTTPTransport (retries = 0 , http2 = True )
721-
722- # Checklist:
723- # - Do we want to disable built in retries
724- # - Can we dispatch the same request multiple times? Is there any side effects?
725-
726- # Two types of retries
727- # - Status code (500s, redirect)
728- # - Error code (read, connect, other)
729- # - more ???
730-
731+ # - We use a full AsyncHTTPTransport under the hood to make use of it's
732+ # fully implemented `handle_async_request()`.
733+ # - We could consider making the `HttpxRetryTransport`` class extend a
734+ # `AsyncHTTPTransport` instead and use the parent class's methods to handle
735+ # requests.
736+ # - We should also ensure that that transport's internal retry is
737+ # not enabled.
738+ transport_kwargs = kwargs .copy ()
739+ transport_kwargs .update ({'retries' : 0 , 'http2' : True })
740+ self ._wrapped_transport = httpx .AsyncHTTPTransport (** transport_kwargs )
741+
742+
731743 async def handle_async_request (self , request : httpx .Request ) -> httpx .Response :
732744 _retry_count = 0
733-
745+
734746 while True :
735747 # Dispatch request
748+ # Let exceptions pass through for now
736749 response = await self ._wrapped_transport .handle_async_request (request )
737-
750+
738751 # Check if request is retryable
739752 if response .status_code in self ._retryable_status_codes :
740753 _retry_count += 1
741-
742- # Figure out how we want to handle 0 here
754+
755+ # Return if retries exhausted
743756 if _retry_count > self ._max_retry_count :
744757 return response
745758 else :
746759 return response
747- # break;
0 commit comments