Files
get_wechat/chatlog_fastAPI/services/topic_engine.py

1095 lines
43 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
话题分类引擎
用户点击 AI 分析后,按所选时间段全量分页拉取消息,解析可用媒体证据,
再分批抽取、跨批合并并校验消息归属。
"""
import asyncio
import json
import logging
import re
from datetime import datetime
import aiosqlite
from fastapi import HTTPException
from database import get_active_db_path
from services.ai_client import get_openai_client
from services.chatlog_client import chatlog_client
from services.message_formatter import append_quote_text, extract_quote
from services.media_parser import parse_media
from services.report_learning import build_report_learning_context
from services.runtime_settings import get_ai_settings
log = logging.getLogger(__name__)
_classify_lock = asyncio.Lock()
_classifying_group: int | None = None
FETCH_PAGE_SIZE = 500
CLASSIFY_BATCH_SIZE = 40
CONTEXT_BEFORE = 4
CONTEXT_AFTER = 4
CLASSIFY_CONTENT_LIMIT = 900
MEDIA_PARSE_CONCURRENCY = 2
UNCATEGORIZED_BUSINESS = "未归类设备问题(需人工复核)"
DAILY_TOPIC = "日常交流/待确认"
DEVICE_ISSUE_EXAMPLES = (
"报警/故障代码", "主轴异常", "刀库/换刀异常", "回零/限位异常",
"伺服/驱动异常", "加工精度异常", "气压/液压异常", "润滑/冷却异常",
"电气线路/IO异常", "系统参数/程序操作问题",
)
BUSINESS_KEYWORDS = (
"设备", "机床", "数控", "加工中心", "CNC", "cnc", "大铁", "客户", "公司", "厂家", "现场",
"报警", "警报", "报错", "故障", "故障码", "代码", "异常", "报修", "维修", "售后", "工程师",
"主轴", "转速", "拉刀", "松刀", "刀库", "换刀", "刀臂", "刀盘", "刀号", "卡刀", "掉刀",
"回零", "原点", "限位", "行程", "伺服", "驱动", "驱动器", "电机", "编码器", "变频器",
"丝杆", "导轨", "精度", "尺寸", "跑偏", "震刀", "振刀", "圆度", "平面度", "加工",
"气压", "液压", "油压", "润滑", "冷却", "水泵", "切削液", "漏油", "漏水", "异响",
"电柜", "线路", "电源", "跳闸", "断电", "IO", "I/O", "PLC", "系统", "参数", "程序",
"G代码", "M代码", "面板", "手轮", "急停", "排屑", "卡住", "调试", "安装", "处理",
)
async def _get_client():
return await get_openai_client()
def get_classifying_group() -> int | None:
"""返回正在分析的 group_id没有则 None。供 routers 检查串行状态。"""
return _classifying_group
def _msg_seq(m: dict) -> int:
return int(m.get("sort_seq") or m.get("seq") or m.get("Seq") or m.get("id") or 0)
def _msg_time(m: dict) -> str:
return m.get("create_time") or m.get("time") or m.get("CreateTime") or ""
def _msg_sender(m: dict) -> str:
return (
m.get("sender_name")
or m.get("senderName")
or m.get("SenderName")
or m.get("sender")
or m.get("Sender")
or ""
)
def _raw_sender_id(m: dict) -> str:
return str(m.get("sender") or m.get("Sender") or "").strip()
def _looks_like_raw_sender_id(value: str | None) -> bool:
value = (value or "").strip()
return (
not value
or value.startswith("wxid_")
or value.startswith("gh_")
or value.endswith("@chatroom")
or value.startswith("chatroom_")
)
def _sender_display_name(m: dict, sender: str = "") -> str:
for key in (
"sender_name",
"senderName",
"SenderName",
"accountName",
"groupNickname",
"displayName",
"nickName",
"remark",
):
value = str(m.get(key) or "").strip()
if value and value != sender and not _looks_like_raw_sender_id(value):
return value
return ""
def _fill_message_sender_names(messages: list[dict]) -> None:
names: dict[str, str] = {}
for m in messages:
sender = _raw_sender_id(m)
name = _sender_display_name(m, sender)
if sender and name:
names[sender] = name
if not names:
return
for m in messages:
sender = _raw_sender_id(m)
if not sender or sender not in names:
continue
current = str(m.get("sender_name") or m.get("senderName") or m.get("SenderName") or "").strip()
if not current or current == sender or _looks_like_raw_sender_id(current):
m["sender_name"] = names[sender]
m["senderName"] = names[sender]
def _msg_type(m: dict) -> int:
try:
return int(m.get("type") or m.get("Type") or 1)
except Exception:
return 1
def _msg_sub_type(m: dict) -> int:
try:
return int(m.get("sub_type") or m.get("subType") or m.get("SubType") or 0)
except Exception:
return 0
def _contents(m: dict) -> dict:
contents = m.get("contents") or m.get("Contents") or {}
return contents if isinstance(contents, dict) else {}
def _media_key(m: dict) -> str:
contents = _contents(m)
key = (
contents.get("rawmd5")
or contents.get("md5")
or contents.get("path")
or m.get("media_key")
or m.get("mediaKey")
or m.get("image_path")
or ""
)
return str(key).replace("\\", "/")
def _voice_key(m: dict) -> str:
contents = _contents(m)
if contents.get("voice"):
return str(contents.get("voice"))
return str(m.get("voice_key") or m.get("voiceKey") or "")
def _message_snapshot(m: dict) -> str:
"""保存分类时拿到的原始消息快照,供后续 chatlog batch 查不到时兜底显示。"""
return json.dumps(m, ensure_ascii=False, separators=(",", ":"))
def _base_msg_content(m: dict) -> str:
content = m.get("content") or m.get("Content") or ""
contents = _contents(m)
if isinstance(content, str) and content.lstrip().startswith("<") and extract_quote(m):
content = ""
link_title = contents.get("title") or m.get("link_title") or ""
link_desc = contents.get("desc") or m.get("link_desc") or ""
link_source = contents.get("sourceName") or contents.get("source_name") or m.get("link_source") or ""
link_url = contents.get("url") or m.get("link_url") or ""
file_name = contents.get("fileName") or contents.get("filename") or contents.get("title") or ""
parts: list[str] = []
if link_title:
parts.append(f"[链接/文件] {link_title}")
if file_name and file_name not in parts:
parts.append(f"文件名:{file_name}")
if link_desc:
parts.append(f"描述:{link_desc}")
if link_source:
parts.append(f"来源:{link_source}")
if link_url:
parts.append(f"URL{link_url}")
if content and content not in parts:
parts.append(content)
return append_quote_text("".join(parts) if parts else content, m)
def _is_image_message(m: dict) -> bool:
return _msg_type(m) == 3
def _is_voice_message(m: dict) -> bool:
return _msg_type(m) == 34
def _is_video_message(m: dict) -> bool:
return _msg_type(m) == 43
def _is_file_or_link_message(m: dict) -> bool:
return _msg_type(m) == 49
def _is_file_message(m: dict) -> bool:
return _msg_type(m) == 49 and _msg_sub_type(m) == 6
def _media_kind(m: dict) -> str | None:
if _is_image_message(m):
return "image"
if _is_voice_message(m):
return "voice"
if _is_video_message(m):
return "video"
return None
def _media_parse_key(m: dict) -> str:
kind = _media_kind(m)
if kind == "voice":
return _voice_key(m)
if kind in {"image", "video"}:
return _media_key(m)
return ""
def _message_analysis_text(m: dict) -> str:
base = _base_msg_content(m).strip()
parsed = str(m.get("_ai_media_text") or "").strip()
parse_error = str(m.get("_ai_media_error") or "").strip()
parts = []
if base:
parts.append(base)
kind = _media_kind(m)
if parsed:
label = {"image": "图片描述", "voice": "语音转写", "video": "视频截图描述"}.get(kind or "", "媒体解析")
parts.append(f"[{label}] {parsed}")
elif parse_error:
parts.append(f"[媒体解析失败] {parse_error}")
elif kind:
parts.append(f"[{kind}消息] key={_media_parse_key(m) or ''}")
if _is_file_or_link_message(m) and not parts:
parts.append("[文件/链接消息]")
neighbor_context = str(m.get("_neighbor_context") or "").strip()
if neighbor_context and (kind or _is_file_or_link_message(m) or parse_error):
parts.append(f"[附近上下文] {neighbor_context}")
text = "".join(parts).strip()
if len(text) > CLASSIFY_CONTENT_LIMIT:
text = text[:CLASSIFY_CONTENT_LIMIT] + "..."
return text
def _prompt_item(m: dict, core: bool = True) -> dict:
return {
"seq": _msg_seq(m),
"time": _msg_time(m),
"sender": _msg_sender(m),
"core": core,
"type": _msg_type(m),
"content": _message_analysis_text(m),
}
def _extract_json_array(text: str) -> list:
array_start = text.find("[")
object_start = text.find("{")
if array_start >= 0 and (object_start < 0 or array_start < object_start):
end = text.rfind("]") + 1
if end <= 0:
raise ValueError(f"LLM 返回内容无法解析为 JSON 数组: {text[:200]}")
data = json.loads(text[array_start:end])
elif object_start >= 0:
end = text.rfind("}") + 1
if end <= 0:
raise ValueError(f"LLM 返回内容无法解析为 JSON 对象: {text[:200]}")
data = json.loads(text[object_start:end])
else:
raise ValueError(f"LLM 返回内容无法解析为 JSON: {text[:200]}")
if isinstance(data, dict):
data = data.get("topics") or data.get("items") or []
if not isinstance(data, list):
raise ValueError("LLM JSON 不是数组")
return data
def _parse_msg_time(raw: str) -> float | None:
"""把 chatlog 的 ISO 时间字符串解析为 unix 秒。失败返回 None。"""
if not raw:
return None
try:
s = raw.replace("/", "-")
if "T" not in s and " " in s:
s = s.replace(" ", "T", 1)
return datetime.fromisoformat(s).timestamp()
except Exception:
return None
def _looks_business_message(m: dict) -> bool:
text = _message_analysis_text(m)
if _is_file_or_link_message(m) or _is_voice_message(m):
return True
if (_is_image_message(m) or _is_video_message(m)) and str(m.get("_neighbor_context") or "").strip():
return True
if (_is_image_message(m) or _is_video_message(m)) and str(m.get("_ai_media_text") or "").strip():
return True
low = text.lower()
if any(k.lower() in low for k in BUSINESS_KEYWORDS):
return True
if re.search(r"\b(a20\d+|ap\d+|sp\d+|\d+p|[a-z]{2,}\d{2,})\b", low):
return True
if re.search(r"\.(dwg|xlsx?|pdf|docx?|zip|rar)\b", low):
return True
return False
def _has_text_signal(m: dict) -> bool:
text = _base_msg_content(m).strip()
if text:
return True
contents = _contents(m)
return bool(contents.get("desc") or contents.get("title") or contents.get("refer") or contents.get("recordInfo"))
def _attach_neighbor_context(messages: list[dict], radius: int = 4) -> None:
"""给图片/视频/语音/文件消息补充临近文本,供无法解析媒体时按上下文归类。"""
for idx, m in enumerate(messages):
if not (_media_kind(m) or _is_file_or_link_message(m)):
continue
lines: list[str] = []
sender = _msg_sender(m)
for j in range(max(0, idx - radius), min(len(messages), idx + radius + 1)):
if j == idx:
continue
n = messages[j]
if not _has_text_signal(n):
continue
# 同一发送者或时间邻近的文本最能解释媒体;不同发送者也保留少量上下文。
prefix = "前文" if j < idx else "后文"
relation = "同发送人" if sender and _msg_sender(n) == sender else "邻近"
text = _base_msg_content(n).strip()
if not text:
text = _message_analysis_text(n).strip()
if not text:
continue
if len(text) > 180:
text = text[:180] + "..."
lines.append(f"{prefix}/{relation} seq={_msg_seq(n)} {text}")
if lines:
m["_neighbor_context"] = " | ".join(lines[:6])
async def _fetch_window_messages(talker: str, start_ts: int, end_ts: int) -> list[dict]:
start_dt = datetime.fromtimestamp(start_ts)
end_dt = datetime.fromtimestamp(end_ts)
time_str = f"{start_dt.strftime('%Y-%m-%d')},{end_dt.strftime('%Y-%m-%d')}"
all_items: list[dict] = []
offset = 0
seen: set[int] = set()
while True:
data = await chatlog_client.get_messages(
talker,
time=time_str,
limit=FETCH_PAGE_SIZE,
offset=offset,
)
items = data.get("items", []) or []
total = int(data.get("total") or 0)
if not items:
break
for m in items:
seq = _msg_seq(m)
if seq and seq in seen:
continue
t = _msg_time(m)
ts = _parse_msg_time(t)
if ts is None or (start_ts <= ts <= end_ts):
all_items.append(m)
if seq:
seen.add(seq)
offset += len(items)
if total and offset >= total:
break
if len(items) < FETCH_PAGE_SIZE:
break
all_items.sort(key=lambda m: (_parse_msg_time(_msg_time(m)) or 0, _msg_seq(m)))
return all_items
def _validate_media_settings(messages: list[dict], ai_settings: dict) -> str | None:
has_visual = any(_media_kind(m) in {"image", "video"} for m in messages)
has_voice = any(_media_kind(m) == "voice" for m in messages)
if has_visual and not ai_settings.get("vision_model"):
return "所选时间段包含图片/视频消息,请先在「设置」页面填入视觉模型"
if has_voice and not ai_settings.get("voice_model"):
return "所选时间段包含语音消息,请先在「设置」页面填入语音模型"
return None
async def _parse_message_media(messages: list[dict], update_progress, base_processed: int, total: int) -> int:
semaphore = asyncio.Semaphore(MEDIA_PARSE_CONCURRENCY)
media_messages = [m for m in messages if _media_kind(m) and _media_parse_key(m)]
parsed_count = 0
cache: dict[tuple[str, str], str] = {}
async def parse_one(m: dict):
nonlocal parsed_count
kind = _media_kind(m)
key = _media_parse_key(m)
if not kind or not key:
return
async with semaphore:
try:
cache_key = (kind, key)
if cache_key not in cache:
parsed = await parse_media(kind, key)
cache[cache_key] = parsed.get("text") or ""
m["_ai_media_text"] = cache[cache_key]
except HTTPException as e:
m["_ai_media_error"] = str(e.detail)
except Exception as e:
log.warning(f"[classify] media parse failed seq={_msg_seq(m)}: {e}", exc_info=True)
m["_ai_media_error"] = str(e)
finally:
parsed_count += 1
await update_progress("running", base_processed + parsed_count, total)
await asyncio.gather(*(parse_one(m) for m in media_messages))
return len(media_messages)
async def _manual_assigned_seqs(db: aiosqlite.Connection, group_id: int) -> set[int]:
async with db.execute(
"""
SELECT DISTINCT tm.msg_seq
FROM topic_messages tm
JOIN topics t ON t.id = tm.topic_id
WHERE t.group_id = ? AND COALESCE(t.source, 'manual') = 'manual'
""",
(group_id,),
) as cur:
return {int(row["msg_seq"]) for row in await cur.fetchall()}
async def _delete_ai_topics(db: aiosqlite.Connection, group_id: int) -> None:
await db.execute(
"""
DELETE FROM knowledge_fts WHERE doc_id IN (
SELECT d.id FROM knowledge_docs d
JOIN topics t ON t.id = d.topic_id
WHERE t.group_id=? AND COALESCE(t.source, 'manual')='ai'
)
""",
(group_id,),
)
await db.execute(
"""
DELETE FROM knowledge_docs
WHERE topic_id IN (
SELECT id FROM topics WHERE group_id=? AND COALESCE(source, 'manual')='ai'
)
""",
(group_id,),
)
await db.execute(
"""
DELETE FROM topic_messages
WHERE topic_id IN (
SELECT id FROM topics WHERE group_id=? AND COALESCE(source, 'manual')='ai'
)
""",
(group_id,),
)
await db.execute("DELETE FROM topics WHERE group_id=? AND COALESCE(source, 'manual')='ai'", (group_id,))
await db.commit()
def _chunk_messages(messages: list[dict]) -> list[tuple[list[dict], list[dict]]]:
chunks = []
for start in range(0, len(messages), CLASSIFY_BATCH_SIZE):
end = min(len(messages), start + CLASSIFY_BATCH_SIZE)
ctx_start = max(0, start - CONTEXT_BEFORE)
ctx_end = min(len(messages), end + CONTEXT_AFTER)
chunks.append((messages[start:end], messages[ctx_start:ctx_end]))
return chunks
def _batch_classify_prompt(core_messages: list[dict], context_messages: list[dict], guidance: str = "") -> str:
core_seqs = {_msg_seq(m) for m in core_messages}
payload = [_prompt_item(m, _msg_seq(m) in core_seqs) for m in context_messages]
guidance_block = f"\n\n{guidance.strip()}\n\n" if guidance and guidance.strip() else ""
prompt = (
"你是广东大铁数控机械有限公司设备售后群话题分析助手。这个微信群里主要是大铁设备客户反馈问题,售后工程师对接处理。\n"
"请只为 core=true 的消息按【设备问题/故障现象】聚类,而不是按客户公司、地域、日期、工程师或单次对话聚类。\n"
"同类设备问题即使来自不同公司、不同地区,也要合并到同一话题;不同设备问题必须准确分开。\n"
f"常见问题口径包括:{', '.join(DEVICE_ISSUE_EXAMPLES)}\n\n"
f"消息列表:\n{json.dumps(payload, ensure_ascii=False)}\n\n"
"规则:\n"
"1. 只输出 core=true 消息的归属context 只用来理解上下文。\n"
"2. 标题必须体现设备部件和问题现象,建议如「主轴报警/故障代码问题」「刀库换刀卡顿」「回零/限位异常」。不要把客户公司名作为标题核心。\n"
"3. 不同公司出现相同设备问题时合并;同一公司出现不同问题时拆开。\n"
"4. [引用消息] 是当前回复的强上下文,必须用于理解当前消息含义,但归类对象仍然只能是当前 core seq。\n"
"5. 工程师回复、客户补充说明、图片、语音、视频、文件,必须跟随其对应的设备问题归类。\n"
"6. 图片/视频无法识别时,不要编造内容,但必须根据附近上下文判断所属设备问题。\n"
"7. 真正寒暄、入群通知、撤回、无设备售后含义短句才可归入「日常交流/待确认」;设备咨询和报修不得归入日常交流。\n"
"8. 每条 core 消息最多归入一个话题,尽量覆盖所有 core seq。\n\n"
"输出严格 JSON 数组,不要解释。格式:\n"
'[{"topic_key":"稳定短键","title":"具体话题标题","seqs":[1,2],"reason":"分类依据"}]'
)
return prompt + guidance_block
async def _classify_batches(
messages: list[dict],
update_progress,
base_processed: int,
total: int,
guidance: str = "",
) -> list[dict]:
chunks = _chunk_messages(messages)
results: list[dict] = []
_client, _ai = await _get_client()
for i, (core, context) in enumerate(chunks, start=1):
prompt = _batch_classify_prompt(core, context, guidance)
try:
resp = await _client.chat.completions.create(
model=_ai["ai_model"],
messages=[
{
"role": "system",
"content": "你是广东大铁数控机械有限公司设备售后问题分类器。你只输出 JSON 数组不输出解释、Markdown 或额外文字。",
},
{"role": "user", "content": prompt},
],
temperature=0.05,
)
batch_items = _extract_json_array(resp.choices[0].message.content.strip())
for item in batch_items:
if isinstance(item, dict):
item["_batch"] = i
results.append(item)
except Exception as e:
log.warning(f"[classify] batch {i} classify failed: {e}", exc_info=True)
await update_progress("running", base_processed + i, total)
return results
def _sanitize_topic_items(items: list[dict], seq_to_msg: dict[int, dict], allowed_seqs: set[int]) -> list[dict]:
assigned: set[int] = set()
clean: list[dict] = []
for item in items:
title = str(item.get("title") or item.get("new_topic") or item.get("topic") or "").strip()
seqs = item.get("seqs") or item.get("message_seqs") or item.get("messages") or []
topic_key = str(item.get("topic_key") or title).strip()
reason = str(item.get("reason") or item.get("evidence") or "").strip()
clean_seqs: list[int] = []
for seq in seqs:
try:
n = int(seq)
except Exception:
continue
if n in allowed_seqs and n in seq_to_msg and n not in assigned:
clean_seqs.append(n)
if not title or not clean_seqs:
continue
if title == DAILY_TOPIC:
business = [n for n in clean_seqs if _looks_business_message(seq_to_msg[n])]
non_business = [n for n in clean_seqs if n not in business]
if business:
clean.append({
"topic_key": UNCATEGORIZED_BUSINESS,
"title": UNCATEGORIZED_BUSINESS,
"seqs": business,
"reason": "模型归入日常交流但消息含业务信号,转入人工复核。",
})
assigned.update(business)
clean_seqs = non_business
if not clean_seqs:
continue
if title != DAILY_TOPIC and len(title) > 80:
title = title[:80]
clean.append({"topic_key": topic_key, "title": title, "seqs": clean_seqs, "reason": reason})
assigned.update(clean_seqs)
return clean
def _merge_prompt(candidates: list[dict]) -> str:
compact = [
{
"id": i,
"topic_key": c.get("topic_key", ""),
"title": c.get("title", ""),
"seqs": c.get("seqs", []),
"reason": c.get("reason", ""),
}
for i, c in enumerate(candidates)
]
return (
"你是广东大铁数控机械有限公司设备售后话题合并审核员。请把分批识别出的候选话题做跨批次合并。\n"
"合并粒度是【设备问题/故障现象】。不要按客户公司、地域、日期、工程师或单次对话合并。\n"
"不同公司出现相同设备问题时必须合并,例如两个客户都反馈主轴报警,应归为同一类主轴报警问题。\n"
"不同设备部件、不同故障现象、不同处理路径必须分开,例如主轴报警、刀库卡刀、回零限位异常、加工精度异常不能互相合并。\n\n"
f"候选话题:\n{json.dumps(compact, ensure_ascii=False)}\n\n"
"输出严格 JSON 数组,不要解释。格式:\n"
'[{"title":"合并后的具体话题标题","candidate_ids":[0,1],"reason":"合并依据"}]'
)
async def _merge_candidates(candidates: list[dict], update_progress, processed: int, total: int) -> list[dict]:
if not candidates:
return []
if len(candidates) == 1:
await update_progress("running", processed + 1, total)
return candidates
_client, _ai = await _get_client()
resp = await _client.chat.completions.create(
model=_ai["ai_model"],
messages=[
{
"role": "system",
"content": "你是广东大铁数控机械有限公司设备售后话题合并器。你只输出 JSON 数组不输出解释、Markdown 或额外文字。",
},
{"role": "user", "content": _merge_prompt(candidates)},
],
temperature=0.05,
)
merged_raw = _extract_json_array(resp.choices[0].message.content.strip())
used: set[int] = set()
merged: list[dict] = []
for item in merged_raw:
if not isinstance(item, dict):
continue
ids = item.get("candidate_ids") or item.get("ids") or []
valid_ids = []
for raw_id in ids:
try:
idx = int(raw_id)
except Exception:
continue
if 0 <= idx < len(candidates) and idx not in used:
valid_ids.append(idx)
if not valid_ids:
continue
title = str(item.get("title") or candidates[valid_ids[0]].get("title") or "").strip()
if not title:
continue
seqs: list[int] = []
for idx in valid_ids:
used.add(idx)
seqs.extend(candidates[idx].get("seqs", []))
merged.append({
"title": title[:80] if title != DAILY_TOPIC else title,
"seqs": seqs,
"reason": str(item.get("reason") or "").strip(),
})
for idx, candidate in enumerate(candidates):
if idx not in used:
merged.append(candidate)
await update_progress("running", processed + 1, total)
return merged
def _supplement_prompt(unassigned: list[dict], topics: list[dict]) -> str:
payload = [_prompt_item(m, True) for m in unassigned]
topic_payload = [{"index": i, "title": t["title"], "seqs": t.get("seqs", [])[:8]} for i, t in enumerate(topics)]
return (
"你是广东大铁数控机械有限公司设备售后消息补充分配审核员。请把未归属消息分配到已有话题,或新建设备问题话题。\n"
"分配粒度是设备问题/故障现象,不是客户公司、地域、日期、工程师或单次对话。\n"
"不同公司出现相同设备问题时分配到同一话题;不同设备部件、故障现象或处理路径要新建不同话题。\n"
"[引用消息] 是当前回复的强上下文,必须用于理解当前消息含义,但归类对象仍然只能是当前消息 seq。\n"
"图片/视频内容无法识别时,必须优先使用消息里的[附近上下文]归入最相关已有设备问题;只有完全孤立且无上下文线索才新建「未归类设备问题(需人工复核)」。\n"
"设备咨询、报修、售后处理消息不得归入「日常交流/待确认」。只有寒暄、通知、撤回等无售后含义消息才可归入日常交流。\n\n"
f"已有话题:\n{json.dumps(topic_payload, ensure_ascii=False)}\n\n"
f"未归属消息:\n{json.dumps(payload, ensure_ascii=False)}\n\n"
"输出严格 JSON 数组,不要解释。格式:\n"
'[{"seq":1,"topic_index":0,"new_topic":"","reason":"依据"}]\n'
'若需要新建话题topic_index 用 null并填写 new_topic。'
)
async def _supplement_assignments(
topics: list[dict],
messages: list[dict],
seq_to_msg: dict[int, dict],
allowed_seqs: set[int],
update_progress,
processed: int,
total: int,
) -> list[dict]:
assigned = {seq for t in topics for seq in t.get("seqs", [])}
missing = [seq_to_msg[s] for s in allowed_seqs if s not in assigned]
if not missing:
await update_progress("running", processed + 1, total)
return topics
try:
_client, _ai = await _get_client()
resp = await _client.chat.completions.create(
model=_ai["ai_model"],
messages=[
{
"role": "system",
"content": "你是广东大铁数控机械有限公司设备售后消息补充分配器。你只输出 JSON 数组不输出解释、Markdown 或额外文字。",
},
{"role": "user", "content": _supplement_prompt(missing, topics)},
],
temperature=0.05,
)
results = _extract_json_array(resp.choices[0].message.content.strip())
except Exception as e:
log.warning(f"[classify] supplement assignment failed: {e}", exc_info=True)
results = []
topic_by_title = {t["title"]: t for t in topics}
for item in results:
if not isinstance(item, dict):
continue
try:
seq = int(item.get("seq"))
except Exception:
continue
if seq not in allowed_seqs or seq in assigned:
continue
title = ""
idx = item.get("topic_index")
try:
if idx is not None and 0 <= int(idx) < len(topics):
title = topics[int(idx)]["title"]
except Exception:
title = ""
if not title:
title = str(item.get("new_topic") or "").strip()
if not title or title == DAILY_TOPIC:
title = UNCATEGORIZED_BUSINESS if _looks_business_message(seq_to_msg[seq]) else DAILY_TOPIC
if title != DAILY_TOPIC and len(title) > 80:
title = title[:80]
if title not in topic_by_title:
topic_by_title[title] = {"title": title, "seqs": [], "reason": str(item.get("reason") or "").strip()}
topics.append(topic_by_title[title])
topic_by_title[title]["seqs"].append(seq)
assigned.add(seq)
for m in missing:
seq = _msg_seq(m)
if seq in assigned:
continue
title = UNCATEGORIZED_BUSINESS if _looks_business_message(m) else DAILY_TOPIC
if title not in topic_by_title:
topic_by_title[title] = {"title": title, "seqs": [], "reason": "补充分配后仍无法归入具体话题"}
topics.append(topic_by_title[title])
topic_by_title[title]["seqs"].append(seq)
assigned.add(seq)
await update_progress("running", processed + 1, total)
return topics
def _finalize_topics(topics: list[dict], seq_to_msg: dict[int, dict], allowed_seqs: set[int]) -> list[dict]:
assigned: set[int] = set()
grouped: dict[str, dict] = {}
for topic in topics:
title = str(topic.get("title") or "").strip()
if not title:
continue
seqs: list[int] = []
for seq in topic.get("seqs", []):
try:
n = int(seq)
except Exception:
continue
if n not in allowed_seqs or n not in seq_to_msg or n in assigned:
continue
if title == DAILY_TOPIC and _looks_business_message(seq_to_msg[n]):
title = UNCATEGORIZED_BUSINESS
seqs.append(n)
assigned.add(n)
if not seqs:
continue
title = title[:80] if title != DAILY_TOPIC else title
if title not in grouped:
grouped[title] = {"title": title, "seqs": [], "reason": str(topic.get("reason") or "")}
grouped[title]["seqs"].extend(seqs)
for seq in allowed_seqs:
if seq in assigned:
continue
title = UNCATEGORIZED_BUSINESS if _looks_business_message(seq_to_msg[seq]) else DAILY_TOPIC
if title not in grouped:
grouped[title] = {"title": title, "seqs": [], "reason": "最终兜底归属"}
grouped[title]["seqs"].append(seq)
result = list(grouped.values())
result.sort(key=lambda t: min(t["seqs"]) if t["seqs"] else 0)
return result
def _topic_text(topic: dict, seq_to_msg: dict[int, dict]) -> str:
parts = [str(topic.get("title") or ""), str(topic.get("reason") or "")]
for seq in topic.get("seqs", []):
try:
m = seq_to_msg[int(seq)]
except Exception:
continue
parts.append(_message_analysis_text(m))
return "\n".join(parts)
def _strip_customer_prefix(title: str) -> str:
title = title.strip()
patterns = (
r"^[\u4e00-\u9fa5A-Za-z0-9()#_-]{2,30}(?:公司|工厂|厂|客户|现场|车间)[-—:/]+",
r"^(?:客户|厂家|现场|公司)[\u4e00-\u9fa5A-Za-z0-9()#_-]{0,20}[-—:/]+",
)
for pattern in patterns:
title = re.sub(pattern, "", title).strip()
return title
def _infer_device_issue_title(topic: dict, seq_to_msg: dict[int, dict]) -> str:
original = str(topic.get("title") or "").strip()
if not original or original == DAILY_TOPIC:
return original
original = _strip_customer_prefix(original)
text = _topic_text(topic, seq_to_msg)
low = text.lower()
broad_titles = (
UNCATEGORIZED_BUSINESS,
"现场故障/售后处理",
"业务事项跟进",
"沟通响应跟进",
)
is_broad = original in broad_titles or any(original.endswith(f"-{title}") for title in broad_titles)
if not is_broad:
return original[:80] if original != DAILY_TOPIC else original
if any(k in text for k in ("主轴", "转速", "拉刀", "松刀")):
if any(k in text for k in ("报警", "警报", "报错", "故障码", "代码")):
return "主轴报警/故障代码问题"
return "主轴异常问题"
if any(k in text for k in ("刀库", "换刀", "刀臂", "刀盘", "刀号", "卡刀", "掉刀")):
return "刀库/换刀异常问题"
if any(k in text for k in ("回零", "原点", "限位", "行程")):
return "回零/限位异常问题"
if any(k in text for k in ("伺服", "驱动", "驱动器", "电机", "编码器", "变频器")):
return "伺服/驱动异常问题"
if any(k in text for k in ("精度", "尺寸", "跑偏", "震刀", "振刀", "圆度", "平面度")):
return "加工精度异常问题"
if any(k in text for k in ("气压", "液压", "油压", "漏油")):
return "气压/液压异常问题"
if any(k in text for k in ("润滑", "冷却", "水泵", "切削液", "漏水")):
return "润滑/冷却异常问题"
if any(k in text for k in ("电柜", "线路", "电源", "跳闸", "断电", "IO", "I/O", "PLC", "急停")):
return "电气线路/IO异常问题"
if any(k in text for k in ("系统", "参数", "程序", "G代码", "M代码", "面板", "手轮")):
return "系统参数/程序操作问题"
if "报警" in text or "警报" in text or "报错" in text or "故障码" in text or re.search(r"\b(error|alarm|err|e\d+)\b", low):
return "设备报警/故障代码问题"
if any(k in text for k in ("异响", "卡住", "排屑")):
return "设备机械异常问题"
return original
def _coalesce_device_issue_topics(topics: list[dict], seq_to_msg: dict[int, dict]) -> list[dict]:
grouped: dict[str, dict] = {}
for topic in topics:
title = _infer_device_issue_title(topic, seq_to_msg)
if not title:
continue
if title not in grouped:
grouped[title] = {"title": title, "seqs": [], "reason": "按设备问题/故障现象自动合并"}
grouped[title]["seqs"].extend(topic.get("seqs", []))
result: list[dict] = []
seen_global: set[int] = set()
for title, topic in grouped.items():
seqs: list[int] = []
for seq in topic["seqs"]:
try:
n = int(seq)
except Exception:
continue
if n in seen_global:
continue
seqs.append(n)
seen_global.add(n)
if seqs:
result.append({"title": title[:80] if title != DAILY_TOPIC else title, "seqs": seqs, "reason": topic["reason"]})
result.sort(key=lambda t: min(t["seqs"]) if t["seqs"] else 0)
return result
async def _save_topics(db: aiosqlite.Connection, group_id: int, talker: str, topics: list[dict], seq_to_msg: dict[int, dict]) -> None:
for topic in topics:
title = topic["title"]
seqs = topic.get("seqs", [])
if not seqs:
continue
await db.execute(
"INSERT INTO topics (group_id, title, source) VALUES (?, ?, 'ai')",
(group_id, title),
)
await db.commit()
async with db.execute("SELECT last_insert_rowid() AS id") as cur:
row = await cur.fetchone()
topic_id = row["id"]
for seq in seqs:
await db.execute(
"""
INSERT OR IGNORE INTO topic_messages
(topic_id, msg_seq, talker, added_by, message_json)
VALUES (?, ?, ?, 'ai', ?)
""",
(topic_id, seq, talker, _message_snapshot(seq_to_msg[seq])),
)
await db.commit()
async def run_classify_window(
group_id: int,
task_id: int,
group: dict,
start_ts: int,
end_ts: int,
):
"""
对指定时间区间内全部消息做 AI 话题分类。
- 串行执行:同一时间只允许一个分类任务。
- 重跑时删除旧 AI 话题,保留人工话题。
"""
global _classifying_group
path = get_active_db_path()
async def _update_task(status: str, processed: int = 0, total: int = 0, error: str = ""):
try:
async with aiosqlite.connect(path) as _db:
_db.row_factory = aiosqlite.Row
if error:
await _db.execute(
"UPDATE ai_tasks SET status=?, progress=?, error=?, updated_at=CURRENT_TIMESTAMP WHERE id=?",
(status, json.dumps({"processed": processed, "total": total}), error, task_id),
)
else:
await _db.execute(
"UPDATE ai_tasks SET status=?, progress=?, updated_at=CURRENT_TIMESTAMP WHERE id=?",
(status, json.dumps({"processed": processed, "total": total}), task_id),
)
await _db.commit()
except Exception as e:
log.warning(f"[classify] 更新 task {task_id} 失败: {e}")
_s = await get_ai_settings()
if not _s.get("ai_api_key") or not _s.get("ai_model"):
await _update_task("error", 0, 0, "AI 未配置,请在「设置」页面填入 API Key 和话题分析模型")
log.warning(f"[classify] group={group_id} aborted: AI not configured")
return
async with _classify_lock:
_classifying_group = group_id
try:
try:
messages = await _fetch_window_messages(group["talker"], start_ts, end_ts)
except Exception as e:
log.error(f"[classify] group={group_id} fetch error: {e}", exc_info=True)
await _update_task("error", 0, 0, f"拉取聊天记录失败:{e}")
return
total_messages = len(messages)
log.info(f"[classify] group={group_id} got {total_messages} msgs in selected window")
if total_messages == 0:
await _update_task("done", 0, 0)
return
_fill_message_sender_names(messages)
media_error = _validate_media_settings(messages, _s)
if media_error:
await _update_task("error", 0, total_messages, media_error)
return
_attach_neighbor_context(messages)
media_count = sum(1 for m in messages if _media_kind(m) and _media_parse_key(m))
chunks_count = len(_chunk_messages(messages))
total_units = total_messages + media_count + chunks_count + 2
await _update_task("running", total_messages, total_units)
async with aiosqlite.connect(path) as db:
db.row_factory = aiosqlite.Row
manual_seqs = await _manual_assigned_seqs(db, group_id)
await _delete_ai_topics(db, group_id)
parsed_units = await _parse_message_media(messages, _update_task, total_messages, total_units)
classifiable = [m for m in messages if _msg_seq(m) and _msg_seq(m) not in manual_seqs]
seq_to_msg = {_msg_seq(m): m for m in classifiable if _msg_seq(m)}
allowed_seqs = set(seq_to_msg)
if not classifiable:
await _update_task("done", total_units, total_units)
return
configured_prompt = (group.get("analysis_prompt") or _s.get("topic_analysis_prompt") or "").strip()
sample_query = "\n".join(_message_analysis_text(m) for m in classifiable[:80])
learning_context = await build_report_learning_context(
db,
group_id=group_id,
query=sample_query,
purpose="topic",
)
guidance_parts = []
if configured_prompt:
guidance_parts.append(f"客户自定义话题分析提示词:\n{configured_prompt}")
if learning_context:
guidance_parts.append(
"报告库学习参考:以下是人工修订过的历史报告,请学习其话题命名、分类颗粒度和关注点;不要复制历史事实。\n"
f"{learning_context}"
)
analysis_guidance = "\n\n".join(guidance_parts)
base = total_messages + parsed_units
candidates_raw = await _classify_batches(
classifiable,
_update_task,
base,
total_units,
analysis_guidance,
)
candidates = _sanitize_topic_items(candidates_raw, seq_to_msg, allowed_seqs)
processed_after_batches = base + chunks_count
merged = await _merge_candidates(candidates, _update_task, processed_after_batches, total_units)
merged = _finalize_topics(merged, seq_to_msg, allowed_seqs)
supplemented = await _supplement_assignments(
merged,
classifiable,
seq_to_msg,
allowed_seqs,
_update_task,
processed_after_batches + 1,
total_units,
)
final_topics = _finalize_topics(supplemented, seq_to_msg, allowed_seqs)
final_topics = _coalesce_device_issue_topics(final_topics, seq_to_msg)
final_topics = _finalize_topics(final_topics, seq_to_msg, allowed_seqs)
await _save_topics(db, group_id, group["talker"], final_topics, seq_to_msg)
await _update_task("done", total_units, total_units)
log.info(f"[classify] group={group_id} done, messages={total_messages}, topics={len(final_topics)}")
except Exception as e:
log.error(f"[classify] group={group_id} error: {e}", exc_info=True)
await _update_task("error", 0, 0, str(e))
finally:
_classifying_group = None