From c0ac850fbf6272be1ea6ac579a33075d9d603fac Mon Sep 17 00:00:00 2001 From: Martin Torp Date: Wed, 10 Jun 2026 13:20:55 +0200 Subject: [PATCH 1/2] Add transient-error classification to APIFailure API.do_request now records the HTTP status code on every exception it raises (status_code attribute), and APIFailure gains is_transient_error(): True for gateway/connection-level failures (HTTP 408/502/503/504, dropped or reset connections, client-side timeouts) where retrying the same request may succeed, False for deterministic errors (400/401/403/404/429, wrapped unexpected errors). Classification is based on the recorded status code rather than exception class identity or message text, so it stays correct if a status code gains a dedicated subclass later. Motivated by SocketDev/socket-python-cli#232: the CLI retries transient full-scan upload failures and previously had to parse the status code out of catch-all APIFailure message text. --- pyproject.toml | 2 +- socketdev/core/api.py | 20 ++-- socketdev/exceptions.py | 127 +++++++++++++++--------- socketdev/version.py | 2 +- tests/unit/test_exceptions.py | 176 ++++++++++++++++++++++++++++++++++ uv.lock | 2 +- 6 files changed, 269 insertions(+), 60 deletions(-) create mode 100644 tests/unit/test_exceptions.py diff --git a/pyproject.toml b/pyproject.toml index 7c30914..609daa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "socketdev" -version = "3.2.1" +version = "3.3.0" requires-python = ">= 3.9" dependencies = [ 'requests', diff --git a/socketdev/core/api.py b/socketdev/core/api.py index 575e086..92fdace 100644 --- a/socketdev/core/api.py +++ b/socketdev/core/api.py @@ -76,26 +76,26 @@ def format_headers(headers_dict): path_str = f"\nPath: {url}" if response.status_code == 401: - raise APIAccessDenied(f"Unauthorized{path_str}{headers_str}") + raise APIAccessDenied(f"Unauthorized{path_str}{headers_str}", status_code=401) if response.status_code == 403: try: error_message = response.json().get("error", {}).get("message", "") if "Insufficient permissions for API method" in error_message: log.error(f"{error_message}{path_str}{headers_str}") - raise APIInsufficientPermissions() + raise APIInsufficientPermissions(status_code=403) elif "Organization not allowed" in error_message: log.error(f"{error_message}{path_str}{headers_str}") - raise APIOrganizationNotAllowed() + raise APIOrganizationNotAllowed(status_code=403) elif "Insufficient max quota" in error_message: log.error(f"{error_message}{path_str}{headers_str}") - raise APIInsufficientQuota() + raise APIInsufficientQuota(status_code=403) else: - raise APIAccessDenied(f"{error_message or 'Access denied'}{path_str}{headers_str}") + raise APIAccessDenied(f"{error_message or 'Access denied'}{path_str}{headers_str}", status_code=403) except ValueError: - raise APIAccessDenied(f"Access denied{path_str}{headers_str}") + raise APIAccessDenied(f"Access denied{path_str}{headers_str}", status_code=403) if response.status_code == 404: log.error(f"Path not found {path}{path_str}{headers_str}") - raise APIResourceNotFound() + raise APIResourceNotFound(status_code=404) if response.status_code == 429: retry_after = response.headers.get("retry-after") if retry_after: @@ -109,10 +109,10 @@ def format_headers(headers_dict): else: time_msg = "" log.error(f"Insufficient quota for API route.{time_msg}{path_str}{headers_str}") - raise APIInsufficientQuota() + raise APIInsufficientQuota(status_code=429) if response.status_code == 502: log.error(f"Upstream server error{path_str}{headers_str}") - raise APIBadGateway() + raise APIBadGateway(status_code=502) if response.status_code >= 400: try: error_json = response.json() @@ -124,7 +124,7 @@ def format_headers(headers_dict): f"Error message: {error_message}" ) log.error(error) - raise APIFailure(error) + raise APIFailure(error, status_code=response.status_code) return response except Timeout: diff --git a/socketdev/exceptions.py b/socketdev/exceptions.py index d61b827..077ae07 100644 --- a/socketdev/exceptions.py +++ b/socketdev/exceptions.py @@ -1,47 +1,80 @@ -class APIFailure(Exception): - """Base exception for all Socket API errors""" - pass - - -class APIKeyMissing(APIFailure): - """Raised when the api key is not passed and the headers are empty""" - - -class APIAccessDenied(APIFailure): - """Raised when access is denied to the API""" - pass - - -class APIInsufficientPermissions(APIFailure): - """Raised when the API token doesn't have required permissions""" - pass - - -class APIOrganizationNotAllowed(APIFailure): - """Raised when organization doesn't have access to the feature""" - pass - - -class APIInsufficientQuota(APIFailure): - """Raised when access is denied to the API due to quota limits""" - pass - - -class APIResourceNotFound(APIFailure): - """Raised when the requested resource is not found""" - pass - - -class APITimeout(APIFailure): - """Raised when a request times out""" - pass - - -class APIConnectionError(APIFailure): - """Raised when there's a connection error""" - pass - - -class APIBadGateway(APIFailure): - """Raised when the upstream server returns a 502 Bad Gateway error""" - pass \ No newline at end of file +from typing import Optional + +# HTTP statuses classified as transient by APIFailure.is_transient_error(): gateway / +# availability failures where the request was dropped before the application produced a +# definitive response, so retrying the same request may succeed (408 Request Timeout, +# 502 Bad Gateway, 503 Service Unavailable, 504 Gateway Timeout). +TRANSIENT_HTTP_STATUS_CODES = frozenset({408, 502, 503, 504}) + + +class APIFailure(Exception): + """Base exception for all Socket API errors""" + + def __init__(self, *args, status_code: Optional[int] = None): + super().__init__(*args) + self.status_code = status_code + + def is_transient_error(self) -> bool: + """Whether this failure is transient, i.e. retrying the same request may succeed. + + Transient failures happen at the gateway/connection level - HTTP 408/502/503/504, + dropped or reset connections, and client-side timeouts - before the server produced + a definitive response. Deterministic errors (e.g. 400/401/403/404/429) are not + transient: retrying the same request fails the same way. Classification is based on + the HTTP status code recorded when the exception was raised (or overridden by + subclasses without an HTTP status, like timeouts), so it stays correct even if a + status code gains a dedicated exception subclass later. + """ + return self.status_code in TRANSIENT_HTTP_STATUS_CODES + + +class APIKeyMissing(APIFailure): + """Raised when the api key is not passed and the headers are empty""" + + +class APIAccessDenied(APIFailure): + """Raised when access is denied to the API""" + pass + + +class APIInsufficientPermissions(APIFailure): + """Raised when the API token doesn't have required permissions""" + pass + + +class APIOrganizationNotAllowed(APIFailure): + """Raised when organization doesn't have access to the feature""" + pass + + +class APIInsufficientQuota(APIFailure): + """Raised when access is denied to the API due to quota limits""" + pass + + +class APIResourceNotFound(APIFailure): + """Raised when the requested resource is not found""" + pass + + +class APITimeout(APIFailure): + """Raised when a request times out""" + + def is_transient_error(self) -> bool: + # No HTTP status: the request timed out client-side, so a retry may succeed. + return True + + +class APIConnectionError(APIFailure): + """Raised when there's a connection error""" + + def is_transient_error(self) -> bool: + # No HTTP status: the connection was dropped/reset mid-request, so a retry may succeed. + return True + + +class APIBadGateway(APIFailure): + """Raised when the upstream server returns a 502 Bad Gateway error""" + + def __init__(self, *args, status_code: Optional[int] = 502): + super().__init__(*args, status_code=status_code) diff --git a/socketdev/version.py b/socketdev/version.py index 1da6a55..88c513e 100644 --- a/socketdev/version.py +++ b/socketdev/version.py @@ -1 +1 @@ -__version__ = "3.2.1" +__version__ = "3.3.0" diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py new file mode 100644 index 0000000..df621ad --- /dev/null +++ b/tests/unit/test_exceptions.py @@ -0,0 +1,176 @@ +""" +Unit tests for the SDK exception hierarchy and transient-error classification. + +`APIFailure.is_transient_error()` tells consumers whether retrying the same request may +succeed (gateway/connection-level failures: HTTP 408/502/503/504, dropped or reset +connections, client-side timeouts) or whether the failure is deterministic (400/401/403/ +404/429 and similar). Classification is based on the `status_code` recorded at raise time +inside `API.do_request`, so these tests cover both the exception classes themselves and +the status codes `do_request` attaches when raising them. + +Run with: python -m pytest tests/unit/ -v +""" + +import unittest +from unittest.mock import MagicMock, patch + +import requests + +from socketdev.core.api import API +from socketdev.exceptions import ( + APIAccessDenied, + APIBadGateway, + APIConnectionError, + APIFailure, + APIInsufficientPermissions, + APIInsufficientQuota, + APIOrganizationNotAllowed, + APIResourceNotFound, + APITimeout, +) + + +class TestIsTransientError(unittest.TestCase): + """Classification of exceptions constructed directly.""" + + def test_transient_statuses_on_catch_all_failure(self): + for status in (408, 502, 503, 504): + self.assertTrue(APIFailure("boom", status_code=status).is_transient_error()) + + def test_deterministic_statuses_on_catch_all_failure(self): + for status in (400, 401, 403, 404, 422, 429, 500): + self.assertFalse(APIFailure("boom", status_code=status).is_transient_error()) + + def test_no_status_code_is_not_transient(self): + # The wrapped-unexpected-error case: do_request raises a bare APIFailure(). + self.assertFalse(APIFailure().is_transient_error()) + self.assertFalse(APIFailure("boom").is_transient_error()) + + def test_connection_level_classes_are_transient(self): + self.assertTrue(APITimeout().is_transient_error()) + self.assertTrue(APIConnectionError().is_transient_error()) + self.assertTrue(APIBadGateway().is_transient_error()) + + def test_bad_gateway_carries_502_by_default(self): + self.assertEqual(APIBadGateway().status_code, 502) + + def test_dedicated_4xx_classes_are_not_transient(self): + self.assertFalse(APIAccessDenied("denied", status_code=401).is_transient_error()) + self.assertFalse(APIInsufficientPermissions(status_code=403).is_transient_error()) + self.assertFalse(APIOrganizationNotAllowed(status_code=403).is_transient_error()) + self.assertFalse(APIInsufficientQuota(status_code=429).is_transient_error()) + self.assertFalse(APIResourceNotFound(status_code=404).is_transient_error()) + + def test_subclass_with_transient_status_follows_the_status(self): + # Classification is by recorded status, not class identity: if a transient status + # ever gains a dedicated subclass, is_transient_error() keeps working unchanged. + class APIServiceUnavailable(APIFailure): + pass + + self.assertTrue(APIServiceUnavailable(status_code=503).is_transient_error()) + + def test_message_text_does_not_affect_classification(self): + self.assertFalse( + APIFailure("original_status_code:503 lookalike").is_transient_error() + ) + + def test_single_message_arg_is_preserved(self): + error = APIFailure("something broke", status_code=503) + self.assertEqual(str(error), "something broke") + + +def _mock_response(status_code, json_data=None, headers=None, text=""): + response = MagicMock() + response.status_code = status_code + response.headers = headers if headers is not None else {} + response.text = text + if json_data is None: + response.json.side_effect = ValueError("no json") + else: + response.json.return_value = json_data + return response + + +class TestDoRequestStatusCodes(unittest.TestCase): + """do_request attaches the HTTP status to the exceptions it raises.""" + + def setUp(self): + self.api = API() + self.api.encode_key("test-token") + + def _do_request_raising(self, expected_class, response=None, side_effect=None): + with patch("socketdev.core.api.requests.request") as mock_request: + if side_effect is not None: + mock_request.side_effect = side_effect + else: + mock_request.return_value = response + with self.assertRaises(expected_class) as ctx: + self.api.do_request("orgs/test/full-scans", method="POST") + return ctx.exception + + def test_401_access_denied_is_not_transient(self): + error = self._do_request_raising(APIAccessDenied, _mock_response(401)) + self.assertEqual(error.status_code, 401) + self.assertFalse(error.is_transient_error()) + + def test_403_insufficient_permissions_is_not_transient(self): + response = _mock_response( + 403, + json_data={"error": {"message": "Insufficient permissions for API method"}}, + ) + error = self._do_request_raising(APIInsufficientPermissions, response) + self.assertEqual(error.status_code, 403) + self.assertFalse(error.is_transient_error()) + + def test_404_not_found_is_not_transient(self): + error = self._do_request_raising(APIResourceNotFound, _mock_response(404)) + self.assertEqual(error.status_code, 404) + self.assertFalse(error.is_transient_error()) + + def test_429_quota_is_not_transient(self): + error = self._do_request_raising(APIInsufficientQuota, _mock_response(429)) + self.assertEqual(error.status_code, 429) + self.assertFalse(error.is_transient_error()) + + def test_502_bad_gateway_is_transient(self): + error = self._do_request_raising(APIBadGateway, _mock_response(502)) + self.assertEqual(error.status_code, 502) + self.assertTrue(error.is_transient_error()) + + def test_catch_all_transient_statuses(self): + for status in (408, 503, 504): + error = self._do_request_raising(APIFailure, _mock_response(status)) + self.assertIs(type(error), APIFailure) + self.assertEqual(error.status_code, status) + self.assertTrue(error.is_transient_error()) + + def test_catch_all_deterministic_statuses(self): + for status in (400, 500): + error = self._do_request_raising(APIFailure, _mock_response(status)) + self.assertIs(type(error), APIFailure) + self.assertEqual(error.status_code, status) + self.assertFalse(error.is_transient_error()) + + def test_timeout_is_transient(self): + error = self._do_request_raising( + APITimeout, side_effect=requests.exceptions.Timeout("timed out") + ) + self.assertIsNone(error.status_code) + self.assertTrue(error.is_transient_error()) + + def test_connection_error_is_transient(self): + error = self._do_request_raising( + APIConnectionError, + side_effect=requests.exceptions.ConnectionError("reset"), + ) + self.assertIsNone(error.status_code) + self.assertTrue(error.is_transient_error()) + + def test_unexpected_error_wrapped_without_status_is_not_transient(self): + error = self._do_request_raising(APIFailure, side_effect=RuntimeError("boom")) + self.assertIsNone(error.status_code) + self.assertFalse(error.is_transient_error()) + + +if __name__ == "__main__": + unittest.main() diff --git a/uv.lock b/uv.lock index 6331649..7397d8f 100644 --- a/uv.lock +++ b/uv.lock @@ -1353,7 +1353,7 @@ wheels = [ [[package]] name = "socketdev" -version = "3.2.1" +version = "3.3.0" source = { editable = "." } dependencies = [ { name = "requests" }, From 30fba8af55678f98616378b0d1a10ec2c1e53e9d Mon Sep 17 00:00:00 2001 From: Martin Torp Date: Wed, 10 Jun 2026 13:34:03 +0200 Subject: [PATCH 2/2] Hardcode 502 in APIBadGateway instead of accepting a status_code The class definitionally represents a 502, so there is no reason for construction sites to pass (or be able to override) the status. --- socketdev/core/api.py | 2 +- socketdev/exceptions.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/socketdev/core/api.py b/socketdev/core/api.py index 92fdace..8c35231 100644 --- a/socketdev/core/api.py +++ b/socketdev/core/api.py @@ -112,7 +112,7 @@ def format_headers(headers_dict): raise APIInsufficientQuota(status_code=429) if response.status_code == 502: log.error(f"Upstream server error{path_str}{headers_str}") - raise APIBadGateway(status_code=502) + raise APIBadGateway() if response.status_code >= 400: try: error_json = response.json() diff --git a/socketdev/exceptions.py b/socketdev/exceptions.py index 077ae07..980aaf9 100644 --- a/socketdev/exceptions.py +++ b/socketdev/exceptions.py @@ -76,5 +76,5 @@ def is_transient_error(self) -> bool: class APIBadGateway(APIFailure): """Raised when the upstream server returns a 502 Bad Gateway error""" - def __init__(self, *args, status_code: Optional[int] = 502): - super().__init__(*args, status_code=status_code) + def __init__(self, *args): + super().__init__(*args, status_code=502)