Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions PyViCare/PyViCareAbstractOAuthManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def renewToken(self) -> None:
def get(self, url: str) -> Any:
try:
logger.debug(self.__oauth)
response = self.__oauth.get(f"{API_BASE_URL}{url}", timeout=31).json()
raw_response = self.__oauth.get(f"{API_BASE_URL}{url}", timeout=31)
self.__raise_on_non_json_error(raw_response)
response = raw_response.json()
logger.debug("Response to get request: %s", response)
self.__handle_expired_token(response)
self.__handle_rate_limit(response)
Expand All @@ -58,6 +60,21 @@ def get(self, url: str) -> Any:
except InvalidTokenError:
self.renewToken()
return self.get(url)
except OSError as e:
raise PyViCareInternalServerError(
{"statusCode": 0,
"message": str(e),
"viErrorId": "n/a"}) from e

def __raise_on_non_json_error(self, response):
"""Guard against non-JSON error responses (e.g. 502 HTML pages from API gateway)."""
if response.status_code >= 500:
content_type = response.headers.get('content-type', '')
if 'application/json' not in content_type:
raise PyViCareInternalServerError(
{"statusCode": response.status_code,
"message": f"Non-JSON {response.status_code} response",
"viErrorId": "n/a"})

def __handle_expired_token(self, response):
if ("error" in response and response["error"] == "EXPIRED TOKEN"):
Expand All @@ -82,6 +99,13 @@ def __handle_server_error(self, response):
if ("statusCode" in response and response["statusCode"] >= 500):
raise PyViCareInternalServerError(response)

extended = response.get("extendedPayload", {})
if isinstance(extended, dict) and extended.get("code") in ("500", "502", "503"):
raise PyViCareInternalServerError(
{"statusCode": int(extended["code"]),
"message": extended.get("reason", ""),
"viErrorId": response.get("viErrorId", "n/a")})

def __handle_command_error(self, response):
if not Feature.raise_exception_on_command_failure:
return
Expand All @@ -106,8 +130,10 @@ def post(self, url, data) -> Any:
headers = {"Content-Type": "application/json",
"Accept": "application/vnd.siren+json"}
try:
response = self.__oauth.post(
f"{API_BASE_URL}{url}", data, headers=headers).json()
raw_response = self.__oauth.post(
f"{API_BASE_URL}{url}", data, headers=headers)
self.__raise_on_non_json_error(raw_response)
response = raw_response.json()
self.__handle_expired_token(response)
self.__handle_rate_limit(response)
self.__handle_command_error(response)
Expand All @@ -118,3 +144,8 @@ def post(self, url, data) -> Any:
except InvalidTokenError:
self.renewToken()
return self.post(url, data)
except OSError as e:
raise PyViCareInternalServerError(
{"statusCode": 0,
"message": str(e),
"viErrorId": "n/a"}) from e
53 changes: 52 additions & 1 deletion tests/test_ViCareOAuthManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ def renewToken(self):


class FakeResponse:
def __init__(self, file_name):
def __init__(self, file_name, status_code=200, content_type='application/json'):
self.file_name = file_name
self.status_code = status_code
self.headers = {'content-type': content_type}

def json(self):
return readJson(self.file_name)
Expand Down Expand Up @@ -87,6 +89,55 @@ def func():
return self.manager.post("/", "some")
self.assertRaises(PyViCareRateLimitError, func)

def test_get_raise_on_non_json_502(self):
response = Mock()
response.status_code = 502
response.headers = {'content-type': 'text/html'}
self.oauth_mock.get.return_value = response

def func():
return self.manager.get("/")
self.assertRaises(PyViCareInternalServerError, func)

def test_get_raise_on_extended_payload_timeout(self):
self.oauth_mock.get.return_value = FakeResponse.__new__(FakeResponse)
self.oauth_mock.get.return_value.status_code = 200
self.oauth_mock.get.return_value.headers = {'content-type': 'application/json'}
self.oauth_mock.get.return_value.json = lambda: {
'viErrorId': '00-abc-def-00',
'errorType': '',
'message': '',
'extendedPayload': {'code': '500', 'reason': 'TIMEOUT'}
}

def func():
return self.manager.get("/")
self.assertRaises(PyViCareInternalServerError, func)

def test_get_raise_on_connection_error(self):
self.oauth_mock.get.side_effect = OSError("Timeout while contacting DNS servers")

def func():
return self.manager.get("/")
self.assertRaises(PyViCareInternalServerError, func)

def test_post_raise_on_connection_error(self):
self.oauth_mock.post.side_effect = OSError("Connection refused")

def func():
return self.manager.post("/", {})
self.assertRaises(PyViCareInternalServerError, func)

def test_post_raise_on_non_json_502(self):
response = Mock()
response.status_code = 502
response.headers = {'content-type': 'text/html'}
self.oauth_mock.post.return_value = response

def func():
return self.manager.post("/", {})
self.assertRaises(PyViCareInternalServerError, func)

def test_post_renewtoken_ifexpired(self):
self.oauth_mock.post.side_effect = [
FakeResponse('response/errors/expired_token.json'), # first call expired
Expand Down
Loading