From fec1abf3de2200f4893081be320ea9c95cd2c253 Mon Sep 17 00:00:00 2001 From: ZouZhang <61148270+Zzzzzzouhang@users.noreply.github.com> Date: Wed, 24 Jun 2026 22:15:12 +0800 Subject: [PATCH 1/7] fix: preserve At components when sending messages on qq_official platform (#8983) - Add At component handling in _parse_to_qqofficial method - Convert At(qq=openid) to <@openid> plain_text format - Maintain original message chain order by appending - Skip At(qq='all') since QQ Official API may not support it Closes #8982 --- astrbot/core/db/migration/migra_45_to_46.py | 60 +- .../db/migration/migra_webchat_session.py | 124 +++- astrbot/core/db/sqlite.py | 28 +- astrbot/core/log.py | 25 + .../sources/wecom_ai_bot/WXBizJsonMsgCrypt.py | 596 +++++++++--------- .../platform/sources/wecom_ai_bot/ierror.py | 58 +- .../core/provider/sources/edge_tts_source.py | 252 ++++---- .../core/provider/sources/gsvi_tts_source.py | 152 ++--- astrbot/core/umop_config_router.py | 139 +++- astrbot/core/utils/astrbot_path.py | 10 +- astrbot/core/utils/migra_helper.py | 92 ++- astrbot/core/utils/t2i/network_strategy.py | 231 ++++++- astrbot/core/utils/t2i/renderer.py | 77 ++- main.py | 6 +- openspec/openapi-v1.yaml | 3 + tests/test_kook/data/kook_card_data.json | 198 +++--- .../data/kook_ws_event_group_message.json | 236 +++---- ...k_ws_event_group_message_with_mention.json | 174 ++--- tests/test_kook/data/kook_ws_event_hello.json | 14 +- .../kook_ws_event_message_with_card_1.json | 142 ++--- .../kook_ws_event_message_with_card_2.json | 156 ++--- tests/test_kook/data/kook_ws_event_ping.json | 6 +- tests/test_kook/data/kook_ws_event_pong.json | 4 +- .../data/kook_ws_event_private_message.json | 126 ++-- .../kook_ws_event_private_system_message.json | 60 +- .../data/kook_ws_event_reconnect_err.json | 12 +- .../test_kook/data/kook_ws_event_resume.json | 6 +- .../data/kook_ws_event_resume_ack.json | 10 +- 28 files changed, 1740 insertions(+), 1257 deletions(-) diff --git a/astrbot/core/db/migration/migra_45_to_46.py b/astrbot/core/db/migration/migra_45_to_46.py index 58736ab51f..10f18d2b5c 100644 --- a/astrbot/core/db/migration/migra_45_to_46.py +++ b/astrbot/core/db/migration/migra_45_to_46.py @@ -1,44 +1,80 @@ +# 导入全局日志记录器和共享偏好设置实例 from astrbot.api import logger, sp +# 导入 AstrBot 配置管理器 from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +# 导入 UMOP 配置路由器 from astrbot.core.umop_config_router import UmopConfigRouter async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None: + """ + 执行从版本 4.5 到 4.6 的数据迁移。 + + 主要变更:在 4.5 版本中,UMOP(统一消息对象标识符)路由信息存储在 + abconf_data 的每个配置项内部(作为 'umop' 字段);在 4.6 版本中, + UMOP 路由被提取到独立的 UmopConfigRouter 中进行管理。 + + 迁移过程: + 1. 检测是否需要迁移(检查配置中是否存在 'umop' 字段) + 2. 提取所有 umop 到 conf_id 的映射关系 + 3. 从原配置中移除 'umop' 字段 + 4. 更新配置存储和 UMOP 路由器 + + Args: + acm: AstrBot 配置管理器实例,包含旧版本的配置数据 + ucr: UMOP 配置路由器实例,用于存储迁移后的路由数据 + """ + # 获取当前的配置数据(包含可能的旧版本 umop 字段) abconf_data = acm.abconf_data + # 验证配置数据类型是否为字典 if not isinstance(abconf_data, dict): - # should be unreachable + # 理论上不应该到达这里,但如果数据格式异常则记录警告并退出 logger.warning( f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}", ) - return + return # 数据类型异常,无法进行迁移 - # 如果任何一项带有 umop,则说明需要迁移 - need_migration = False + # 检查是否需要执行迁移: + # 遍历所有配置项,查找是否包含旧版本的 'umop' 字段 + need_migration = False # 迁移标志,默认为不需要 for conf_id, conf_info in abconf_data.items(): + # 检查配置项是否为字典且包含 'umop' 键 if isinstance(conf_info, dict) and "umop" in conf_info: - need_migration = True - break + need_migration = True # 发现需要迁移的数据 + break # 找到一个就足够了,跳出循环 + # 如果没有需要迁移的数据,直接返回 if not need_migration: return + # 记录迁移开始日志 logger.info("Starting migration from version 4.5 to 4.6") - # extract umo->conf_id mapping - umo_to_conf_id = {} + # 第一步:提取 umo 到 conf_id 的映射关系 + umo_to_conf_id = {} # 初始化 UMOP 到配置 ID 的映射字典 for conf_id, conf_info in abconf_data.items(): + # 只处理包含 'umop' 字段的字典类型配置项 if isinstance(conf_info, dict) and "umop" in conf_info: + # 从配置项中取出并删除 'umop' 字段(pop 方法会同时删除该字段) umop_ls = conf_info.pop("umop") + # 验证 umop 字段是否为列表类型 if not isinstance(umop_ls, list): - continue + continue # 如果不是列表,跳过该配置项 + # 遍历 umop 列表中的每个 UMO 字符串 for umo in umop_ls: + # 确保 umo 是字符串类型且尚未存在于映射中 if isinstance(umo, str) and umo not in umo_to_conf_id: + # 建立 UMO 到配置 ID 的映射关系 umo_to_conf_id[umo] = conf_id - # update the abconf data + # 第二步:更新配置数据到持久化存储 + # 将移除了 umop 字段的配置数据保存到 SharedPreferences await sp.global_put("abconf_mapping", abconf_data) - # update the umop config router + + # 第三步:更新 UMOP 配置路由器 + # 将提取的映射关系批量更新到路由器中 await ucr.update_routing_data(umo_to_conf_id) - logger.info("Migration from version 45 to 46 completed successfully") + # 记录迁移完成日志 + logger.info("Migration from version 45 to 46 completed successfully") \ No newline at end of file diff --git a/astrbot/core/db/migration/migra_webchat_session.py b/astrbot/core/db/migration/migra_webchat_session.py index 46025fc646..e72a255ba4 100644 --- a/astrbot/core/db/migration/migra_webchat_session.py +++ b/astrbot/core/db/migration/migra_webchat_session.py @@ -1,131 +1,185 @@ -"""Migration script for WebChat sessions. +""" +WebChat 会话数据迁移脚本。 -This migration creates PlatformSession from existing platform_message_history records. +此迁移从现有的 platform_message_history 记录中创建 PlatformSession 记录。 -Changes: -- Creates platform_sessions table -- Adds platform_id field (default: 'webchat') -- Adds display_name field -- Session_id format: {platform_id}_{uuid} +变更内容: +- 创建 platform_sessions 表 +- 添加 platform_id 字段(默认值:'webchat') +- 添加 display_name 字段 +- Session_id 格式:{platform_id}_{uuid} """ +# 导入 SQLAlchemy 的聚合函数和查询构建函数 from sqlalchemy import func, select +# 导入 SQLModel 的列选择函数,用于在查询中引用模型列 from sqlmodel import col +# 导入全局日志记录器和共享偏好设置实例 from astrbot.api import logger, sp +# 导入数据库基础操作类 from astrbot.core.db import BaseDatabase +# 导入数据库持久化对象(PO)模型 from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession async def migrate_webchat_session(db_helper: BaseDatabase) -> None: - """Create PlatformSession records from platform_message_history. - - This migration extracts all unique user_ids from platform_message_history - where platform_id='webchat' and creates corresponding PlatformSession records. """ - # 检查是否已经完成迁移 + 从 platform_message_history 表创建 PlatformSession 记录。 + + 此迁移提取 platform_message_history 中所有 platform_id='webchat' 的 + 唯一 user_id,并为每个用户创建对应的 PlatformSession 记录。 + 同时从 Conversations 表中获取对话标题作为会话的显示名称。 + + 迁移过程: + 1. 检查是否已完成迁移(通过偏好设置标记) + 2. 查询所有 WebChat 用户的聊天历史记录 + 3. 检查已存在的会话,避免重复创建 + 4. 从 Conversations 表获取对话标题 + 5. 批量创建 PlatformSession 记录 + 6. 标记迁移完成 + + Args: + db_helper: 数据库操作助手实例,用于数据库访问和偏好设置管理 + """ + # 检查迁移是否已经完成 + # 从偏好设置中读取迁移完成标记 migration_done = await db_helper.get_preference( "global", "global", "migration_done_webchat_session_1" ) + # 如果已经完成迁移,直接返回 if migration_done: return + # 记录迁移开始日志 logger.info("开始执行数据库迁移(WebChat 会话迁移)...") try: + # 获取数据库会话上下文管理器 async with db_helper.get_db() as session: - # 从 platform_message_history 创建 PlatformSession + # 构建查询:从 platform_message_history 中提取 WebChat 用户数据 query = ( select( + # 选择 user_id 字段(作为会话标识) col(PlatformMessageHistory.user_id), + # 选择 sender_name 字段(发送者名称) col(PlatformMessageHistory.sender_name), + # 使用聚合函数获取最早的创建时间 func.min(PlatformMessageHistory.created_at).label("earliest"), + # 使用聚合函数获取最晚的更新时间 func.max(PlatformMessageHistory.updated_at).label("latest"), ) + # 过滤条件:只查询 WebChat 平台的消息 .where(col(PlatformMessageHistory.platform_id) == "webchat") + # 过滤条件:排除机器人自己发送的消息 .where(col(PlatformMessageHistory.sender_id) != "bot") + # 按 user_id 分组,获取每个用户的聚合数据 .group_by(col(PlatformMessageHistory.user_id)) ) + # 执行查询并获取结果 result = await session.execute(query) + # 获取所有查询结果行 webchat_users = result.all() + # 如果没有找到需要迁移的用户数据 if not webchat_users: logger.info("没有找到需要迁移的 WebChat 数据") + # 直接标记迁移完成并返回 await sp.put_async( "global", "global", "migration_done_webchat_session_1", True ) return + # 记录找到的待迁移会话数量 logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移") - # 检查已存在的会话 + # 查询已存在的 PlatformSession,避免重复创建 existing_query = select(col(PlatformSession.session_id)) existing_result = await session.execute(existing_query) + # 将已存在的 session_id 转换为集合,方便快速查找 existing_session_ids = {row[0] for row in existing_result.fetchall()} # 查询 Conversations 表中的 title,用于设置 display_name - # 对于每个 user_id,对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id} + # 构建 Conversations 的 user_id 列表 + # 格式: webchat:FriendMessage:webchat!astrbot!{user_id} user_ids_to_query = [ f"webchat:FriendMessage:webchat!astrbot!{user_id}" for user_id, _, _, _ in webchat_users ] + # 构建查询:获取 Conversations 的标题信息 conv_query = select( - col(ConversationV2.user_id), col(ConversationV2.title) - ).where(col(ConversationV2.user_id).in_(user_ids_to_query)) + col(ConversationV2.user_id), # 对话的用户 ID + col(ConversationV2.title) # 对话的标题 + ).where( + # 筛选出属于这些 WebChat 用户的对话 + col(ConversationV2.user_id).in_(user_ids_to_query) + ) + # 执行对话查询 conv_result = await session.execute(conv_query) # 创建 user_id -> title 的映射字典 + # 从 Conversations 的复合 user_id 中提取原始 user_id 作为键 title_map = { user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title for user_id, title in conv_result.fetchall() } - # 批量创建 PlatformSession 记录 + # 准备批量创建的 PlatformSession 记录列表 sessions_to_add = [] + # 记录跳过的会话数量(已存在的会话) skipped_count = 0 + # 遍历每个 WebChat 用户,创建对应的 PlatformSession for user_id, sender_name, created_at, updated_at in webchat_users: - # user_id 就是 webchat_conv_id (session_id) + # user_id 直接作为 session_id 使用 session_id = user_id - # sender_name 通常是 username,但可能为 None + # 设置创建者名称,如果 sender_name 为空则使用 "guest" creator = sender_name if sender_name else "guest" - # 检查是否已经存在该会话 + # 检查该会话是否已经存在 if session_id in existing_session_ids: logger.debug(f"会话 {session_id} 已存在,跳过") - skipped_count += 1 - continue + skipped_count += 1 # 增加跳过计数 + continue # 跳过已存在的会话 - # 从 Conversations 表中获取 display_name + # 从 Conversations 表的映射中获取 display_name display_name = title_map.get(user_id) - # 创建新的 PlatformSession(保留原有的时间戳) + # 创建新的 PlatformSession 对象,保留原始的时间戳信息 new_session = PlatformSession( - session_id=session_id, - platform_id="webchat", - creator=creator, - is_group=0, - created_at=created_at, - updated_at=updated_at, - display_name=display_name, + session_id=session_id, # 会话唯一标识 + platform_id="webchat", # 平台标识 + creator=creator, # 创建者名称 + is_group=0, # 非群组会话(0 表示私聊) + created_at=created_at, # 原始创建时间 + updated_at=updated_at, # 原始更新时间 + display_name=display_name, # 显示名称(从对话标题获取) ) + # 添加到待插入列表 sessions_to_add.append(new_session) - # 批量插入 + # 批量插入新创建的会话记录 if sessions_to_add: + # 使用 add_all 批量添加所有新会话 session.add_all(sessions_to_add) + # 提交事务,将所有更改写入数据库 await session.commit() + # 记录迁移完成统计信息 logger.info( f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}", ) else: + # 没有新会话需要创建 logger.info("没有新会话需要迁移") - # 标记迁移完成 + # 迁移成功完成,在偏好设置中标记完成状态 await sp.put_async("global", "global", "migration_done_webchat_session_1", True) except Exception as e: + # 捕获迁移过程中的任何异常 + # exc_info=True 会记录完整的异常堆栈信息 logger.error(f"迁移过程中发生错误: {e}", exc_info=True) - raise + # 重新抛出异常,让上层调用者处理 + raise \ No newline at end of file diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index b7706cc513..581a70fd17 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -50,21 +50,21 @@ def __init__(self, db_path: str) -> None: async def initialize(self) -> None: """Initialize the database by creating tables if they do not exist.""" - async with self.engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - await conn.execute(text("PRAGMA journal_mode=WAL")) - await conn.execute(text("PRAGMA busy_timeout=30000")) - await conn.execute(text("PRAGMA synchronous=NORMAL")) - await conn.execute(text("PRAGMA cache_size=20000")) - await conn.execute(text("PRAGMA temp_store=MEMORY")) - await conn.execute(text("PRAGMA mmap_size=134217728")) - await conn.execute(text("PRAGMA optimize")) + async with self.engine.begin() as conn: # 开启数据库事务 + await conn.run_sync(SQLModel.metadata.create_all) # 同步执行 SQLModel.metadata.create_all, 根据所有 SQLModel 模型定义,自动创建不存在的数据库表, create_all 是同步方法,需要用 run_sync 在异步环境中执行 + await conn.execute(text("PRAGMA journal_mode=WAL")) # 设置日志模式为 WAL (Write-Ahead Logging), 允许并发读写,写入不阻塞读取,提升并发性能 + await conn.execute(text("PRAGMA busy_timeout=30000")) # 设置忙等待超时为 30 秒(30000 毫秒),当数据库被锁定时,等待最多 30 秒而不是立即报错 + await conn.execute(text("PRAGMA synchronous=NORMAL")) # 设置同步模式为 NORMAL,在安全性和性能间取得平衡,比 FULL 模式快,比 OFF 模式安全 + await conn.execute(text("PRAGMA cache_size=20000")) # 设置缓存大小为 20000 页(约 80MB),增加内存缓存,减少磁盘 I/O,提升查询速度 + await conn.execute(text("PRAGMA temp_store=MEMORY")) # 将临时表和索引存储在内存中,避免创建临时文件,提升临时操作的速度, + await conn.execute(text("PRAGMA mmap_size=134217728")) # 设置内存映射大小为 128MB (134217728 字节),将数据库文件映射到内存,减少 read/write 系统调用 + await conn.execute(text("PRAGMA optimize")) # 执行数据库优, 分析表并更新查询优化器统计信息,提升查询性能化 # 确保 personas 表有 folder_id、sort_order、skills 列(前向兼容) - await self._ensure_persona_folder_columns(conn) - await self._ensure_persona_skills_column(conn) - await self._ensure_persona_custom_error_message_column(conn) - await self._ensure_platform_message_history_checkpoint_column(conn) - await conn.commit() + await self._ensure_persona_folder_columns(conn) # 确保 personas 表有 folder_id 和 sort_order 列,向前兼容,为旧版本数据库添加新字段 + await self._ensure_persona_skills_column(conn) # 确保 personas 表有 skills 列,向前兼容,为旧数据库补充技能字段 + await self._ensure_persona_custom_error_message_column(conn) # 确保 personas 表有 custom_error_message 列,向前兼容,为旧数据库添加自定义错误消息字段 + await self._ensure_platform_message_history_checkpoint_column(conn) # 确保 platform_message_history 表有 llm_checkpoint_id 列,向前兼容,为旧数据库添加 llm_checkpoint_id 字段 + await conn.commit() # 提交所有更改,确保表创建和 PRAGMA 设置持久化到数据库 async def _ensure_persona_folder_columns(self, conn) -> None: """确保 personas 表有 folder_id 和 sort_order 列。 diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 3dd0719b11..e18713ecae 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -264,6 +264,31 @@ def GetLogger(cls, log_name: str = "default") -> logging.Logger: @classmethod def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None: + """ + 这段代码是一个类方法,用于给指定的日志记录器配置一个队列日志处理器。具体作用如下: + + 1. **确保日志增强过滤器存在**:调用 `_ensure_logger_enricher_filter(logger)` 为日志记录器添加必要的信息增强过滤器。 + + 2. **避免重复添加**:遍历日志记录器已有的处理器,如果已经存在 `LogQueueHandler` 类型的处理器,则直接返回,防止重复配置。 + + 3. **创建并配置队列日志处理器**: + - 创建一个 `LogQueueHandler` 实例,传入 `log_broker`(日志代理/中间件) + - 设置日志级别为 `DEBUG`,接收所有级别的日志 + - 添加 ANSI 颜色过滤器,用于过滤控制台颜色代码 + - 设置日志格式,包含: + - `ansi_prefix`:ANSI 颜色前缀 + - 时间戳(精确到毫秒) + - 插件标签 + - 日志级别缩写 + - 版本标签 + - 源文件名和行号 + - 日志消息内容 + - `ansi_reset`:ANSI 颜色重置 + + 4. **激活处理器**:将配置好的处理器添加到日志记录器中,使其开始工作。 + + 整体来说,这是一个**队列化的日志输出配置**,将日志消息通过代理/中间件异步发送,常用于分布式系统或需要统一处理日志的场景。 + """ cls._ensure_logger_enricher_filter(logger) for handler in logger.handlers: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 260b950d19..e24a68d054 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -1,298 +1,298 @@ -#!/usr/bin/env python - -"""对企业微信发送给企业后台的消息加解密示例代码. -@copyright: Copyright (c) 1998-2020 Tencent Inc. - -""" -# ------------------------------------------------------------------------ - -import base64 -import hashlib -import json -import logging -import secrets -import socket -import struct -import time -from typing import NoReturn - -from Crypto.Cipher import AES - -from . import ierror - -""" -关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案 -请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。 -下载后,按照README中的“Installation”小节的提示进行pycrypto安装。 -""" - - -class FormatException(Exception): - pass - - -def throw_exception(message, exception_class=FormatException) -> NoReturn: - """My define raise exception function""" - raise exception_class(message) - - -class SHA1: - """计算企业微信的消息签名接口""" - - def getSHA1(self, token, timestamp, nonce, encrypt): - """用SHA1算法生成安全签名 - @param token: 票据 - @param timestamp: 时间戳 - @param encrypt: 密文 - @param nonce: 随机字符串 - @return: 安全签名 - """ - try: - # 确保所有输入都是字符串类型 - if isinstance(encrypt, bytes): - encrypt = encrypt.decode("utf-8") - - sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)] - sortlist.sort() - sha = hashlib.sha1() - sha.update("".join(sortlist).encode("utf-8")) - return ierror.WXBizMsgCrypt_OK, sha.hexdigest() - - except Exception as e: - print(e) - return ierror.WXBizMsgCrypt_ComputeSignature_Error, None - - -class JsonParse: - """提供提取消息格式中的密文及生成回复消息格式的接口""" - - # json消息模板 - AES_TEXT_RESPONSE_TEMPLATE = """{ - "encrypt": "%(msg_encrypt)s", - "msgsignature": "%(msg_signaturet)s", - "timestamp": "%(timestamp)s", - "nonce": "%(nonce)s" - }""" - - def extract(self, jsontext): - """提取出json数据包中的加密消息 - @param jsontext: 待提取的json字符串 - @return: 提取出的加密消息字符串 - """ - try: - json_dict = json.loads(jsontext) - return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"] - except Exception as e: - print(e) - return ierror.WXBizMsgCrypt_ParseJson_Error, None - - def generate(self, encrypt, signature, timestamp, nonce): - """生成json消息 - @param encrypt: 加密后的消息密文 - @param signature: 安全签名 - @param timestamp: 时间戳 - @param nonce: 随机字符串 - @return: 生成的json字符串 - """ - resp_dict = { - "msg_encrypt": encrypt, - "msg_signaturet": signature, - "timestamp": timestamp, - "nonce": nonce, - } - resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict - return resp_json - - -class PKCS7Encoder: - """提供基于PKCS7算法的加解密接口""" - - block_size = 32 - - def encode(self, text): - """对需要加密的明文进行填充补位 - @param text: 需要进行填充补位操作的明文(bytes类型) - @return: 补齐明文字符串(bytes类型) - """ - text_length = len(text) - # 计算需要填充的位数 - amount_to_pad = self.block_size - (text_length % self.block_size) - if amount_to_pad == 0: - amount_to_pad = self.block_size - # 获得补位所用的字符 - pad = bytes([amount_to_pad]) - # 确保text是bytes类型 - if isinstance(text, str): - text = text.encode("utf-8") - return text + pad * amount_to_pad - - def decode(self, decrypted): - """删除解密后明文的补位字符 - @param decrypted: 解密后的明文 - @return: 删除补位字符后的明文 - """ - pad = ord(decrypted[-1]) - if pad < 1 or pad > 32: - pad = 0 - return decrypted[:-pad] - - -class Prpcrypt: - """提供接收和推送给企业微信消息的加解密接口""" - - # 16位随机字符串的范围常量 - # randbelow(RANDOM_RANGE) 返回 [0, 8999999999999999](两端都包含,即包含0和8999999999999999) - # 加上 MIN_RANDOM_VALUE 后得到 [1000000000000000, 9999999999999999](两端都包含)即16位数字 - MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位) - RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位) - - def __init__(self, key) -> None: - # self.key = base64.b64decode(key+"=") - self.key = key - # 设置加解密模式为AES的CBC模式 - self.mode = AES.MODE_CBC - - def encrypt(self, text, receiveid): - """对明文进行加密 - @param text: 需要加密的明文 - @return: 加密得到的字符串 - """ - # 16位随机字符串添加到明文开头 - text = text.encode() - text = ( - self.get_random_str() - + struct.pack("I", socket.htonl(len(text))) - + text - + receiveid.encode() - ) - - # 使用自定义的填充方式对明文进行补位填充 - pkcs7 = PKCS7Encoder() - text = pkcs7.encode(text) - # 加密 - cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore - try: - ciphertext = cryptor.encrypt(text) - # 使用BASE64对加密后的字符串进行编码 - return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext) - except Exception as e: - logger = logging.getLogger("astrbot") - logger.error(e) - return ierror.WXBizMsgCrypt_EncryptAES_Error, None - - def decrypt(self, text, receiveid): - """对解密后的明文进行补位删除 - @param text: 密文 - @return: 删除填充补位后的明文 - """ - try: - cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore - # 使用BASE64对密文进行解码,然后AES-CBC解密 - plain_text = cryptor.decrypt(base64.b64decode(text)) - except Exception as e: - print(e) - return ierror.WXBizMsgCrypt_DecryptAES_Error, None - try: - pad = plain_text[-1] - # 去掉补位字符串 - # pkcs7 = PKCS7Encoder() - # plain_text = pkcs7.encode(plain_text) - # 去除16位随机字符串 - content = plain_text[16:-pad] - json_len = socket.ntohl(struct.unpack("I", content[:4])[0]) - json_content = content[4 : json_len + 4].decode("utf-8") - from_receiveid = content[json_len + 4 :].decode("utf-8") - except Exception as e: - print(e) - return ierror.WXBizMsgCrypt_IllegalBuffer, None - if from_receiveid != receiveid: - print("receiveid not match", receiveid, from_receiveid) - return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None - return 0, json_content - - def get_random_str(self): - """随机生成16位字符串 - @return: 16位字符串 - """ - return str( - secrets.randbelow(self.RANDOM_RANGE) + self.MIN_RANDOM_VALUE - ).encode() - - -class WXBizJsonMsgCrypt: - # 构造函数 - def __init__(self, sToken, sEncodingAESKey, sReceiveId) -> None: - try: - self.key = base64.b64decode(sEncodingAESKey + "=") - assert len(self.key) == 32 - except Exception as e: - throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException) - # return ierror.WXBizMsgCrypt_IllegalAesKey,None - self.m_sToken = sToken - self.m_sReceiveId = sReceiveId - - # 验证URL - # @param sMsgSignature: 签名串,对应URL参数的msg_signature - # @param sTimeStamp: 时间戳,对应URL参数的timestamp - # @param sNonce: 随机串,对应URL参数的nonce - # @param sEchoStr: 随机串,对应URL参数的echostr - # @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效 - # @return:成功0,失败返回对应的错误码 - - def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr): - sha1 = SHA1() - ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr) - if ret != 0: - return ret, None - if not signature == sMsgSignature: - return ierror.WXBizMsgCrypt_ValidateSignature_Error, None - pc = Prpcrypt(self.key) - ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId) - return ret, sReplyEchoStr - - def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None): - # 将企业回复用户的消息加密打包 - # @param sReplyMsg: 企业号待回复用户的消息,json格式的字符串 - # @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间 - # @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce - # sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的json格式的字符串, - # return:成功0,sEncryptMsg,失败返回对应的错误码None - pc = Prpcrypt(self.key) - ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId) - encrypt = encrypt.decode("utf-8") # type: ignore - if ret != 0: - return ret, None - if timestamp is None: - timestamp = str(int(time.time())) - # 生成安全签名 - sha1 = SHA1() - ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt) - if ret != 0: - return ret, None - jsonParse = JsonParse() - return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce) - - def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce): - # 检验消息的真实性,并且获取解密后的明文 - # @param sMsgSignature: 签名串,对应URL参数的msg_signature - # @param sTimeStamp: 时间戳,对应URL参数的timestamp - # @param sNonce: 随机串,对应URL参数的nonce - # @param sPostData: 密文,对应POST请求的数据 - # json_content: 解密后的原文,当return返回0时有效 - # @return: 成功0,失败返回对应的错误码 - # 验证安全签名 - jsonParse = JsonParse() - ret, encrypt = jsonParse.extract(sPostData) - if ret != 0: - return ret, None - sha1 = SHA1() - ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt) - if ret != 0: - return ret, None - if not signature == sMsgSignature: - print("signature not match") - print(signature) - return ierror.WXBizMsgCrypt_ValidateSignature_Error, None - pc = Prpcrypt(self.key) - ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId) - return ret, json_content +#!/usr/bin/env python + +"""对企业微信发送给企业后台的消息加解密示例代码. +@copyright: Copyright (c) 1998-2020 Tencent Inc. + +""" +# ------------------------------------------------------------------------ + +import base64 +import hashlib +import json +import logging +import secrets +import socket +import struct +import time +from typing import NoReturn + +from Crypto.Cipher import AES + +from . import ierror + +""" +关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案 +请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。 +下载后,按照README中的“Installation”小节的提示进行pycrypto安装。 +""" + + +class FormatException(Exception): + pass + + +def throw_exception(message, exception_class=FormatException) -> NoReturn: + """My define raise exception function""" + raise exception_class(message) + + +class SHA1: + """计算企业微信的消息签名接口""" + + def getSHA1(self, token, timestamp, nonce, encrypt): + """用SHA1算法生成安全签名 + @param token: 票据 + @param timestamp: 时间戳 + @param encrypt: 密文 + @param nonce: 随机字符串 + @return: 安全签名 + """ + try: + # 确保所有输入都是字符串类型 + if isinstance(encrypt, bytes): + encrypt = encrypt.decode("utf-8") + + sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)] + sortlist.sort() + sha = hashlib.sha1() + sha.update("".join(sortlist).encode("utf-8")) + return ierror.WXBizMsgCrypt_OK, sha.hexdigest() + + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_ComputeSignature_Error, None + + +class JsonParse: + """提供提取消息格式中的密文及生成回复消息格式的接口""" + + # json消息模板 + AES_TEXT_RESPONSE_TEMPLATE = """{ + "encrypt": "%(msg_encrypt)s", + "msgsignature": "%(msg_signaturet)s", + "timestamp": "%(timestamp)s", + "nonce": "%(nonce)s" + }""" + + def extract(self, jsontext): + """提取出json数据包中的加密消息 + @param jsontext: 待提取的json字符串 + @return: 提取出的加密消息字符串 + """ + try: + json_dict = json.loads(jsontext) + return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"] + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_ParseJson_Error, None + + def generate(self, encrypt, signature, timestamp, nonce): + """生成json消息 + @param encrypt: 加密后的消息密文 + @param signature: 安全签名 + @param timestamp: 时间戳 + @param nonce: 随机字符串 + @return: 生成的json字符串 + """ + resp_dict = { + "msg_encrypt": encrypt, + "msg_signaturet": signature, + "timestamp": timestamp, + "nonce": nonce, + } + resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict + return resp_json + + +class PKCS7Encoder: + """提供基于PKCS7算法的加解密接口""" + + block_size = 32 + + def encode(self, text): + """对需要加密的明文进行填充补位 + @param text: 需要进行填充补位操作的明文(bytes类型) + @return: 补齐明文字符串(bytes类型) + """ + text_length = len(text) + # 计算需要填充的位数 + amount_to_pad = self.block_size - (text_length % self.block_size) + if amount_to_pad == 0: + amount_to_pad = self.block_size + # 获得补位所用的字符 + pad = bytes([amount_to_pad]) + # 确保text是bytes类型 + if isinstance(text, str): + text = text.encode("utf-8") + return text + pad * amount_to_pad + + def decode(self, decrypted): + """删除解密后明文的补位字符 + @param decrypted: 解密后的明文 + @return: 删除补位字符后的明文 + """ + pad = ord(decrypted[-1]) + if pad < 1 or pad > 32: + pad = 0 + return decrypted[:-pad] + + +class Prpcrypt: + """提供接收和推送给企业微信消息的加解密接口""" + + # 16位随机字符串的范围常量 + # randbelow(RANDOM_RANGE) 返回 [0, 8999999999999999](两端都包含,即包含0和8999999999999999) + # 加上 MIN_RANDOM_VALUE 后得到 [1000000000000000, 9999999999999999](两端都包含)即16位数字 + MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位) + RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位) + + def __init__(self, key) -> None: + # self.key = base64.b64decode(key+"=") + self.key = key + # 设置加解密模式为AES的CBC模式 + self.mode = AES.MODE_CBC + + def encrypt(self, text, receiveid): + """对明文进行加密 + @param text: 需要加密的明文 + @return: 加密得到的字符串 + """ + # 16位随机字符串添加到明文开头 + text = text.encode() + text = ( + self.get_random_str() + + struct.pack("I", socket.htonl(len(text))) + + text + + receiveid.encode() + ) + + # 使用自定义的填充方式对明文进行补位填充 + pkcs7 = PKCS7Encoder() + text = pkcs7.encode(text) + # 加密 + cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore + try: + ciphertext = cryptor.encrypt(text) + # 使用BASE64对加密后的字符串进行编码 + return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext) + except Exception as e: + logger = logging.getLogger("astrbot") + logger.error(e) + return ierror.WXBizMsgCrypt_EncryptAES_Error, None + + def decrypt(self, text, receiveid): + """对解密后的明文进行补位删除 + @param text: 密文 + @return: 删除填充补位后的明文 + """ + try: + cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore + # 使用BASE64对密文进行解码,然后AES-CBC解密 + plain_text = cryptor.decrypt(base64.b64decode(text)) + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_DecryptAES_Error, None + try: + pad = plain_text[-1] + # 去掉补位字符串 + # pkcs7 = PKCS7Encoder() + # plain_text = pkcs7.encode(plain_text) + # 去除16位随机字符串 + content = plain_text[16:-pad] + json_len = socket.ntohl(struct.unpack("I", content[:4])[0]) + json_content = content[4 : json_len + 4].decode("utf-8") + from_receiveid = content[json_len + 4 :].decode("utf-8") + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_IllegalBuffer, None + if from_receiveid != receiveid: + print("receiveid not match", receiveid, from_receiveid) + return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None + return 0, json_content + + def get_random_str(self): + """随机生成16位字符串 + @return: 16位字符串 + """ + return str( + secrets.randbelow(self.RANDOM_RANGE) + self.MIN_RANDOM_VALUE + ).encode() + + +class WXBizJsonMsgCrypt: + # 构造函数 + def __init__(self, sToken, sEncodingAESKey, sReceiveId) -> None: + try: + self.key = base64.b64decode(sEncodingAESKey + "=") + assert len(self.key) == 32 + except Exception as e: + throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException) + # return ierror.WXBizMsgCrypt_IllegalAesKey,None + self.m_sToken = sToken + self.m_sReceiveId = sReceiveId + + # 验证URL + # @param sMsgSignature: 签名串,对应URL参数的msg_signature + # @param sTimeStamp: 时间戳,对应URL参数的timestamp + # @param sNonce: 随机串,对应URL参数的nonce + # @param sEchoStr: 随机串,对应URL参数的echostr + # @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效 + # @return:成功0,失败返回对应的错误码 + + def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr): + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId) + return ret, sReplyEchoStr + + def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None): + # 将企业回复用户的消息加密打包 + # @param sReplyMsg: 企业号待回复用户的消息,json格式的字符串 + # @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间 + # @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce + # sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的json格式的字符串, + # return:成功0,sEncryptMsg,失败返回对应的错误码None + pc = Prpcrypt(self.key) + ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId) + encrypt = encrypt.decode("utf-8") # type: ignore + if ret != 0: + return ret, None + if timestamp is None: + timestamp = str(int(time.time())) + # 生成安全签名 + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt) + if ret != 0: + return ret, None + jsonParse = JsonParse() + return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce) + + def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce): + # 检验消息的真实性,并且获取解密后的明文 + # @param sMsgSignature: 签名串,对应URL参数的msg_signature + # @param sTimeStamp: 时间戳,对应URL参数的timestamp + # @param sNonce: 随机串,对应URL参数的nonce + # @param sPostData: 密文,对应POST请求的数据 + # json_content: 解密后的原文,当return返回0时有效 + # @return: 成功0,失败返回对应的错误码 + # 验证安全签名 + jsonParse = JsonParse() + ret, encrypt = jsonParse.extract(sPostData) + if ret != 0: + return ret, None + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + print("signature not match") + print(signature) + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId) + return ret, json_content diff --git a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py index 0df14a5059..febf321939 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py @@ -1,19 +1,39 @@ -#!/usr/bin/env python -######################################################################### -# Author: jonyqin -# Created Time: Thu 11 Sep 2014 01:53:58 PM CST -# File Name: ierror.py -# Description:定义错误码含义 -######################################################################### -WXBizMsgCrypt_OK = 0 -WXBizMsgCrypt_ValidateSignature_Error = -40001 -WXBizMsgCrypt_ParseJson_Error = -40002 -WXBizMsgCrypt_ComputeSignature_Error = -40003 -WXBizMsgCrypt_IllegalAesKey = -40004 -WXBizMsgCrypt_ValidateCorpid_Error = -40005 -WXBizMsgCrypt_EncryptAES_Error = -40006 -WXBizMsgCrypt_DecryptAES_Error = -40007 -WXBizMsgCrypt_IllegalBuffer = -40008 -WXBizMsgCrypt_EncodeBase64_Error = -40009 -WXBizMsgCrypt_DecodeBase64_Error = -40010 -WXBizMsgCrypt_GenReturnJson_Error = -40011 +#!/usr/bin/env python +######################################################################### +# Author: jonyqin +# Created Time: Thu 11 Sep 2014 01:53:58 PM CST +# File Name: ierror.py +# Description:定义错误码含义 +######################################################################### +WXBizMsgCrypt_OK = 0 +WXBizMsgCrypt_ValidateSignature_Error = -40001 +WXBizMsgCrypt_ParseJson_Error = -40002 +WXBizMsgCrypt_ComputeSignature_Error = -40003 +WXBizMsgCrypt_IllegalAesKey = -40004 +WXBizMsgCrypt_ValidateCorpid_Error = -40005 +WXBizMsgCrypt_EncryptAES_Error = -40006 +WXBizMsgCrypt_DecryptAES_Error = -40007 +WXBizMsgCrypt_IllegalBuffer = -40008 +WXBizMsgCrypt_EncodeBase64_Error = -40009 +WXBizMsgCrypt_DecodeBase64_Error = -40010 +WXBizMsgCrypt_GenReturnJson_Error = -40011 + +#!/usr/bin/env python +######################################################################### +# Author: jonyqin +# Created Time: Thu 11 Sep 2014 01:53:58 PM CST +# File Name: ierror.py +# Description:定义错误码含义 +######################################################################### +WXBizMsgCrypt_OK = 0 +WXBizMsgCrypt_ValidateSignature_Error = -40001 +WXBizMsgCrypt_ParseJson_Error = -40002 +WXBizMsgCrypt_ComputeSignature_Error = -40003 +WXBizMsgCrypt_IllegalAesKey = -40004 +WXBizMsgCrypt_ValidateCorpid_Error = -40005 +WXBizMsgCrypt_EncryptAES_Error = -40006 +WXBizMsgCrypt_DecryptAES_Error = -40007 +WXBizMsgCrypt_IllegalBuffer = -40008 +WXBizMsgCrypt_EncodeBase64_Error = -40009 +WXBizMsgCrypt_DecodeBase64_Error = -40010 +WXBizMsgCrypt_GenReturnJson_Error = -40011 diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 503bd275b4..2ccec27b57 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -1,126 +1,126 @@ -import asyncio -import os -import subprocess -import uuid - -import edge_tts - -from astrbot.core import logger -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..entities import ProviderType -from ..provider import TTSProvider -from ..register import register_provider_adapter - -""" -edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 -``` -pip install edge_tts -``` -Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot -""" - - -@register_provider_adapter( - "edge_tts", - "Microsoft Edge TTS", - provider_type=ProviderType.TEXT_TO_SPEECH, -) -class ProviderEdgeTTS(TTSProvider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - super().__init__(provider_config, provider_settings) - - # 设置默认语音,如果没有指定则使用中文小萱 - self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") - self.rate = provider_config.get("rate") - self.volume = provider_config.get("volume") - self.pitch = provider_config.get("pitch") - self.timeout = provider_config.get("timeout", 30) - - self.proxy = os.getenv("https_proxy", None) - - self.set_model("edge_tts") - - async def get_audio(self, text: str) -> str: - temp_dir = get_astrbot_temp_path() - mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") - wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") - - # 构建 Edge TTS 参数 - kwargs = {"text": text, "voice": self.voice} - if self.rate: - kwargs["rate"] = self.rate - if self.volume: - kwargs["volume"] = self.volume - if self.pitch: - kwargs["pitch"] = self.pitch - - try: - communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs) - await communicate.save(mp3_path) - - try: - from pyffmpeg import FFmpeg - - ff = FFmpeg() - ff.convert(input_file=mp3_path, output_file=wav_path) - except Exception as e: - logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") - # use ffmpeg command line - - # 使用ffmpeg将MP3转换为标准WAV格式 - p = await asyncio.create_subprocess_exec( - "ffmpeg", - "-y", # 覆盖输出文件 - "-i", - mp3_path, # 输入文件 - "-acodec", - "pcm_s16le", # 16位PCM编码 - "-ar", - "24000", # 采样率24kHz (适合微信语音) - "-ac", - "1", # 单声道 - "-af", - "apad=pad_dur=2", # 确保输出时长准确 - "-fflags", - "+genpts", # 强制生成时间戳 - "-hide_banner", # 隐藏版本信息 - wav_path, # 输出文件 - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - # 等待进程完成并获取输出 - stdout, stderr = await p.communicate() - logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}") - logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}") - logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}") - - os.remove(mp3_path) - if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0: - return wav_path - logger.error("生成的WAV文件不存在或为空") - raise RuntimeError("生成的WAV文件不存在或为空") - - except subprocess.CalledProcessError as e: - logger.error( - f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}", - ) - try: - if os.path.exists(mp3_path): - os.remove(mp3_path) - except Exception: - pass - raise RuntimeError(f"FFmpeg 转换失败: {e!s}") - - except Exception as e: - logger.error(f"音频生成失败: {e!s}") - try: - if os.path.exists(mp3_path): - os.remove(mp3_path) - except Exception: - pass - raise RuntimeError(f"音频生成失败: {e!s}") +import asyncio +import os +import subprocess +import uuid + +import edge_tts + +from astrbot.core import logger +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + +""" +edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 +``` +pip install edge_tts +``` +Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot +""" + + +@register_provider_adapter( + "edge_tts", + "Microsoft Edge TTS", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderEdgeTTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + + # 设置默认语音,如果没有指定则使用中文小萱 + self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") + self.rate = provider_config.get("rate") + self.volume = provider_config.get("volume") + self.pitch = provider_config.get("pitch") + self.timeout = provider_config.get("timeout", 30) + + self.proxy = os.getenv("https_proxy", None) + + self.set_model("edge_tts") + + async def get_audio(self, text: str) -> str: + temp_dir = get_astrbot_temp_path() + mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") + wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") + + # 构建 Edge TTS 参数 + kwargs = {"text": text, "voice": self.voice} + if self.rate: + kwargs["rate"] = self.rate + if self.volume: + kwargs["volume"] = self.volume + if self.pitch: + kwargs["pitch"] = self.pitch + + try: + communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs) + await communicate.save(mp3_path) + + try: + from pyffmpeg import FFmpeg + + ff = FFmpeg() + ff.convert(input_file=mp3_path, output_file=wav_path) + except Exception as e: + logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") + # use ffmpeg command line + + # 使用ffmpeg将MP3转换为标准WAV格式 + p = await asyncio.create_subprocess_exec( + "ffmpeg", + "-y", # 覆盖输出文件 + "-i", + mp3_path, # 输入文件 + "-acodec", + "pcm_s16le", # 16位PCM编码 + "-ar", + "24000", # 采样率24kHz (适合微信语音) + "-ac", + "1", # 单声道 + "-af", + "apad=pad_dur=2", # 确保输出时长准确 + "-fflags", + "+genpts", # 强制生成时间戳 + "-hide_banner", # 隐藏版本信息 + wav_path, # 输出文件 + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + # 等待进程完成并获取输出 + stdout, stderr = await p.communicate() + logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}") + logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}") + logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}") + + os.remove(mp3_path) + if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0: + return wav_path + logger.error("生成的WAV文件不存在或为空") + raise RuntimeError("生成的WAV文件不存在或为空") + + except subprocess.CalledProcessError as e: + logger.error( + f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}", + ) + try: + if os.path.exists(mp3_path): + os.remove(mp3_path) + except Exception: + pass + raise RuntimeError(f"FFmpeg 转换失败: {e!s}") + + except Exception as e: + logger.error(f"音频生成失败: {e!s}") + try: + if os.path.exists(mp3_path): + os.remove(mp3_path) + except Exception: + pass + raise RuntimeError(f"音频生成失败: {e!s}") diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index 55a0975de6..edcc39af5f 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -1,76 +1,76 @@ -import uuid -from pathlib import Path - -import aiohttp - -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..entities import ProviderType -from ..provider import TTSProvider -from ..register import register_provider_adapter - - -@register_provider_adapter( - "gsvi_tts_api", - "GSVI TTS API", - provider_type=ProviderType.TEXT_TO_SPEECH, -) -class ProviderGSVITTS(TTSProvider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - super().__init__(provider_config, provider_settings) - self.api_key = provider_config.get("api_key", "") - self.api_base = provider_config.get("api_base", "http://127.0.0.1:8000") - self.api_base = self.api_base.removesuffix("/") - self.version = provider_config.get("version", "v4") - self.character = provider_config.get("character") - self.prompt_text_lang = provider_config.get("prompt_text_lang", "中文") - self.emotion = provider_config.get("emotion", "默认") - self.text_lang = provider_config.get("text_lang", "中文") - - async def get_audio(self, text: str) -> str: - temp_dir = get_astrbot_temp_path() - path = Path(temp_dir) / f"gsvi_tts_{uuid.uuid4()}.wav" - url = f"{self.api_base}/infer_single" - - headers = {"Content-Type": "application/json"} - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - data = { - "dl_url": self.api_base, - "version": self.version, - "model_name": self.character, - "prompt_text_lang": self.prompt_text_lang, - "emotion": self.emotion, - "text": text, - "text_lang": self.text_lang, - } - - async with aiohttp.ClientSession() as session: - async with session.post(url, json=data, headers=headers) as response: - if response.status == 200: - resp_json = await response.json() - msg = resp_json.get("msg") - audio_url = resp_json.get("audio_url") - if not msg or msg != "合成成功": - raise Exception(f"GSVI TTS API 合成失败: {msg}") - async with session.get(audio_url) as audio_response: - if audio_response.status == 200: - with open(path, "wb") as f: - f.write(await audio_response.read()) - else: - error_text = await audio_response.text() - raise Exception( - f"GSVI TTS API 下载音频失败,状态码: {audio_response.status},错误: {error_text}", - ) - else: - error_text = await response.text() - raise Exception( - f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", - ) - - return str(path) +import uuid +from pathlib import Path + +import aiohttp + +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + + +@register_provider_adapter( + "gsvi_tts_api", + "GSVI TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderGSVITTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.api_key = provider_config.get("api_key", "") + self.api_base = provider_config.get("api_base", "http://127.0.0.1:8000") + self.api_base = self.api_base.removesuffix("/") + self.version = provider_config.get("version", "v4") + self.character = provider_config.get("character") + self.prompt_text_lang = provider_config.get("prompt_text_lang", "中文") + self.emotion = provider_config.get("emotion", "默认") + self.text_lang = provider_config.get("text_lang", "中文") + + async def get_audio(self, text: str) -> str: + temp_dir = get_astrbot_temp_path() + path = Path(temp_dir) / f"gsvi_tts_{uuid.uuid4()}.wav" + url = f"{self.api_base}/infer_single" + + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + data = { + "dl_url": self.api_base, + "version": self.version, + "model_name": self.character, + "prompt_text_lang": self.prompt_text_lang, + "emotion": self.emotion, + "text": text, + "text_lang": self.text_lang, + } + + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data, headers=headers) as response: + if response.status == 200: + resp_json = await response.json() + msg = resp_json.get("msg") + audio_url = resp_json.get("audio_url") + if not msg or msg != "合成成功": + raise Exception(f"GSVI TTS API 合成失败: {msg}") + async with session.get(audio_url) as audio_response: + if audio_response.status == 200: + with open(path, "wb") as f: + f.write(await audio_response.read()) + else: + error_text = await audio_response.text() + raise Exception( + f"GSVI TTS API 下载音频失败,状态码: {audio_response.status},错误: {error_text}", + ) + else: + error_text = await response.text() + raise Exception( + f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", + ) + + return str(path) diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index c2588e6c29..f3b6d644ef 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -1,119 +1,206 @@ +# 导入 fnmatch 模块,用于 Unix 风格的文件名模式匹配(支持 * 通配符) import fnmatch +# 导入共享偏好设置类,用于持久化存储配置数据 from astrbot.core.utils.shared_preferences import SharedPreferences class UmopConfigRouter: - """UMOP 配置路由器""" + """ + UMOP 配置路由器。 + 负责管理 UMO(统一消息对象标识符)到配置文件 ID 的路由映射。 + 支持通配符匹配,实现灵活的消息路由配置。 + + UMO 格式: [platform_id]:[message_type]:[session_id] + 通配符规则: + - "::" 匹配所有消息 + - "[platform_id]::" 匹配指定平台的所有消息 + - 使用 fnmatch 的 * 通配符进行模式匹配 + """ def __init__(self, sp: SharedPreferences) -> None: + """ + 初始化 UMOP 配置路由器。 + + Args: + sp: SharedPreferences 实例,用于持久化存储路由配置 + """ + # 初始化 UMOP 到配置文件 ID 的映射字典 self.umop_to_conf_id: dict[str, str] = {} """UMOP 到配置文件 ID 的映射""" + # 保存 SharedPreferences 实例引用 self.sp = sp async def initialize(self) -> None: + """ + 异步初始化路由器。 + 从持久化存储中加载路由表数据。 + """ + # 调用内部方法加载路由表 await self._load_routing_table() async def _load_routing_table(self) -> None: - """加载路由表""" - # 从 SharedPreferences 中加载 umop_to_conf_id 映射 + """ + 从 SharedPreferences 加载路由表。 + 读取持久化存储中的 UMOP 到配置 ID 的映射数据。 + """ + # 从 SharedPreferences 中异步获取路由配置数据 sp_data = await self.sp.get_async( - key="umop_config_routing", - default={}, - scope="global", - scope_id="global", + key="umop_config_routing", # 存储键名 + default={}, # 默认值为空字典 + scope="global", # 全局作用域 + scope_id="global", # 全局作用域 ID ) + # 更新内存中的路由映射 self.umop_to_conf_id = sp_data @staticmethod def _split_umo(umo: str) -> tuple[str, str, str] | None: - """将 UMO 拆分为 3 个部分,同时保留 session_id 中的 ':'""" + """ + 将 UMO 字符串拆分为三个部分。 + 保留 session_id 中可能存在的 ':' 字符。 + + Args: + umo: UMO 字符串,格式为 platform_id:message_type:session_id + + Returns: + tuple[str, str, str] | None: 成功返回三元素元组 (platform_id, message_type, session_id), + 如果输入不是字符串或格式不正确则返回 None + """ + # 检查输入是否为字符串类型 if not isinstance(umo, str): return None + # 按 ':' 分割,最多分割 2 次(保留 session_id 中的 ':') parts = umo.split(":", 2) + # 验证分割结果是否为 3 个部分 if len(parts) != 3: return None + # 返回三个部分的元组 return parts[0], parts[1], parts[2] def _is_umo_match(self, p1: str, p2: str) -> bool: - """判断 p2 umo 是否逻辑包含于 p1 umo""" + """ + 判断 p2 的 UMO 是否与 p1 的模式匹配。 + p1 作为模式(可包含通配符),p2 作为待匹配的完整 UMO。 + + 匹配规则: + - 空字符串表示匹配所有 + - 支持 fnmatch 的通配符 * 和 ? + - 三个部分分别匹配,全部匹配成功才返回 True + + Args: + p1: 模式 UMO 字符串(可能包含通配符) + p2: 完整的 UMO 字符串 + + Returns: + bool: 如果 p2 匹配 p1 的模式则返回 True,否则返回 False + """ + # 将模式 UMO 拆分为三部分 p1_ls = self._split_umo(p1) + # 将待匹配 UMO 拆分为三部分 p2_ls = self._split_umo(p2) + # 如果任一 UMO 格式非法,返回 False if p1_ls is None or p2_ls is None: return False # 非法格式 + # 逐部分比较: + # p 是模式,t 是待匹配的目标 + # p == "" 表示该部分匹配所有 + # 否则使用 fnmatch 进行通配符匹配(区分大小写) return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls)) def get_conf_id_for_umop(self, umo: str) -> str | None: - """根据 UMO 获取对应的配置文件 ID + """ + 根据 UMO 获取对应的配置文件 ID。 + 遍历所有路由规则,返回第一个匹配的配置 ID。 Args: - umo (str): UMO 字符串 + umo: 完整的 UMO 字符串 Returns: - str | None: 配置文件 ID,如果没有找到则返回 None - + str | None: 匹配的配置文件 ID,如果没有找到匹配则返回 None """ + # 遍历所有路由映射 for pattern, conf_id in self.umop_to_conf_id.items(): + # 检查当前 UMO 是否匹配该模式 if self._is_umo_match(pattern, umo): + # 返回匹配的配置 ID return conf_id + # 没有匹配的路由,返回 None return None async def update_routing_data(self, new_routing: dict[str, str]) -> None: - """更新路由表 + """ + 批量更新整个路由表。 + 用新的路由映射替换所有现有路由。 Args: - new_routing (dict[str, str]): 新的 UMOP 到配置文件 ID 的映射。umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。 - umop 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。 + new_routing: 新的 UMOP 到配置文件 ID 的映射字典。 + UMO 格式: [platform_id]:[message_type]:[session_id] + 支持通配符: + - "::" 代表匹配所有消息 + - "[platform_id]::" 代表匹配指定平台下的所有类型消息和会话 Raises: - ValueError: 如果 new_routing 中的 key 格式不正确 - + ValueError: 如果 new_routing 中的任何 key 格式不正确 """ + # 验证所有新路由的 UMO 格式 for part in new_routing: if self._split_umo(part) is None: + # 格式不正确,抛出异常 raise ValueError( "umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", ) + # 更新内存中的路由映射 self.umop_to_conf_id = new_routing + # 持久化保存到 SharedPreferences await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) async def update_route(self, umo: str, conf_id: str) -> None: - """更新一条路由 + """ + 更新或添加单条路由规则。 Args: - umo (str): UMO 字符串 - conf_id (str): 配置文件 ID + umo: UMO 模式字符串 + conf_id: 对应的配置文件 ID Raises: ValueError: 如果 umo 格式不正确 - """ + # 验证 UMO 格式 if self._split_umo(umo) is None: + # 格式不正确,抛出异常 raise ValueError( "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", ) + # 添加或更新路由映射 self.umop_to_conf_id[umo] = conf_id + # 持久化保存更新后的路由表 await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) async def delete_route(self, umo: str) -> None: - """删除一条路由 + """ + 删除一条路由规则。 Args: - umo (str): 需要删除的 UMO 字符串 + umo: 需要删除的 UMO 模式字符串 Raises: ValueError: 当 umo 格式不正确时抛出 """ - + # 验证 UMO 格式 if self._split_umo(umo) is None: + # 格式不正确,抛出异常 raise ValueError( "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", ) + # 检查路由是否存在 if umo in self.umop_to_conf_id: + # 从映射中删除该路由 del self.umop_to_conf_id[umo] - await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) + # 持久化保存更新后的路由表 + await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) \ No newline at end of file diff --git a/astrbot/core/utils/astrbot_path.py b/astrbot/core/utils/astrbot_path.py index c7771c1a64..440031bf08 100644 --- a/astrbot/core/utils/astrbot_path.py +++ b/astrbot/core/utils/astrbot_path.py @@ -42,12 +42,12 @@ def get_astrbot_data_path() -> str: def get_astrbot_config_path() -> str: """Return the AstrBot config directory path.""" - return os.path.realpath(os.path.join(get_astrbot_data_path(), "config")) + return os.path.realpath(os.path.join(get_astrbot_data_path(), "config")) # 'C:\\Users\\Admin\\Master\\projects\\github\\AstrBot\\data\\config' def get_astrbot_plugin_path() -> str: """Return the AstrBot plugin directory path.""" - return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins")) + return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins")) # 'C:\\Users\\Admin\\Master\\projects\\github\\AstrBot\\data\\plugins' def get_astrbot_plugin_data_path() -> str: @@ -67,7 +67,7 @@ def get_astrbot_webchat_path() -> str: def get_astrbot_temp_path() -> str: """Return the AstrBot temporary data directory path.""" - return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp")) + return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp")) # 'C:\\Users\\Admin\\Master\\projects\\github\\AstrBot\\data\\temp' def get_astrbot_skills_path() -> str: @@ -87,12 +87,12 @@ def get_astrbot_system_tmp_path() -> str: def get_astrbot_site_packages_path() -> str: """Return the AstrBot third-party site-packages directory path.""" - return os.path.realpath(os.path.join(get_astrbot_data_path(), "site-packages")) + return os.path.realpath(os.path.join(get_astrbot_data_path(), "site-packages")) # 'C:\\Users\\Admin\\Master\\projects\\github\\AstrBot\\data\\site-packages' def get_astrbot_knowledge_base_path() -> str: """Return the AstrBot knowledge base root path.""" - return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base")) + return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base")) # 'C:\\Users\\Admin\\Master\\projects\\github\\AstrBot\\data\\knowledge_base' def get_astrbot_backups_path() -> str: diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index 40b899620d..ebe4bbd094 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -44,90 +44,130 @@ def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None: def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None: """ - Migrate old provider structure to new provider-source separation. - Provider only keeps: id, provider_source_id, model, modalities, custom_extra_body - All other fields move to provider_sources. + 将旧的 provider 结构迁移到新的 provider-source 分离结构。 + + 迁移规则: + - Provider 只保留核心字段:id, provider_source_id, model, modalities, + custom_extra_body, enable + - 所有其他字段移动到 provider_sources 中 + - 从 model_config 中提取 model 字段 + - 将 model_config 中的其他字段合并到 custom_extra_body + + 迁移过程: + 1. 遍历所有 provider 配置 + 2. 跳过已有 provider_source_id 的 provider(已迁移) + 3. 跳过非 chat_completion 类型的 provider + 4. 提取需要迁移的字段创建新的 provider_source + 5. 更新 provider 保留必要字段 + 6. 处理 model_config 字段 + 7. 将新的 source 添加到 provider_sources 列表 + + Args: + conf: AstrBot 配置对象,包含 provider 和 provider_sources 配置 """ + # 获取当前的 provider 配置列表 providers = conf.get("provider", []) + # 获取当前的 provider_sources 配置列表 provider_sources = conf.get("provider_sources", []) - # Track if any migration happened + # 跟踪是否发生了任何迁移 migrated = False - # Provider-only fields that should stay in provider + # 定义应保留在 provider 中的核心字段集合 + # 这些字段不会迁移到 provider_source provider_only_fields = { - "id", - "provider_source_id", - "model", - "modalities", - "custom_extra_body", - "enable", + "id", # Provider 唯一标识 + "provider_source_id", # 关联的 provider_source ID + "model", # 使用的模型名称 + "modalities", # 支持的模式(如文本、图像等) + "custom_extra_body", # 自定义额外请求体参数 + "enable", # 是否启用 } - # Fields that should not go to source + # 定义不应迁移到 source 的字段集合 + # 包含 provider_only_fields 和 model_config(需要特殊处理) source_exclude_fields = provider_only_fields | {"model_config"} + # 遍历所有 provider 配置 for provider in providers: - # Skip if already has provider_source_id + # 如果 provider 已经有 provider_source_id,说明已经迁移过,跳过 if provider.get("provider_source_id"): continue - # Skip non-chat-completion types (they don't need source separation) + # 检查 provider 类型 provider_type = provider.get("provider_type", "") + # 如果不是 chat_completion 类型,不需要 source 分离,跳过 if provider_type != "chat_completion": - # For old types without provider_type, check type field + # 对于没有 provider_type 字段的旧配置,检查 type 字段 old_type = provider.get("type", "") + # 如果 type 字段中不包含 chat_completion,也跳过 if "chat_completion" not in old_type: continue + # 标记发生了迁移 migrated = True + # 记录迁移日志 logger.info(f"Migrating provider {provider.get('id')} to new structure") - # Extract source fields from provider + # 第一步:从 provider 中提取需要迁移到 source 的字段 source_fields = {} + # 遍历 provider 的所有键值对(使用 list 避免迭代时修改字典) for key, value in list(provider.items()): + # 如果字段不在排除列表中,说明需要迁移 if key not in source_exclude_fields: source_fields[key] = value - # Create new provider_source + # 第二步:创建新的 provider_source 对象 + # source_id 格式:{provider_id}_source source_id = provider.get("id", "") + "_source" + # 构建新的 source 对象,包含 id 和所有提取的字段 new_source = {"id": source_id, **source_fields} - # Update provider to only keep necessary fields + # 第三步:更新 provider,设置 provider_source_id 关联 provider["provider_source_id"] = source_id - # Extract model from model_config if exists + # 第四步:处理 model_config 字段(特殊迁移逻辑) if "model_config" in provider and isinstance(provider["model_config"], dict): model_config = provider["model_config"] + # 从 model_config 中提取 model 字段 provider["model"] = model_config.get("model", "") - # Put other model_config fields into custom_extra_body + # 将 model_config 中除 model 外的其他字段合并到 custom_extra_body + # 构建额外字段字典(排除 model 字段) extra_body_fields = {k: v for k, v in model_config.items() if k != "model"} + # 如果存在额外字段 if extra_body_fields: + # 确保 custom_extra_body 字段存在 if "custom_extra_body" not in provider: provider["custom_extra_body"] = {} + # 将额外字段合并到 custom_extra_body provider["custom_extra_body"].update(extra_body_fields) - # Initialize new fields if not present + # 第五步:初始化缺失的核心字段(设置默认值) if "modalities" not in provider: - provider["modalities"] = [] + provider["modalities"] = [] # 默认为空列表 if "custom_extra_body" not in provider: - provider["custom_extra_body"] = {} + provider["custom_extra_body"] = {} # 默认为空字典 - # Remove fields that should be in source + # 第六步:从 provider 中移除所有不应保留的字段 + # 生成需要删除的字段列表 keys_to_remove = [k for k in provider.keys() if k not in provider_only_fields] + # 逐个删除不应保留的字段 for key in keys_to_remove: del provider[key] - # Add source to provider_sources + # 第七步:将新创建的 source 添加到 provider_sources 列表 provider_sources.append(new_source) + # 如果发生了迁移,保存配置 if migrated: + # 更新配置中的 provider_sources conf["provider_sources"] = provider_sources + # 保存配置到持久化存储 conf.save_config() + # 记录迁移完成日志 logger.info("Provider-source structure migration completed") - async def migra( db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager ) -> None: diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 1191e154a9..e13843e248 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -14,115 +14,218 @@ from . import RenderStrategy +# 默认的文转图服务端点 URL ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img" +# Shiki 运行时脚本的唯一标识 ID,用于检测模板中是否已注入 SHIKI_RUNTIME_SCRIPT_ID = "astrbot-t2i-shiki-runtime" +# 匹配 Shiki 运行时模板变量的正则表达式,用于检测模板是否需要注入 SHIKI_RUNTIME_TEMPLATE_PATTERN = re.compile(r"\{\{\s*shiki_runtime\s*\|\s*safe\s*\}\}") +# 匹配 Jinja2 语法标记的正则表达式,用于检测字符串是否包含 Jinja2 语法 JINJA_SYNTAX_PATTERN = re.compile(r"\{[{%#]") +# 匹配 Jinja2 raw 块开始标记的正则表达式 JINJA_RAW_OPEN_PATTERN = re.compile(r"{%-?\s*raw\s*-?%}") +# 匹配 Jinja2 raw 块结束标记的正则表达式 JINJA_RAW_CLOSE_PATTERN = re.compile(r"{%-?\s*endraw\s*-?%}") +# 获取日志记录器实例 logger = logging.getLogger("astrbot") @lru_cache(maxsize=1) def get_shiki_runtime() -> str: + """ + 获取 Shiki 运行时 JavaScript 代码。 + 使用 LRU 缓存避免重复读取文件,提升性能。 + + Returns: + str: Shiki 运行时的 IIFE JavaScript 代码,如果文件不存在或读取失败则返回空字符串。 + """ + # 构建 Shiki 运行时文件的完整路径 runtime_path = ( Path(__file__).resolve().parent / "template" / "shiki_runtime.iife.js" ) + # 检查运行时文件是否存在 if not runtime_path.exists(): logger.error( "T2I Shiki runtime not found at %s. Run `cd dashboard && pnpm run build:t2i-shiki-runtime` to regenerate it. Continuing without code highlighting.", runtime_path, ) - return "" + return "" # 文件不存在,返回空字符串 try: + # 读取 JavaScript 文件内容 runtime = runtime_path.read_text(encoding="utf-8") except (OSError, UnicodeDecodeError) as err: + # 捕获文件读取或编码错误 logger.warning( "Failed to load T2I Shiki runtime from %s: %s. Continuing without code highlighting.", runtime_path, err, ) - return "" + return "" # 读取失败,返回空字符串 + # 转义 标签,防止在 HTML 中提前闭合,确保脚本完整注入 return re.sub(r" bool: + """ + 检查指定索引位置是否位于 Jinja2 的 raw 块内部。 + 在 raw 块内部的代码不会被 Jinja2 引擎解析,因此不需要特殊处理。 + + Args: + tmpl_str: 模板字符串 + index: 要检查的字符位置索引 + + Returns: + bool: 如果在 raw 块内返回 True,否则返回 False + """ + # 查找指定位置之前最后一个 raw 开始标记的位置 raw_open_index = -1 for match in JINJA_RAW_OPEN_PATTERN.finditer(tmpl_str, 0, index): - raw_open_index = match.start() + raw_open_index = match.start() # 更新为最新的 raw 开始位置 + # 查找指定位置之前最后一个 raw 结束标记的位置 raw_close_index = -1 for match in JINJA_RAW_CLOSE_PATTERN.finditer(tmpl_str, 0, index): - raw_close_index = match.start() + raw_close_index = match.start() # 更新为最新的 raw 结束位置 + # 如果最近的 raw 开始标记在结束标记之后,说明当前位置在 raw 块内 return raw_open_index > raw_close_index def _wrap_runtime_for_jinja(tmpl_str: str, script: str, index: int) -> str: + """ + 根据模板字符串的上下文,决定是否需要用 Jinja2 raw 标签包裹脚本内容。 + 如果脚本包含 Jinja2 语法且不在 raw 块内,则包裹以避免被模板引擎错误解析。 + + Args: + tmpl_str: 模板字符串 + script: 要注入的脚本字符串 + index: 注入位置的字符索引 + + Returns: + str: 可能被 raw 标签包裹的脚本字符串 + """ + # 如果脚本不包含 Jinja2 语法,或者注入位置已经在 raw 块内,则直接返回脚本 if not JINJA_SYNTAX_PATTERN.search(script) or _is_inside_jinja_raw_block( tmpl_str, index, ): return script + # 否则用 raw 标签包裹,防止 Jinja2 模板引擎错误解析脚本内容 return f"{{% raw %}}{script}{{% endraw %}}" def inject_shiki_runtime(tmpl_str: str) -> str: + """ + 将 Shiki 运行时代码注入到 HTML 模板中。 + 会检查模板是否已包含运行时,避免重复注入。 + 优先将脚本插入到 标签之前,如果没有 head 标签则插入到模板开头。 + + Args: + tmpl_str: HTML 模板字符串 + + Returns: + str: 注入后的模板字符串 + """ + # 检查模板是否已经包含了 Shiki 运行时脚本或模板变量 if SHIKI_RUNTIME_SCRIPT_ID in tmpl_str or SHIKI_RUNTIME_TEMPLATE_PATTERN.search( tmpl_str, ): - return tmpl_str + return tmpl_str # 已存在,无需重复注入 + # 获取 Shiki 运行时 JavaScript 代码 runtime = get_shiki_runtime() if not runtime: - return tmpl_str + return tmpl_str # 无法获取运行时,返回原模板 + # 构建包含唯一 ID 的 script 标签 script = f'' + # 查找 标签的位置 head_close = re.search(r"", tmpl_str, flags=re.IGNORECASE) if head_close: + # 如果找到 ,根据上下文决定是否需要 raw 包裹 script = _wrap_runtime_for_jinja(tmpl_str, script, head_close.start()) + # 将脚本插入到 之前 return f"{tmpl_str[: head_close.start()]} {script}\n{tmpl_str[head_close.start() :]}" + # 没有找到 ,在模板开头插入 script = _wrap_runtime_for_jinja(tmpl_str, script, 0) return f"{script}\n{tmpl_str}" class NetworkRenderStrategy(RenderStrategy): + """ + 网络渲染策略类。 + 通过远程文转图服务(Text-to-Image)将 HTML 模板渲染为图像。 + 支持多端点负载均衡和故障转移。 + """ + def __init__(self, base_url: str | None = None) -> None: - super().__init__() + """ + 初始化网络渲染策略。 + + Args: + base_url: 自定义的文转图服务基础 URL,如果为 None 则使用默认端点 + """ + super().__init__() # 调用父类构造函数 if not base_url: + # 使用默认的文转图服务端点 self.BASE_RENDER_URL = ASTRBOT_T2I_DEFAULT_ENDPOINT else: + # 使用自定义 URL,并进行清理和格式化 self.BASE_RENDER_URL = self._clean_url(base_url) + # 初始化端点列表,默认只有基础 URL self.endpoints = [self.BASE_RENDER_URL] + # 创建模板管理器实例 self.template_manager = TemplateManager() async def initialize(self) -> None: + """ + 异步初始化方法。 + 如果使用的是官方默认端点,则异步获取官方提供的所有可用端点列表。 + """ if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT: + # 创建异步任务获取官方端点列表,不阻塞主流程 asyncio.create_task(self.get_official_endpoints()) async def get_template(self, name: str = "base") -> str: - """通过名称获取文转图 HTML 模板""" + """ + 通过名称获取文转图 HTML 模板。 + + Args: + name: 模板名称,默认为 "base" + + Returns: + str: HTML 模板字符串 + """ + # 从模板管理器中获取指定名称的模板 return self.template_manager.get_template(name) async def get_official_endpoints(self) -> None: - """获取官方的 t2i 端点列表。""" + """ + 获取官方的 T2I(文转图)端点列表。 + 从官方 API 获取所有活跃的端点 URL,用于负载均衡和故障转移。 + """ try: + # 创建 HTTP 客户端会话 async with aiohttp.ClientSession( - trust_env=True, - connector=build_tls_connector(), + trust_env=True, # 信任环境变量中的代理设置 + connector=build_tls_connector(), # 使用自定义 TLS 连接器 ) as session: + # 发送 GET 请求获取端点列表 async with session.get( "https://api.soulter.top/astrbot/t2i-endpoints", ) as resp: if resp.status == 200: + # 解析 JSON 响应 data = await resp.json() + # 获取端点数据列表 all_endpoints: list[dict] = data.get("data", []) + # 过滤出活跃且有 URL 的端点 self.endpoints = [ ep.get("url") for ep in all_endpoints @@ -132,10 +235,23 @@ async def get_official_endpoints(self) -> None: f"Successfully got {len(self.endpoints)} official T2I endpoints.", ) except Exception as e: + # 捕获所有异常,避免影响主流程 logger.error(f"Failed to get official endpoints: {e}") def _clean_url(self, url: str): + """ + 清理和格式化 URL。 + 移除末尾斜杠,确保 URL 以 /text2img 结尾。 + + Args: + url: 原始 URL + + Returns: + str: 清理后的 URL + """ + # 移除 URL 末尾的斜杠 url = url.removesuffix("/") + # 如果 URL 不以 text2img 结尾,则添加 if not url.endswith("text2img"): url += "/text2img" return url @@ -147,59 +263,87 @@ async def render_custom_template( return_url: bool = True, options: dict | None = None, ) -> str: - """使用自定义文转图模板""" + """ + 使用自定义文转图模板进行渲染。 + + Args: + tmpl_str: HTML 模板字符串 + tmpl_data: 模板数据字典 + return_url: 是否返回图片 URL,False 则返回图片文件路径 + options: 渲染选项,如页面大小、图片格式、质量等 + + Returns: + str: 图片的 URL 或本地文件路径 + + Raises: + RuntimeError: 当所有端点都渲染失败时抛出 + """ + # 设置默认的渲染选项 default_options = { - "full_page": True, - "type": "jpeg", - "quality": 40, + "full_page": True, # 渲染整个页面 + "type": "jpeg", # 输出格式为 JPEG + "quality": 40, # 图片质量 40% } + # 合并用户自定义选项 if options: default_options |= options - # 在线程池中执行 Shiki 注入,避免 1.2MB JS 处理阻塞事件循环 + # 获取当前事件循环 loop = asyncio.get_running_loop() + # 在线程池中执行模板预处理,避免 1.2MB 的 JS 处理阻塞事件循环 tmpl_str, tmpl_data = await loop.run_in_executor( None, self._prepare_template_sync, tmpl_str, tmpl_data ) + + # 构建 POST 请求数据 post_data = { - "tmpl": tmpl_str, - "json": return_url, - "tmpldata": tmpl_data, - "options": default_options, + "tmpl": tmpl_str, # 模板字符串 + "json": return_url, # 是否返回 JSON 格式(包含图片 URL) + "tmpldata": tmpl_data, # 模板渲染数据 + "options": default_options, # 渲染选项 } + # 复制端点列表并随机打乱,实现负载均衡 endpoints = self.endpoints.copy() if self.endpoints else [self.BASE_RENDER_URL] random.shuffle(endpoints) - last_exception = None + last_exception = None # 记录最后一个异常,用于全部失败时抛出 + + # 遍历所有端点进行故障转移 for endpoint in endpoints: try: if return_url: + # 需要返回图片 URL 时,使用异步 HTTP 请求 async with ( aiohttp.ClientSession( trust_env=True, connector=build_tls_connector(), ) as session, session.post( - f"{endpoint}/generate", + f"{endpoint}/generate", # 发送渲染请求 json=post_data, ) as resp, ): if resp.status == 200: + # 请求成功,解析返回的 JSON 数据 ret = await resp.json() + # 返回完整的图片 URL return f"{endpoint}/{ret['data']['id']}" + # HTTP 状态码非 200,抛出异常进入故障转移 raise Exception(f"HTTP {resp.status}") else: - # download_image_by_url 失败时抛异常 + # 直接下载图片到本地,返回本地文件路径 return await download_image_by_url( f"{endpoint}/generate", - post=True, + post=True, # 使用 POST 请求 post_data=post_data, ) except Exception as e: + # 记录异常并尝试下一个端点 last_exception = e logger.warning(f"Endpoint {endpoint} failed: {e}, trying next...") continue - # 全部失败 + + # 所有端点都失败了 logger.error(f"All endpoints failed: {last_exception}") raise RuntimeError(f"All endpoints failed: {last_exception}") @@ -209,23 +353,50 @@ async def render( return_url: bool = False, template_name: str | None = "base", ) -> str: - """返回图像的文件路径""" + """ + 渲染文本为图像。 + + Args: + text: 要渲染的文本内容 + return_url: 是否返回图片 URL,默认返回本地文件路径 + template_name: 使用的模板名称,默认为 "base" + + Returns: + str: 图像的文件路径或 URL + """ + # 如果未指定模板名称,使用默认模板 if not template_name: template_name = "base" + # 获取模板字符串 tmpl_str = await self.get_template(name=template_name) + # 使用自定义模板渲染,传入文本内容和版本信息 return await self.render_custom_template( tmpl_str, { - "text": text, - "version": f"v{VERSION}", + "text": text, # 要渲染的文本 + "version": f"v{VERSION}", # 当前版本号 }, return_url, ) @staticmethod def _prepare_template_sync(tmpl_str: str, tmpl_data: dict) -> tuple[str, dict]: - """在线程池中执行的同步模板预处理(避免阻塞事件循环)""" + """ + 同步方法,在线程池中执行的模板预处理。 + 处理 Shiki 运行时注入和模板变量设置,避免阻塞事件循环。 + + Args: + tmpl_str: 模板字符串 + tmpl_data: 模板数据字典 + + Returns: + tuple[str, dict]: 处理后的模板字符串和数据字典 + """ + # 检查模板中是否包含 Shiki 运行时模板变量 if SHIKI_RUNTIME_TEMPLATE_PATTERN.search(tmpl_str): + # 将 Shiki 运行时添加到模板数据的最前面 tmpl_data = {"shiki_runtime": get_shiki_runtime()} | tmpl_data + # 将 Shiki 运行时脚本注入到 HTML 模板中 tmpl_str = inject_shiki_runtime(tmpl_str) - return tmpl_str, tmpl_data + # 返回处理后的模板和数据 + return tmpl_str, tmpl_data \ No newline at end of file diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index 995c3d2443..df66630dc5 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -1,17 +1,40 @@ +# 导入日志管理器,用于获取统一的日志记录器 from astrbot.core.log import LogManager +# 导入本地渲染策略类 from .local_strategy import LocalRenderStrategy +# 导入网络渲染策略类 from .network_strategy import NetworkRenderStrategy +# 获取名为 "astrbot" 的日志记录器实例 logger = LogManager.GetLogger(log_name="astrbot") class HtmlRenderer: + """ + HTML 渲染器主类。 + 整合了网络渲染和本地渲染两种策略,提供统一的文转图渲染接口。 + 优先使用网络渲染,失败时自动降级到本地渲染。 + """ + def __init__(self, endpoint_url: str | None = None) -> None: + """ + 初始化 HTML 渲染器。 + + Args: + endpoint_url: 自定义的网络渲染端点 URL,为 None 时使用默认官方端点 + """ + # 创建网络渲染策略实例,用于远程 API 渲染 self.network_strategy = NetworkRenderStrategy(endpoint_url) + # 创建本地渲染策略实例,用于本地后备渲染 self.local_strategy = LocalRenderStrategy() async def initialize(self) -> None: + """ + 异步初始化渲染器。 + 主要初始化网络渲染策略,包括获取可用端点列表等。 + """ + # 调用网络策略的初始化方法 await self.network_strategy.initialize() async def render_custom_template( @@ -21,20 +44,27 @@ async def render_custom_template( return_url: bool = False, options: dict | None = None, ): - """使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。 - @param tmpl_str: HTML Jinja2 模板。 - @param tmpl_data: jinja2 模板数据。 - @param options: 渲染选项。 + """ + 使用自定义文转图模板进行渲染。 + 该方法会通过网络调用 t2i 终结点图文渲染 API。 + + Args: + tmpl_str: HTML Jinja2 模板字符串。 + tmpl_data: Jinja2 模板数据字典,用于填充模板变量。 + options: 渲染选项字典,如页面大小、图片格式、质量等。 - @return: 图片 URL 或者文件路径,取决于 return_url 参数。 + Returns: + str: 图片 URL 或文件路径,取决于 return_url 参数。 - @example: 参见 https://docs.astrbot.app 插件开发部分。 + Note: + 使用示例可参考 https://docs.astrbot.app 插件开发部分文档。 """ + # 委托给网络渲染策略处理自定义模板渲染 return await self.network_strategy.render_custom_template( - tmpl_str, - tmpl_data, - return_url, - options, + tmpl_str, # 传入模板字符串 + tmpl_data, # 传入模板渲染数据 + return_url, # 传入返回类型标志 + options, # 传入渲染选项 ) async def render_t2i( @@ -44,18 +74,35 @@ async def render_t2i( return_url: bool = False, template_name: str | None = None, ): - """使用默认文转图模板。""" + """ + 使用默认文转图模板将文本渲染为图像。 + 支持网络渲染和本地渲染两种方式,网络渲染失败时自动降级。 + + Args: + text: 要渲染的文本内容。 + use_network: 是否优先使用网络渲染,默认为 True。 + return_url: 是否返回图片 URL,False 则返回本地文件路径。 + template_name: 使用的模板名称,为 None 时使用默认模板。 + + Returns: + str: 渲染后的图像文件路径或 URL。 + """ + # 判断是否使用网络渲染 if use_network: try: + # 尝试使用网络策略进行渲染 return await self.network_strategy.render( - text, - return_url=return_url, - template_name=template_name, + text, # 要渲染的文本 + return_url=return_url, # 返回类型 + template_name=template_name, # 模板名称 ) except BaseException as e: + # 网络渲染失败,记录错误日志 logger.error( f"Failed to render image via AstrBot API: {e}. Falling back to local rendering.", ) + # 降级到本地渲染策略 return await self.local_strategy.render(text) else: - return await self.local_strategy.render(text) + # 直接使用本地渲染策略 + return await self.local_strategy.render(text) \ No newline at end of file diff --git a/main.py b/main.py index 01d3167f53..c5a1808076 100644 --- a/main.py +++ b/main.py @@ -74,11 +74,11 @@ def check_env() -> None: logger.error("请使用 Python3.10+ 运行本项目。") exit() - astrbot_root = get_astrbot_root() + astrbot_root = get_astrbot_root() # 项目所在的路径 'C:\\Users\\Admin\\Master\\projects\\github\\AstrBot' if astrbot_root not in sys.path: sys.path.insert(0, astrbot_root) - site_packages_path = get_astrbot_site_packages_path() + site_packages_path = get_astrbot_site_packages_path() # 'C:\\Users\\Admin\\Master\\projects\\github\\AstrBot\\data\\site-packages' if not is_packaged_desktop_runtime() and site_packages_path not in sys.path: sys.path.append(site_packages_path) @@ -112,7 +112,7 @@ async def check_dashboard_files(webui_dir: str | None = None): logger.warning("WebUI directory not found: %s. Using default.", webui_dir) data_dist_path = Path(get_astrbot_data_path()) / "dist" - bundled_dist = get_bundled_dashboard_dist_path() + bundled_dist = get_bundled_dashboard_dist_path() # WindowsPath('C:/Users/Admin/Master/projects/github/AstrBot/astrbot/dashboard/dist') if data_dist_path.exists(): v = get_dashboard_dist_version(data_dist_path) if is_dashboard_dist_compatible(data_dist_path, VERSION): diff --git a/openspec/openapi-v1.yaml b/openspec/openapi-v1.yaml index 3fe8c6e13f..2ce9977aac 100644 --- a/openspec/openapi-v1.yaml +++ b/openspec/openapi-v1.yaml @@ -4135,6 +4135,7 @@ paths: "200": $ref: "#/components/responses/Ok" +<<<<<<< HEAD /api/v1/stats/versions: get: tags: [Stats] @@ -4145,6 +4146,8 @@ paths: "200": $ref: "#/components/responses/Ok" +======= +>>>>>>> 40eeb785 (First commit) /api/v1/stats/first-notice: get: tags: [Stats] diff --git a/tests/test_kook/data/kook_card_data.json b/tests/test_kook/data/kook_card_data.json index a142318e46..a9b7fe6354 100644 --- a/tests/test_kook/data/kook_card_data.json +++ b/tests/test_kook/data/kook_card_data.json @@ -1,100 +1,100 @@ -{ - "type": "card", - "theme": "info", - "size": "lg", - "modules": [ - { - "type": "header", - "text": { - "type": "plain-text", - "content": "test1", - "emoji": true - } - }, - { - "type": "section", - "text": { - "type": "kmarkdown", - "content": "test2" - }, - "mode": "left" - }, - { - "type": "divider" - }, - { - "type": "section", - "text": { - "type": "paragraph", - "fields": [ - { - "type": "kmarkdown", - "content": "test3" - }, - { - "type": "kmarkdown", - "content": "**test4**" - } - ], - "cols": 2 - }, - "mode": "left" - }, - { - "type": "image-group", - "elements": [ - { - "type": "image", - "src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg", - "alt": "", - "size": "lg", - "circle": false - } - ] - }, - { - "type": "file", - "src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg", - "title": "test5" - }, - { - "type": "countdown", - "endTime": 1772343427360, - "startTime": 1772343378259, - "mode": "second" - }, - { - "type": "action-group", - "elements": [ - { - "type": "button", - "text": "点我测试回调", - "theme": "primary", - "value": "btn_clicked", - "click": "return-val" - }, - { - "type": "button", - "text": "访问官网", - "theme": "danger", - "value": "https://www.kookapp.cn", - "click": "link" - } - ] - }, - { - "type": "context", - "elements": [ - { - "type": "plain-text", - "content": "test6", - "emoji": true - } - ] - }, - { - "type": "invite", - "code": "test7" - } - ] +{ + "type": "card", + "theme": "info", + "size": "lg", + "modules": [ + { + "type": "header", + "text": { + "type": "plain-text", + "content": "test1", + "emoji": true + } + }, + { + "type": "section", + "text": { + "type": "kmarkdown", + "content": "test2" + }, + "mode": "left" + }, + { + "type": "divider" + }, + { + "type": "section", + "text": { + "type": "paragraph", + "fields": [ + { + "type": "kmarkdown", + "content": "test3" + }, + { + "type": "kmarkdown", + "content": "**test4**" + } + ], + "cols": 2 + }, + "mode": "left" + }, + { + "type": "image-group", + "elements": [ + { + "type": "image", + "src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg", + "alt": "", + "size": "lg", + "circle": false + } + ] + }, + { + "type": "file", + "src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg", + "title": "test5" + }, + { + "type": "countdown", + "endTime": 1772343427360, + "startTime": 1772343378259, + "mode": "second" + }, + { + "type": "action-group", + "elements": [ + { + "type": "button", + "text": "点我测试回调", + "theme": "primary", + "value": "btn_clicked", + "click": "return-val" + }, + { + "type": "button", + "text": "访问官网", + "theme": "danger", + "value": "https://www.kookapp.cn", + "click": "link" + } + ] + }, + { + "type": "context", + "elements": [ + { + "type": "plain-text", + "content": "test6", + "emoji": true + } + ] + }, + { + "type": "invite", + "code": "test7" + } + ] } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_group_message.json b/tests/test_kook/data/kook_ws_event_group_message.json index 53bd50481a..ce8f531137 100644 --- a/tests/test_kook/data/kook_ws_event_group_message.json +++ b/tests/test_kook/data/kook_ws_event_group_message.json @@ -1,119 +1,119 @@ -{ - "s": 0, - "d": { - "channel_type": "GROUP", - "type": 9, - "target_id": "2732467349811313213", - "author_id": "7324688132731983", - "content": "done!", - "extra": { - "quote": { - "id": "69a788adb0cfb9ece50eae1c", - "rong_id": "7baef72c-0cd7-49ad-9592-1615236136cb", - "type": 9, - "content": "/am 1", - "interact_res": null, - "create_at": 1772587180973, - "author": { - "id": "2701973210937821093781", - "username": "some_username", - "identify_num": "4198", - "online": true, - "os": "Websocket", - "status": 1, - "avatar": "https://example.com", - "vip_avatar": "https://example.com", - "banner": "", - "nickname": "some_username", - "roles": [ - 63423577 - ], - "is_vip": false, - "vip_amp": false, - "bot": false, - "nameplate": [], - "kpm_vip": null, - "wealth_level": 0, - "decorations_id_map": null, - "mobile_verified": true, - "is_sys": false, - "joined_at": 1772259607000, - "active_time": 1772587181304 - }, - "can_jump": true, - "preview_content": null, - "kmarkdown": { - "mention_part": [], - "mention_role_part": [], - "channel_part": [], - "item_part": [] - } - }, - "type": 9, - "code": "", - "guild_id": "273902183210983210983", - "guild_type": 0, - "channel_name": "聊天大厅", - "author": { - "id": "7324688132731983", - "username": "Bot_Test", - "identify_num": "9561", - "online": true, - "os": "Websocket", - "status": 0, - "avatar": "https://example.com", - "vip_avatar": "https://example.com", - "banner": "", - "nickname": "Bot_Test", - "roles": [ - 63725384 - ], - "is_vip": false, - "vip_amp": false, - "bot": true, - "nameplate": [], - "kpm_vip": null, - "wealth_level": 0, - "bot_status": 0, - "tag_info": { - "color": "#0096FF", - "bg_color": "#0096FF33", - "text": "机器人" - }, - "is_sys": false, - "client_id": "sAdiIHoGhdSFUOA", - "verified": false - }, - "visible_only": "", - "mention": [], - "mention_no_at": [], - "mention_all": false, - "mention_roles": [], - "mention_here": false, - "nav_channels": [], - "kmarkdown": { - "raw_content": "done!", - "mention_part": [], - "mention_role_part": [], - "channel_part": [], - "spl": [] - }, - "emoji": [], - "preview_content": "", - "channel_type": 1, - "last_msg_content": "Bot_Test:done!", - "send_msg_device": 0 - }, - "msg_id": "c51a8761-63bv-5l2a-5681-0ac16e140a1b", - "msg_timestamp": 1772587182234, - "nonce": "", - "from_type": 1 - }, - "extra": { - "verifyToken": "kW4FH_ASHio1hosd", - "encryptKey": "", - "callbackUrl": "", - "intent": 255 - }, - "sn": 3 +{ + "s": 0, + "d": { + "channel_type": "GROUP", + "type": 9, + "target_id": "2732467349811313213", + "author_id": "7324688132731983", + "content": "done!", + "extra": { + "quote": { + "id": "69a788adb0cfb9ece50eae1c", + "rong_id": "7baef72c-0cd7-49ad-9592-1615236136cb", + "type": 9, + "content": "/am 1", + "interact_res": null, + "create_at": 1772587180973, + "author": { + "id": "2701973210937821093781", + "username": "some_username", + "identify_num": "4198", + "online": true, + "os": "Websocket", + "status": 1, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "banner": "", + "nickname": "some_username", + "roles": [ + 63423577 + ], + "is_vip": false, + "vip_amp": false, + "bot": false, + "nameplate": [], + "kpm_vip": null, + "wealth_level": 0, + "decorations_id_map": null, + "mobile_verified": true, + "is_sys": false, + "joined_at": 1772259607000, + "active_time": 1772587181304 + }, + "can_jump": true, + "preview_content": null, + "kmarkdown": { + "mention_part": [], + "mention_role_part": [], + "channel_part": [], + "item_part": [] + } + }, + "type": 9, + "code": "", + "guild_id": "273902183210983210983", + "guild_type": 0, + "channel_name": "聊天大厅", + "author": { + "id": "7324688132731983", + "username": "Bot_Test", + "identify_num": "9561", + "online": true, + "os": "Websocket", + "status": 0, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "banner": "", + "nickname": "Bot_Test", + "roles": [ + 63725384 + ], + "is_vip": false, + "vip_amp": false, + "bot": true, + "nameplate": [], + "kpm_vip": null, + "wealth_level": 0, + "bot_status": 0, + "tag_info": { + "color": "#0096FF", + "bg_color": "#0096FF33", + "text": "机器人" + }, + "is_sys": false, + "client_id": "sAdiIHoGhdSFUOA", + "verified": false + }, + "visible_only": "", + "mention": [], + "mention_no_at": [], + "mention_all": false, + "mention_roles": [], + "mention_here": false, + "nav_channels": [], + "kmarkdown": { + "raw_content": "done!", + "mention_part": [], + "mention_role_part": [], + "channel_part": [], + "spl": [] + }, + "emoji": [], + "preview_content": "", + "channel_type": 1, + "last_msg_content": "Bot_Test:done!", + "send_msg_device": 0 + }, + "msg_id": "c51a8761-63bv-5l2a-5681-0ac16e140a1b", + "msg_timestamp": 1772587182234, + "nonce": "", + "from_type": 1 + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 3 } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_group_message_with_mention.json b/tests/test_kook/data/kook_ws_event_group_message_with_mention.json index 50f598a9e9..03a9539a34 100644 --- a/tests/test_kook/data/kook_ws_event_group_message_with_mention.json +++ b/tests/test_kook/data/kook_ws_event_group_message_with_mention.json @@ -1,88 +1,88 @@ -{ - "s": 0, - "d": { - "channel_type": "GROUP", - "type": 9, - "target_id": "2732467349811313213", - "author_id": "7324688132731983", - "content": "(rol)25555643(rol) /help (met)3351526782(met) (met)all(met) ", - "msg_id": "9b047d81-40fe-41af-ad39-916ae77e6b20", - "msg_timestamp": 1776405840600, - "nonce": "r2oQkO7kRpNSgmsv7TNl2zOA", - "from_type": 1, - "extra": { - "type": 9, - "code": "", - "author": { - "id": "3351526782", - "username": "some_username", - "identify_num": "4198", - "nickname": "some_username", - "bot": false, - "online": true, - "avatar": "https://example.com", - "vip_avatar": "https://example.com", - "status": 1, - "roles": [ - 63423577 - ], - "os": "Websocket", - "banner": "", - "is_vip": false, - "vip_amp": false, - "nameplate": [], - "wealth_level": 0, - "is_sys": false - }, - "kmarkdown": { - "raw_content": "@some_role /help @some_username @全体成员", - "mention_part": [ - { - "id": "3351526782", - "username": "some_username", - "full_name": "some_username#4198", - "avatar": "https://example.com", - "wealth_level": 0 - } - ], - "mention_role_part": [ - { - "role_id": 25555643, - "name": "some_role", - "color": 0, - "color_type": 1, - "color_map": [] - } - ], - "channel_part": [], - "spl": [] - }, - "last_msg_content": "some_username:@some_role /help @some_username @全体成员", - "mention": [ - "3351526782" - ], - "mention_all": true, - "mention_here": false, - "guild_id": "1239678456780469", - "guild_type": 0, - "channel_name": "聊天大厅", - "visible_only": "", - "mention_no_at": [], - "mention_roles": [ - 25555643 - ], - "nav_channels": [], - "emoji": [], - "preview_content": "", - "channel_type": 1, - "send_msg_device": 0 - } - }, - "extra": { - "verifyToken": "kW4FH_ASHio1hosd", - "encryptKey": "", - "callbackUrl": "", - "intent": 255 - }, - "sn": 1 +{ + "s": 0, + "d": { + "channel_type": "GROUP", + "type": 9, + "target_id": "2732467349811313213", + "author_id": "7324688132731983", + "content": "(rol)25555643(rol) /help (met)3351526782(met) (met)all(met) ", + "msg_id": "9b047d81-40fe-41af-ad39-916ae77e6b20", + "msg_timestamp": 1776405840600, + "nonce": "r2oQkO7kRpNSgmsv7TNl2zOA", + "from_type": 1, + "extra": { + "type": 9, + "code": "", + "author": { + "id": "3351526782", + "username": "some_username", + "identify_num": "4198", + "nickname": "some_username", + "bot": false, + "online": true, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "status": 1, + "roles": [ + 63423577 + ], + "os": "Websocket", + "banner": "", + "is_vip": false, + "vip_amp": false, + "nameplate": [], + "wealth_level": 0, + "is_sys": false + }, + "kmarkdown": { + "raw_content": "@some_role /help @some_username @全体成员", + "mention_part": [ + { + "id": "3351526782", + "username": "some_username", + "full_name": "some_username#4198", + "avatar": "https://example.com", + "wealth_level": 0 + } + ], + "mention_role_part": [ + { + "role_id": 25555643, + "name": "some_role", + "color": 0, + "color_type": 1, + "color_map": [] + } + ], + "channel_part": [], + "spl": [] + }, + "last_msg_content": "some_username:@some_role /help @some_username @全体成员", + "mention": [ + "3351526782" + ], + "mention_all": true, + "mention_here": false, + "guild_id": "1239678456780469", + "guild_type": 0, + "channel_name": "聊天大厅", + "visible_only": "", + "mention_no_at": [], + "mention_roles": [ + 25555643 + ], + "nav_channels": [], + "emoji": [], + "preview_content": "", + "channel_type": 1, + "send_msg_device": 0 + } + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 1 } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_hello.json b/tests/test_kook/data/kook_ws_event_hello.json index a6ab68d984..908a6eac32 100644 --- a/tests/test_kook/data/kook_ws_event_hello.json +++ b/tests/test_kook/data/kook_ws_event_hello.json @@ -1,8 +1,8 @@ -{ - "s": 1, - "d": { - "sessionId": "67d7d497-2b10-4849-9c2c-dda2fe58ed60", - "session_id": "67d7d497-2b10-4849-9c2c-dda2fe58ed60", - "code": 0 - } +{ + "s": 1, + "d": { + "sessionId": "67d7d497-2b10-4849-9c2c-dda2fe58ed60", + "session_id": "67d7d497-2b10-4849-9c2c-dda2fe58ed60", + "code": 0 + } } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_message_with_card_1.json b/tests/test_kook/data/kook_ws_event_message_with_card_1.json index d4456651e5..f46087ae57 100644 --- a/tests/test_kook/data/kook_ws_event_message_with_card_1.json +++ b/tests/test_kook/data/kook_ws_event_message_with_card_1.json @@ -1,72 +1,72 @@ -{ - "s": 0, - "d": { - "channel_type": "PERSON", - "type": 10, - "target_id": "2732467349811313213", - "author_id": "7324688132731983", - "content": "[{\"theme\":\"primary\",\"color\":\"\",\"size\":\"lg\",\"expand\":false,\"modules\":[{\"type\":\"audio\",\"cover\":\"\",\"duration\":0,\"title\":\"dancing_shot5.wav\",\"src\":\"https:\\/\\/img.kookapp.cn\\/attachments\\/2026-03\\/03\\/69a6841c3125d.wav\",\"external\":false,\"size\":443414,\"canDownload\":true,\"elements\":[]}],\"type\":\"card\"}]", - "extra": { - "type": 10, - "code": "1738914789hd8fd91098he809h19y491", - "author": { - "id": "7324688132731983", - "username": "Bot_Test", - "identify_num": "9561", - "online": true, - "os": "Websocket", - "status": 0, - "avatar": "https://example.com", - "vip_avatar": "https://example.com", - "banner": "", - "nickname": "Bot_Test", - "roles": [], - "is_vip": false, - "vip_amp": false, - "bot": true, - "nameplate": [], - "kpm_vip": null, - "wealth_level": 0, - "bot_status": 0, - "tag_info": { - "color": "#0096FF", - "bg_color": "#0096FF33", - "text": "机器人" - }, - "is_sys": false, - "client_id": "u109u3108h8ds0qsdaHUIOS", - "verified": false - }, - "visible_only": "", - "mention": [], - "mention_no_at": [], - "mention_all": false, - "mention_roles": [], - "mention_here": false, - "nav_channels": [], - "emoji": [], - "kmarkdown": { - "raw_content": "[音频]dancing_shot5.wav", - "mention_part": [], - "mention_role_part": [], - "channel_part": [] - }, - "editable": false, - "preview_content": "[音频]dancing_shot5.wav", - "preview_content_search": "[音频]dancing_shot5.wav", - "last_msg_content": "[音频]dancing_shot5.wav", - "send_msg_device": 0 - }, - "msg_id": "82c0b042-79b4-4066-a0f4-6c7a95c74e67", - "msg_timestamp": 1772587223043, - "nonce": "", - "from_type": 1 - }, - "extra": { - "verifyToken": "kW4FH_ASHio1hosd", - "encryptKey": "", - "callbackUrl": "", - "intent": 255 - }, - "sn": 5 +{ + "s": 0, + "d": { + "channel_type": "PERSON", + "type": 10, + "target_id": "2732467349811313213", + "author_id": "7324688132731983", + "content": "[{\"theme\":\"primary\",\"color\":\"\",\"size\":\"lg\",\"expand\":false,\"modules\":[{\"type\":\"audio\",\"cover\":\"\",\"duration\":0,\"title\":\"dancing_shot5.wav\",\"src\":\"https:\\/\\/img.kookapp.cn\\/attachments\\/2026-03\\/03\\/69a6841c3125d.wav\",\"external\":false,\"size\":443414,\"canDownload\":true,\"elements\":[]}],\"type\":\"card\"}]", + "extra": { + "type": 10, + "code": "1738914789hd8fd91098he809h19y491", + "author": { + "id": "7324688132731983", + "username": "Bot_Test", + "identify_num": "9561", + "online": true, + "os": "Websocket", + "status": 0, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "banner": "", + "nickname": "Bot_Test", + "roles": [], + "is_vip": false, + "vip_amp": false, + "bot": true, + "nameplate": [], + "kpm_vip": null, + "wealth_level": 0, + "bot_status": 0, + "tag_info": { + "color": "#0096FF", + "bg_color": "#0096FF33", + "text": "机器人" + }, + "is_sys": false, + "client_id": "u109u3108h8ds0qsdaHUIOS", + "verified": false + }, + "visible_only": "", + "mention": [], + "mention_no_at": [], + "mention_all": false, + "mention_roles": [], + "mention_here": false, + "nav_channels": [], + "emoji": [], + "kmarkdown": { + "raw_content": "[音频]dancing_shot5.wav", + "mention_part": [], + "mention_role_part": [], + "channel_part": [] + }, + "editable": false, + "preview_content": "[音频]dancing_shot5.wav", + "preview_content_search": "[音频]dancing_shot5.wav", + "last_msg_content": "[音频]dancing_shot5.wav", + "send_msg_device": 0 + }, + "msg_id": "82c0b042-79b4-4066-a0f4-6c7a95c74e67", + "msg_timestamp": 1772587223043, + "nonce": "", + "from_type": 1 + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 5 } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_message_with_card_2.json b/tests/test_kook/data/kook_ws_event_message_with_card_2.json index fd122391e3..9ac2b7bd58 100644 --- a/tests/test_kook/data/kook_ws_event_message_with_card_2.json +++ b/tests/test_kook/data/kook_ws_event_message_with_card_2.json @@ -1,79 +1,79 @@ -{ - "s": 0, - "d": { - "channel_type": "GROUP", - "type": 10, - "target_id": "2723723449021809", - "author_id": "1237198731983", - "content": "[{\"theme\":\"invisible\",\"color\":\"\",\"size\":\"lg\",\"expand\":false,\"modules\":[{\"type\":\"section\",\"mode\":\"left\",\"accessory\":null,\"text\":{\"type\":\"kmarkdown\",\"content\":\"(met)(met) (met)all(met) #hello \\\\*\\\\*world\\\\*\\\\* \",\"elements\":[]},\"elements\":[]},{\"type\":\"audio\",\"cover\":\"\",\"duration\":0,\"title\":\"dancing_shot5.wav\",\"src\":\"https:\\/\\/img.kookapp.cn\\/attachments\\/2026-03\\/03\\/69a6841c3125d.wav\",\"external\":false,\"size\":443414,\"canDownload\":true,\"elements\":[]},{\"type\":\"section\",\"mode\":\"left\",\"accessory\":null,\"text\":{\"type\":\"kmarkdown\",\"content\":\"\\n😆 \",\"elements\":[]},\"elements\":[]}],\"type\":\"card\"}]", - "msg_id": "ec4046e9-ea43-4907-9fc3-8c6d0bd4ec56", - "msg_timestamp": 1772600762056, - "nonce": "sy8f91y248yda", - "from_type": 1, - "extra": { - "type": 10, - "code": "", - "author": { - "id": "1237198731983", - "username": "some_username", - "identify_num": "4198", - "nickname": "some_username", - "bot": false, - "online": true, - "avatar": "https://example.com", - "vip_avatar": "https://example.com", - "status": 1, - "roles": [ - 12783219731984 - ], - "os": "Websocket", - "banner": "", - "is_vip": false, - "vip_amp": false, - "nameplate": [], - "wealth_level": 0, - "is_sys": false - }, - "kmarkdown": { - "raw_content": "@Bot_Test @全体成员 #hello **world**[音频]dancing_shot5.wav😆", - "mention_part": [ - { - "id": "", - "username": "Bot_Test", - "full_name": "Bot_Test#9561", - "avatar": "https://example.com", - "wealth_level": 0 - } - ], - "mention_role_part": [], - "channel_part": [] - }, - "last_msg_content": "some_username:@Bot_Test @ 全体成员 #hello **world**[音频]dancing_shot5.wav😆", - "mention": [ - "" - ], - "mention_all": true, - "mention_here": false, - "guild_id": "28321098321093", - "guild_type": 0, - "channel_name": "聊天大厅", - "visible_only": "", - "mention_no_at": [], - "mention_roles": [], - "nav_channels": [], - "emoji": [], - "editable": true, - "preview_content": "@Bot_Test @全体成员 #hello **world**[音频]dancing_shot5.wav😆", - "preview_content_search": "@Bot_Test @全体成员 #hello **world**[音频]dancing_shot5.wav😆", - "channel_type": 1, - "send_msg_device": 0 - } - }, - "extra": { - "verifyToken": "kW4FH_ASHio1hosd", - "encryptKey": "", - "callbackUrl": "", - "intent": 255 - }, - "sn": 5 +{ + "s": 0, + "d": { + "channel_type": "GROUP", + "type": 10, + "target_id": "2723723449021809", + "author_id": "1237198731983", + "content": "[{\"theme\":\"invisible\",\"color\":\"\",\"size\":\"lg\",\"expand\":false,\"modules\":[{\"type\":\"section\",\"mode\":\"left\",\"accessory\":null,\"text\":{\"type\":\"kmarkdown\",\"content\":\"(met)(met) (met)all(met) #hello \\\\*\\\\*world\\\\*\\\\* \",\"elements\":[]},\"elements\":[]},{\"type\":\"audio\",\"cover\":\"\",\"duration\":0,\"title\":\"dancing_shot5.wav\",\"src\":\"https:\\/\\/img.kookapp.cn\\/attachments\\/2026-03\\/03\\/69a6841c3125d.wav\",\"external\":false,\"size\":443414,\"canDownload\":true,\"elements\":[]},{\"type\":\"section\",\"mode\":\"left\",\"accessory\":null,\"text\":{\"type\":\"kmarkdown\",\"content\":\"\\n😆 \",\"elements\":[]},\"elements\":[]}],\"type\":\"card\"}]", + "msg_id": "ec4046e9-ea43-4907-9fc3-8c6d0bd4ec56", + "msg_timestamp": 1772600762056, + "nonce": "sy8f91y248yda", + "from_type": 1, + "extra": { + "type": 10, + "code": "", + "author": { + "id": "1237198731983", + "username": "some_username", + "identify_num": "4198", + "nickname": "some_username", + "bot": false, + "online": true, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "status": 1, + "roles": [ + 12783219731984 + ], + "os": "Websocket", + "banner": "", + "is_vip": false, + "vip_amp": false, + "nameplate": [], + "wealth_level": 0, + "is_sys": false + }, + "kmarkdown": { + "raw_content": "@Bot_Test @全体成员 #hello **world**[音频]dancing_shot5.wav😆", + "mention_part": [ + { + "id": "", + "username": "Bot_Test", + "full_name": "Bot_Test#9561", + "avatar": "https://example.com", + "wealth_level": 0 + } + ], + "mention_role_part": [], + "channel_part": [] + }, + "last_msg_content": "some_username:@Bot_Test @ 全体成员 #hello **world**[音频]dancing_shot5.wav😆", + "mention": [ + "" + ], + "mention_all": true, + "mention_here": false, + "guild_id": "28321098321093", + "guild_type": 0, + "channel_name": "聊天大厅", + "visible_only": "", + "mention_no_at": [], + "mention_roles": [], + "nav_channels": [], + "emoji": [], + "editable": true, + "preview_content": "@Bot_Test @全体成员 #hello **world**[音频]dancing_shot5.wav😆", + "preview_content_search": "@Bot_Test @全体成员 #hello **world**[音频]dancing_shot5.wav😆", + "channel_type": 1, + "send_msg_device": 0 + } + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 5 } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_ping.json b/tests/test_kook/data/kook_ws_event_ping.json index 1b4e8e7cfd..2b19f1c7bc 100644 --- a/tests/test_kook/data/kook_ws_event_ping.json +++ b/tests/test_kook/data/kook_ws_event_ping.json @@ -1,4 +1,4 @@ -{ - "s": 2, - "sn": 0 +{ + "s": 2, + "sn": 0 } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_pong.json b/tests/test_kook/data/kook_ws_event_pong.json index da07a35c6c..d908bb6e78 100644 --- a/tests/test_kook/data/kook_ws_event_pong.json +++ b/tests/test_kook/data/kook_ws_event_pong.json @@ -1,3 +1,3 @@ -{ - "s": 3 +{ + "s": 3 } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_private_message.json b/tests/test_kook/data/kook_ws_event_private_message.json index 13b0180282..813be5cb8a 100644 --- a/tests/test_kook/data/kook_ws_event_private_message.json +++ b/tests/test_kook/data/kook_ws_event_private_message.json @@ -1,64 +1,64 @@ -{ - "s": 0, - "d": { - "channel_type": "PERSON", - "type": 9, - "target_id": "7324688132731983", - "author_id": "2732467349811313213", - "content": "/help", - "extra": { - "type": 9, - "code": "1738914789hd8fd91098he809h19y491", - "author": { - "id": "2732467349811313213", - "username": "shuiping233", - "identify_num": "4198", - "online": true, - "os": "Websocket", - "status": 1, - "avatar": "https://example.com", - "vip_avatar": "https://example.com", - "banner": "", - "nickname": "shuiping233", - "roles": [], - "is_vip": false, - "vip_amp": false, - "bot": false, - "nameplate": [], - "kpm_vip": null, - "wealth_level": 0, - "decorations_id_map": null, - "is_sys": false - }, - "visible_only": "", - "mention": [], - "mention_no_at": [], - "mention_all": false, - "mention_roles": [], - "mention_here": false, - "nav_channels": [], - "kmarkdown": { - "raw_content": "/help", - "mention_part": [], - "mention_role_part": [], - "channel_part": [], - "spl": [] - }, - "emoji": [], - "preview_content": "", - "last_msg_content": "/help", - "send_msg_device": 0 - }, - "msg_id": "b0f57b9e-2cd4-4e07-8f0e-9c1ecfeaa837", - "msg_timestamp": 1772587358662, - "nonce": "6AwzUe5YjgyC8pAfxcLGjewL", - "from_type": 1 - }, - "extra": { - "verifyToken": "kW4FH_ASHio1hosd", - "encryptKey": "", - "callbackUrl": "", - "intent": 255 - }, - "sn": 19 +{ + "s": 0, + "d": { + "channel_type": "PERSON", + "type": 9, + "target_id": "7324688132731983", + "author_id": "2732467349811313213", + "content": "/help", + "extra": { + "type": 9, + "code": "1738914789hd8fd91098he809h19y491", + "author": { + "id": "2732467349811313213", + "username": "shuiping233", + "identify_num": "4198", + "online": true, + "os": "Websocket", + "status": 1, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "banner": "", + "nickname": "shuiping233", + "roles": [], + "is_vip": false, + "vip_amp": false, + "bot": false, + "nameplate": [], + "kpm_vip": null, + "wealth_level": 0, + "decorations_id_map": null, + "is_sys": false + }, + "visible_only": "", + "mention": [], + "mention_no_at": [], + "mention_all": false, + "mention_roles": [], + "mention_here": false, + "nav_channels": [], + "kmarkdown": { + "raw_content": "/help", + "mention_part": [], + "mention_role_part": [], + "channel_part": [], + "spl": [] + }, + "emoji": [], + "preview_content": "", + "last_msg_content": "/help", + "send_msg_device": 0 + }, + "msg_id": "b0f57b9e-2cd4-4e07-8f0e-9c1ecfeaa837", + "msg_timestamp": 1772587358662, + "nonce": "6AwzUe5YjgyC8pAfxcLGjewL", + "from_type": 1 + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 19 } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_private_system_message.json b/tests/test_kook/data/kook_ws_event_private_system_message.json index 1a60adc4af..09063449ca 100644 --- a/tests/test_kook/data/kook_ws_event_private_system_message.json +++ b/tests/test_kook/data/kook_ws_event_private_system_message.json @@ -1,31 +1,31 @@ -{ - "s": 0, - "d": { - "channel_type": "PERSON", - "type": 255, - "target_id": "7324688132731983", - "author_id": "1", - "content": "[系统消息]", - "extra": { - "type": "guild_member_offline", - "body": { - "user_id": "2732467349811313213", - "event_time": 1772589748914, - "guilds": [ - "78941897317309873120973" - ] - } - }, - "msg_id": "e91b4451-75ce-47bd-bda6-e4498ed8d30d", - "msg_timestamp": 1772589748933, - "nonce": "", - "from_type": 1 - }, - "extra": { - "verifyToken": "kW4FH_ASHio1hosd", - "encryptKey": "", - "callbackUrl": "", - "intent": 255 - }, - "sn": 1 +{ + "s": 0, + "d": { + "channel_type": "PERSON", + "type": 255, + "target_id": "7324688132731983", + "author_id": "1", + "content": "[系统消息]", + "extra": { + "type": "guild_member_offline", + "body": { + "user_id": "2732467349811313213", + "event_time": 1772589748914, + "guilds": [ + "78941897317309873120973" + ] + } + }, + "msg_id": "e91b4451-75ce-47bd-bda6-e4498ed8d30d", + "msg_timestamp": 1772589748933, + "nonce": "", + "from_type": 1 + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 1 } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_reconnect_err.json b/tests/test_kook/data/kook_ws_event_reconnect_err.json index 5346680f2e..0991715317 100644 --- a/tests/test_kook/data/kook_ws_event_reconnect_err.json +++ b/tests/test_kook/data/kook_ws_event_reconnect_err.json @@ -1,7 +1,7 @@ -{ - "s": 5, - "d": { - "code": 40108, - "err": "Invalid SN" - } +{ + "s": 5, + "d": { + "code": 40108, + "err": "Invalid SN" + } } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_resume.json b/tests/test_kook/data/kook_ws_event_resume.json index 427f4ca2a9..8fb13ce9fb 100644 --- a/tests/test_kook/data/kook_ws_event_resume.json +++ b/tests/test_kook/data/kook_ws_event_resume.json @@ -1,4 +1,4 @@ -{ - "s": 4, - "sn": 100 +{ + "s": 4, + "sn": 100 } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_resume_ack.json b/tests/test_kook/data/kook_ws_event_resume_ack.json index da8edab146..070edb160e 100644 --- a/tests/test_kook/data/kook_ws_event_resume_ack.json +++ b/tests/test_kook/data/kook_ws_event_resume_ack.json @@ -1,6 +1,6 @@ -{ - "s": 6, - "d": { - "session_id": "xxxx-xxxxxx-xxx-xxx" - } +{ + "s": 6, + "d": { + "session_id": "xxxx-xxxxxx-xxx-xxx" + } } \ No newline at end of file From 3061555d15f9a006f6c5074107e58c1dc7ab707e Mon Sep 17 00:00:00 2001 From: EterUltimate <139631158+EterUltimate@users.noreply.github.com> Date: Wed, 24 Jun 2026 22:43:24 +0800 Subject: [PATCH 2/7] fix: reconnect MCP client on terminated session (#8694) * fix: reconnect MCP client on terminated session * Update astrbot/core/agent/mcp_client.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update mcp_client.py --------- Co-authored-by: Weilong Liao <37870767+Soulter@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- astrbot/core/agent/mcp_client.py | 25 +++++- tests/unit/test_mcp_client_reconnect.py | 103 ++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_mcp_client_reconnect.py diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index dad9e5be4e..fc04307be4 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -13,7 +13,7 @@ from tenacity import ( before_sleep_log, retry, - retry_if_exception_type, + retry_if_exception, stop_after_attempt, wait_exponential, ) @@ -93,6 +93,10 @@ } ) _STDIO_ALLOWLIST_ENV = "ASTRBOT_MCP_STDIO_ALLOWED_COMMANDS" +_MCP_RECONNECT_ERROR_MESSAGES = ( + "session terminated", + "session was terminated", +) try: import anyio @@ -121,6 +125,13 @@ ) +def _is_mcp_reconnect_error(exc: BaseException) -> bool: + if "anyio" in globals() and isinstance(exc, anyio.ClosedResourceError): + return True + message = str(exc).lower() + return any(marker in message for marker in _MCP_RECONNECT_ERROR_MESSAGES) + + def _prepare_config(config: dict) -> dict: """Prepare configuration, handle nested format""" if config.get("mcpServers"): @@ -635,7 +646,7 @@ async def call_tool_with_reconnect( """ @retry( - retry=retry_if_exception_type(anyio.ClosedResourceError), + retry=retry_if_exception(_is_mcp_reconnect_error), stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=1, max=3), before_sleep=before_sleep_log(logger, logging.WARNING), @@ -651,9 +662,15 @@ async def _call_with_retry(): arguments=arguments, read_timeout_seconds=read_timeout_seconds, ) - except anyio.ClosedResourceError: + except Exception as exc: + if not _is_mcp_reconnect_error(exc): + raise + logger.warning( - f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." + "MCP tool %s call failed (%s: %s), attempting to reconnect...", + tool_name, + type(exc).__name__, + exc, ) # Attempt to reconnect await self._reconnect() diff --git a/tests/unit/test_mcp_client_reconnect.py b/tests/unit/test_mcp_client_reconnect.py new file mode 100644 index 0000000000..4ae1dcea44 --- /dev/null +++ b/tests/unit/test_mcp_client_reconnect.py @@ -0,0 +1,103 @@ +from datetime import timedelta + +import anyio +import pytest +from tenacity import wait_none + +from astrbot.core.agent import mcp_client + + +class FlakyMcpSession: + def __init__(self, first_error: Exception | None = None) -> None: + self.calls = 0 + self.first_error = first_error or RuntimeError("Session terminated") + + async def call_tool( + self, + *, + name: str, + arguments: dict, + read_timeout_seconds: timedelta, + ) -> dict[str, object]: + self.calls += 1 + if self.calls == 1: + raise self.first_error + return { + "name": name, + "arguments": arguments, + "timeout": read_timeout_seconds.total_seconds(), + } + + +@pytest.mark.parametrize( + ("error", "expected"), + [ + (RuntimeError("Session terminated"), True), + (RuntimeError("SESSION TERMINATED"), True), + (RuntimeError("session was terminated"), True), + (anyio.ClosedResourceError(), True), + (RuntimeError("business flow terminated normally"), False), + (RuntimeError("terminated"), False), + ], +) +def test_mcp_reconnect_error_detection_is_narrow( + error: BaseException, expected: bool +) -> None: + assert mcp_client._is_mcp_reconnect_error(error) is expected + + +@pytest.mark.asyncio +async def test_call_tool_reconnects_on_session_terminated(monkeypatch) -> None: + monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none()) + + client = mcp_client.MCPClient() + session = FlakyMcpSession() + reconnects = 0 + + async def reconnect() -> None: + nonlocal reconnects + reconnects += 1 + client.session = session + + client.session = session + client._reconnect = reconnect + + result = await client.call_tool_with_reconnect( + tool_name="lookup", + arguments={"url": "https://example.com"}, + read_timeout_seconds=timedelta(seconds=5), + ) + + assert result == { + "name": "lookup", + "arguments": {"url": "https://example.com"}, + "timeout": 5.0, + } + assert session.calls == 2 + assert reconnects == 1 + + +@pytest.mark.asyncio +async def test_call_tool_does_not_reconnect_on_business_error(monkeypatch) -> None: + monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none()) + + client = mcp_client.MCPClient() + session = FlakyMcpSession(first_error=ValueError("business logic failed")) + reconnects = 0 + + async def reconnect() -> None: + nonlocal reconnects + reconnects += 1 + + client.session = session + client._reconnect = reconnect + + with pytest.raises(ValueError, match="business logic failed"): + await client.call_tool_with_reconnect( + tool_name="lookup", + arguments={"url": "https://example.com"}, + read_timeout_seconds=timedelta(seconds=5), + ) + + assert session.calls == 1 + assert reconnects == 0 From abac4f252c1f91b79894d1bb81c5a01802c26870 Mon Sep 17 00:00:00 2001 From: Weilong Liao <37870767+Soulter@users.noreply.github.com> Date: Wed, 24 Jun 2026 22:44:45 +0800 Subject: [PATCH 3/7] Revert "fix: reconnect MCP client on terminated session (#8694)" (#8991) This reverts commit 2bda4e4d967c3ac23341f8ce101e716f4ec94965. --- astrbot/core/agent/mcp_client.py | 25 +----- tests/unit/test_mcp_client_reconnect.py | 103 ------------------------ 2 files changed, 4 insertions(+), 124 deletions(-) delete mode 100644 tests/unit/test_mcp_client_reconnect.py diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index fc04307be4..dad9e5be4e 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -13,7 +13,7 @@ from tenacity import ( before_sleep_log, retry, - retry_if_exception, + retry_if_exception_type, stop_after_attempt, wait_exponential, ) @@ -93,10 +93,6 @@ } ) _STDIO_ALLOWLIST_ENV = "ASTRBOT_MCP_STDIO_ALLOWED_COMMANDS" -_MCP_RECONNECT_ERROR_MESSAGES = ( - "session terminated", - "session was terminated", -) try: import anyio @@ -125,13 +121,6 @@ ) -def _is_mcp_reconnect_error(exc: BaseException) -> bool: - if "anyio" in globals() and isinstance(exc, anyio.ClosedResourceError): - return True - message = str(exc).lower() - return any(marker in message for marker in _MCP_RECONNECT_ERROR_MESSAGES) - - def _prepare_config(config: dict) -> dict: """Prepare configuration, handle nested format""" if config.get("mcpServers"): @@ -646,7 +635,7 @@ async def call_tool_with_reconnect( """ @retry( - retry=retry_if_exception(_is_mcp_reconnect_error), + retry=retry_if_exception_type(anyio.ClosedResourceError), stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=1, max=3), before_sleep=before_sleep_log(logger, logging.WARNING), @@ -662,15 +651,9 @@ async def _call_with_retry(): arguments=arguments, read_timeout_seconds=read_timeout_seconds, ) - except Exception as exc: - if not _is_mcp_reconnect_error(exc): - raise - + except anyio.ClosedResourceError: logger.warning( - "MCP tool %s call failed (%s: %s), attempting to reconnect...", - tool_name, - type(exc).__name__, - exc, + f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." ) # Attempt to reconnect await self._reconnect() diff --git a/tests/unit/test_mcp_client_reconnect.py b/tests/unit/test_mcp_client_reconnect.py deleted file mode 100644 index 4ae1dcea44..0000000000 --- a/tests/unit/test_mcp_client_reconnect.py +++ /dev/null @@ -1,103 +0,0 @@ -from datetime import timedelta - -import anyio -import pytest -from tenacity import wait_none - -from astrbot.core.agent import mcp_client - - -class FlakyMcpSession: - def __init__(self, first_error: Exception | None = None) -> None: - self.calls = 0 - self.first_error = first_error or RuntimeError("Session terminated") - - async def call_tool( - self, - *, - name: str, - arguments: dict, - read_timeout_seconds: timedelta, - ) -> dict[str, object]: - self.calls += 1 - if self.calls == 1: - raise self.first_error - return { - "name": name, - "arguments": arguments, - "timeout": read_timeout_seconds.total_seconds(), - } - - -@pytest.mark.parametrize( - ("error", "expected"), - [ - (RuntimeError("Session terminated"), True), - (RuntimeError("SESSION TERMINATED"), True), - (RuntimeError("session was terminated"), True), - (anyio.ClosedResourceError(), True), - (RuntimeError("business flow terminated normally"), False), - (RuntimeError("terminated"), False), - ], -) -def test_mcp_reconnect_error_detection_is_narrow( - error: BaseException, expected: bool -) -> None: - assert mcp_client._is_mcp_reconnect_error(error) is expected - - -@pytest.mark.asyncio -async def test_call_tool_reconnects_on_session_terminated(monkeypatch) -> None: - monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none()) - - client = mcp_client.MCPClient() - session = FlakyMcpSession() - reconnects = 0 - - async def reconnect() -> None: - nonlocal reconnects - reconnects += 1 - client.session = session - - client.session = session - client._reconnect = reconnect - - result = await client.call_tool_with_reconnect( - tool_name="lookup", - arguments={"url": "https://example.com"}, - read_timeout_seconds=timedelta(seconds=5), - ) - - assert result == { - "name": "lookup", - "arguments": {"url": "https://example.com"}, - "timeout": 5.0, - } - assert session.calls == 2 - assert reconnects == 1 - - -@pytest.mark.asyncio -async def test_call_tool_does_not_reconnect_on_business_error(monkeypatch) -> None: - monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none()) - - client = mcp_client.MCPClient() - session = FlakyMcpSession(first_error=ValueError("business logic failed")) - reconnects = 0 - - async def reconnect() -> None: - nonlocal reconnects - reconnects += 1 - - client.session = session - client._reconnect = reconnect - - with pytest.raises(ValueError, match="business logic failed"): - await client.call_tool_with_reconnect( - tool_name="lookup", - arguments={"url": "https://example.com"}, - read_timeout_seconds=timedelta(seconds=5), - ) - - assert session.calls == 1 - assert reconnects == 0 From 90d94ebf88b7254c6ad61d1c641317a93919807f Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 24 Jun 2026 23:21:47 +0800 Subject: [PATCH 4/7] fix: update max_context_length and dequeue_context_length defaults --- astrbot/core/config/default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 9ddb4aa64d..ae2151ae23 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -136,8 +136,8 @@ ), "llm_compress_keep_recent_ratio": 0.15, "llm_compress_provider_id": "", - "max_context_length": 50, - "dequeue_context_length": 10, + "max_context_length": -1, # 默认不限制 + "dequeue_context_length": 1, "streaming_response": False, "show_tool_use_status": False, "show_tool_call_result": False, From dcfbb678dd35262c0a5c39f86e20ad2577a50ba2 Mon Sep 17 00:00:00 2001 From: Weilong Liao <37870767+Soulter@users.noreply.github.com> Date: Wed, 24 Jun 2026 23:44:15 +0800 Subject: [PATCH 5/7] chore: bump version to 4.26.0 (#8994) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: bump version to 4.26.0 * feat: 更新 v4.26.0 更新日志,添加 WebUI 设置迁移提示及新功能说明 * fix: 修复多个 WebUI 和工具相关问题,提升稳定性和用户体验 --- astrbot/__init__.py | 2 +- astrbot/core/config/default.py | 2 +- changelogs/v4.26.0.md | 154 +++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 4 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 changelogs/v4.26.0.md diff --git a/astrbot/__init__.py b/astrbot/__init__.py index a6a5bc5cec..a7083c8cea 100644 --- a/astrbot/__init__.py +++ b/astrbot/__init__.py @@ -1,4 +1,4 @@ import logging -__version__ = "4.26.0-beta.12" +__version__ = "4.26.0" logger = logging.getLogger("astrbot") diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index ae2151ae23..700a345ccc 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -136,7 +136,7 @@ ), "llm_compress_keep_recent_ratio": 0.15, "llm_compress_provider_id": "", - "max_context_length": -1, # 默认不限制 + "max_context_length": -1, # 默认不限制 "dequeue_context_length": 1, "streaming_response": False, "show_tool_use_status": False, diff --git a/changelogs/v4.26.0.md b/changelogs/v4.26.0.md new file mode 100644 index 0000000000..7537e40d9f --- /dev/null +++ b/changelogs/v4.26.0.md @@ -0,0 +1,154 @@ + +> Note: +> 1. WebUI “Config -> System Config” has moved to “Settings” at the bottom of the WebUI sidebar. +> 2. It's recommended to update to v4.25.6 before updating to this version. +> +> 提醒: +> 1. WebUI 的“配置 -> 系统配置”已迁移至 WebUI 侧边栏下方的“设置”。 +> 2. 建议先升级到 v4.25.6 再升级到此版本。 +> + +- [更新日志(简体中文)](#chinese) +- [Changelog(English)](#english) + + + +## What's Changed + +### ✨ 新功能 + +- 后端架构从 Quart 迁移至 FastAPI。新增多个 AstrBot OpenAPI。([#8688](https://github.com/AstrBotDevs/AstrBot/pull/8688)) +- 统一全平台消息媒体文件的处理逻辑,提升图片、音频、文件和引用消息媒体的解析一致性,对腾讯系 Silk 格式的语音文件不再使用 pilk 库。([#8764](https://github.com/AstrBotDevs/AstrBot/pull/8764)) +- WebUI 新增函数工具的逐工具权限管理,支持在工具面板中查看和切换工具权限。([#8693](https://github.com/AstrBotDevs/AstrBot/pull/8693)) +- WebUI 新增浅色、深色、跟随系统三种主题模式,并集中处理系统主题同步。([#8648](https://github.com/AstrBotDevs/AstrBot/pull/8648)) +- 重组 WebUI 系统配置页面,将系统配置入口迁移到侧边栏下方的设置区域,并优化相关设置项、自动保存和重启提示体验。([#8777](https://github.com/AstrBotDevs/AstrBot/pull/8777)) +- 新增 QQ 官方机器人 WebSocket 适配器扫码绑定流程,可通过 WebUI 一键扫码获取并回填 AppID 与 Secret,同时将 WebSocket 模板标记为推荐。([#8821](https://github.com/AstrBotDevs/AstrBot/pull/8821)) +- 增强 QQ 官方机器人群聊能力,支持群消息创建类型,并允许 Webhook 适配器在无缓存 `msg_id` 时主动发送群消息。([#8838](https://github.com/AstrBotDevs/AstrBot/pull/8838), [#8841](https://github.com/AstrBotDevs/AstrBot/pull/8841)) +- 为 OpenAI、Gemini、Anthropic 等模型请求加入可配置的重试机制,并新增请求最大重试次数配置,提升临时网络错误与 5xx 服务端错误下的稳定性。([#8893](https://github.com/AstrBotDevs/AstrBot/pull/8893)) +- 现在更新项目时,下载 AstrBot Core 会走 AstrBot 官方托管地址,提高网络稳定性。([#8888](https://github.com/AstrBotDevs/AstrBot/pull/8888)) +- 支持在请求中加载 workspace skills,并加固 workspace skill 发现流程。([#8884](https://github.com/AstrBotDevs/AstrBot/pull/8884)) +- 新增 Exa Web Search 提供商。([#8973](https://github.com/AstrBotDevs/AstrBot/pull/8973)) +- 新增 ElevenLabs TTS API Provider。([commit](https://github.com/AstrBotDevs/AstrBot/commit/0b2234936)) +- 新增启动时重置 WebUI 密码的命令行开关,便于无法登录时恢复访问。([commit](https://github.com/AstrBotDevs/AstrBot/commit/4f5075e60)) +- 新增预发布版本可见性开关。([commit](https://github.com/AstrBotDevs/AstrBot/commit/f9d408221)) +- 登录页新增公开版本详情展示。([#8986](https://github.com/AstrBotDevs/AstrBot/pull/8986)) +- 备份功能现在会包含 skills 目录。([#8700](https://github.com/AstrBotDevs/AstrBot/pull/8700)) + +### 优化 + +- 加强未来任务所有者校验,避免越权访问定时任务。([#8881](https://github.com/AstrBotDevs/AstrBot/pull/8881)) +- 优化知识库上传文件名路径穿越风险。([#8971](https://github.com/AstrBotDevs/AstrBot/pull/8971)) +- 优化插件上传文件名路径穿越风险。([#8968](https://github.com/AstrBotDevs/AstrBot/pull/8968)) +- 在受限本地文件系统工具中拒绝 hardlink 文件,避免通过工作区 hardlink 别名读写允许目录外的文件。 +- 加固沙箱文件传输与 CUA 健康检查流程,降低异常环境下的文件操作风险。([#8840](https://github.com/AstrBotDevs/AstrBot/pull/8840)) +- 消息组件日志输出现在会截断过长的 base64 字段,避免日志中出现大体积内联媒体内容。([#8591](https://github.com/AstrBotDevs/AstrBot/pull/8591)) +- 群聊上下文现在会展示被引用消息的内容。([commit](https://github.com/AstrBotDevs/AstrBot/commit/32cfcbf52)) +- 对话上下文新增当前星期信息。([#8669](https://github.com/AstrBotDevs/AstrBot/pull/8669)) +- 使用原子写入方式保存配置文件,降低写入中断导致配置损坏的概率。([#8793](https://github.com/AstrBotDevs/AstrBot/pull/8793)) +- 新增 GitHub 代理 `gh.dpik.top`,并移除失效代理 `gh.llkk.cc`。([#8772](https://github.com/AstrBotDevs/AstrBot/pull/8772), [#8761](https://github.com/AstrBotDevs/AstrBot/pull/8761)) + +### 修复 + +- 修复 aiocqhttp 平台适配器与消息事件处理器在多处 API 调用中缺少 `self_id` 路由参数的问题。([#8779](https://github.com/AstrBotDevs/AstrBot/pull/8779)) +- 修复生成平台 ID 时可能包含空白字符的问题。([#8768](https://github.com/AstrBotDevs/AstrBot/pull/8768)) +- 修复 Gemini Provider 工具定义没有正确传回模型,导致重复工具调用的问题。([#8833](https://github.com/AstrBotDevs/AstrBot/pull/8833)) +- 修复提供商源修改 ID 后保存被静默还原的问题。([#8915](https://github.com/AstrBotDevs/AstrBot/pull/8915)) +- 修复插件 LLM Tools 开关的归属校验问题,避免误操作其他插件的工具配置。([commit](https://github.com/AstrBotDevs/AstrBot/commit/fadada3d6)) +- 完善插件命名模式校验和边界场景处理。([commit](https://github.com/AstrBotDevs/AstrBot/commit/992aea986)) +- 修复插件重装后仓库来源丢失的问题。([commit](https://github.com/AstrBotDevs/AstrBot/commit/a3c25ec2c)) +- 修复子目录工具的 `handler_module_path` 不一致问题。([#8578](https://github.com/AstrBotDevs/AstrBot/pull/8578)) +- 修复 run-based tools 的保护逻辑,避免受保护工具在注册流程中被错误丢失。([#8790](https://github.com/AstrBotDevs/AstrBot/pull/8790)) +- 修复 persona 工具列表场景下系统工具被移除的问题。([#8908](https://github.com/AstrBotDevs/AstrBot/pull/8908)) +- 修复人格设定中将工具和 Skills 从指定列表切回“默认使用全部”后不生效的问题。([#8835](https://github.com/AstrBotDevs/AstrBot/pull/8835)) +- 修复新版 MCP 中 Streamable HTTP client 重命名导致的兼容问题,并保持 `mcp` 依赖小于 2。 +- 修复本地 Python 工具没有在当前 session workspace 中运行的问题。([#8792](https://github.com/AstrBotDevs/AstrBot/pull/8792)) +- 修复静态资源缺失时仍显示 WebUI ready banner 的问题。([#8804](https://github.com/AstrBotDevs/AstrBot/pull/8804)) +- 修复 Dashboard 创建文件夹时按 Enter 无法提交的问题。([#8597](https://github.com/AstrBotDevs/AstrBot/pull/8597)) +- 修复聊天输入框在非末尾位置使用输入法组合输入时可能丢失字符的问题。([#8811](https://github.com/AstrBotDevs/AstrBot/pull/8811)) +- 修复 changelog 弹窗中的锚点链接处理。([#8750](https://github.com/AstrBotDevs/AstrBot/pull/8750)) +- 修复 onboarding 平台配置与备份上传相关问题。([#8834](https://github.com/AstrBotDevs/AstrBot/pull/8834)) +- 将知识库上下文作为临时 user 内容注入,修复模型请求中知识库上下文角色不准确的问题。([#8904](https://github.com/AstrBotDevs/AstrBot/pull/8904)) +- 修复 cron 星期调度规范化问题。([#8984](https://github.com/AstrBotDevs/AstrBot/pull/8984)) +- 修复 QQ 官方平台发送消息时 At 组件被丢失的问题。([#8983](https://github.com/AstrBotDevs/AstrBot/pull/8983)) +- 修复引用消息中的 image caption 可能重复显示的问题。([#8718](https://github.com/AstrBotDevs/AstrBot/pull/8718)) +- 修复 Embedding API version 后缀被错误截断的问题。([#8736](https://github.com/AstrBotDevs/AstrBot/pull/8736)) +- 延迟导入 FAISS C 库,避免部分环境启动时进程卡住。([#8696](https://github.com/AstrBotDevs/AstrBot/pull/8696)) +- 关闭时主动释放数据库 engine,减少会话和测试环境中的资源残留。([#8650](https://github.com/AstrBotDevs/AstrBot/pull/8650)) +- 修复 CLI 版本来源不正确的问题。([#8692](https://github.com/AstrBotDevs/AstrBot/pull/8692)) +- 修复执行 `astrbot` 命令时不必要地创建 data 目录的问题。([#8932](https://github.com/AstrBotDevs/AstrBot/pull/8932)) +- 修复 sdist 构建产物路径,确保 Dashboard artifact 可被包含。([#8933](https://github.com/AstrBotDevs/AstrBot/pull/8933)) +- 修复插件页资源 token fallback。([#8970](https://github.com/AstrBotDevs/AstrBot/pull/8970)) +- 更新 `max_context_length` 与 `dequeue_context_length` 默认值。 +- 稳定 FastAPI Dashboard 路由注册测试,兼容 included router 节点。([commit](https://github.com/AstrBotDevs/AstrBot/commit/ad1b64d12)) +- 稳定 Dashboard 路由相关测试,兼容最新 FastAPI 行为。([commit](https://github.com/AstrBotDevs/AstrBot/commit/a2b6aad84)) + + + +## What's Changed (EN) + +### ✨ New Features + +- Migrated the backend architecture from Quart to FastAPI and added multiple OpenAPI definitions. ([#8688](https://github.com/AstrBotDevs/AstrBot/pull/8688)) +- Unified media file handling across platforms, improving consistency for images, audio, files, and quoted-message media. Tencent Silk voice files no longer use the pilk library. ([#8764](https://github.com/AstrBotDevs/AstrBot/pull/8764)) +- Added per-tool permission management for function tools in WebUI, with support for viewing and toggling tool permissions from the tools panel. ([#8693](https://github.com/AstrBotDevs/AstrBot/pull/8693)) +- Added light, dark, and system theme modes to WebUI, with centralized system theme synchronization. ([#8648](https://github.com/AstrBotDevs/AstrBot/pull/8648)) +- Reorganized the WebUI system configuration page. The system configuration entry has moved to the Settings area at the bottom of the sidebar, with improved settings, autosave, and restart notices. ([#8777](https://github.com/AstrBotDevs/AstrBot/pull/8777)) +- Added a QR binding flow for the QQ Official Bot WebSocket adapter. WebUI can now fetch and autofill AppID and Secret through one-click QR setup, and the WebSocket template is marked as recommended. ([#8821](https://github.com/AstrBotDevs/AstrBot/pull/8821)) +- Enhanced QQ Official Bot group chat support by adding group message creation types and allowing the Webhook adapter to proactively send group messages without a cached `msg_id`. ([#8838](https://github.com/AstrBotDevs/AstrBot/pull/8838), [#8841](https://github.com/AstrBotDevs/AstrBot/pull/8841)) +- Added configurable retry handling for OpenAI, Gemini, Anthropic, and related model requests, including a maximum request retry setting for better stability during temporary network errors and 5xx server errors. ([#8893](https://github.com/AstrBotDevs/AstrBot/pull/8893)) +- AstrBot Core downloads now use AstrBot's officially hosted source during project updates, improving network stability. ([#8888](https://github.com/AstrBotDevs/AstrBot/pull/8888)) +- Added support for loading workspace skills in requests and hardened workspace skill discovery. ([#8884](https://github.com/AstrBotDevs/AstrBot/pull/8884)) +- Added Exa as a web search provider. ([#8973](https://github.com/AstrBotDevs/AstrBot/pull/8973)) +- Added the ElevenLabs TTS API provider. ([commit](https://github.com/AstrBotDevs/AstrBot/commit/0b2234936)) +- Added a startup flag to reset the WebUI password, making it easier to recover access when login is unavailable. ([commit](https://github.com/AstrBotDevs/AstrBot/commit/4f5075e60)) +- Added a prerelease visibility toggle. ([commit](https://github.com/AstrBotDevs/AstrBot/commit/f9d408221)) +- Added public version details to the login page. ([#8986](https://github.com/AstrBotDevs/AstrBot/pull/8986)) +- Backups now include the skills directory. ([#8700](https://github.com/AstrBotDevs/AstrBot/pull/8700)) + +### Improvements + +- Strengthened Future Task owner checks to prevent unauthorized scheduled-task access. ([#8881](https://github.com/AstrBotDevs/AstrBot/pull/8881)) +- Hardened knowledge base upload filename handling against path traversal risks. ([#8971](https://github.com/AstrBotDevs/AstrBot/pull/8971)) +- Hardened plugin upload filename handling against path traversal risks. ([#8968](https://github.com/AstrBotDevs/AstrBot/pull/8968)) +- Rejected hardlinked files in restricted local filesystem tools to prevent workspace hardlink aliases from reading or writing files outside allowed directories. +- Hardened sandbox file transfers and CUA health checks to reduce file-operation risks in abnormal environments. ([#8840](https://github.com/AstrBotDevs/AstrBot/pull/8840)) +- Message component logs now truncate long base64 fields to avoid large inline media payloads in logs. ([#8591](https://github.com/AstrBotDevs/AstrBot/pull/8591)) +- Group chat context now includes quoted message content. ([commit](https://github.com/AstrBotDevs/AstrBot/commit/32cfcbf52)) +- Added current weekday information to conversation context. ([#8669](https://github.com/AstrBotDevs/AstrBot/pull/8669)) +- Configuration files are now saved atomically, reducing the chance of corruption during interrupted writes. ([#8793](https://github.com/AstrBotDevs/AstrBot/pull/8793)) +- Added GitHub proxy `gh.dpik.top` and removed the invalid `gh.llkk.cc` proxy. ([#8772](https://github.com/AstrBotDevs/AstrBot/pull/8772), [#8761](https://github.com/AstrBotDevs/AstrBot/pull/8761)) + +### Bug Fixes + +- Fixed missing `self_id` routing parameters across multiple aiocqhttp platform adapter and message event API calls. ([#8779](https://github.com/AstrBotDevs/AstrBot/pull/8779)) +- Fixed generated platform IDs potentially containing whitespace characters. ([#8768](https://github.com/AstrBotDevs/AstrBot/pull/8768)) +- Fixed Gemini Provider tool definitions not being passed back to the model correctly, which could cause repeated tool calls. ([#8833](https://github.com/AstrBotDevs/AstrBot/pull/8833)) +- Fixed provider source ID edits being silently restored after saving. ([#8915](https://github.com/AstrBotDevs/AstrBot/pull/8915)) +- Fixed ownership checks when toggling plugin LLM tools to prevent accidental changes to other plugins' tool settings. ([commit](https://github.com/AstrBotDevs/AstrBot/commit/fadada3d6)) +- Improved plugin naming pattern validation and edge-case handling. ([commit](https://github.com/AstrBotDevs/AstrBot/commit/992aea986)) +- Fixed repository source information being lost after reinstalling plugins. ([commit](https://github.com/AstrBotDevs/AstrBot/commit/a3c25ec2c)) +- Fixed inconsistent `handler_module_path` values for tools inside subdirectories. ([#8578](https://github.com/AstrBotDevs/AstrBot/pull/8578)) +- Fixed run-based tool protection so protected tools are not incorrectly dropped during registration. ([#8790](https://github.com/AstrBotDevs/AstrBot/pull/8790)) +- Fixed system tools being removed when persona tool lists are configured. ([#8908](https://github.com/AstrBotDevs/AstrBot/pull/8908)) +- Fixed persona tool and Skill settings not taking effect after switching from selected items back to "use all by default". ([#8835](https://github.com/AstrBotDevs/AstrBot/pull/8835)) +- Fixed compatibility with the renamed Streamable HTTP client in newer MCP versions while keeping the `mcp` dependency below 2. +- Fixed local Python tools not running inside the current session workspace. ([#8792](https://github.com/AstrBotDevs/AstrBot/pull/8792)) +- Fixed the WebUI ready banner being shown even when static assets are missing. ([#8804](https://github.com/AstrBotDevs/AstrBot/pull/8804)) +- Fixed Dashboard folder creation not being submitted when pressing Enter. ([#8597](https://github.com/AstrBotDevs/AstrBot/pull/8597)) +- Fixed possible IME composition character loss when typing at a non-terminal cursor position in the chat input. ([#8811](https://github.com/AstrBotDevs/AstrBot/pull/8811)) +- Fixed changelog anchor link handling in the Dashboard dialog. ([#8750](https://github.com/AstrBotDevs/AstrBot/pull/8750)) +- Fixed onboarding platform configuration and backup upload issues. ([#8834](https://github.com/AstrBotDevs/AstrBot/pull/8834)) +- Injected knowledge base context as temporary user content, fixing the role used for knowledge context in model requests. ([#8904](https://github.com/AstrBotDevs/AstrBot/pull/8904)) +- Fixed cron weekday scheduling normalization. ([#8984](https://github.com/AstrBotDevs/AstrBot/pull/8984)) +- Fixed At components being dropped when sending messages on the QQ Official platform. ([#8983](https://github.com/AstrBotDevs/AstrBot/pull/8983)) +- Fixed duplicate captions for quoted images. ([#8718](https://github.com/AstrBotDevs/AstrBot/pull/8718)) +- Fixed Embedding API version suffixes being truncated incorrectly. ([#8736](https://github.com/AstrBotDevs/AstrBot/pull/8736)) +- Deferred FAISS C library imports to avoid startup hangs in some environments. ([#8696](https://github.com/AstrBotDevs/AstrBot/pull/8696)) +- Disposed the database engine on shutdown to reduce resource leftovers in sessions and tests. ([#8650](https://github.com/AstrBotDevs/AstrBot/pull/8650)) +- Fixed the CLI version source. ([#8692](https://github.com/AstrBotDevs/AstrBot/pull/8692)) +- Fixed unnecessary data directory creation when executing the `astrbot` command. ([#8932](https://github.com/AstrBotDevs/AstrBot/pull/8932)) +- Fixed the sdist build artifact path so the Dashboard artifact can be included. ([#8933](https://github.com/AstrBotDevs/AstrBot/pull/8933)) +- Fixed plugin page asset token fallback. ([#8970](https://github.com/AstrBotDevs/AstrBot/pull/8970)) +- Updated the `max_context_length` and `dequeue_context_length` defaults. +- Stabilized FastAPI Dashboard route registration tests for included router nodes. ([commit](https://github.com/AstrBotDevs/AstrBot/commit/ad1b64d12)) +- Stabilized Dashboard route tests for the latest FastAPI behavior. ([commit](https://github.com/AstrBotDevs/AstrBot/commit/a2b6aad84)) diff --git a/pyproject.toml b/pyproject.toml index 17c30e2908..ba21f6e311 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "AstrBot" -version = "4.26.0-beta.12" +version = "4.26.0" description = "Easy-to-use multi-platform LLM chatbot and development framework" readme = "README.md" license = { text = "AGPL-3.0-or-later" } From 33d60c6de0abeed10a43b1428ecf11b33854ce35 Mon Sep 17 00:00:00 2001 From: supomaker Date: Thu, 25 Jun 2026 17:49:07 +0800 Subject: [PATCH 6/7] debug --- astrbot/core/core_lifecycle.py | 4 +- astrbot/core/platform/manager.py | 4 +- .../sources/webchat/webchat_queue_mgr.py | 108 +++++++++- astrbot/core/subagent_orchestrator.py | 91 ++++++-- astrbot/dashboard/api/app.py | 199 +++++++++++++----- astrbot/dashboard/server.py | 4 +- astrbot/dashboard/services/chat_service.py | 2 +- dashboard/src/composables/useMessages.ts | 3 + 8 files changed, 340 insertions(+), 75 deletions(-) diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index c325a2ea38..fc8498af50 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -302,12 +302,12 @@ def _load(self) -> None: for task in self.star_context._register_tasks: extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore - tasks_ = [event_bus_task, *(extra_tasks if extra_tasks else [])] + tasks_ = [event_bus_task, *(extra_tasks if extra_tasks else [])] # [>] if cron_task: tasks_.append(cron_task) if temp_dir_cleaner_task: tasks_.append(temp_dir_cleaner_task) - for task in tasks_: + for task in tasks_: # 为每个任务创建 task self.curr_tasks.append( asyncio.create_task(self._task_wrapper(task), name=task.get_name()), ) diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 22409c0f83..dd7f903a7e 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -47,7 +47,7 @@ def _sanitize_platform_id(self, platform_id: str | None) -> tuple[str | None, bo return sanitized, sanitized != platform_id def _start_platform_task(self, task_name: str, inst: Platform) -> None: - run_task = asyncio.create_task(inst.run(), name=task_name) + run_task = asyncio.create_task(inst.run(), name=task_name) # 调用平台的 run 方法 wrapper_task = asyncio.create_task( self._task_wrapper(run_task, platform=inst), name=f"{task_name}_wrapper", @@ -90,7 +90,7 @@ async def initialize(self) -> None: try: if ensure_platform_webhook_config(platform): self.astrbot_config.save_config() - await self.load_platform(platform) + await self.load_platform(platform) # 里面会初始化platform实例,然后使用self.platform_insts.append()方法添加到self.platform_insts列表中 except Exception as e: logger.error(f"初始化 {platform} 平台适配器失败: {e}") diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index f3ade1589a..9413ccb3c7 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -1,164 +1,266 @@ +# 导入 asyncio 模块,用于异步编程,支持协程、事件循环和异步队列等 import asyncio +# 从 collections.abc 模块导入类型提示 Awaitable 和 Callable,用于类型注解 from collections.abc import Awaitable, Callable +# 从 astrbot 包导入 logger 对象,用于记录日志 from astrbot import logger +# 定义一个 WebChat 队列管理器类,用于管理 WebChat 中的异步消息队列 class WebChatQueueMgr: + # 初始化方法,创建队列管理器实例 def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None: + # 存储对话ID到异步队列的映射,用于存放用户输入消息的队列 self.queues: dict[str, asyncio.Queue] = {} """Conversation ID to asyncio.Queue mapping""" + # 存储请求ID到异步队列的映射,用于存放对应响应的队列(回传队列) self.back_queues: dict[str, asyncio.Queue] = {} """Request ID to asyncio.Queue mapping for responses""" + # 存储对话ID到其所有关联的请求ID集合的映射,用于追踪对话下的活跃请求 self._conversation_back_requests: dict[str, set[str]] = {} + # 存储请求ID到其所属对话ID的映射,用于快速查找请求所属对话 self._request_conversation: dict[str, str] = {} + # 存储对话ID到asyncio.Event的映射,用于通知监听器停止工作 self._queue_close_events: dict[str, asyncio.Event] = {} + # 存储对话ID到asyncio.Task的映射,用于管理每个对话的监听器协程任务 self._listener_tasks: dict[str, asyncio.Task] = {} + # 存储监听器回调函数,当队列有新消息时调用此异步回调处理消息 self._listener_callback: Callable[[tuple], Awaitable[None]] | None = None + # 设置输入队列的最大容量 self.queue_maxsize = queue_maxsize + # 设置回传队列的最大容量 self.back_queue_maxsize = back_queue_maxsize + # 获取或创建指定对话ID的输入队列 def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue: """Get or create a queue for the given conversation ID""" + # 检查该对话ID是否已经有对应的队列 if conversation_id not in self.queues: + # 如果没有,则创建一个新的异步队列,并设置最大容量 self.queues[conversation_id] = asyncio.Queue(maxsize=self.queue_maxsize) + # 同时为该对话创建一个关闭事件,用于后续停止监听器 self._queue_close_events[conversation_id] = asyncio.Event() + # 启动监听器(如果需要且回调已设置) self._start_listener_if_needed(conversation_id) + # 返回该对话ID对应的队列 return self.queues[conversation_id] + # 获取或创建指定请求ID的回传队列(可关联到对话) def get_or_create_back_queue( self, request_id: str, conversation_id: str | None = None, ) -> asyncio.Queue: """Get or create a back queue for the given request ID""" + # 检查该请求ID是否已经有对应的回传队列 if request_id not in self.back_queues: + # 如果没有,则创建一个新的异步回传队列,并设置最大容量 self.back_queues[request_id] = asyncio.Queue( maxsize=self.back_queue_maxsize ) + # 如果提供了对话ID,则进行关联映射 if conversation_id: + # 记录该请求ID属于哪个对话 self._request_conversation[request_id] = conversation_id + # 如果该对话ID在反向请求映射中不存在,则初始化一个空集合 if conversation_id not in self._conversation_back_requests: self._conversation_back_requests[conversation_id] = set() + # 将该请求ID添加到对话的活跃请求集合中 self._conversation_back_requests[conversation_id].add(request_id) + # 返回该请求ID对应的回传队列 return self.back_queues[request_id] + # 移除指定请求ID的回传队列及其关联关系 def remove_back_queue(self, request_id: str): """Remove back queue for the given request ID""" + # 从回传队列字典中删除该请求ID的队列,如果不存在则忽略 self.back_queues.pop(request_id, None) + # 从请求到对话的映射中取出该请求ID对应的对话ID conversation_id = self._request_conversation.pop(request_id, None) + # 如果存在关联的对话ID if conversation_id: + # 获取该对话ID下的所有请求ID集合 request_ids = self._conversation_back_requests.get(conversation_id) + # 如果集合存在 if request_ids is not None: + # 从集合中移除该请求ID request_ids.discard(request_id) + # 如果移除后集合为空,说明该对话没有活跃请求了 if not request_ids: + # 从反向请求映射中删除该对话ID的条目,清理内存 self._conversation_back_requests.pop(conversation_id, None) + # 移除指定对话ID的所有队列,包括其关联的所有回传队列和自身队列 def remove_queues(self, conversation_id: str) -> None: """Remove queues for the given conversation ID""" + # 遍历该对话下所有活跃的请求ID列表的副本,防止在迭代中修改集合 for request_id in list( self._conversation_back_requests.get(conversation_id, set()) ): + # 调用 remove_back_queue 方法移除每个关联的请求回传队列 self.remove_back_queue(request_id) + # 从反向请求映射中删除该对话ID的条目(可能在循环中已被删除,此处确保清理) self._conversation_back_requests.pop(conversation_id, None) + # 调用 remove_queue 方法移除该对话的输入队列和监听器 self.remove_queue(conversation_id) + # 移除指定对话ID的输入队列和监听器 def remove_queue(self, conversation_id: str): """Remove input queue and listener for the given conversation ID""" + # 从队列字典中删除该对话ID的输入队列 self.queues.pop(conversation_id, None) + # 从关闭事件字典中取出并删除该对话ID对应的关闭事件 close_event = self._queue_close_events.pop(conversation_id, None) + # 如果关闭事件存在 if close_event is not None: + # 设置该事件,通知所有等待此事件的监听器停止运行 close_event.set() + # 从监听器任务字典中取出并删除该对话ID对应的监听器协程任务 task = self._listener_tasks.pop(conversation_id, None) + # 如果任务存在且未完成 if task is not None: + # 取消该协程任务的执行 task.cancel() + # 列出指定对话ID下所有活跃的回传请求ID def list_back_request_ids(self, conversation_id: str) -> list[str]: """List active back-queue request IDs for a conversation.""" + # 返回该对话ID对应的请求ID集合的列表形式,如果不存在则返回空列表 return list(self._conversation_back_requests.get(conversation_id, set())) + # 检查指定对话ID是否存在输入队列 def has_queue(self, conversation_id: str) -> bool: """Check if a queue exists for the given conversation ID""" + # 判断该对话ID是否在队列字典的键中 return conversation_id in self.queues + # 设置监听器回调函数,并为所有已有对话启动监听任务 def set_listener( self, callback: Callable[[tuple], Awaitable[None]], ): + # 存储传入的回调函数,用于后续处理队列中的消息 self._listener_callback = callback + # 遍历当前所有已有队列的对话ID列表 for conversation_id in list(self.queues.keys()): - self._start_listener_if_needed(conversation_id) + # 为每个对话启动监听任务(如果尚未启动) + self._start_listener_if_needed(conversation_id) # 启动监听任务 + # 清除监听器,停止所有监听任务并清理相关事件 async def clear_listener(self) -> None: + # 将监听器回调设置为 None,后续消息将不会被处理 self._listener_callback = None + # 获取所有关闭事件的列表副本,并遍历 for close_event in list(self._queue_close_events.values()): + # 设置每个关闭事件,通知监听任务退出 close_event.set() + # 清空关闭事件字典 self._queue_close_events.clear() + # 获取所有监听器任务的列表副本 listener_tasks = list(self._listener_tasks.values()) + # 遍历所有监听器任务 for task in listener_tasks: + # 取消任务 task.cancel() + # 如果存在被取消的任务 if listener_tasks: + # 等待所有任务完成取消或抛出异常,并忽略这些异常 await asyncio.gather(*listener_tasks, return_exceptions=True) + # 清空监听器任务字典 self._listener_tasks.clear() + # 如果需要,为指定对话启动监听器任务(内部方法) def _start_listener_if_needed(self, conversation_id: str): + # 如果监听器回调尚未设置,则直接返回,无法启动监听 if self._listener_callback is None: return + # 检查该对话是否已有监听器任务 if conversation_id in self._listener_tasks: + # 获取已有任务 task = self._listener_tasks[conversation_id] + # 如果任务未完成,说明已在运行,直接返回 if not task.done(): return + # 获取该对话的输入队列 queue = self.queues.get(conversation_id) + # 获取该对话的关闭事件 close_event = self._queue_close_events.get(conversation_id) + # 如果队列或关闭事件不存在,则无法监听,直接返回 if queue is None or close_event is None: return + # 创建一个异步任务来运行监听协程 _listen_to_queue task = asyncio.create_task( - self._listen_to_queue(conversation_id, queue, close_event), + # TODO 监听消息队列 + self._listen_to_queue(conversation_id, queue, close_event), # 监听消息队列 + # 为任务指定一个有意义的名字,方便调试 name=f"webchat_listener_{conversation_id}", ) + # 将创建的任务保存到监听器任务字典中 self._listener_tasks[conversation_id] = task + # 为任务添加一个完成后的回调,用于自动清理 task.add_done_callback( + # 当任务完成(无论成功或异常)时,从字典中移除该任务的条目 lambda _: self._listener_tasks.pop(conversation_id, None) ) + # 记录调试日志,表示监听器已启动 logger.debug(f"Started listener for conversation: {conversation_id}") + # 监听指定对话队列的内部协程,持续从队列中获取消息并交给回调处理 async def _listen_to_queue( self, conversation_id: str, queue: asyncio.Queue, close_event: asyncio.Event, ): + # 无限循环,持续监听队列 while True: + # TODO 创建一个任务,用于从队列中获取消息(这是一个异步操作,会挂起直到有消息),有消息进来时,会返回该消息 get_task = asyncio.create_task(queue.get()) + # 创建一个任务,用于等待关闭事件被触发 close_task = asyncio.create_task(close_event.wait()) try: + # 同时等待 get_task 和 close_task,返回第一个完成的任务集合 done, pending = await asyncio.wait( {get_task, close_task}, return_when=asyncio.FIRST_COMPLETED, ) + # 取消所有未完成的任务,避免资源泄露(例如,如果 get_task 先完成,则取消 close_task) for task in pending: task.cancel() + # 如果关闭事件任务在完成的集合中,表示需要终止监听 if close_task in done: + # 跳出循环,结束协程 break + # 如果 get_task 先完成,则获取队列中的数据 data = get_task.result() + # 在调用回调前,再次检查回调是否还存在(可能在等待期间被清除) if self._listener_callback is None: + # 如果回调不存在,则跳过处理,继续下一次循环 continue try: + # TODO 调用监听器回调函数处理获取到的数据,这是一个异步调用 await self._listener_callback(data) except Exception as e: + # 捕获并记录回调函数中发生的任何异常,避免监听任务崩溃 logger.error( f"Error processing message from conversation {conversation_id}: {e}" ) except asyncio.CancelledError: + # 如果当前协程被取消,则退出循环 break finally: + # 确保在每次循环结束时清理任务,防止资源泄露 + # 如果 get_task 还未完成(例如因取消循环),则取消它 if not get_task.done(): get_task.cancel() + # 如果 close_task 还未完成,则取消它 if not close_task.done(): close_task.cancel() -webchat_queue_mgr = WebChatQueueMgr() +# 创建一个 WebChatQueueMgr 类的全局单例实例 +webchat_queue_mgr = WebChatQueueMgr() \ No newline at end of file diff --git a/astrbot/core/subagent_orchestrator.py b/astrbot/core/subagent_orchestrator.py index c6c595dfc9..a3e34e148d 100644 --- a/astrbot/core/subagent_orchestrator.py +++ b/astrbot/core/subagent_orchestrator.py @@ -1,104 +1,169 @@ +# 导入未来版本的注解特性,允许在类型注解中使用前向引用 from __future__ import annotations +# 导入深拷贝模块,用于创建对象的完整副本 import copy +# TYPE_CHECKING 用于类型检查时的条件导入,避免运行时循环引用 from typing import TYPE_CHECKING, Any +# 从 astrbot 导入日志记录器 from astrbot import logger +# 导入 Agent 类,表示一个智能代理 from astrbot.core.agent.agent import Agent +# 导入 HandoffTool 类,用于创建代理间的任务移交工具 from astrbot.core.agent.handoff import HandoffTool +# 导入 FunctionToolManager 类,管理所有可用的函数工具 from astrbot.core.provider.func_tool_manager import FunctionToolManager +# 类型检查时的条件导入块 if TYPE_CHECKING: + # 仅在类型检查时导入 PersonaManager,避免运行时依赖 from astrbot.core.persona_mgr import PersonaManager class SubAgentOrchestrator: - """Loads subagent definitions from config and registers handoff tools. - - This is intentionally lightweight: it does not execute agents itself. - Execution happens via HandoffTool in FunctionToolExecutor. + """ + 子代理编排器类 + + 作用:从配置中加载子代理定义,并注册移交工具。 + 设计原则:此类本身不执行代理,执行通过 FunctionToolExecutor 中的 HandoffTool 完成。 + 这是一个轻量级的编排层,负责配置管理和工具注册。 """ def __init__( self, tool_mgr: FunctionToolManager, persona_mgr: PersonaManager ) -> None: + """ + 初始化子代理编排器 + + 参数: + tool_mgr: 函数工具管理器,用于注册和管理工具 + persona_mgr: 人设管理器,管理代理的角色和提示词 + """ + # 保存函数工具管理器实例 self._tool_mgr = tool_mgr + # 保存人设管理器实例 self._persona_mgr = persona_mgr + # 初始化移交工具列表,存储所有已注册的 HandoffTool self.handoffs: list[HandoffTool] = [] async def reload_from_config(self, cfg: dict[str, Any]) -> None: + """ + 从配置字典重新加载子代理配置 + + 此方法会解析配置,创建 Agent 和 HandoffTool,并更新内部状态。 + + 参数: + cfg: 包含子代理配置的字典,应包含 "agents" 键 + """ + # 导入 AstrAgentContext 类,用于 Agent 的类型参数 from astrbot.core.astr_agent_context import AstrAgentContext + # 从配置中获取代理列表,默认为空列表 agents = cfg.get("agents", []) + # 检查代理配置是否为列表类型,不是则记录警告并返回 if not isinstance(agents, list): logger.warning("subagent_orchestrator.agents must be a list") return + # 初始化移交工具列表,用于存储本次加载的工具 handoffs: list[HandoffTool] = [] + # 遍历每个代理配置项 for item in agents: + # 跳过非字典类型的配置项 if not isinstance(item, dict): continue + # 检查代理是否启用(默认启用),未启用则跳过 if not item.get("enabled", True): continue + # 获取代理名称并去除首尾空格 name = str(item.get("name", "")).strip() + # 名称为空则跳过此代理 if not name: continue + # 获取人设 ID persona_id = item.get("persona_id") + # 如果人设 ID 存在,转换为字符串并去除空格,空字符串转为 None if persona_id is not None: persona_id = str(persona_id).strip() or None + # 通过人设 ID 获取人设数据 persona_data = self._persona_mgr.get_persona_v3_by_id(persona_id) + # 如果指定了人设但未找到对应数据,记录警告 if persona_id and persona_data is None: logger.warning( "SubAgent persona %s not found, fallback to inline prompt.", persona_id, ) + # 获取系统提示词(指令),去除首尾空格 instructions = str(item.get("system_prompt", "")).strip() + # 获取公开描述,去除首尾空格 public_description = str(item.get("public_description", "")).strip() + # 获取提供商 ID provider_id = item.get("provider_id") + # 如果提供商 ID 存在,转换为字符串并去除空格,空字符串转为 None if provider_id is not None: provider_id = str(provider_id).strip() or None + # 获取工具列表配置 tools = item.get("tools", []) + # 初始化对话开始数据为 None begin_dialogs = None + # 如果人设数据存在,使用人设数据覆盖配置 if persona_data: + # 获取人设的提示词,去除首尾空格 prompt = str(persona_data.get("prompt", "")).strip() + # 如果提示词不为空,使用人设提示词作为指令 if prompt: instructions = prompt + # 深拷贝处理后的对话开始数据,避免原始数据被修改 begin_dialogs = copy.deepcopy( persona_data.get("_begin_dialogs_processed") ) + # 获取人设的工具配置 tools = persona_data.get("tools") + # 如果公开描述为空且人设提示词存在,使用提示词的前120个字符作为描述 if public_description == "" and prompt: public_description = prompt[:120] + + # 工具配置规范化处理 if tools is None: + # 如果工具为 None,保持 None tools = None elif not isinstance(tools, list): + # 如果工具不是列表类型,设为空列表 tools = [] else: + # 过滤工具列表:转换为字符串,去除空格,并过滤空字符串 tools = [str(t).strip() for t in tools if str(t).strip()] + # 创建 Agent 实例,指定上下文类型为 AstrAgentContext agent = Agent[AstrAgentContext]( - name=name, - instructions=instructions, - tools=tools, # type: ignore + name=name, # 代理名称 + instructions=instructions, # 代理指令/系统提示词 + tools=tools, # type: ignore # 代理可用的工具列表 ) + # 设置代理的对话开始数据 agent.begin_dialogs = begin_dialogs - # The tool description should be a short description for the main LLM, - # while the subagent system prompt can be longer/more specific. + + # 创建移交工具 + # 工具描述是对主 LLM 的简短描述,子代理的系统提示词可以更长更具体 handoff = HandoffTool( - agent=agent, - tool_description=public_description or None, + agent=agent, # 关联的代理实例 + tool_description=public_description or None, # 工具描述(优先使用公开描述) ) - # Optional per-subagent chat provider override. + # 可选的子代理聊天提供商覆盖设置 handoff.provider_id = provider_id + # 将创建的移交工具添加到列表中 handoffs.append(handoff) + # 记录所有已注册的子代理移交工具 for handoff in handoffs: logger.info(f"Registered subagent handoff tool: {handoff.name}") - self.handoffs = handoffs + # 更新实例的移交工具列表为本次加载的结果 + self.handoffs = handoffs \ No newline at end of file diff --git a/astrbot/dashboard/api/app.py b/astrbot/dashboard/api/app.py index 00ceb11cfc..5ed0bf7f72 100644 --- a/astrbot/dashboard/api/app.py +++ b/astrbot/dashboard/api/app.py @@ -1,56 +1,94 @@ +# 启用 Python 的延迟注解求值特性(PEP 563),允许在类型注解中使用尚未定义的类型名称 from __future__ import annotations +# 从 types 模块导入 SimpleNamespace,用于创建简单的命名空间对象,可以用点号访问属性 from types import SimpleNamespace +# 从 fastapi 导入核心类:FastAPI 应用实例、HTTPException 异常类、Request 请求对象 from fastapi import FastAPI, HTTPException, Request +# 从 fastapi.responses 导入 JSONResponse,用于返回 JSON 格式的响应 from fastapi.responses import JSONResponse +# 导入核心模块:日志代理器,用于收集和分发日志消息 from astrbot.core import LogBroker +# 导入核心生命周期管理器,统筹整个 AstrBot 的启动、运行和关闭流程 from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +# 导入数据库抽象基类,定义统一的数据库操作接口 from astrbot.core.db import BaseDatabase +# 导入 API 响应工具:ApiError 自定义异常类和 error 响应构造辅助函数 from astrbot.dashboard.responses import ApiError, error + +# 以下是一系列服务层的导入,每个服务封装了特定业务逻辑,供路由层调用 +# API 密钥管理服务,处理 API 密钥的增删改查 from astrbot.dashboard.services.api_key_service import ApiKeyService +# 认证服务,处理用户登录、JWT 令牌生成与验证 from astrbot.dashboard.services.auth_service import AuthService +# 备份服务,管理系统配置和数据的备份与恢复 from astrbot.dashboard.services.backup_service import BackupService +# 聊天服务,处理聊天消息的收发逻辑 from astrbot.dashboard.services.chat_service import ChatService +# ChatUI 项目服务,管理前端聊天界面的项目配置 from astrbot.dashboard.services.chatui_project_service import ChatUIProjectService +# 命令服务,管理系统中的自定义命令 from astrbot.dashboard.services.command_service import CommandService +# 配置相关服务的分组导入: +# Bot 配置服务,管理机器人实例的配置 from astrbot.dashboard.services.config_service import ( BotConfigService, - ConfigDisplayService, - ConfigFileService, - ConfigProfileService, - ConfigRoutingService, - ProviderConfigService, + ConfigDisplayService, # 配置展示服务,格式化配置信息供前端展示 + ConfigFileService, # 配置文件服务,管理配置文件读写 + ConfigProfileService, # 配置概要服务,管理多套配置方案 + ConfigRoutingService, # 配置路由服务,管理配置与路由的映射 + ProviderConfigService, # 提供者配置服务,管理 LLM 提供商的配置 ) +# 对话服务,管理聊天对话的创建、历史记录等 from astrbot.dashboard.services.conversation_service import ConversationService +# 定时任务服务,管理系统中的 cron 定时任务 from astrbot.dashboard.services.cron_service import CronService +# 文件服务,处理文件上传、下载等操作 from astrbot.dashboard.services.file_service import FileService +# 知识库服务,管理 RAG 知识库的文档和检索 from astrbot.dashboard.services.knowledge_base_service import KnowledgeBaseService +# 实时聊天服务,管理基于 WebSocket 的实时对话 from astrbot.dashboard.services.live_chat_service import LiveChatService +# 日志服务,提供日志查询和流式推送功能 from astrbot.dashboard.services.log_service import LogService +# OpenAPI 服务,管理对外 API 的配置和调用 from astrbot.dashboard.services.open_api_service import OpenApiService +# 人格服务,管理 AI 回复的角色设定(人设) from astrbot.dashboard.services.persona_service import PersonaService +# 平台服务,管理机器人接入的不同消息平台(如 QQ、微信等) from astrbot.dashboard.services.platform_service import PlatformService +# 插件页面服务,管理插件提供的自定义前端页面 from astrbot.dashboard.services.plugin_page_service import PluginPageService +# 插件服务,管理插件的安装、卸载、启用、禁用等生命周期 from astrbot.dashboard.services.plugin_service import PluginService +# 会话管理服务,管理用户与机器人的会话状态 from astrbot.dashboard.services.session_management_service import ( SessionManagementService, ) +# 技能服务,管理 AI 可调用的技能和工具 from astrbot.dashboard.services.skills_service import SkillsService +# 统计服务,提供系统使用数据的统计和汇总 from astrbot.dashboard.services.stat_service import StatService +# 子智能体服务,管理子 AI 代理的配置和运行 from astrbot.dashboard.services.subagent_service import SubAgentService +# 文本转图片服务,管理文本转图片的相关功能 from astrbot.dashboard.services.t2i_service import T2iService +# 工具服务,管理可供 AI 调用的外部工具 from astrbot.dashboard.services.tools_service import ToolsService +# 更新服务,管理系统和控制面板的版本更新 from astrbot.dashboard.services.update_service import ( - DEMO_MODE, - UpdateService, - call_download_dashboard, - call_extract_dashboard, - call_get_dashboard_version, - call_pip_install, + DEMO_MODE, # 演示模式标志,控制是否启用演示限制 + UpdateService, # 更新服务主类 + call_download_dashboard, # 下载控制面板静态文件的函数 + call_extract_dashboard, # 解压控制面板文件的函数 + call_get_dashboard_version, # 获取当前控制面板版本的函数 + call_pip_install, # 执行 pip 安装命令的函数 ) +# 从当前包导入各个路由模块的 legacy 路由器(用于兼容旧版 API 路径) +# 每个 legacy_router 包含了对应模块的旧版路由规则 from .api_keys import legacy_router as legacy_api_keys_router from .auth import legacy_router as legacy_auth_router from .backups import legacy_router as legacy_backups_router @@ -69,9 +107,11 @@ from .platform import legacy_router as legacy_platform_router from .plugins import legacy_router as legacy_plugins_router from .providers import legacy_router as legacy_providers_router +# 从 router 模块导入 API 版本前缀常量和用于构建新 API 路由器的工厂函数 from .router import API_V1_PREFIX, build_api_router from .sessions import legacy_router as legacy_sessions_router from .skills import legacy_router as legacy_skills_router +# 导入静态文件路由器,负责提供控制面板的前端静态资源 from .static_files import router as static_files_router from .stats import legacy_router as legacy_stats_router from .subagents import legacy_router as legacy_subagents_router @@ -79,115 +119,170 @@ from .tools import legacy_router as legacy_tools_router from .updates import legacy_router as legacy_updates_router +# 定义清空站点数据的响应头,用于通知浏览器清除缓存('cache' 表示清除所有缓存数据) CLEAR_SITE_DATA_HEADERS = {"Clear-Site-Data": '"cache"'} +# 工厂函数:创建并配置整个控制面板的 ASGI 应用实例 def create_dashboard_asgi_app( *, + # 核心生命周期管理器,提供对整个 AstrBot 运行状态的访问和控制 core_lifecycle: AstrBotCoreLifecycle, + # 数据库实例,用于持久化存储各类业务数据 db: BaseDatabase, + # JWT 签名密钥,用于生成和验证认证令牌 jwt_secret: str, + # 控制面板静态文件夹的路径,如果不提供则不启用静态文件服务 static_folder: str | None = None, ) -> FastAPI: + # 创建 FastAPI 应用实例,并配置 OpenAPI 文档相关参数 app = FastAPI( - title="AstrBot OpenAPI", - version="1.0.0", - openapi_url=f"{API_V1_PREFIX}/openapi.json", - docs_url=f"{API_V1_PREFIX}/docs", - redoc_url=f"{API_V1_PREFIX}/redoc", + title="AstrBot OpenAPI", # API 文档标题 + version="1.0.0", # API 版本号 + openapi_url=f"{API_V1_PREFIX}/openapi.json", # OpenAPI JSON 规范文档的访问地址 + docs_url=f"{API_V1_PREFIX}/docs", # Swagger UI 文档页面的访问地址 + redoc_url=f"{API_V1_PREFIX}/redoc", # ReDoc 文档页面的访问地址 ) + # 将核心组件挂载到 app.state 上,供所有路由和服务通过 request.app.state 访问 + # 存储核心生命周期管理器 app.state.core_lifecycle = core_lifecycle + # 存储数据库实例 app.state.db = db + # 存储 JWT 密钥,供认证中间件和路由使用 app.state.jwt_secret = jwt_secret + # 存储静态文件夹路径,供静态文件路由器使用 app.state.dashboard_static_folder = static_folder + # 获取或创建日志代理器实例,用于日志的收集和分发 log_broker = getattr(core_lifecycle, "log_broker", None) or LogBroker() + # 使用 SimpleNamespace 创建服务容器,将所有业务服务以属性方式组织在一起 app.state.services = SimpleNamespace( + # 配置方案服务:管理多套配置的切换和保存 config_profiles=ConfigProfileService(core_lifecycle, db), + # 配置展示服务:格式化配置数据供前端展示 config_display=ConfigDisplayService(core_lifecycle), + # 配置文件服务:处理配置文件(如 YAML/JSON)的读写操作 config_files=ConfigFileService(core_lifecycle), + # 配置路由服务:管理消息路由规则的配置 config_routes=ConfigRoutingService(core_lifecycle), + # API 密钥服务:管理用于外部调用的 API Key api_keys=ApiKeyService(db), + # 认证服务:处理用户登录和 JWT 令牌 auth=AuthService(db, core_lifecycle.astrbot_config), + # 备份服务:管理系统数据的备份与恢复 backups=BackupService(db, core_lifecycle), + # 聊天服务:处理 AI 对话的核心逻辑 chat=ChatService(db, core_lifecycle), + # ChatUI 项目服务:管理前端聊天界面的项目配置 chat_projects=ChatUIProjectService(db), + # 命令服务:管理自定义命令的注册和执行 commands=CommandService(core_lifecycle.astrbot_config, core_lifecycle), + # 对话服务:管理对话会话的创建和历史 conversations=ConversationService(db, core_lifecycle), + # 定时任务服务:管理 cron 定时任务的增删改 cron=CronService(core_lifecycle), + # 文件服务:处理文件的上传和下载 files=FileService(), + # 知识库服务:管理知识库文档和向量检索 knowledge_bases=KnowledgeBaseService(core_lifecycle), + # 实时聊天服务:管理 WebSocket 连接和实时消息推送 live_chat=LiveChatService(db, core_lifecycle), + # 日志服务:提供日志查询和流式推送 logs=LogService(log_broker, core_lifecycle.astrbot_config), + # Bot 配置服务:管理机器人实例的配置参数 bots=BotConfigService(core_lifecycle), + # 平台服务:管理接入的消息平台 platforms=PlatformService(core_lifecycle), + # 提供者配置服务:管理 LLM 提供商的 API 配置 providers=ProviderConfigService(core_lifecycle), + # 人格服务:管理 AI 回复的人设和角色 personas=PersonaService(core_lifecycle), + # 插件服务:管理插件的生命周期 plugins=PluginService(core_lifecycle, core_lifecycle.plugin_manager), + # 插件页面服务:管理插件提供的自定义网页 plugin_pages=PluginPageService( core_lifecycle.plugin_manager, - core_lifecycle=core_lifecycle, + core_lifecycle=core_lifecycle, # 显式传递核心生命周期引用 ), + # OpenAPI 服务:管理对外 API 的配置 open_api=OpenApiService(db, core_lifecycle), + # 会话管理服务:管理用户与机器人的交互会话 sessions=SessionManagementService(core_lifecycle, db), + # 技能服务:管理 AI 可调用的技能 skills=SkillsService(core_lifecycle), + # 统计服务:收集和汇总系统运行数据 stats=StatService(db, core_lifecycle, core_lifecycle.astrbot_config), + # 子智能体服务:管理子 AI 代理 subagents=SubAgentService(core_lifecycle), + # 文本转图片服务:管理文字转图片的渲染 t2i=T2iService(core_lifecycle), + # 工具服务:管理外部工具的定义和调用 tools=ToolsService(core_lifecycle), + # 更新服务:管理系统版本更新 updates=UpdateService( - core_lifecycle.astrbot_updator, - core_lifecycle, - download_dashboard_func=call_download_dashboard, - extract_dashboard_func=call_extract_dashboard, - get_dashboard_version_func=call_get_dashboard_version, - pip_install_func=call_pip_install, - demo_mode=DEMO_MODE, - clear_site_data_headers=CLEAR_SITE_DATA_HEADERS, + core_lifecycle.astrbot_updator, # 更新器实例 + core_lifecycle, # 核心生命周期引用 + download_dashboard_func=call_download_dashboard, # 下载控制面板的回调函数 + extract_dashboard_func=call_extract_dashboard, # 解压控制面板的回调函数 + get_dashboard_version_func=call_get_dashboard_version, # 获取版本的回调函数 + pip_install_func=call_pip_install, # pip 安装的回调函数 + demo_mode=DEMO_MODE, # 是否为演示模式 + clear_site_data_headers=CLEAR_SITE_DATA_HEADERS, # 清除缓存的响应头 ), ) + # 注册全局异常处理器:捕获并处理 ApiError 自定义异常 @app.exception_handler(ApiError) async def api_error_handler(_request: Request, exc: ApiError): + # 返回 JSON 格式的错误响应,状态码由异常对象提供 return JSONResponse( - error(exc.message, exc.data), + error(exc.message, exc.data), # 使用 error 辅助函数构造标准错误响应体 status_code=exc.status_code, ) + # 注册全局异常处理器:捕获并处理 ValueError 异常(通常表示参数验证失败) @app.exception_handler(ValueError) async def value_error_handler(_request: Request, exc: ValueError): + # 返回 400 Bad Request 错误,消息为异常的描述信息 return JSONResponse(error(str(exc)), status_code=400) + # 注册全局异常处理器:捕获并处理 FastAPI 内置的 HTTPException 异常 @app.exception_handler(HTTPException) async def http_error_handler(_request: Request, exc: HTTPException): + # 提取异常中的详细信息,如果是字符串则直接使用,否则使用默认消息 detail = exc.detail if isinstance(exc.detail, str) else "Request failed" + # 返回相应状态码的 JSON 错误响应 return JSONResponse(error(detail), status_code=exc.status_code) # Legacy dashboard routes keep old /api/* callers working without entering OpenAPI. - app.include_router(legacy_api_keys_router) - app.include_router(legacy_auth_router) - app.include_router(legacy_backups_router) - app.include_router(legacy_config_profiles_router) - app.include_router(legacy_bots_router) - app.include_router(legacy_providers_router) - app.include_router(legacy_chat_router) - app.include_router(legacy_chat_projects_router) - app.include_router(legacy_conversations_router) - app.include_router(legacy_cron_router) - app.include_router(legacy_extensions_router) - app.include_router(legacy_files_router) - app.include_router(legacy_knowledge_bases_router) - app.include_router(legacy_live_chat_router) - app.include_router(legacy_logs_router) - app.include_router(legacy_sessions_router) - app.include_router(legacy_skills_router) - app.include_router(legacy_stats_router) - app.include_router(legacy_subagents_router) - app.include_router(legacy_tools_router) - app.include_router(legacy_platform_router) - app.include_router(legacy_plugins_router) - app.include_router(legacy_t2i_router) - app.include_router(legacy_personas_router) - app.include_router(legacy_updates_router) + # 将各个旧版路由器注册到应用中,保持向后兼容,使得旧版 /api/* 路径请求仍然有效 + app.include_router(legacy_api_keys_router) # API 密钥管理相关路由 + app.include_router(legacy_auth_router) # 认证相关路由 + app.include_router(legacy_backups_router) # 备份相关路由 + app.include_router(legacy_config_profiles_router) # 配置方案相关路由 + app.include_router(legacy_bots_router) # 机器人配置相关路由 + app.include_router(legacy_providers_router) # 提供者配置相关路由 + app.include_router(legacy_chat_router) # 聊天相关路由 + app.include_router(legacy_chat_projects_router) # ChatUI 项目相关路由 + app.include_router(legacy_conversations_router) # 对话管理相关路由 + app.include_router(legacy_cron_router) # 定时任务相关路由 + app.include_router(legacy_extensions_router) # 扩展相关路由 + app.include_router(legacy_files_router) # 文件管理相关路由 + app.include_router(legacy_knowledge_bases_router) # 知识库相关路由 + app.include_router(legacy_live_chat_router) # 实时聊天相关路由 + app.include_router(legacy_logs_router) # 日志相关路由 + app.include_router(legacy_sessions_router) # 会话管理相关路由 + app.include_router(legacy_skills_router) # 技能相关路由 + app.include_router(legacy_stats_router) # 统计相关路由 + app.include_router(legacy_subagents_router) # 子智能体相关路由 + app.include_router(legacy_tools_router) # 工具相关路由 + app.include_router(legacy_platform_router) # 平台相关路由 + app.include_router(legacy_plugins_router) # 插件相关路由 + app.include_router(legacy_t2i_router) # 文本转图片相关路由 + app.include_router(legacy_personas_router) # 人格管理相关路由 + app.include_router(legacy_updates_router) # 系统更新相关路由 + # 注册新版本 API 路由器(基于 OpenAPI 规范的标准化路由) app.include_router(build_api_router()) + # 注册静态文件路由器,负责提供控制面板的前端页面和资源 app.include_router(static_files_router) - return app + # 返回配置完成的 FastAPI 应用实例 + return app \ No newline at end of file diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 6776927ba0..34a8dd5243 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -176,7 +176,7 @@ def __init__( core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase, shutdown_event: asyncio.Event, - webui_dir: str | None = None, + webui_dir: str | None = None, # 项目目录的 AstrBot\\data\\dist 路径 ) -> None: self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config @@ -225,7 +225,7 @@ def __init__( self._rate_limiter_registry = _RateLimiterRegistry() self._init_jwt_secret() - self.asgi_app = create_dashboard_asgi_app( + self.asgi_app = create_dashboard_asgi_app( # 启动dashboard页面 core_lifecycle=core_lifecycle, db=db, jwt_secret=self._jwt_secret, diff --git a/astrbot/dashboard/services/chat_service.py b/astrbot/dashboard/services/chat_service.py index 80bf6bbf3b..7ff6d67f5b 100644 --- a/astrbot/dashboard/services/chat_service.py +++ b/astrbot/dashboard/services/chat_service.py @@ -938,7 +938,7 @@ def build_attachment_saved_event(part: dict | None) -> str | None: webchat_queue_mgr.remove_back_queue(message_id) chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id) - await chat_queue.put( + await chat_queue.put( # TODO 这里将信息放入queue ( username, webchat_conv_id, diff --git a/dashboard/src/composables/useMessages.ts b/dashboard/src/composables/useMessages.ts index 59d591af20..1d1dd42cd2 100644 --- a/dashboard/src/composables/useMessages.ts +++ b/dashboard/src/composables/useMessages.ts @@ -302,6 +302,7 @@ export function useMessages(options: UseMessagesOptions) { ); return; } + // TODO 使用startSseStream startSseStream( sessionId, messageId, @@ -514,6 +515,7 @@ export function useMessages(options: UseMessagesOptions) { skipUserHistory = false, llmCheckpointId: string | null = null, ) { + console.log("使用startSseStream", chatApi.sendStreamUrl()) const abort = new AbortController(); activeConnections[sessionId] = { sessionId, @@ -568,6 +570,7 @@ export function useMessages(options: UseMessagesOptions) { selectedProvider: string, selectedModel: string, ) { + console.log("使用starWeb") const ws = getOrCreateChatWebSocket(sessionId); activeConnections[sessionId] = { From 5161eb45b139fa5484f8247bde31a8e62082cb5e Mon Sep 17 00:00:00 2001 From: supoMaker Date: Thu, 25 Jun 2026 23:08:21 +0800 Subject: [PATCH 7/7] debug --- .../agent/runners/tool_loop_agent_runner.py | 2 +- astrbot/core/event_bus.py | 2 +- .../core/pipeline/preprocess_stage/stage.py | 262 ++++++++++--- .../method/agent_sub_stages/internal.py | 4 +- astrbot/core/pipeline/process_stage/stage.py | 2 +- astrbot/core/pipeline/respond/stage.py | 6 +- astrbot/core/pipeline/waking_check/stage.py | 241 +++++++++--- astrbot/core/platform/platform.py | 2 +- .../sources/webchat/webchat_adapter.py | 8 +- .../platform/sources/webchat/webchat_event.py | 2 +- .../sources/webchat/webchat_queue_mgr.py | 6 +- astrbot/dashboard/api/open_api.py | 7 +- astrbot/dashboard/server.py | 362 +++++++++++++++--- astrbot/dashboard/services/chat_service.py | 28 +- 14 files changed, 738 insertions(+), 196 deletions(-) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index b56d7e62fb..74a24152f2 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -512,7 +512,7 @@ async def _iter_llm_responses_with_fallback( has_stream_output = False with attempt: try: - async for resp in self._iter_llm_responses( + async for resp in self._iter_llm_responses( # TODO 调用llm模型,获取流式响应 include_model=idx == 0 ): if resp.is_chunk: diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index c2a96273ec..2b9537a95e 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -38,7 +38,7 @@ def __init__( async def dispatch(self) -> None: while True: - event: AstrMessageEvent = await self.event_queue.get() + event: AstrMessageEvent = await self.event_queue.get() # 这里会监控 event_queue,监听新信息 conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) conf_id = conf_info["id"] conf_name = conf_info.get("name") or conf_id diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index c284600a62..5cd951d307 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -1,161 +1,281 @@ -import asyncio -import random -import traceback -from collections.abc import AsyncGenerator -from pathlib import Path - -from astrbot.core import logger -from astrbot.core.message.components import Image, Plain, Record, Reply -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path -from astrbot.core.utils.media_utils import ( - describe_media_ref, - ensure_jpeg, - ensure_wav, - file_uri_to_path, - is_file_uri, +# 导入所需的库和模块 +import asyncio # 异步IO支持,用于处理异步操作 +import random # 随机数生成,用于随机选择表情 +import traceback # 异常追踪,用于格式化异常信息 +from collections.abc import AsyncGenerator # 异步生成器类型注解 +from pathlib import Path # 文件路径处理 + +# 从AstrBot核心模块导入 +from astrbot.core import logger # 日志记录器 +from astrbot.core.message.components import Image, Plain, Record, Reply # 消息组件类型 +from astrbot.core.platform.astr_message_event import AstrMessageEvent # 消息事件基类 +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path # 获取临时文件路径 +from astrbot.core.utils.media_utils import ( # 多媒体工具函数 + describe_media_ref, # 描述媒体引用(用于日志) + ensure_jpeg, # 确保图像为JPEG格式 + ensure_wav, # 确保音频为WAV格式 + file_uri_to_path, # 文件URI转本地路径 + is_file_uri, # 判断是否为文件URI ) -from ..context import PipelineContext -from ..stage import Stage, register_stage +# 导入管道相关的基类和注册装饰器 +from ..context import PipelineContext # 管道上下文 +from ..stage import Stage, register_stage # 管道阶段基类和注册装饰器 -@register_stage +@register_stage # 注册为管道阶段,使该阶段能被管道自动发现和调用 class PreProcessStage(Stage): + """ + 消息预处理阶段 + + 功能:在消息进入核心处理之前进行必要的预处理工作,包括: + 1. 发送预回应表情(如Telegram的表情回应) + 2. 路径映射处理(支持不同环境的文件路径转换) + 3. 媒体文件格式标准化(音频转WAV、图像转JPEG) + 4. 语音转文本处理(STT - Speech to Text) + 5. 临时文件生命周期管理 + + 该阶段确保下游处理器接收到的消息都是标准化、可处理的形式 + """ + async def initialize(self, ctx: PipelineContext) -> None: + """ + 初始化预处理阶段,加载必要的配置 + + 功能:在管道启动时进行初始化,从上下文配置中加载各项设置 + 包括平台设置、STT设置等 + + Args: + ctx (PipelineContext): 管道上下文对象,包含全局配置和插件管理器 + """ + # 保存管道上下文的引用 self.ctx = ctx + # 保存全局配置的引用 self.config = ctx.astrbot_config + # 保存插件管理器的引用 self.plugin_manager = ctx.plugin_manager + # 获取语音转文本(STT)的配置设置,默认为空字典 self.stt_settings: dict = self.config.get("provider_stt_settings", {}) + # 获取平台通用设置,默认为空字典 self.platform_settings: dict = self.config.get("platform_settings", {}) @staticmethod def _track_temp_media(event: AstrMessageEvent, media_path: str) -> None: - """Track a media file owned by the current event. - + """ + 追踪事件拥有的临时媒体文件 + + 功能:管理临时媒体文件的生命周期,确保当事件结束时临时文件能被正确清理 + 只追踪位于AstrBot临时目录下的文件,避免意外删除用户文件 + + 工作原理: + 1. 解析媒体文件的绝对路径 + 2. 检查文件是否在AstrBot临时目录下 + 3. 如果是临时文件,注册到事件的生命周期管理中 + Args: - event: Message event whose lifecycle owns the temporary file. - media_path: Local media path to track when it lives under AstrBot temp. + event (AstrMessageEvent): 拥有该临时文件的消息事件 + media_path (str): 需要追踪的本地媒体文件路径 """ - try: + # 将媒体路径解析为绝对路径 path = Path(media_path).resolve() + # 获取AstrBot临时目录的绝对路径 temp_dir = Path(get_astrbot_temp_path()).resolve() + # 检查媒体文件是否位于临时目录下 + # relative_to 会抛出异常如果路径不在临时目录下 path.relative_to(temp_dir) except (OSError, ValueError): + # 如果文件不在临时目录下或路径无效,不进行追踪,直接返回 return + # 将临时文件注册到事件中,事件结束时会自动清理 event.track_temporary_local_file(str(path)) async def process( self, event: AstrMessageEvent, ) -> None | AsyncGenerator[None, None]: - """在处理事件之前的预处理""" - # 平台特异配置:platform_specific..pre_ack_emoji + """ + 处理消息事件的预处理流程 + + 功能:执行完整的消息预处理流程,包括: + 1. 平台特定功能(如预回应表情) + 2. 路径映射转换 + 3. 媒体文件格式标准化 + 4. 回复链中的媒体处理 + 5. 语音转文本转换 + + Args: + event (AstrMessageEvent): 需要预处理的消息事件 + + Returns: + None | AsyncGenerator[None, None]: 通常返回None,异步生成器用于流式处理 + """ + + # ===== 第一步:平台特定的预回应表情处理 ===== + # 定义支持预回应表情的平台列表 supported = {"telegram", "lark", "discord"} + # 获取当前消息来自的平台名称 platform = event.get_platform_name() + + # 从配置中获取平台特定的预回应表情设置 + # 配置路径:platform_specific..pre_ack_emoji cfg = ( - self.config.get("platform_specific", {}) - .get(platform, {}) - .get("pre_ack_emoji", {}) - ) or {} + self.config.get("platform_specific", {}) # 获取平台特定配置 + .get(platform, {}) # 获取当前平台的配置 + .get("pre_ack_emoji", {}) # 获取预回应表情配置 + ) or {} # 如果为None则使用空字典 + # 获取可用的表情列表 emojis = cfg.get("emojis") or [] + + # 检查是否需要发送预回应表情 if ( - cfg.get("enable", False) - and platform in supported - and emojis - and event.is_at_or_wake_command + cfg.get("enable", False) # 功能已启用 + and platform in supported # 当前平台支持 + and emojis # 有可用的表情列表 + and event.is_at_or_wake_command # 是唤醒命令或@消息 ): try: + # 随机选择一个表情并发送给消息作为回应 await event.react(random.choice(emojis)) except Exception as e: + # 表情发送失败时记录警告日志 logger.warning(f"{platform} 预回应表情发送失败: {e}") - # 路径映射 + # ===== 第二步:路径映射处理 ===== + # 检查是否配置了路径映射规则 if mappings := self.platform_settings.get("path_mapping", []): - # 支持 Record,Image 消息段的路径映射。 + # 获取消息的所有消息组件 message_chain = event.get_messages() + # 遍历每个消息组件 for idx, component in enumerate(message_chain): + # 只处理Record或Image类型的组件,且必须有URL if isinstance(component, Record | Image) and component.url: + # 遍历所有映射规则 for mapping in mappings: + # 解析映射规则:格式为 "原始路径:目标路径" from_, to_ = mapping.split(":") + # 去除路径末尾的斜杠,统一格式 from_ = from_.removesuffix("/") to_ = to_.removesuffix("/") + # 获取URL的实际路径(如果是文件URI则转换) url = ( - file_uri_to_path(component.url) - if is_file_uri(component.url) - else component.url + file_uri_to_path(component.url) # 文件URI转路径 + if is_file_uri(component.url) # 判断是否为文件URI + else component.url # 不是URI则保持原样 ) + # 如果URL以映射的源路径开头 if url.startswith(from_): + # 执行路径替换映射 component.url = url.replace(from_, to_, 1) + # 记录映射的调试信息 logger.debug(f"路径映射: {url} -> {component.url}") + # 更新消息链中的组件 message_chain[idx] = component - # Normalize provider-facing media early so downstream code sees local files. + # ===== 第三步:媒体文件格式标准化 ===== + # 获取消息链(可能在第二步中被修改) message_chain = event.get_messages() + + # 遍历所有消息组件进行标准化处理 for idx, component in enumerate(message_chain): + # 处理音频组件(Record) if isinstance(component, Record): try: + # 将音频组件转换为本地文件路径 original_path = await component.convert_to_file_path() + # 追踪原始音频文件的临时文件 self._track_temp_media(event, original_path) + # 确保音频格式为WAV(统一的音频格式) record_path = await ensure_wav(original_path) + # 追踪转换后的WAV文件 self._track_temp_media(event, record_path) + # 更新组件的文件路径属性 component.file = record_path component.path = record_path + # 更新消息链中的组件 message_chain[idx] = component except Exception as e: + # 音频处理失败时记录警告 logger.warning(f"Voice processing failed: {e}") + + # 处理图像组件(Image) elif isinstance(component, Image): try: + # 将图像组件转换为本地文件路径 original_path = await component.convert_to_file_path() + # 追踪原始图像文件的临时文件 self._track_temp_media(event, original_path) + # 确保图像格式为JPEG(统一的图像格式) image_path = await ensure_jpeg(original_path) + # 追踪转换后的JPEG文件 self._track_temp_media(event, image_path) + # 更新组件的文件路径和URL属性 component.file = image_path component.path = image_path - # Image.convert_to_file_path() prefers url, so keep it aligned. + # Image.convert_to_file_path() 方法优先使用url属性,所以保持url同步 component.url = image_path + # 更新消息链中的组件 message_chain[idx] = component except Exception as e: + # 获取媒体引用描述用于日志 media_ref = component.url or component.file logger.warning( "Image processing failed for %s: %s", - describe_media_ref(media_ref), + describe_media_ref(media_ref), # 描述媒体来源 e, ) - # Also normalize media components inside Reply chains. + # ===== 第四步:处理回复链中的媒体组件 ===== + # 遍历所有消息组件,检查是否包含回复(Reply) for component in event.get_messages(): + # 如果是回复消息且有回复的消息链 if isinstance(component, Reply) and component.chain: + # 遍历回复消息链中的每个组件 for idx, reply_comp in enumerate(component.chain): + # 处理回复中的音频组件 if isinstance(reply_comp, Record): try: + # 转换为本地文件路径 original_path = await reply_comp.convert_to_file_path() + # 追踪临时文件 self._track_temp_media(event, original_path) + # 转换为WAV格式 record_path = await ensure_wav(original_path) + # 追踪转换后的文件 self._track_temp_media(event, record_path) + # 更新组件属性 reply_comp.file = record_path reply_comp.path = record_path + # 更新回复链中的组件 component.chain[idx] = reply_comp except Exception as e: + # 回复链中的音频处理失败 logger.warning( f"Voice processing in reply chain failed: {e}" ) + # 处理回复中的图像组件 elif isinstance(reply_comp, Image): try: + # 转换为本地文件路径 original_path = await reply_comp.convert_to_file_path() + # 追踪临时文件 self._track_temp_media(event, original_path) + # 转换为JPEG格式 image_path = await ensure_jpeg(original_path) + # 追踪转换后的文件 self._track_temp_media(event, image_path) + # 更新组件属性 reply_comp.file = image_path reply_comp.path = image_path - # Image.convert_to_file_path() prefers url, so keep it aligned. + # 保持url与文件路径同步 reply_comp.url = image_path + # 更新回复链中的组件 component.chain[idx] = reply_comp except Exception as e: + # 获取媒体引用描述 media_ref = reply_comp.url or reply_comp.file logger.warning( "Image processing in reply chain failed for %s: %s", @@ -163,63 +283,103 @@ async def process( e, ) - # STT + # ===== 第五步:语音转文本处理(STT) ===== + # 检查是否启用了语音转文本功能 if self.stt_settings.get("enable", False): - # TODO: 独立 + # 获取插件管理器上下文 ctx = self.plugin_manager.context + # 获取当前会话的STT提供者 stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin) + + # 如果没有配置STT提供者,记录警告并返回 if not stt_provider: logger.warning( f"会话 {event.unified_msg_origin} 未配置语音转文本模型。", ) return + # 定义内部异步函数:处理单个音频组件的语音转文本 async def _stt_record(record_comp: Record, is_reply: bool = False): - """对单个 Record 组件执行语音转文本,成功返回 Plain,失败返回 None。""" + """ + 对单个音频组件执行语音转文本转换 + + 功能:将音频组件转换为文本,支持重试机制 + 因为某些平台(如napcat)的文件可能不会立即就绪 + + Args: + record_comp (Record): 需要转换的音频组件 + is_reply (bool): 是否为回复消息中的音频 + + Returns: + Plain | None: 成功返回包含文本的Plain组件,失败返回None + """ + # 根据是否为回复消息设置前缀文本 prefix = "引用消息" if is_reply else "" try: + # 将音频组件转换为本地文件路径 path = await record_comp.convert_to_file_path() except Exception as e: + # 获取音频路径失败 logger.warning(f"获取{prefix}语音路径失败: {e}") return None + # 设置重试次数为5次 retry = 5 + # 重试循环 for i in range(retry): try: + # 调用STT提供者进行语音转文本 result = await stt_provider.get_text(audio_url=path) if result: + # 转文本成功,添加标记后缀 suffix = "(引用消息)" if is_reply else "" logger.info(f"语音转文本{suffix}结果: " + result) + # 返回包含文本的Plain组件 return Plain(result) + # 如果结果为空,跳出重试循环 break except FileNotFoundError: - # napcat workaround: file may not be ready immediately + # 文件未找到的特殊处理(napcat平台的已知问题) + # 文件可能不会立即准备就绪,等待后重试 logger.debug(f"文件尚未就绪 ({path}),重试 {i + 1}/{retry}") - await asyncio.sleep(0.5) + await asyncio.sleep(0.5) # 等待0.5秒后重试 continue except BaseException as e: + # 其他异常,记录错误并停止重试 logger.error(traceback.format_exc()) suffix = "(引用消息)" if is_reply else "" logger.error(f"语音转文本{suffix}失败: {e}") break + # 所有重试都失败,返回None return None + # 处理当前消息中的音频组件 message_chain = event.get_messages() for idx, component in enumerate(message_chain): + # 如果是音频组件 if isinstance(component, Record): + # 执行语音转文本 plain_comp = await _stt_record(component) if plain_comp: + # 转换成功,替换音频组件为文本组件 message_chain[idx] = plain_comp + # 将识别出的文本添加到消息字符串中 event.message_str += plain_comp.text event.message_obj.message_str += plain_comp.text - # Also STT for Record components inside Reply chains + # 处理回复消息链中的音频组件 for component in event.get_messages(): + # 如果是回复消息且有回复链 if isinstance(component, Reply) and component.chain: + # 遍历回复链中的组件 for idx, reply_comp in enumerate(component.chain): + # 如果是音频组件 if isinstance(reply_comp, Record): + # 执行语音转文本(标记为回复消息) plain_comp = await _stt_record(reply_comp, is_reply=True) if plain_comp: + # 替换音频组件为文本组件 component.chain[idx] = plain_comp + # 将文本添加到消息字符串中 event.message_str += plain_comp.text - event.message_obj.message_str += plain_comp.text + event.message_obj.message_str += plain_comp.text \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 0b636b5b2b..91101ef93e 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -222,7 +222,7 @@ async def process( provider_wake_prefix=provider_wake_prefix, streaming_response=streaming_response, ) - + # TODO 构建主代理 build_result: MainAgentBuildResult | None = await build_main_agent( event=event, plugin_context=self.ctx.plugin_manager.context, @@ -334,7 +334,7 @@ async def process( ) elif streaming_response and not stream_to_general: - # 流式响应 + # TODO 流式响应,这里设置 set_async_stream(run_agent()),其里面调用llm处理请求 event.set_result( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 076f7f12ac..bb3a95c5bf 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -62,5 +62,5 @@ async def process( if ( event.get_result() and not event.is_stopped() ) or not event.get_result(): - async for _ in self.agent_sub_stage.process(event): + async for _ in self.agent_sub_stage.process(event): # 调用LLM处理请求 yield diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 0145456908..67f51fbc01 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -170,7 +170,7 @@ async def process( self, event: AstrMessageEvent, ) -> None | AsyncGenerator[None, None]: - result = event.get_result() + result = event.get_result() # 这里获取比如astrbot\core\pipeline\process_stage\method\agent_sub_stages\internal.py中的InternalAgentSubStage类在process()方法中使用event.set_result()设置的内容,其中包含run_agent()方法 if result is None: return if event.get_extra("_streaming_finished", False): @@ -184,7 +184,7 @@ async def process( f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}", ) - if result.result_content_type == ResultContentType.STREAMING_RESULT: + if result.result_content_type == ResultContentType.STREAMING_RESULT: # 流式结果 if result.async_stream is None: logger.warning("async_stream 为空,跳过发送。") return @@ -197,7 +197,7 @@ async def process( == "realtime_segmenting" ) logger.info(f"应用流式输出({event.get_platform_id()})") - await event.send_streaming(result.async_stream, realtime_segmenting) + await event.send_streaming(result.async_stream, realtime_segmenting) # TODO 这里调用 run_agent()方法,比如astrbot\core\pipeline\process_stage\method\agent_sub_stages\internal.py中的InternalAgentSubStage类在process()方法中使用event.set_result()设置的内容,其中包含run_agent()方法 return if len(result.chain) > 0: # 检查路径映射 diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index ddc2a6cb83..1fb0322e18 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -1,3 +1,4 @@ +# 导入必要的类型和模块 from collections.abc import AsyncGenerator, Callable from astrbot import logger @@ -14,225 +15,347 @@ from ..context import PipelineContext from ..stage import Stage, register_stage +# 不同平台的唯一会话ID构建器映射表 +# 用于在多平台环境下为每个用户-群组组合创建唯一的会话标识 +# 键:平台名称,值:lambda函数,接收事件对象,返回唯一会话ID字符串或None UNIQUE_SESSION_ID_BUILDERS: dict[str, Callable[[AstrMessageEvent], str | None]] = { - "aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}", - "slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}", - "dingtalk": lambda e: e.get_sender_id(), - "qq_official": lambda e: e.get_sender_id(), - "qq_official_webhook": lambda e: e.get_sender_id(), - "lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}", - "misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}", - "matrix": lambda e: f"{e.get_sender_id()}_{e.get_group_id() or e.get_session_id()}", + "aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}", # QQ平台:发送者ID_群组ID + "slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}", # Slack平台:发送者ID_群组ID + "dingtalk": lambda e: e.get_sender_id(), # 钉钉平台:仅使用发送者ID + "qq_official": lambda e: e.get_sender_id(), # QQ官方平台:仅使用发送者ID + "qq_official_webhook": lambda e: e.get_sender_id(), # QQ官方Webhook:仅使用发送者ID + "lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}", # 飞书平台:发送者ID%群组ID + "misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}", # Misskey平台:会话ID_发送者ID + "matrix": lambda e: f"{e.get_sender_id()}_{e.get_group_id() or e.get_session_id()}", # Matrix平台:发送者ID_群组ID或会话ID } def build_unique_session_id(event: AstrMessageEvent) -> str | None: + """ + 根据事件所属平台构建唯一的会话ID + + 功能:从不同平台的消息事件中提取信息,构建唯一的会话标识符 + 用于实现跨平台的统一会话管理 + + Args: + event: Astrbot消息事件对象 + + Returns: + str | None: 成功构建的会话ID字符串,如果平台不支持则返回None + """ + # 获取当前消息所属平台名称 platform = event.get_platform_name() + # 根据平台名称查找对应的构建器函数 builder = UNIQUE_SESSION_ID_BUILDERS.get(platform) + # 如果找到构建器就调用它,否则返回None return builder(event) if builder else None -@register_stage +@register_stage # 注册为管道阶段,使其能被管道自动调用 class WakingCheckStage(Stage): - """检查是否需要唤醒。唤醒机器人有如下几点条件: + """ + 消息管道唤醒检查阶段 + + 功能:判断机器人是否应该被唤醒并响应消息,是消息处理管道的早期阶段 + 检查是否需要唤醒。唤醒机器人有如下几点条件: - 1. 机器人被 @ 了 - 2. 机器人的消息被提到了 + 1. 机器人被 @ 了(群聊中被@) + 2. 机器人的消息被提到了(使用回复功能) 3. 以 wake_prefix 前缀开头,并且消息没有以 At 消息段开头 - 4. 插件(Star)的 handler filter 通过 + 4. 插件(Star)的 handler filter 通过(插件自定义的过滤条件) 5. 私聊情况下,位于 admins_id 列表中的管理员的消息(在白名单阶段中) + + 该阶段负责:唯一会话设置、机器人自身消息过滤、唤醒判断、权限检查、插件处理器激活 """ async def initialize(self, ctx: PipelineContext) -> None: - """初始化唤醒检查阶段 - + """ + 初始化唤醒检查阶段,从配置中加载各种设置参数 + + 功能:在管道启动时进行一次性的初始化配置加载 + 从上下文配置中读取各种与唤醒相关的设置项 + Args: - ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器 - + ctx (PipelineContext): 消息管道上下文对象, 包括全局配置和插件管理器 """ + # 保存上下文对象的引用 self.ctx = ctx + # 获取无权限回复设置:当用户无权限时是否发送提示消息 self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get( "no_permission_reply", True, ) - # 私聊是否需要 wake_prefix 才能唤醒机器人 + # 获取私聊消息是否需要唤醒前缀的设置 + # 如果为True,则私聊也需要使用唤醒前缀才能触发机器人 self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[ "platform_settings" ].get("friend_message_needs_wake_prefix", False) - # 是否忽略机器人自己发送的消息 + # 获取是否忽略机器人自己发送的消息的设置 + # 防止机器人响应自己发出的消息造成死循环 self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get( "ignore_bot_self_message", False, ) + # 获取是否忽略@全体成员消息的设置 self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get( "ignore_at_all", False, ) + # 获取是否禁用内置命令的设置 self.disable_builtin_commands = self.ctx.astrbot_config.get( "disable_builtin_commands", False ) + # 获取平台设置的完整配置 platform_settings = self.ctx.astrbot_config.get("platform_settings", {}) + # 获取是否启用唯一会话的设置 self.unique_session = platform_settings.get("unique_session", False) async def process( self, event: AstrMessageEvent, ) -> None | AsyncGenerator[None, None]: - # apply unique session + """ + 处理消息事件的核心方法:执行唤醒检查流程 + + 功能:实现完整的唤醒检查逻辑,包括: + 1. 设置唯一会话ID + 2. 过滤机器人自身消息 + 3. 设置发送者身份(管理员/普通用户) + 4. 检查是否需要唤醒(@、回复、前缀、私聊等) + 5. 检查插件处理器过滤条件 + 6. 管理激活的处理器列表 + + Args: + event: Astrbot消息事件对象 + + Returns: + None | AsyncGenerator[None, None]: 如果消息被停止则返回None + """ + # ===== 第一步:唯一会话处理 ===== + # 应用唯一会话设置:如果启用了唯一会话且是群组消息 if self.unique_session and event.message_obj.type == MessageType.GROUP_MESSAGE: + # 为当前事件构建唯一会话ID sid = build_unique_session_id(event) if sid: + # 设置事件的会话ID,实现会话隔离 event.session_id = sid - # ignore bot self message + # ===== 第二步:过滤机器人自身消息 ===== + # 检查是否需要忽略机器人自己发送的消息 if ( - self.ignore_bot_self_message - and event.get_self_id() == event.get_sender_id() + self.ignore_bot_self_message # 配置了忽略自身消息 + and event.get_self_id() == event.get_sender_id() # 发送者ID等于机器人自身ID ): + # 停止事件处理,防止机器人响应自己的消息 event.stop_event() return - # 设置 sender 身份 + # ===== 第三步:设置发送者身份 ===== + # 去除消息字符串首尾空白字符 event.message_str = event.message_str.strip() + # 检查发送者是否在管理员列表中 for admin_id in self.ctx.astrbot_config["admins_id"]: + # 将发送者ID与管理员ID进行字符串比较 if str(event.get_sender_id()) == admin_id: + # 设置为管理员角色 event.role = "admin" break - # 检查 wake + # ===== 第四步:检查唤醒条件 ===== + # 获取唤醒前缀列表,例如:['/', '!', '#' 等] wake_prefixes = self.ctx.astrbot_config["wake_prefix"] + # 获取消息中的所有消息段(可以包含文本、@、图片等) messages = event.get_messages() + # 唤醒标志位,初始为False is_wake = False + # 遍历所有唤醒前缀进行匹配 for wake_prefix in wake_prefixes: + # 检查消息是否以某个唤醒前缀开头 if event.message_str.startswith(wake_prefix): + # 特殊处理:群聊中@某人但不是@机器人或@全体成员的情况 if ( - not event.is_private_chat() - and isinstance(messages[0], At) - and str(messages[0].qq) != str(event.get_self_id()) - and str(messages[0].qq) != "all" + not event.is_private_chat() # 不是私聊 + and isinstance(messages[0], At) # 第一个消息段是@消息 + and str(messages[0].qq) != str(event.get_self_id()) # @的不是机器人自己 + and str(messages[0].qq) != "all" # @的不是全体成员 ): # 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒 + # 跳出循环,不设置唤醒 break + # 标记为已唤醒 is_wake = True + # 设置事件的唤醒相关标志 event.is_at_or_wake_command = True event.is_wake = True - event.message_str = event.message_str[len(wake_prefix) :].strip() + # 移除消息中的唤醒前缀,提取实际命令内容 + event.message_str = event.message_str[len(wake_prefix):].strip() break + + # 如果没有通过前缀唤醒,检查其他唤醒方式 if not is_wake: - # 检查是否有at消息 / at全体成员消息 / 引用了bot的消息 + # 检查消息段中是否有@消息、@全体成员消息或引用了机器人的消息 for message in messages: + # 条件1:被@了且@的是机器人 if ( - ( - isinstance(message, At) - and (str(message.qq) == str(event.get_self_id())) - ) - or (isinstance(message, AtAll) and not self.ignore_at_all) - or ( - isinstance(message, Reply) - and str(message.sender_id) == str(event.get_self_id()) - ) + isinstance(message, At) + and (str(message.qq) == str(event.get_self_id())) ): is_wake = True event.is_wake = True - wake_prefix = "" + wake_prefix = "" # 清空前缀 event.is_at_or_wake_command = True break - # 检查是否是私聊 + # 条件2:有人@全体成员且没有配置忽略 + elif (isinstance(message, AtAll) and not self.ignore_at_all): + is_wake = True + event.is_wake = True + wake_prefix = "" # 清空前缀 + event.is_at_or_wake_command = True + break + # 条件3:有人回复了机器人的消息 + elif ( + isinstance(message, Reply) + and str(message.sender_id) == str(event.get_self_id()) + ): + is_wake = True + event.is_wake = True + wake_prefix = "" # 清空前缀 + event.is_at_or_wake_command = True + break + + # 条件4:如果是私聊且配置中私聊不需要唤醒前缀 if event.is_private_chat() and not self.friend_message_needs_wake_prefix: is_wake = True event.is_wake = True event.is_at_or_wake_command = True - wake_prefix = "" + wake_prefix = "" # 清空前缀 - # 检查插件的 handler filter + # ===== 第五步:检查插件的处理器过滤条件 ===== + # 存储被激活的处理器列表 activated_handlers = [] - handlers_parsed_params = {} # 注册了指令的 handler + # 存储已经解析了参数的处理器(注册了指令的 handler) + handlers_parsed_params = {} - # 将 plugins_name 设置到 event 中 - enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"]) + # 设置事件的插件名称列表 + # 从配置中获取已启用的插件名称 + enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"]) # ["*"]获取插件设置 if enabled_plugins_name == ["*"]: - # 如果是 *,则表示所有插件都启用 + # 如果是通配符 "*",则表示所有插件都启用 event.plugins_name = None else: + # 否则指定具体的启用插件列表 event.plugins_name = enabled_plugins_name + # 记录调试信息 logger.debug(f"enabled_plugins_name: {enabled_plugins_name}") + # 遍历所有注册的适配器消息事件处理器 for handler in star_handlers_registry.get_handlers_by_event_type( - EventType.AdapterMessageEvent, - plugins_name=event.plugins_name, + EventType.AdapterMessageEvent, # 指定事件类型 + plugins_name=event.plugins_name, # 按插件名称过滤 ): + # 检查是否需要禁用内置命令处理器 if ( - self.disable_builtin_commands + self.disable_builtin_commands # 配置了禁用内置命令 and handler.handler_module_path - == "astrbot.builtin_stars.builtin_commands.main" + == "astrbot.builtin_stars.builtin_commands.main" # 是内置命令模块 ): + # 跳过该处理器 continue - # filter 需满足 AND 逻辑关系 + # 过滤条件需要满足 AND 逻辑关系(所有条件都必须通过) passed = True - permission_not_pass = False - permission_filter_raise_error = False + permission_not_pass = False # 权限未通过的标志 + permission_filter_raise_error = False # 权限过滤器是否抛出错误的标志 + + # 如果处理器没有事件过滤器,则跳过 if len(handler.event_filters) == 0: continue + # 遍历处理器的所有事件过滤器 for filter in handler.event_filters: try: + # 检查是否是权限类型过滤器 if isinstance(filter, PermissionTypeFilter): + # 调用权限过滤器的filter方法 if not filter.filter(event, self.ctx.astrbot_config): + # 权限未通过 permission_not_pass = True + # 记录权限过滤器是否需要抛出错误 permission_filter_raise_error = filter.raise_error + # 其他类型的过滤器 elif not filter.filter(event, self.ctx.astrbot_config): + # 过滤未通过 passed = False - break + break # 跳出循环,不再检查其他过滤器 except Exception as e: + # 过滤器执行异常,发送错误消息给用户 await event.send( MessageEventResult().message( f"插件 {star_map[handler.handler_module_path].name}: {e}", ), ) + # 停止事件处理 event.stop_event() passed = False break + + # 如果所有过滤条件都通过了 if passed: + # 如果存在权限未通过的情况 if permission_not_pass: + # 如果权限过滤器不需要抛出错误 if not permission_filter_raise_error: - # 跳过 + # 跳过该处理器,继续处理下一个 continue + # 如果配置了无权限回复 if self.no_permission_reply: + # 发送无权限提示消息 await event.send( MessageChain().message( f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。", ), ) + # 记录权限不足的日志 logger.info( f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。", ) + # 停止事件处理 event.stop_event() return + # 标记为已唤醒 is_wake = True event.is_wake = True + # 检查是否是命令组处理器(包含CommandGroupFilter过滤器) is_group_cmd_handler = any( isinstance(f, CommandGroupFilter) for f in handler.event_filters ) + # 如果不是命令组处理器,添加到激活列表 if not is_group_cmd_handler: activated_handlers.append(handler) + # 如果有解析的参数,保存到参数字典中 if "parsed_params" in event.get_extra(default={}): handlers_parsed_params[handler.handler_full_name] = ( event.get_extra("parsed_params") ) + # 清除事件的解析参数,避免影响下一个处理器的判断 event._extras.pop("parsed_params", None) - # 根据会话配置过滤插件处理器 + # ===== 第六步:根据会话配置过滤插件处理器 ===== + # 根据会话设置进一步过滤已激活的处理器 activated_handlers = await SessionPluginManager.filter_handlers_by_session( event, activated_handlers, ) + # ===== 第七步:保存结果到事件对象 ===== + # 将激活的处理器列表和解析的参数保存到事件的额外数据中 + # 供后续管道阶段使用 event.set_extra("activated_handlers", activated_handlers) event.set_extra("handlers_parsed_params", handlers_parsed_params) + # 如果最终没有唤醒,停止事件处理 if not is_wake: - event.stop_event() + event.stop_event() \ No newline at end of file diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 3a74c3b91a..43170918e5 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -146,7 +146,7 @@ async def send_by_session( def commit_event(self, event: AstrMessageEvent) -> None: """提交一个事件到事件队列。""" - self._event_queue.put_nowait(event) + self._event_queue.put_nowait(event) # 对总线的_envent_queue 进行put,队列监控到信息之后,就会进入处理流程 def create_event(self, message: AstrBotMessage) -> AstrMessageEvent: """Creates a message event for this platform. diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 1bfa254079..078c85522e 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -201,7 +201,7 @@ async def get_reply_parts( return components, text_parts async def convert_message(self, data: tuple) -> AstrBotMessage: - username, cid, payload = data + username, cid, payload = data # username = 'astrbot', cid = '8be47862-3fdd-48c9-aada-1b11e03ced50', payload = {'message': [{'type': 'plain', 'text': '你好'}], 'selected_provider': 'lm_studio/qwen3.5-2b', 'selected_model': 'qwen3.5-2b', 'enable_streaming': True, 'message_id': 'de84fbbb-203e-4414-8a1a-c189ccae6363', 'llm_checkpoint_id': 'f66a6a1a-2763-404e-84cd-3cfc8a424cb9', 'thread_selected_text': None} abm = AstrBotMessage() abm.self_id = "webchat" @@ -209,13 +209,13 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: abm.type = MessageType.FRIEND_MESSAGE - abm.session_id = f"webchat!{username}!{cid}" + abm.session_id = f"webchat!{username}!{cid}" # 设置session_id abm.message_id = payload.get("message_id") # 处理消息段列表 message_parts = payload.get("message", []) - abm.message, message_str_parts = await self._parse_message_parts(message_parts) + abm.message, message_str_parts = await self._parse_message_parts(message_parts) # message_str_parts = ['你好'] logger.debug(f"WebChatAdapter: {abm.message}") @@ -273,7 +273,7 @@ def create_event(self, message: AstrBotMessage) -> WebChatMessageEvent: return message_event async def handle_msg(self, message: AstrBotMessage) -> None: - self.commit_event(self.create_event(message)) + self.commit_event(self.create_event(message)) # 这里将信息提交,放到event_bus中的event_queue中,就进入了处理流程 async def terminate(self) -> None: self._shutdown_event.set() diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index bc1e1a6bcd..5b685e7b1c 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -152,7 +152,7 @@ async def send_streaming(self, generator, use_fallback: bool = False) -> None: message_id = self.message_obj.message_id request_id = str(message_id) conversation_id = _extract_conversation_id(self.session_id) - web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( + web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( # 这里获取 back_queue,跟 astrbot\dashboard\services\chat_service.py 的 ChatService 类的 build_chat_stream方法的 back_queue 一致,从这里推送llm的result,然后被那边接收,进行消息的回传 request_id, conversation_id, ) diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index 9413ccb3c7..7ad1cd7533 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -143,7 +143,7 @@ def set_listener( callback: Callable[[tuple], Awaitable[None]], ): # 存储传入的回调函数,用于后续处理队列中的消息 - self._listener_callback = callback + self._listener_callback = callback # 将 callback 函数存储在 _listener_callback 属性中 # 遍历当前所有已有队列的对话ID列表 for conversation_id in list(self.queues.keys()): # 为每个对话启动监听任务(如果尚未启动) @@ -236,14 +236,14 @@ async def _listen_to_queue( # 跳出循环,结束协程 break # 如果 get_task 先完成,则获取队列中的数据 - data = get_task.result() + data = get_task.result() # ('astrbot', '8be47862-3fdd-48c9-aada-1b11e03ced50', {'message': [{'type': 'plain', 'text': '你好'}], 'selected_provider': 'lm_studio/qwen3.5-2b', 'selected_model': 'qwen3.5-2b', 'enable_streaming': True, 'message_id': 'de84fbbb-203e-4414-8a1a-c189ccae6363', 'llm_checkpoint_id': 'f66a6a1a-2763-404e-84cd-3cfc8a424cb9', 'thread_selected_text': None}) # 在调用回调前,再次检查回调是否还存在(可能在等待期间被清除) if self._listener_callback is None: # 如果回调不存在,则跳过处理,继续下一次循环 continue try: # TODO 调用监听器回调函数处理获取到的数据,这是一个异步调用 - await self._listener_callback(data) + await self._listener_callback(data) # 调用每个设置的 callback() 方法 except Exception as e: # 捕获并记录回调函数中发生的任何异常,避免监听任务崩溃 logger.error( diff --git a/astrbot/dashboard/api/open_api.py b/astrbot/dashboard/api/open_api.py index 6a4b035307..b64846c38e 100644 --- a/astrbot/dashboard/api/open_api.py +++ b/astrbot/dashboard/api/open_api.py @@ -69,7 +69,7 @@ async def _build_streaming_chat_response( username: str, post_data: dict[str, Any], ) -> StreamingResponse | JSONResponse: - try: + try: # TODO: 处理流式响应 stream = await chat_service.build_chat_stream(username, post_data) except ChatServiceError as exc: return _open_api_error(str(exc)) @@ -91,7 +91,7 @@ async def _open_api_chat_response( open_api_service: OpenApiService, chat_service: ChatService, ) -> StreamingResponse | JSONResponse: - if auth.via != "api_key": + if auth.via != "api_key": # TODO 如果不是 API Key 认证,直接返回流式响应 return await _build_streaming_chat_response( chat_service, auth.username, @@ -182,7 +182,8 @@ async def chat( auth: AuthContext = Depends(require_chat_scope), service: OpenApiService = Depends(get_service), chat_service: ChatService = Depends(get_chat_service), -): +): + """这个接口接收聊天请求""" return await _open_api_chat_response( _model_dict(payload), auth, diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 34a8dd5243..b900f47b43 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -1,211 +1,325 @@ +# 导入 asyncio 模块,用于异步编程 import asyncio +# 导入 ipaddress 模块,用于处理 IP 地址的解析和验证 import ipaddress +# 导入 os 模块,用于操作系统相关功能(环境变量、文件路径等) import os +# 导入 socket 模块,用于网络通信和端口检测 import socket +# 导入 time 模块,用于时间相关的操作 import time +# 从 pathlib 导入 Path 类,用于面向对象的文件路径操作 from pathlib import Path +# 从 typing 导入类型提示相关工具:Any 任意类型、Protocol 协议类、cast 类型转换 from typing import Any, Protocol, cast +# 导入 jwt 库,用于 JSON Web Token 的编码和解码 import jwt +# 导入 psutil 库,用于获取系统和进程信息 import psutil +# 从 fastapi 导入 Request 类,表示 HTTP 请求对象 from fastapi import Request +# 从 fastapi.responses 导入 JSONResponse,用于返回 JSON 格式的响应 from fastapi.responses import JSONResponse +# 从 hypercorn.asyncio 导入 serve 函数,用于启动 ASGI 服务器 from hypercorn.asyncio import serve +# 从 hypercorn.config 导入 Config 并重命名为 HyperConfig,用于配置服务器 from hypercorn.config import Config as HyperConfig +# 从 hypercorn.logging 导入 AccessLogAtoms,表示访问日志的原子数据 from hypercorn.logging import AccessLogAtoms +# 从 hypercorn.logging 导入 Logger 并重命名为 HypercornLogger,表示日志记录器 from hypercorn.logging import Logger as HypercornLogger +# 从 astrbot.core 导入 logger 对象,用于记录日志 from astrbot.core import logger +# 从配置默认模块导入 VERSION 常量,表示当前核心版本号 from astrbot.core.config.default import VERSION +# 导入核心生命周期管理器类 from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +# 导入数据库抽象基类 from astrbot.core.db import BaseDatabase +# 导入获取 AstrBot 数据目录路径的工具函数 from astrbot.core.utils.astrbot_path import get_astrbot_data_path +# 从 IO 工具模块导入多个与仪表盘静态文件相关的工具函数 from astrbot.core.utils.io import ( - get_bundled_dashboard_dist_path, - get_dashboard_dist_version, - get_local_ip_addresses, - is_dashboard_dist_compatible, - should_use_bundled_dashboard_dist, + get_bundled_dashboard_dist_path, # 获取打包的内置仪表盘静态文件路径 + get_dashboard_dist_version, # 获取仪表盘静态文件的版本号 + get_local_ip_addresses, # 获取本机的所有 IP 地址列表 + is_dashboard_dist_compatible, # 检查仪表盘静态文件是否与当前核心版本兼容 + should_use_bundled_dashboard_dist, # 判断是否应该使用内置的仪表盘静态文件 ) +# 从仪表盘的 ASGI 运行时模块导入相关类 from astrbot.dashboard.asgi_runtime import ( - DashboardRequestState, - FastAPIAppAdapter, + DashboardRequestState, # 仪表盘请求状态对象,用于在请求中传递状态信息 + FastAPIAppAdapter, # FastAPI 应用适配器,封装 ASGI 应用 ) +# 从仪表盘响应模块导入 error 函数,用于构造错误响应 from astrbot.dashboard.responses import error +# 从当前包的 api.app 模块导入创建仪表盘 ASGI 应用的工厂函数 from .api.app import create_dashboard_asgi_app +# 导入插件页面认证类,用于验证插件页面的访问权限 from .plugin_page_auth import PluginPageAuth +# 从认证服务模块导入 JWT Cookie 名称常量 from .services.auth_service import DASHBOARD_JWT_COOKIE_NAME +# 定义需要进行速率限制的 API 端点集合(使用 frozenset 使其不可变) +# 这些端点涉及登录和更新等敏感操作,需要防止暴力破解 _RATE_LIMITED_ENDPOINTS: frozenset = frozenset( { - "/api/config/astrbot/update", - "/api/auth/totp/setup", - "/api/v1/auth/totp/setup", - "/api/auth/login", - "/api/v1/auth/login", + "/api/config/astrbot/update", # 更新配置端点 + "/api/auth/totp/setup", # TOTP 设置端点 + "/api/v1/auth/totp/setup", # v1 版本的 TOTP 设置端点 + "/api/auth/login", # 登录端点 + "/api/v1/auth/login", # v1 版本的登录端点 } ) +# 基于令牌桶算法的认证速率限制器类 class _AuthRateLimiter: + # 初始化速率限制器,设置桶容量和令牌补充速率 def __init__(self, capacity: int, refill_rate: float): + # 桶的最大容量(允许的最大突发请求数) self.capacity = capacity + # 令牌补充速率(每秒补充的令牌数) self.refill_rate = refill_rate + # 当前桶中的令牌数量,初始为满桶 self.tokens = float(capacity) + # 上次补充令牌的时间戳 self.last_refill = time.monotonic() + # 上次访问时间,用于过期淘汰判断 self.last_accessed = time.monotonic() + # 异步锁,确保并发安全 self.lock = asyncio.Lock() + # 尝试获取一个令牌,返回是否成功 async def acquire(self) -> bool: + # 使用异步锁保护临界区 async with self.lock: + # 获取当前时间 now = time.monotonic() + # 计算距离上次补充经过的时间 elapsed = now - self.last_refill + # 根据经过的时间和补充速率计算新令牌数,不超过桶容量 self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate) + # 更新上次补充时间 self.last_refill = now + # 更新最后访问时间 self.last_accessed = now + # 如果桶中至少有一个令牌 if self.tokens >= 1: + # 消耗一个令牌 self.tokens -= 1 return True + # 令牌不足,请求被限流 return False +# 基于 IP 的令牌桶速率限制器注册表类 class _RateLimiterRegistry: """Per-IP token-bucket rate limiter registry. Idle entries expire after 1 hour.""" + # 每个 IP 限流器实例的过期时间(秒),闲置超过此时间将被清除 + _ENTRY_TTL: float = 3600.0 # 1 小时 + # 过期清理的时间间隔(秒),每半小时检查一次 + _INTERVAL: float = 1800.0 # 30 分钟 - _ENTRY_TTL: float = 3600.0 - _INTERVAL: float = 1800.0 - + # 初始化注册表 def __init__(self) -> None: + # 存储 IP 地址到速率限制器实例的映射字典 self._limiters: dict[str, _AuthRateLimiter] = {} + # 上次执行过期清理的时间戳 self._last_eviction = time.monotonic() + # 获取或创建指定 IP 的速率限制器 def get_or_create( self, key: str, capacity: int, refill_rate: float ) -> _AuthRateLimiter: + # 先执行过期清理(根据时间间隔判断是否需要清理) self._evict_expired() + # 尝试获取已存在的限制器 limiter = self._limiters.get(key) + # 如果不存在 if limiter is None: + # 创建新的速率限制器实例 limiter = _AuthRateLimiter(capacity=capacity, refill_rate=refill_rate) + # 存入注册表 self._limiters[key] = limiter + # 返回限制器 return limiter + # 清除过期的速率限制器条目 def _evict_expired(self) -> None: + # 获取当前时间 now = time.monotonic() + # 如果距离上次清理未达到间隔时间,则跳过 if now - self._last_eviction < self._INTERVAL: return + # 更新上次清理时间 self._last_eviction = now + # 计算过期时间点(当前时间减去 TTL) cutoff = now - self._ENTRY_TTL + # 找出所有最后访问时间早于过期时间点的 IP 列表 stale = [k for k, v in self._limiters.items() if v.last_accessed < cutoff] + # 删除所有过期的条目 for k in stale: del self._limiters[k] + # 清空所有速率限制器 def clear(self) -> None: self._limiters.clear() + # 返回当前注册的限制器数量 def __len__(self) -> int: return len(self._limiters) + # 检查指定 IP 是否有对应的限制器 def __contains__(self, key: str) -> bool: return key in self._limiters +# 定义带有 port 属性的协议类,用于类型提示 class _AddrWithPort(Protocol): + # 端口号属性 port: int +# 全局变量,存储当前的 FastAPI 应用适配器实例,初始为 None APP: FastAPIAppAdapter | None = None +# 解析环境变量中的布尔值 def _parse_env_bool(value: str | None, default: bool) -> bool: + # 如果值为 None,返回默认值 if value is None: return default + # 去除首尾空格并转为小写,判断是否在真值集合中 return value.strip().lower() in {"1", "true", "yes", "on"} +# 支持代理感知的 Hypercorn 日志记录器类 class _ProxyAwareHypercornLogger(HypercornLogger): + # 静态方法:从请求作用域中提取真实的客户端 IP 地址 @staticmethod def _get_request_log_host(request_scope) -> str | None: + # 初始化代理头变量 forwarded_for = None real_ip = None + # 遍历请求头 for raw_name, raw_value in request_scope.get("headers", []): + # 解码头名称并转为小写 header_name = raw_name.decode("latin1").lower() + # 如果找到 X-Forwarded-For 头 if header_name == "x-forwarded-for": + # 解码并存储其值 forwarded_for = raw_value.decode("latin1") + # 如果找到 X-Real-IP 头 elif header_name == "x-real-ip": + # 解码并存储其值 real_ip = raw_value.decode("latin1") + # 如果两个头都找到了,提前跳出循环 if forwarded_for is not None and real_ip is not None: break + # 处理 X-Forwarded-For 头,提取第一个 IP forwarded_for = str(forwarded_for or "").strip() if forwarded_for: + # 用逗号分割,取第一个 IP 地址 first_ip = forwarded_for.split(",", 1)[0].strip() + # 如果 IP 有效且不是 "unknown" if first_ip and first_ip.lower() != "unknown": try: + # 尝试解析为 IP 地址并返回 return str(ipaddress.ip_address(first_ip)) except ValueError: + # 解析失败则继续尝试其他方式 pass + # 处理 X-Real-IP 头 real_ip = str(real_ip or "").strip() if real_ip and real_ip.lower() != "unknown": try: + # 尝试解析并返回 return str(ipaddress.ip_address(real_ip)) except ValueError: pass + # 如果代理头都不可用,尝试从连接信息中获取 client = request_scope.get("client") if not client: return None + # 提取客户端 IP 地址 host = str(client[0]).strip() if host: return host return None + # 重写 atoms 方法,使用代理感知的 IP 地址构建访问日志 def atoms(self, request, response, request_time): + # 调用父类方法获取基本日志原子数据 atoms = AccessLogAtoms(request, response, request_time) + # 获取真实的客户端 IP client_host = self._get_request_log_host(request) + # 如果获取到了真实 IP if client_host: + # 替换日志中的主机地址 atoms["h"] = client_host return atoms +# AstrBot 仪表盘主类,管理 WebUI 的配置、启动和运行 class AstrBotDashboard: + # 初始化仪表盘实例 def __init__( self, - core_lifecycle: AstrBotCoreLifecycle, - db: BaseDatabase, - shutdown_event: asyncio.Event, - webui_dir: str | None = None, # 项目目录的 AstrBot\\data\\dist 路径 + core_lifecycle: AstrBotCoreLifecycle, # 核心生命周期管理器 + db: BaseDatabase, # 数据库实例 + shutdown_event: asyncio.Event, # 关闭事件,用于通知服务器停止 + webui_dir: str | None = None, # WebUI 静态文件目录路径(可选) ) -> None: + # 保存核心生命周期引用 self.core_lifecycle = core_lifecycle + # 保存配置对象引用 self.config = core_lifecycle.astrbot_config + # 保存数据库引用 self.db = db - # Path priority: - # 1. Explicit webui_dir argument - # 2. data/dist/ when it matches the core version - # 3. astrbot/dashboard/dist/ when it matches the core version + # 确定静态文件路径的优先级顺序: + # 1. 明确指定的 webui_dir 参数 + # 2. data/dist/ 目录(如果与核心版本匹配) + # 3. 内置的 astrbot/dashboard/dist/ 目录(如果与核心版本匹配) if webui_dir and os.path.exists(webui_dir): + # 如果指定了路径且存在,使用该路径 self.data_path = os.path.abspath(webui_dir) else: + # 获取用户数据目录下的 dist 路径 user_dist = os.path.join(get_astrbot_data_path(), "dist") + # 获取内置打包的仪表盘静态文件路径 bundled_dist = get_bundled_dashboard_dist_path() + # 获取用户 dist 目录的版本 user_version = get_dashboard_dist_version(user_dist) + # 如果用户 dist 存在且与当前核心版本兼容 if os.path.exists(user_dist) and is_dashboard_dist_compatible( user_dist, VERSION, ): + # 使用用户 dist 目录 self.data_path = os.path.abspath(user_dist) + # 如果建议使用内置版本,或内置版本兼容 elif should_use_bundled_dashboard_dist( user_dist, VERSION, ) or is_dashboard_dist_compatible(bundled_dist, VERSION): + # 使用内置打包的静态文件路径 self.data_path = str(bundled_dist) + # 记录日志 logger.info("Using bundled dashboard dist: %s", self.data_path) + # 如果用户 dist 存在且包含 index.html elif ( os.path.exists(user_dist) and (Path(user_dist) / "index.html").is_file() ): + # 警告:版本不匹配,但回退使用用户 dist logger.warning( "Using existing data/dist as a fallback even though WebUI version mismatches core: %s, expected v%s. " "Some dashboard features may not work until the matching WebUI is available.", @@ -213,98 +327,135 @@ def __init__( VERSION, ) self.data_path = os.path.abspath(user_dist) + # 如果用户 dist 存在但不完整 elif os.path.exists(user_dist): + # 警告:文件不完整,忽略 logger.warning( "Ignoring data/dist because WebUI files are incomplete for core v%s.", VERSION, ) self.data_path = None else: - # Fall back to expected user path (will fail gracefully later) + # 最后回退到用户路径(后续会优雅失败) self.data_path = os.path.abspath(user_dist) + # 创建速率限制器注册表 self._rate_limiter_registry = _RateLimiterRegistry() + # 初始化 JWT 密钥 self._init_jwt_secret() - self.asgi_app = create_dashboard_asgi_app( # 启动dashboard页面 + # 创建 ASGI 应用(包括所有 API 路由和服务) + self.asgi_app = create_dashboard_asgi_app( # 启动dashboard页面 core_lifecycle=core_lifecycle, db=db, jwt_secret=self._jwt_secret, static_folder=self.data_path, ) + # 创建 FastAPI 应用适配器,封装 ASGI 应用 self.app = FastAPIAppAdapter(self.asgi_app, static_folder=self.data_path) + # 将适配器实例挂载到 ASGI 应用的状态中,方便路由访问 self.asgi_app.state.dashboard_app_adapter = self.app + # 建立反向引用,让适配器可以访问仪表盘服务器实例 self.app._dashboard_server = self + # 设置全局 APP 变量 global APP APP = self.app + # 设置最大上传文件体大小限制为 128 MB self.app.config["MAX_CONTENT_LENGTH"] = ( 128 * 1024 * 1024 ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB + # 注册 HTTP 中间件,用于仪表盘认证 @self.asgi_app.middleware("http") async def dashboard_auth_middleware(request_, call_next): + # 初始化请求状态对象 request_.state.dashboard_g = DashboardRequestState() + # 调用认证中间件 auth_response = await self.auth_middleware(request_) + # 如果认证失败(返回了错误响应),直接返回 if auth_response is not None: return auth_response + # 认证成功,继续处理请求 return await call_next(request_) + # 保存关闭事件引用 self.shutdown_event = shutdown_event + # 认证中间件:验证请求的 JWT 令牌和权限 async def auth_middleware(self, current_request: Request): + # 获取请求路径 path = current_request.url.path + # 如果路径不以 /api 开头,不需要认证 if not path.startswith("/api"): return None + # 应用速率限制检查 rate_limit_response = await self._apply_auth_rate_limit(current_request, path) + # 如果触发了速率限制,返回错误响应 if rate_limit_response is not None: return rate_limit_response + # V1 版本的 API 使用不同的认证机制(OpenAPI Bearer Token),此处放行 if path.startswith("/api/v1"): return None + # 定义不需要认证的精确匹配端点集合 allowed_exact_endpoints = { - "/api/auth/login", - "/api/auth/logout", - "/api/auth/setup-status", - "/api/auth/setup", - "/api/stat/versions", + "/api/auth/login", # 登录接口 + "/api/auth/logout", # 登出接口 + "/api/auth/setup-status", # 设置状态查询 + "/api/auth/setup", # 初始设置 + "/api/stat/versions", # 版本信息查询 } + # 定义不需要认证的路径前缀列表 allowed_endpoint_prefixes = [ - "/api/file", - "/api/v1/files/tokens", - "/api/platform/webhook", - "/api/stat/start-time", - "/api/backup/download", # 备份下载使用 URL 参数传递 token + "/api/file", # 文件相关接口 + "/api/v1/files/tokens", # 文件令牌接口 + "/api/platform/webhook", # 平台 Webhook 回调 + "/api/stat/start-time", # 启动时间查询 + "/api/backup/download", # 备份下载(使用 URL 参数传递 token) ] + # 如果路径在白名单中,跳过认证 if path in allowed_exact_endpoints or any( path.startswith(prefix) for prefix in allowed_endpoint_prefixes ): return None + # 检查是否为受保护的插件页面路径 is_plugin_page_path = PluginPageAuth.is_protected_path(path) + # 从请求中提取仪表盘 JWT(从 Header 或 Cookie 中) dashboard_token = self._extract_dashboard_jwt(current_request) + # 如果是插件页面路径,尝试从查询参数中提取资产令牌 asset_token = ( PluginPageAuth.extract_asset_token(current_request.query_params) if is_plugin_page_path else None ) + # 收集所有候选令牌 token_candidates = [] if dashboard_token: token_candidates.append(dashboard_token) if asset_token and asset_token != dashboard_token: token_candidates.append(asset_token) + # 如果没有提供任何令牌,返回 401 未授权 if not token_candidates: r = JSONResponse(error("未授权")) r.status_code = 401 return r + # 记录验证失败的错误信息 token_errors: list[str] = [] + # 遍历候选令牌进行验证 for token in token_candidates: + # 验证令牌有效性 payload, token_error = self._validate_dashboard_token(token, path) + # 如果验证成功 if payload is not None: + # 将用户名存入请求状态 current_request.state.dashboard_g.username = cast( str, payload["username"] ) return None + # 记录错误信息 token_errors.append(token_error) + # 根据错误类型返回不同的错误消息 error_message = ( "Token 过期" if token_errors and all(item == "Token 过期" for item in token_errors) @@ -314,10 +465,11 @@ async def auth_middleware(self, current_request: Request): r.status_code = 401 return r + # 验证仪表盘 JWT 令牌或插件页面资产令牌 def _validate_dashboard_token( self, - token: str, - path: str, + token: str, # JWT 令牌字符串 + path: str, # 当前请求路径,用于插件页面令牌的作用域检查 ) -> tuple[dict[str, Any] | None, str]: """Validate a dashboard JWT or scoped plugin page asset token. @@ -330,68 +482,93 @@ def _validate_dashboard_token( present only when the token is valid for the current request path. """ try: + # 尝试解码 JWT 令牌 payload = jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) except jwt.ExpiredSignatureError: + # 令牌过期 return None, "Token 过期" except jwt.InvalidTokenError: + # 令牌无效 return None, "Token 无效" + # 如果是资产令牌,检查作用域是否有效 if PluginPageAuth.is_asset_token(payload) and not PluginPageAuth.is_scope_valid( payload, path, ): return None, "Token 无效" + # 验证载荷中是否包含有效的用户名 username = payload.get("username") if not isinstance(username, str) or not username.strip(): return None, "Token 无效" + # 返回解码后的载荷和空错误消息(表示验证成功) return payload, "" + # 对敏感端点应用认证速率限制 async def _apply_auth_rate_limit( self, current_request: Request, path: str, ) -> JSONResponse | None: + # 如果不在测试模式且当前路径在速率限制端点集合中 if ( os.environ.get("ASTRBOT_TEST_MODE") != "true" and path in _RATE_LIMITED_ENDPOINTS ): + # 获取速率限制配置 rl_config = self.config.get("dashboard", {}).get("auth_rate_limit", {}) + # 检查是否启用速率限制(默认启用) rl_enabled = rl_config.get("enable", True) if rl_enabled: + # 获取平均请求间隔(秒),默认 1 秒 average_interval = float(rl_config.get("average_interval", 1.0)) + # 获取最大突发请求数,默认 3 max_burst = int(rl_config.get("max_burst", 3)) + # 参数验证和修正 if average_interval <= 0: average_interval = 1.0 if max_burst <= 0: max_burst = 3 + # 计算令牌补充速率(每秒补充的令牌数) refill_rate = 1.0 / average_interval + # 获取客户端 IP 地址 client_ip = self._get_request_client_ip(current_request) + # 获取或创建该 IP 的速率限制器 limiter = self._rate_limiter_registry.get_or_create( client_ip, capacity=max_burst, refill_rate=refill_rate ) + # 尝试获取令牌 if not await limiter.acquire(): + # 令牌不足,返回 429 状态码 r = JSONResponse( error("验证尝试过于频繁,系统可能正在遭受暴力破解") ) r.status_code = 429 return r + # 速率限制通过或不适用 return None + # 获取请求的真实客户端 IP 地址 def _get_request_client_ip(self, current_request) -> str: + # 如果配置了信任代理头 if bool(self.config.get("dashboard", {}).get("trust_proxy_headers", False)): + # 尝试从 X-Forwarded-For 头获取 forwarded_for = str( current_request.headers.get("X-Forwarded-For", "") ).strip() if forwarded_for: + # 取第一个 IP 地址 first_ip = forwarded_for.split(",", 1)[0].strip() + # 验证 IP 有效性 if first_ip and first_ip.lower() != "unknown": try: return str(ipaddress.ip_address(first_ip)) except ValueError: pass + # 尝试从 X-Real-IP 头获取 real_ip = str(current_request.headers.get("X-Real-IP", "")).strip() if real_ip and real_ip.lower() != "unknown": try: @@ -399,6 +576,7 @@ def _get_request_client_ip(self, current_request) -> str: except ValueError: pass + # 从连接信息中获取远程地址 remote_addr = ( str(current_request.client.host).strip() if current_request.client is not None @@ -410,190 +588,254 @@ def _get_request_client_ip(self, current_request) -> str: except ValueError: pass + # 无法获取有效 IP,返回 unknown return "unknown" + # 从请求中提取仪表盘 JWT 令牌(静态方法) @staticmethod def _extract_dashboard_jwt(current_request: Request) -> str | None: + # 尝试从 Authorization 头中提取 Bearer 令牌 auth_header = current_request.headers.get("Authorization", "").strip() if auth_header.startswith("Bearer "): + # 移除 "Bearer " 前缀并去除空白 token = auth_header.removeprefix("Bearer ").strip() if token: return token + # 尝试从 Cookie 中提取令牌 cookie_token = current_request.cookies.get( DASHBOARD_JWT_COOKIE_NAME, "", ).strip() if cookie_token: return cookie_token + # 没有找到令牌 return None + # 检测指定端口是否被占用 def check_port_in_use(self, port: int) -> bool: """跨平台检测端口是否被占用""" try: # 创建 IPv4 TCP Socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # 设置 SO_REUSEADDR 选项,允许重用地址 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # 设置超时时间 + # 设置连接超时时间为 2 秒 sock.settimeout(2) + # 尝试连接到本地指定端口 result = sock.connect_ex(("127.0.0.1", port)) + # 关闭 Socket sock.close() # result 为 0 表示端口被占用 return result == 0 except Exception as e: + # 出现异常时记录警告日志 logger.warning(f"检查端口 {port} 时发生错误: {e!s}") - # 如果出现异常,保守起见认为端口可能被占用 + # 保守起见认为端口可能被占用 return True + # 获取占用指定端口的进程详细信息 def get_process_using_port(self, port: int) -> str: """获取占用端口的进程详细信息""" try: + # 遍历所有网络连接 for conn in psutil.net_connections(kind="inet"): + # 如果连接的本地端口与指定端口匹配 if cast(_AddrWithPort, conn.laddr).port == port: try: + # 获取占用端口的进程对象 process = psutil.Process(conn.pid) - # 获取详细信息 + # 构造进程详细信息列表 proc_info = [ - f"进程名: {process.name()}", - f"PID: {process.pid}", - f"执行路径: {process.exe()}", - f"工作目录: {process.cwd()}", - f"启动命令: {' '.join(process.cmdline())}", + f"进程名: {process.name()}", # 进程名称 + f"PID: {process.pid}", # 进程 ID + f"执行路径: {process.exe()}", # 可执行文件路径 + f"工作目录: {process.cwd()}", # 工作目录 + f"启动命令: {' '.join(process.cmdline())}", # 完整启动命令 ] + # 将信息用换行和空格连接 return "\n ".join(proc_info) except (psutil.NoSuchProcess, psutil.AccessDenied) as e: + # 进程不存在或无权限访问 return f"无法获取进程详细信息(可能需要管理员权限): {e!s}" + # 未找到占用端口的进程 return "未找到占用进程" except Exception as e: + # 获取进程信息失败 return f"获取进程信息失败: {e!s}" + # 初始化 JWT 密钥 def _init_jwt_secret(self) -> None: + # 检查配置中是否已经设置了 JWT 密钥 if not self.config.get("dashboard", {}).get("jwt_secret", None): - # 如果没有设置 JWT 密钥,则生成一个新的密钥 + # 如果没有设置,生成一个随机的 32 字节密钥(十六进制字符串) jwt_secret = os.urandom(32).hex() + # 将密钥存入配置 self.config["dashboard"]["jwt_secret"] = jwt_secret + # 保存配置到文件 self.config.save_config() + # 记录日志 logger.info("Initialized random JWT secret for dashboard.") + # 从配置中读取 JWT 密钥 self._jwt_secret = self.config["dashboard"]["jwt_secret"] + # 构建仪表盘登录凭据显示信息 def _build_dashboard_credentials_display(self) -> str: + # 获取用户名 username = self.config["dashboard"].get("username", "astrbot") + # 获取生成的初始密码(仅首次设置时存在) generated_password = getattr(self.config, "_generated_dashboard_password", None) + # 如果没有生成的密码,只显示用户名 if not generated_password: return f" ➜ Username: {username}\n ✨✨✨\n" + # 如果有生成的初始密码,显示用户名和密码 credentials_display = ( f" ➜ Initial username: {username}\n" f" ➜ Initial password: {generated_password}\n" " ➜ Change it after logging in\n ✨✨✨\n" ) + # 清除生成的密码(仅显示一次),避免在后续日志中再次出现 object.__setattr__(self.config, "_generated_dashboard_password", None) return credentials_display + # 解析并验证仪表盘的 SSL 配置(静态方法) @staticmethod def _resolve_dashboard_ssl_config( ssl_config: dict, ) -> tuple[bool, dict[str, str]]: + # 从环境变量或配置中获取证书文件路径 cert_file = ( os.environ.get("DASHBOARD_SSL_CERT") or os.environ.get("ASTRBOT_DASHBOARD_SSL_CERT") or ssl_config.get("cert_file", "") ) + # 从环境变量或配置中获取私钥文件路径 key_file = ( os.environ.get("DASHBOARD_SSL_KEY") or os.environ.get("ASTRBOT_DASHBOARD_SSL_KEY") or ssl_config.get("key_file", "") ) + # 从环境变量或配置中获取 CA 证书文件路径 ca_certs = ( os.environ.get("DASHBOARD_SSL_CA_CERTS") or os.environ.get("ASTRBOT_DASHBOARD_SSL_CA_CERTS") or ssl_config.get("ca_certs", "") ) + # 如果证书或私钥文件路径缺失,SSL 不可用 if not cert_file or not key_file: logger.warning( "dashboard.ssl.enable is set, but cert_file or key_file is missing. SSL disabled.", ) return False, {} + # 展开用户目录路径(如 ~/ 替换为实际路径) cert_path = Path(cert_file).expanduser() key_path = Path(key_file).expanduser() + # 检查证书文件是否存在 if not cert_path.is_file(): logger.warning( f"dashboard.ssl.enable is set, but cert file is missing: {cert_path}. SSL disabled.", ) return False, {} + # 检查私钥文件是否存在 if not key_path.is_file(): logger.warning( f"dashboard.ssl.enable is set, but key file is missing: {key_path}. SSL disabled.", ) return False, {} + # 构建 SSL 配置字典 resolved_ssl_config = { - "certfile": str(cert_path.resolve()), - "keyfile": str(key_path.resolve()), + "certfile": str(cert_path.resolve()), # 证书文件的绝对路径 + "keyfile": str(key_path.resolve()), # 私钥文件的绝对路径 } + # 如果配置了 CA 证书 if ca_certs: ca_path = Path(ca_certs).expanduser() + # 检查 CA 证书文件是否存在 if not ca_path.is_file(): logger.warning( f"dashboard.ssl.enable is set, but CA cert file is missing: {ca_path}. SSL disabled.", ) return False, {} + # 添加 CA 证书路径 resolved_ssl_config["ca_certs"] = str(ca_path.resolve()) + # SSL 配置成功 return True, resolved_ssl_config + # 启动仪表盘服务器 def run(self): + # 初始化 IP 地址列表 ip_addr = [] + # 获取仪表盘配置 dashboard_config = self.core_lifecycle.astrbot_config.get("dashboard", {}) + # 获取端口配置(优先级:环境变量 > 配置文件 > 默认值 6185) port = ( os.environ.get("DASHBOARD_PORT") or os.environ.get("ASTRBOT_DASHBOARD_PORT") or dashboard_config.get("port", 6185) ) + # 获取主机配置(优先级:环境变量 > 配置文件 > 默认值 0.0.0.0) host = ( os.environ.get("DASHBOARD_HOST") or os.environ.get("ASTRBOT_DASHBOARD_HOST") or dashboard_config.get("host", "0.0.0.0") ) + # 是否启用仪表盘 enable = dashboard_config.get("enable", True) + # 获取 SSL 配置 ssl_config = dashboard_config.get("ssl", {}) + # 确保 ssl_config 是字典类型 if not isinstance(ssl_config, dict): ssl_config = {} + # 解析是否启用 SSL(优先级:环境变量 > 配置文件) ssl_enable = _parse_env_bool( os.environ.get("DASHBOARD_SSL_ENABLE") or os.environ.get("ASTRBOT_DASHBOARD_SSL_ENABLE"), bool(ssl_config.get("enable", False)), ) + # 初始化 SSL 配置字典 resolved_ssl_config: dict[str, str] = {} + # 如果启用了 SSL,解析 SSL 证书配置 if ssl_enable: ssl_enable, resolved_ssl_config = self._resolve_dashboard_ssl_config( ssl_config, ) + # 根据是否启用 SSL 确定协议方案 scheme = "https" if ssl_enable else "http" + # 如果仪表盘被禁用,记录日志并返回 if not enable: logger.info("WebUI disabled.") return None + # 记录启动信息 logger.info("Starting WebUI at %s://%s:%s", scheme, host, port) + # 如果监听在所有接口上,发出安全警告 if host == "0.0.0.0": logger.info( "WebUI listens on all interfaces. Check security. Set dashboard.host in data/cmd_config.json to change it.", ) + # 如果主机不是本地地址,获取本机的 IP 地址列表用于显示 if host not in ["localhost", "127.0.0.1"]: try: ip_addr = get_local_ip_addresses() except Exception as _: pass + # 确保端口是整数类型 if isinstance(port, str): port = int(port) + # 检查端口是否被占用 if self.check_port_in_use(port): + # 获取占用端口的进程信息 process_info = self.get_process_using_port(port) + # 记录错误日志 logger.error( f"错误:端口 {port} 已被占用\n" f"占用信息: \n {process_info}\n" @@ -602,34 +844,45 @@ def run(self): f"2. 端口 {port} 没有被其他程序占用\n" f"3. 如需使用其他端口,请修改配置文件", ) - + # 抛出异常阻止启动 raise Exception(f"端口 {port} 已被占用") + # 检查 WebUI 静态文件是否就绪 if self.data_path and (Path(self.data_path) / "index.html").is_file(): webui_status = "WebUI is ready" else: webui_status = ( f"WebUI is NOT ready: static files are missing at {self.data_path}" ) + # 构建欢迎信息 parts = [f"\n ✨✨✨\n AstrBot v{VERSION} {webui_status}\n\n"] + # 添加本地访问地址 parts.append(f" ➜ Local: {scheme}://localhost:{port}\n") + # 添加网络访问地址 for ip in ip_addr: parts.append(f" ➜ Network: {scheme}://{ip}:{port}\n") + # 添加登录凭据信息 parts.append(self._build_dashboard_credentials_display()) + # 拼接显示字符串 display = "".join(parts) + # 如果没有获取到 IP 地址,提示如何启用远程访问 if not ip_addr: display += ( "Set dashboard.host in data/cmd_config.json to enable remote access.\n" ) + # 记录欢迎信息 logger.info(display) - # 配置 Hypercorn + # 配置 Hypercorn ASGI 服务器 config = HyperConfig() + # 设置绑定的主机和端口 config.bind = [f"{host}:{port}"] + # 如果信任代理头,使用代理感知的日志记录器 if bool(self.config.get("dashboard", {}).get("trust_proxy_headers", False)): config.logger_class = _ProxyAwareHypercornLogger + # 如果启用了 SSL,配置证书 if ssl_enable: config.certfile = resolved_ssl_config["certfile"] config.keyfile = resolved_ssl_config["keyfile"] @@ -639,16 +892,21 @@ def run(self): # 根据配置决定是否禁用访问日志 disable_access_log = dashboard_config.get("disable_access_log", True) if disable_access_log: + # 禁用访问日志 config.accesslog = None else: - # 启用访问日志,使用简洁格式 + # 启用访问日志,使用简洁格式:主机 请求行 状态码 响应大小 响应时间(微秒) config.accesslog = "-" config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s" + # 启动 Hypercorn ASGI 服务器,传入关闭触发器 return serve( cast(Any, self.asgi_app), config, shutdown_trigger=self.shutdown_trigger ) + # 关闭触发器协程:等待关闭事件被触发 async def shutdown_trigger(self) -> None: + # 等待关闭事件 await self.shutdown_event.wait() - logger.info("AstrBot WebUI 已经被关闭") + # 记录关闭日志 + logger.info("AstrBot WebUI 已经被关闭") \ No newline at end of file diff --git a/astrbot/dashboard/services/chat_service.py b/astrbot/dashboard/services/chat_service.py index 7ff6d67f5b..8c2734f50a 100644 --- a/astrbot/dashboard/services/chat_service.py +++ b/astrbot/dashboard/services/chat_service.py @@ -51,7 +51,7 @@ async def track_conversation(convs: dict, conv_id: str): async def poll_webchat_stream_result(back_queue, username: str): try: - result = await asyncio.wait_for(back_queue.get(), timeout=1) + result = await asyncio.wait_for(back_queue.get(), timeout=1) # 使用back_queue监听,意味着模型调用结果放在 back_queue 中 except asyncio.TimeoutError: return None, False except asyncio.CancelledError: @@ -693,15 +693,15 @@ async def build_chat_stream( ) -> AsyncIterator[str]: if "message" not in post_data and "files" not in post_data: raise ChatServiceError("Missing key: message or files") - if "session_id" not in post_data and "conversation_id" not in post_data: + if "session_id" not in post_data and "conversation_id" not in post_data: # {'message': [{'type': 'plain', 'text': '你好'}], 'session_id': '8be47862-3fdd-48c9-aada-1b11e03ced50', 'enable_streaming': True, 'selected_provider': 'lm_studio/qwen3.5-2b', 'selected_model': 'qwen3.5-2b', '_skip_user_history': False} raise ChatServiceError("Missing key: session_id or conversation_id") message = post_data.get("message", post_data.get("files", [])) session_id = post_data.get("session_id", post_data.get("conversation_id")) - selected_provider = post_data.get("selected_provider") - selected_model = post_data.get("selected_model") + selected_provider = post_data.get("selected_provider") # 'lm_studio/qwen3.5-2b' + selected_model = post_data.get("selected_model") # 'qwen3.5-2b' enable_streaming = post_data.get("enable_streaming", True) - platform_history_id = post_data.get("_platform_history_id") or "webchat" + platform_history_id = post_data.get("_platform_history_id") or "webchat" # 'webchat' thread_selected_text = post_data.get("_thread_selected_text") if not session_id: @@ -714,10 +714,10 @@ async def build_chat_stream( "Message content is empty (reply only is not allowed)" ) - message_id = str(uuid.uuid4()) + message_id = str(uuid.uuid4()) # 'de84fbbb-203e-4414-8a1a-c189ccae6363' llm_checkpoint_id = post_data.get("_llm_checkpoint_id") or str(uuid.uuid4()) - skip_user_history = bool(post_data.get("_skip_user_history")) - back_queue = webchat_queue_mgr.get_or_create_back_queue( + skip_user_history = bool(post_data.get("_skip_user_history")) # False + back_queue = webchat_queue_mgr.get_or_create_back_queue( # 获取消息处理的queue back_queue: 0x216a3eb6e10 message_id, webchat_conv_id, ) @@ -801,7 +801,7 @@ def build_attachment_saved_event(part: dict | None) -> str | None: async with track_conversation(self.running_convs, webchat_conv_id): while True: - result, should_break = await poll_webchat_stream_result( + result, should_break = await poll_webchat_stream_result( # TODO 这里监听模型调用结果 back_queue, username ) if should_break: @@ -936,15 +936,15 @@ def build_attachment_saved_event(part: dict | None) -> str | None: exc_info=True, ) webchat_queue_mgr.remove_back_queue(message_id) - - chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id) + # chat_queue: 地址:Queue at 0x216a351c7d0, 这里的 chat_queue 跟 astrbot\core\platform\sources\webchat\webchat_queue_mgr.py 的 WebChatQueueMgr 类的 _listen_to_queue 方法中的 queue 是同一个。这里推入信息之后,会被_listen_to_queue 方法监听。 + chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id) # 这个queue跟back_queue不一样,后者用来接收模型的返回结果,然后传递给前端,这个是用来激活platform消息处理的。而且这里面会启动监听器。 await chat_queue.put( # TODO 这里将信息放入queue ( username, - webchat_conv_id, + webchat_conv_id, # '8be47862-3fdd-48c9-aada-1b11e03ced50' { "message": message_parts, - "selected_provider": selected_provider, + "selected_provider": selected_provider, "selected_model": selected_model, "enable_streaming": enable_streaming, "message_id": message_id, @@ -954,7 +954,7 @@ def build_attachment_saved_event(part: dict | None) -> str | None: ), ) - message_parts_for_storage = strip_message_parts_path_fields(message_parts) + message_parts_for_storage = strip_message_parts_path_fields(message_parts) # [{'type': 'plain', 'text': '你好'}] if not skip_user_history: saved_user_record = await self.platform_history_mgr.insert( platform_id=platform_history_id,