Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion astrbot/core/db/vec_db/faiss_impl/vec_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,15 @@ async def insert_batch(
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
embedding_texts: list[str] | None = None,
) -> list[int]:
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。

Args:
progress_callback: 进度回调函数,接收参数 (current, total)
embedding_texts: 可选的向量化文本,用于将"用于语义匹配的文本"与
"用于存储/检索返回的文本(contents)"解耦。表格知识库使用索引列
文本进行向量化,但存储并返回整行文本。缺省时回退为 contents。

"""
metadatas = metadatas or [{} for _ in contents]
Expand All @@ -81,6 +85,20 @@ async def insert_batch(
)
return []

texts_to_embed = embedding_texts if embedding_texts is not None else contents
if len(texts_to_embed) != len(contents):
raise KnowledgeBaseUploadError(
stage="embedding",
user_message=(
f"向量化失败:用于向量化的文本数量与文本分块数量不一致"
f"(期望 {len(contents)},实际 {len(texts_to_embed)})。"
),
details={
"expected_contents": len(contents),
"actual_embedding_texts": len(texts_to_embed),
},
)

content_count = len(contents)
if len(metadatas) != content_count:
raise KnowledgeBaseUploadError(
Expand Down Expand Up @@ -110,7 +128,7 @@ async def insert_batch(
start = time.time()
logger.debug(f"Generating embeddings for {len(contents)} contents...")
vectors = await self.embedding_provider.get_embeddings_batch(
contents,
texts_to_embed,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
Expand Down
40 changes: 40 additions & 0 deletions astrbot/core/knowledge_base/kb_db_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,46 @@ async def migrate_to_v1(self) -> None:

await session.commit()

async def migrate_to_v2(self) -> None:
"""Run knowledge base database v2 migration.

Adds the table knowledge base columns to existing databases that were
created before the table feature. SQLite does not support
``ADD COLUMN IF NOT EXISTS``, so existing columns are checked via
``PRAGMA table_info`` before issuing ``ALTER TABLE`` statements.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
kb_columns = {
row[1]
for row in (
await session.execute(
text("PRAGMA table_info(knowledge_bases)")
)
).fetchall()
}
if "kb_type" not in kb_columns:
await session.execute(
text(
"ALTER TABLE knowledge_bases "
"ADD COLUMN kb_type VARCHAR(20) NOT NULL DEFAULT 'text'",
),
)

doc_columns = {
row[1]
for row in (
await session.execute(text("PRAGMA table_info(kb_documents)"))
).fetchall()
}
if "table_schema" not in doc_columns:
await session.execute(
text("ALTER TABLE kb_documents ADD COLUMN table_schema TEXT"),
)

await session.commit()

async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
Expand Down
188 changes: 178 additions & 10 deletions astrbot/core/knowledge_base/kb_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,171 @@ async def embedding_progress_callback(current, total) -> None:

raise

@staticmethod
def _format_row_text(columns: list[dict], row: dict[str, str]) -> str:
"""Format selected columns of a row as ``name: value`` lines.

Args:
columns: Column descriptors to render, each with a ``name`` key.
row: Mapping of header name to cell value for one row.

Returns:
A newline-joined ``name: value`` representation, skipping blank cells.
"""
lines = []
for col in columns:
name = col.get("name")
if not name:
continue
value = str(row.get(name) or "").strip()
if value:
lines.append(f"{name}: {value}")
return "\n".join(lines)

async def upload_table_document(
self,
file_name: str,
file_type: str,
headers: list[str],
rows: list[list[str]],
columns_config: list[dict],
file_size: int = 0,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> KBDocument:
"""Upload a structured table where each row is an independent chunk.

Index columns are concatenated to build the embedding text (used for
semantic matching), while the full row is stored/returned and the raw
row values are kept in chunk metadata under ``row_data``.

Args:
file_name: Original table file name.
file_type: File extension without the dot (e.g. ``csv``).
headers: Column header names aligned to ``rows``.
rows: Row values, each aligned to ``headers``.
columns_config: Per-column config items with keys ``name``,
``is_index`` and ``is_returned``.
file_size: Original file size in bytes.
batch_size: Embedding batch size.
tasks_limit: Embedding concurrency limit.
max_retries: Embedding retry count.
progress_callback: Async callback ``(stage, current, total)``.

Returns:
KBDocument: The created document metadata record.

Raises:
KnowledgeBaseUploadError: If no indexable row can be produced.
"""
await self._ensure_vec_db()
doc_id = str(uuid.uuid4())

index_cols = [c for c in columns_config if c.get("is_index")]
if not index_cols:
raise KnowledgeBaseUploadError(
stage="validation",
user_message="表格导入失败:至少需要选择一个索引列用于语义检索。",
details={"file_name": file_name},
)
returned_cols = [c for c in columns_config if c.get("is_returned")] or [
{"name": h} for h in headers
]

if progress_callback:
await progress_callback("parsing", 100, 100)

contents: list[str] = []
embedding_texts: list[str] = []
metadatas: list[dict] = []
for idx, row_values in enumerate(rows):
row = {
h: (row_values[i] if i < len(row_values) else "")
for i, h in enumerate(headers)
}
embedding_text = self._format_row_text(index_cols, row)
if not embedding_text:
# Skip rows whose index columns are all empty.
continue
content_text = self._format_row_text(returned_cols, row) or embedding_text
contents.append(content_text)
embedding_texts.append(embedding_text)
metadatas.append(
{
"kb_id": self.kb.kb_id,
"kb_doc_id": doc_id,
"chunk_index": len(contents) - 1,
"row_index": idx,
"row_data": row,
"is_table_row": True,
},
)

if not contents:
raise KnowledgeBaseUploadError(
stage="validation",
user_message="表格导入失败:所选索引列在所有行中均为空,没有可索引的数据。",
details={"file_name": file_name},
)

if progress_callback:
await progress_callback("chunking", 100, 100)

async def embedding_progress_callback(current, total) -> None:
if progress_callback:
await progress_callback("embedding", current, total)

try:
await self.vec_db.insert_batch(
contents=contents,
metadatas=metadatas,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=embedding_progress_callback,
embedding_texts=embedding_texts,
)
except KnowledgeBaseUploadError:
raise
except Exception as exc:
raise KnowledgeBaseUploadError(
stage="storage",
user_message="存储失败:表格行已生成,但写入知识库索引时出错。",
details={"file_name": file_name},
) from exc

doc = KBDocument(
doc_id=doc_id,
kb_id=self.kb.kb_id,
doc_name=file_name,
file_type=file_type,
file_size=file_size,
file_path="",
table_schema=json.dumps(columns_config, ensure_ascii=False),
chunk_count=len(contents),
media_count=0,
)
try:
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
await session.commit()
await session.refresh(doc)
except Exception as exc:
raise KnowledgeBaseUploadError(
stage="metadata",
user_message="元数据保存失败:表格行已写入知识库,但文档记录保存失败。",
details={"file_name": file_name, "doc_id": doc_id},
) from exc

vec_db: FaissVecDB = self.vec_db # type: ignore
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
await self.refresh_kb()
await self.refresh_document(doc_id)
return doc

async def list_documents(
self,
offset: int = 0,
Expand Down Expand Up @@ -537,16 +702,19 @@ async def get_chunks_by_doc_id(
result = []
for chunk in chunks:
chunk_md = json.loads(chunk["metadata"])
result.append(
{
"chunk_id": chunk["doc_id"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": chunk_md["kb_id"],
"chunk_index": chunk_md["chunk_index"],
"content": chunk["text"],
"char_count": len(chunk["text"]),
},
)
item = {
"chunk_id": chunk["doc_id"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": chunk_md["kb_id"],
"chunk_index": chunk_md["chunk_index"],
"content": chunk["text"],
"char_count": len(chunk["text"]),
}
if chunk_md.get("is_table_row"):
item["is_table_row"] = True
item["row_index"] = chunk_md.get("row_index")
item["row_data"] = chunk_md.get("row_data")
result.append(item)
return result

async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
Expand Down
3 changes: 3 additions & 0 deletions astrbot/core/knowledge_base/kb_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ async def _init_kb_database(self) -> None:
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
await self.kb_db.migrate_to_v2()
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")

async def load_kbs(self) -> None:
Expand Down Expand Up @@ -94,6 +95,7 @@ async def create_kb(
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
kb_type: str | None = None,
) -> KBHelper:
"""创建新的知识库实例"""
if embedding_provider_id is None:
Expand All @@ -102,6 +104,7 @@ async def create_kb(
kb_name=kb_name,
description=description,
emoji=emoji or "📚",
kb_type=kb_type or "text",
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size if chunk_size is not None else 512,
Expand Down
6 changes: 6 additions & 0 deletions astrbot/core/knowledge_base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class KnowledgeBase(BaseKBModel, table=True):
kb_name: str = Field(max_length=100, nullable=False)
description: str | None = Field(default=None, sa_type=Text)
emoji: str | None = Field(default="📚", max_length=10)
# Knowledge base type: "text" (unstructured documents) or "table"
# (structured row-level data, Coze-like table knowledge base).
kb_type: str = Field(default="text", max_length=20, nullable=False)
embedding_provider_id: str | None = Field(default=None, max_length=100)
rerank_provider_id: str | None = Field(default=None, max_length=100)
# 分块配置参数
Expand Down Expand Up @@ -81,6 +84,9 @@ class KBDocument(BaseKBModel, table=True):
file_type: str = Field(max_length=20, nullable=False)
file_size: int = Field(nullable=False)
file_path: str = Field(max_length=512, nullable=False)
# JSON column configuration for table documents (None for text documents).
# Stores the per-document column schema chosen during table preprocessing.
table_schema: str | None = Field(default=None, sa_type=Text)
chunk_count: int = Field(default=0, nullable=False)
media_count: int = Field(default=0, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
Expand Down
Loading