Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "socketdev"
version = "3.2.1"
version = "3.3.0"
requires-python = ">= 3.9"
dependencies = [
'requests',
Expand Down
18 changes: 9 additions & 9 deletions socketdev/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,26 @@ def format_headers(headers_dict):
path_str = f"\nPath: {url}"

if response.status_code == 401:
raise APIAccessDenied(f"Unauthorized{path_str}{headers_str}")
raise APIAccessDenied(f"Unauthorized{path_str}{headers_str}", status_code=401)
if response.status_code == 403:
try:
error_message = response.json().get("error", {}).get("message", "")
if "Insufficient permissions for API method" in error_message:
log.error(f"{error_message}{path_str}{headers_str}")
raise APIInsufficientPermissions()
raise APIInsufficientPermissions(status_code=403)
elif "Organization not allowed" in error_message:
log.error(f"{error_message}{path_str}{headers_str}")
raise APIOrganizationNotAllowed()
raise APIOrganizationNotAllowed(status_code=403)
elif "Insufficient max quota" in error_message:
log.error(f"{error_message}{path_str}{headers_str}")
raise APIInsufficientQuota()
raise APIInsufficientQuota(status_code=403)
else:
raise APIAccessDenied(f"{error_message or 'Access denied'}{path_str}{headers_str}")
raise APIAccessDenied(f"{error_message or 'Access denied'}{path_str}{headers_str}", status_code=403)
except ValueError:
raise APIAccessDenied(f"Access denied{path_str}{headers_str}")
raise APIAccessDenied(f"Access denied{path_str}{headers_str}", status_code=403)
if response.status_code == 404:
log.error(f"Path not found {path}{path_str}{headers_str}")
raise APIResourceNotFound()
raise APIResourceNotFound(status_code=404)
if response.status_code == 429:
retry_after = response.headers.get("retry-after")
if retry_after:
Expand All @@ -109,7 +109,7 @@ def format_headers(headers_dict):
else:
time_msg = ""
log.error(f"Insufficient quota for API route.{time_msg}{path_str}{headers_str}")
raise APIInsufficientQuota()
raise APIInsufficientQuota(status_code=429)
if response.status_code == 502:
log.error(f"Upstream server error{path_str}{headers_str}")
raise APIBadGateway()
Expand All @@ -124,7 +124,7 @@ def format_headers(headers_dict):
f"Error message: {error_message}"
)
log.error(error)
raise APIFailure(error)
raise APIFailure(error, status_code=response.status_code)

return response
except Timeout:
Expand Down
127 changes: 80 additions & 47 deletions socketdev/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,80 @@
class APIFailure(Exception):
"""Base exception for all Socket API errors"""
pass


class APIKeyMissing(APIFailure):
"""Raised when the api key is not passed and the headers are empty"""


class APIAccessDenied(APIFailure):
"""Raised when access is denied to the API"""
pass


class APIInsufficientPermissions(APIFailure):
"""Raised when the API token doesn't have required permissions"""
pass


class APIOrganizationNotAllowed(APIFailure):
"""Raised when organization doesn't have access to the feature"""
pass


class APIInsufficientQuota(APIFailure):
"""Raised when access is denied to the API due to quota limits"""
pass


class APIResourceNotFound(APIFailure):
"""Raised when the requested resource is not found"""
pass


class APITimeout(APIFailure):
"""Raised when a request times out"""
pass


class APIConnectionError(APIFailure):
"""Raised when there's a connection error"""
pass


class APIBadGateway(APIFailure):
"""Raised when the upstream server returns a 502 Bad Gateway error"""
pass
from typing import Optional

# HTTP statuses classified as transient by APIFailure.is_transient_error(): gateway /
# availability failures where the request was dropped before the application produced a
# definitive response, so retrying the same request may succeed (408 Request Timeout,
# 502 Bad Gateway, 503 Service Unavailable, 504 Gateway Timeout).
TRANSIENT_HTTP_STATUS_CODES = frozenset({408, 502, 503, 504})


class APIFailure(Exception):
"""Base exception for all Socket API errors"""

def __init__(self, *args, status_code: Optional[int] = None):
super().__init__(*args)
self.status_code = status_code

def is_transient_error(self) -> bool:
"""Whether this failure is transient, i.e. retrying the same request may succeed.

Transient failures happen at the gateway/connection level - HTTP 408/502/503/504,
dropped or reset connections, and client-side timeouts - before the server produced
a definitive response. Deterministic errors (e.g. 400/401/403/404/429) are not
transient: retrying the same request fails the same way. Classification is based on
the HTTP status code recorded when the exception was raised (or overridden by
subclasses without an HTTP status, like timeouts), so it stays correct even if a
status code gains a dedicated exception subclass later.
"""
return self.status_code in TRANSIENT_HTTP_STATUS_CODES


class APIKeyMissing(APIFailure):
"""Raised when the api key is not passed and the headers are empty"""


class APIAccessDenied(APIFailure):
"""Raised when access is denied to the API"""
pass


class APIInsufficientPermissions(APIFailure):
"""Raised when the API token doesn't have required permissions"""
pass


class APIOrganizationNotAllowed(APIFailure):
"""Raised when organization doesn't have access to the feature"""
pass


class APIInsufficientQuota(APIFailure):
"""Raised when access is denied to the API due to quota limits"""
pass


class APIResourceNotFound(APIFailure):
"""Raised when the requested resource is not found"""
pass


class APITimeout(APIFailure):
"""Raised when a request times out"""

def is_transient_error(self) -> bool:
# No HTTP status: the request timed out client-side, so a retry may succeed.
return True


class APIConnectionError(APIFailure):
"""Raised when there's a connection error"""

def is_transient_error(self) -> bool:
# No HTTP status: the connection was dropped/reset mid-request, so a retry may succeed.
return True


class APIBadGateway(APIFailure):
"""Raised when the upstream server returns a 502 Bad Gateway error"""

def __init__(self, *args):
super().__init__(*args, status_code=502)
2 changes: 1 addition & 1 deletion socketdev/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.1"
__version__ = "3.3.0"
176 changes: 176 additions & 0 deletions tests/unit/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""
Unit tests for the SDK exception hierarchy and transient-error classification.

`APIFailure.is_transient_error()` tells consumers whether retrying the same request may
succeed (gateway/connection-level failures: HTTP 408/502/503/504, dropped or reset
connections, client-side timeouts) or whether the failure is deterministic (400/401/403/
404/429 and similar). Classification is based on the `status_code` recorded at raise time
inside `API.do_request`, so these tests cover both the exception classes themselves and
the status codes `do_request` attaches when raising them.

Run with: python -m pytest tests/unit/ -v
"""

import unittest
from unittest.mock import MagicMock, patch

import requests

from socketdev.core.api import API
from socketdev.exceptions import (
APIAccessDenied,
APIBadGateway,
APIConnectionError,
APIFailure,
APIInsufficientPermissions,
APIInsufficientQuota,
APIOrganizationNotAllowed,
APIResourceNotFound,
APITimeout,
)


class TestIsTransientError(unittest.TestCase):
"""Classification of exceptions constructed directly."""

def test_transient_statuses_on_catch_all_failure(self):
for status in (408, 502, 503, 504):
self.assertTrue(APIFailure("boom", status_code=status).is_transient_error())

def test_deterministic_statuses_on_catch_all_failure(self):
for status in (400, 401, 403, 404, 422, 429, 500):
self.assertFalse(APIFailure("boom", status_code=status).is_transient_error())

def test_no_status_code_is_not_transient(self):
# The wrapped-unexpected-error case: do_request raises a bare APIFailure().
self.assertFalse(APIFailure().is_transient_error())
self.assertFalse(APIFailure("boom").is_transient_error())

def test_connection_level_classes_are_transient(self):
self.assertTrue(APITimeout().is_transient_error())
self.assertTrue(APIConnectionError().is_transient_error())
self.assertTrue(APIBadGateway().is_transient_error())

def test_bad_gateway_carries_502_by_default(self):
self.assertEqual(APIBadGateway().status_code, 502)

def test_dedicated_4xx_classes_are_not_transient(self):
self.assertFalse(APIAccessDenied("denied", status_code=401).is_transient_error())
self.assertFalse(APIInsufficientPermissions(status_code=403).is_transient_error())
self.assertFalse(APIOrganizationNotAllowed(status_code=403).is_transient_error())
self.assertFalse(APIInsufficientQuota(status_code=429).is_transient_error())
self.assertFalse(APIResourceNotFound(status_code=404).is_transient_error())

def test_subclass_with_transient_status_follows_the_status(self):
# Classification is by recorded status, not class identity: if a transient status
# ever gains a dedicated subclass, is_transient_error() keeps working unchanged.
class APIServiceUnavailable(APIFailure):
pass

self.assertTrue(APIServiceUnavailable(status_code=503).is_transient_error())

def test_message_text_does_not_affect_classification(self):
self.assertFalse(
APIFailure("original_status_code:503 lookalike").is_transient_error()
)

def test_single_message_arg_is_preserved(self):
error = APIFailure("something broke", status_code=503)
self.assertEqual(str(error), "something broke")


def _mock_response(status_code, json_data=None, headers=None, text=""):
response = MagicMock()
response.status_code = status_code
response.headers = headers if headers is not None else {}
response.text = text
if json_data is None:
response.json.side_effect = ValueError("no json")
else:
response.json.return_value = json_data
return response


class TestDoRequestStatusCodes(unittest.TestCase):
"""do_request attaches the HTTP status to the exceptions it raises."""

def setUp(self):
self.api = API()
self.api.encode_key("test-token")

def _do_request_raising(self, expected_class, response=None, side_effect=None):
with patch("socketdev.core.api.requests.request") as mock_request:
if side_effect is not None:
mock_request.side_effect = side_effect
else:
mock_request.return_value = response
with self.assertRaises(expected_class) as ctx:
self.api.do_request("orgs/test/full-scans", method="POST")
return ctx.exception

def test_401_access_denied_is_not_transient(self):
error = self._do_request_raising(APIAccessDenied, _mock_response(401))
self.assertEqual(error.status_code, 401)
self.assertFalse(error.is_transient_error())

def test_403_insufficient_permissions_is_not_transient(self):
response = _mock_response(
403,
json_data={"error": {"message": "Insufficient permissions for API method"}},
)
error = self._do_request_raising(APIInsufficientPermissions, response)
self.assertEqual(error.status_code, 403)
self.assertFalse(error.is_transient_error())

def test_404_not_found_is_not_transient(self):
error = self._do_request_raising(APIResourceNotFound, _mock_response(404))
self.assertEqual(error.status_code, 404)
self.assertFalse(error.is_transient_error())

def test_429_quota_is_not_transient(self):
error = self._do_request_raising(APIInsufficientQuota, _mock_response(429))
self.assertEqual(error.status_code, 429)
self.assertFalse(error.is_transient_error())

def test_502_bad_gateway_is_transient(self):
error = self._do_request_raising(APIBadGateway, _mock_response(502))
self.assertEqual(error.status_code, 502)
self.assertTrue(error.is_transient_error())

def test_catch_all_transient_statuses(self):
for status in (408, 503, 504):
error = self._do_request_raising(APIFailure, _mock_response(status))
self.assertIs(type(error), APIFailure)
self.assertEqual(error.status_code, status)
self.assertTrue(error.is_transient_error())

def test_catch_all_deterministic_statuses(self):
for status in (400, 500):
error = self._do_request_raising(APIFailure, _mock_response(status))
self.assertIs(type(error), APIFailure)
self.assertEqual(error.status_code, status)
self.assertFalse(error.is_transient_error())

def test_timeout_is_transient(self):
error = self._do_request_raising(
APITimeout, side_effect=requests.exceptions.Timeout("timed out")
)
self.assertIsNone(error.status_code)
self.assertTrue(error.is_transient_error())

def test_connection_error_is_transient(self):
error = self._do_request_raising(
APIConnectionError,
side_effect=requests.exceptions.ConnectionError("reset"),
)
self.assertIsNone(error.status_code)
self.assertTrue(error.is_transient_error())

def test_unexpected_error_wrapped_without_status_is_not_transient(self):
error = self._do_request_raising(APIFailure, side_effect=RuntimeError("boom"))
self.assertIsNone(error.status_code)
self.assertFalse(error.is_transient_error())


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading