Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions astrbot/core/tools/web_search_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ async def get(self, provider_settings: dict) -> str:
self.index = (self.index + 1) % len(keys)
return key

async def iter_keys(self, provider_settings: dict) -> list[str]:
"""Return every configured key in rotation order (current index first).

Used for failover: callers try the keys in turn and move on to the next
one when a key is invalid, out of quota, or rate-limited.
"""
keys = provider_settings.get(self.setting_name, [])
if not keys:
raise ValueError(
f"Error: {self.provider_name} API key is not configured in AstrBot."
)

async with self.lock:
start = self.index
self.index = (self.index + 1) % len(keys)
return [keys[(start + i) % len(keys)] for i in range(len(keys))]


_TAVILY_KEY_ROTATOR = _KeyRotator("websearch_tavily_key", "Tavily")
_BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha")
Expand Down Expand Up @@ -147,11 +164,10 @@ def _search_result_payload(results: list[SearchResult]) -> str:
return json.dumps({"results": ret_ls}, ensure_ascii=False)


async def _tavily_search(
provider_settings: dict,
async def _tavily_request_once(
tavily_key: str,
payload: dict,
) -> list[SearchResult]:
tavily_key = await _TAVILY_KEY_ROTATOR.get(provider_settings)
header = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
Expand Down Expand Up @@ -179,6 +195,22 @@ async def _tavily_search(
]


async def _tavily_search(
provider_settings: dict,
payload: dict,
) -> list[SearchResult]:
keys = await _TAVILY_KEY_ROTATOR.iter_keys(provider_settings)
last_exc: Exception | None = None
for tavily_key in keys:
try:
return await _tavily_request_once(tavily_key, payload)
except Exception as e:
# Key invalid / out of quota / rate-limited: fall through to the next key.
last_exc = e
assert last_exc is not None # iter_keys raises when no keys are configured
raise last_exc
Comment on lines +198 to +211

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The failover mechanism is currently only implemented for _tavily_search. However, _tavily_extract (which powers the tavily_extract_web_page tool) also uses _TAVILY_KEY_ROTATOR and will fail immediately if the rotated key is invalid or exhausted, even if other valid keys are configured.

To ensure consistent failover behavior across all Tavily tools and avoid code duplication (per the general rules), we should refactor the failover loop into a generic helper function. This also makes it easy to reuse for other providers (like BoCha, Brave, Firecrawl) in the future.

Here is an example of how you can implement this:

async def _execute_with_failover(rotator, provider_settings, func, *args, **kwargs):
    keys = await rotator.iter_keys(provider_settings)
    last_exc: Exception | None = None
    for key in keys:
        try:
            return await func(key, *args, **kwargs)
        except Exception as e:
            last_exc = e
    assert last_exc is not None
    raise last_exc

Then, both _tavily_search and _tavily_extract can be simplified to:

async def _tavily_search(provider_settings: dict, payload: dict) -> list[SearchResult]:
    return await _execute_with_failover(
        _TAVILY_KEY_ROTATOR, provider_settings, _tavily_request_once, payload
    )

async def _tavily_extract(provider_settings: dict, payload: dict) -> list[dict]:
    return await _execute_with_failover(
        _TAVILY_KEY_ROTATOR, provider_settings, _tavily_extract_once, payload
    )

Where _tavily_extract_once is extracted from the original _tavily_extract function:

async def _tavily_extract_once(tavily_key: str, payload: dict) -> list[dict]:
    header = {
        "Authorization": f"Bearer {tavily_key}",
        "Content-Type": "application/json",
    }
    async with aiohttp.ClientSession(trust_env=True) as session:
        async with session.post(
            "https://api.tavily.com/extract",
            json=payload,
            headers=header,
        ) as response:
            if response.status != 200:
                reason = await response.text()
                raise Exception(
                    f"Tavily web search failed: {reason}, status: {response.status}",
                )
            data = await response.json()
            results: list[dict] = data.get("results", [])
            if not results:
                raise ValueError(
                    "Error: Tavily web searcher does not return any results."
                )
            return results
References
  1. When implementing similar functionality for different cases, refactor the logic into a shared helper function to avoid code duplication.



async def _tavily_extract(provider_settings: dict, payload: dict) -> list[dict]:
tavily_key = await _TAVILY_KEY_ROTATOR.get(provider_settings)
header = {
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_web_search_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,54 @@ def fake_client_session(*, trust_env):
{"websearch_exa_key": ["exa-key"]},
{"ids": ["https://example.com"]},
)


@pytest.mark.asyncio
async def test_iter_keys_returns_rotation_order():
rotator = tools._KeyRotator("websearch_tavily_key", "Tavily")
settings = {"websearch_tavily_key": ["k1", "k2", "k3"]}

assert await rotator.iter_keys(settings) == ["k1", "k2", "k3"]
# the starting point advances so a different key leads the next call
assert await rotator.iter_keys(settings) == ["k2", "k3", "k1"]


@pytest.mark.asyncio
async def test_iter_keys_raises_when_unconfigured():
rotator = tools._KeyRotator("websearch_tavily_key", "Tavily")
with pytest.raises(ValueError):
await rotator.iter_keys({"websearch_tavily_key": []})


@pytest.mark.asyncio
async def test_tavily_search_falls_over_to_next_key(monkeypatch):
tried = []

async def fake_request_once(key, payload):
tried.append(key)
if key == "bad-key":
raise Exception("Tavily web search failed: quota exceeded, status: 429")
return [tools.SearchResult(title="t", url="u", snippet="s")]

monkeypatch.setattr(tools, "_tavily_request_once", fake_request_once)
monkeypatch.setattr(tools._TAVILY_KEY_ROTATOR, "index", 0)
settings = {"websearch_tavily_key": ["bad-key", "good-key"]}

results = await tools._tavily_search(settings, {"query": "x"})

# the first key failed, so the search must have moved on to the second key
assert tried == ["bad-key", "good-key"]
assert results[0].title == "t"


@pytest.mark.asyncio
async def test_tavily_search_raises_last_error_when_all_keys_fail(monkeypatch):
async def fake_request_once(key, payload):
raise Exception(f"Tavily web search failed for {key}")

monkeypatch.setattr(tools, "_tavily_request_once", fake_request_once)
monkeypatch.setattr(tools._TAVILY_KEY_ROTATOR, "index", 0)
settings = {"websearch_tavily_key": ["k1", "k2"]}

with pytest.raises(Exception, match="Tavily web search failed for k2"):
await tools._tavily_search(settings, {"query": "x"})
Loading