1095 lines
43 KiB
Python
1095 lines
43 KiB
Python
"""
|
||
话题分类引擎
|
||
|
||
用户点击 AI 分析后,按所选时间段全量分页拉取消息,解析可用媒体证据,
|
||
再分批抽取、跨批合并并校验消息归属。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
from datetime import datetime
|
||
import aiosqlite
|
||
from fastapi import HTTPException
|
||
|
||
from database import get_active_db_path
|
||
from services.ai_client import get_openai_client
|
||
from services.chatlog_client import chatlog_client
|
||
from services.message_formatter import append_quote_text, extract_quote
|
||
from services.media_parser import parse_media
|
||
from services.report_learning import build_report_learning_context
|
||
from services.runtime_settings import get_ai_settings
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
_classify_lock = asyncio.Lock()
|
||
_classifying_group: int | None = None
|
||
|
||
FETCH_PAGE_SIZE = 500
|
||
CLASSIFY_BATCH_SIZE = 40
|
||
CONTEXT_BEFORE = 4
|
||
CONTEXT_AFTER = 4
|
||
CLASSIFY_CONTENT_LIMIT = 900
|
||
MEDIA_PARSE_CONCURRENCY = 2
|
||
UNCATEGORIZED_BUSINESS = "未归类设备问题(需人工复核)"
|
||
DAILY_TOPIC = "日常交流/待确认"
|
||
DEVICE_ISSUE_EXAMPLES = (
|
||
"报警/故障代码", "主轴异常", "刀库/换刀异常", "回零/限位异常",
|
||
"伺服/驱动异常", "加工精度异常", "气压/液压异常", "润滑/冷却异常",
|
||
"电气线路/IO异常", "系统参数/程序操作问题",
|
||
)
|
||
|
||
BUSINESS_KEYWORDS = (
|
||
"设备", "机床", "数控", "加工中心", "CNC", "cnc", "大铁", "客户", "公司", "厂家", "现场",
|
||
"报警", "警报", "报错", "故障", "故障码", "代码", "异常", "报修", "维修", "售后", "工程师",
|
||
"主轴", "转速", "拉刀", "松刀", "刀库", "换刀", "刀臂", "刀盘", "刀号", "卡刀", "掉刀",
|
||
"回零", "原点", "限位", "行程", "伺服", "驱动", "驱动器", "电机", "编码器", "变频器",
|
||
"丝杆", "导轨", "精度", "尺寸", "跑偏", "震刀", "振刀", "圆度", "平面度", "加工",
|
||
"气压", "液压", "油压", "润滑", "冷却", "水泵", "切削液", "漏油", "漏水", "异响",
|
||
"电柜", "线路", "电源", "跳闸", "断电", "IO", "I/O", "PLC", "系统", "参数", "程序",
|
||
"G代码", "M代码", "面板", "手轮", "急停", "排屑", "卡住", "调试", "安装", "处理",
|
||
)
|
||
|
||
|
||
async def _get_client():
|
||
return await get_openai_client()
|
||
|
||
|
||
def get_classifying_group() -> int | None:
|
||
"""返回正在分析的 group_id,没有则 None。供 routers 检查串行状态。"""
|
||
return _classifying_group
|
||
|
||
|
||
def _msg_seq(m: dict) -> int:
|
||
return int(m.get("sort_seq") or m.get("seq") or m.get("Seq") or m.get("id") or 0)
|
||
|
||
|
||
def _msg_time(m: dict) -> str:
|
||
return m.get("create_time") or m.get("time") or m.get("CreateTime") or ""
|
||
|
||
|
||
def _msg_sender(m: dict) -> str:
|
||
return (
|
||
m.get("sender_name")
|
||
or m.get("senderName")
|
||
or m.get("SenderName")
|
||
or m.get("sender")
|
||
or m.get("Sender")
|
||
or ""
|
||
)
|
||
|
||
|
||
def _raw_sender_id(m: dict) -> str:
|
||
return str(m.get("sender") or m.get("Sender") or "").strip()
|
||
|
||
|
||
def _looks_like_raw_sender_id(value: str | None) -> bool:
|
||
value = (value or "").strip()
|
||
return (
|
||
not value
|
||
or value.startswith("wxid_")
|
||
or value.startswith("gh_")
|
||
or value.endswith("@chatroom")
|
||
or value.startswith("chatroom_")
|
||
)
|
||
|
||
|
||
def _sender_display_name(m: dict, sender: str = "") -> str:
|
||
for key in (
|
||
"sender_name",
|
||
"senderName",
|
||
"SenderName",
|
||
"accountName",
|
||
"groupNickname",
|
||
"displayName",
|
||
"nickName",
|
||
"remark",
|
||
):
|
||
value = str(m.get(key) or "").strip()
|
||
if value and value != sender and not _looks_like_raw_sender_id(value):
|
||
return value
|
||
return ""
|
||
|
||
|
||
def _fill_message_sender_names(messages: list[dict]) -> None:
|
||
names: dict[str, str] = {}
|
||
for m in messages:
|
||
sender = _raw_sender_id(m)
|
||
name = _sender_display_name(m, sender)
|
||
if sender and name:
|
||
names[sender] = name
|
||
if not names:
|
||
return
|
||
for m in messages:
|
||
sender = _raw_sender_id(m)
|
||
if not sender or sender not in names:
|
||
continue
|
||
current = str(m.get("sender_name") or m.get("senderName") or m.get("SenderName") or "").strip()
|
||
if not current or current == sender or _looks_like_raw_sender_id(current):
|
||
m["sender_name"] = names[sender]
|
||
m["senderName"] = names[sender]
|
||
|
||
|
||
def _msg_type(m: dict) -> int:
|
||
try:
|
||
return int(m.get("type") or m.get("Type") or 1)
|
||
except Exception:
|
||
return 1
|
||
|
||
|
||
def _msg_sub_type(m: dict) -> int:
|
||
try:
|
||
return int(m.get("sub_type") or m.get("subType") or m.get("SubType") or 0)
|
||
except Exception:
|
||
return 0
|
||
|
||
|
||
def _contents(m: dict) -> dict:
|
||
contents = m.get("contents") or m.get("Contents") or {}
|
||
return contents if isinstance(contents, dict) else {}
|
||
|
||
|
||
def _media_key(m: dict) -> str:
|
||
contents = _contents(m)
|
||
key = (
|
||
contents.get("rawmd5")
|
||
or contents.get("md5")
|
||
or contents.get("path")
|
||
or m.get("media_key")
|
||
or m.get("mediaKey")
|
||
or m.get("image_path")
|
||
or ""
|
||
)
|
||
return str(key).replace("\\", "/")
|
||
|
||
|
||
def _voice_key(m: dict) -> str:
|
||
contents = _contents(m)
|
||
if contents.get("voice"):
|
||
return str(contents.get("voice"))
|
||
return str(m.get("voice_key") or m.get("voiceKey") or "")
|
||
|
||
|
||
def _message_snapshot(m: dict) -> str:
|
||
"""保存分类时拿到的原始消息快照,供后续 chatlog batch 查不到时兜底显示。"""
|
||
return json.dumps(m, ensure_ascii=False, separators=(",", ":"))
|
||
|
||
|
||
def _base_msg_content(m: dict) -> str:
|
||
content = m.get("content") or m.get("Content") or ""
|
||
contents = _contents(m)
|
||
if isinstance(content, str) and content.lstrip().startswith("<") and extract_quote(m):
|
||
content = ""
|
||
|
||
link_title = contents.get("title") or m.get("link_title") or ""
|
||
link_desc = contents.get("desc") or m.get("link_desc") or ""
|
||
link_source = contents.get("sourceName") or contents.get("source_name") or m.get("link_source") or ""
|
||
link_url = contents.get("url") or m.get("link_url") or ""
|
||
file_name = contents.get("fileName") or contents.get("filename") or contents.get("title") or ""
|
||
|
||
parts: list[str] = []
|
||
if link_title:
|
||
parts.append(f"[链接/文件] {link_title}")
|
||
if file_name and file_name not in parts:
|
||
parts.append(f"文件名:{file_name}")
|
||
if link_desc:
|
||
parts.append(f"描述:{link_desc}")
|
||
if link_source:
|
||
parts.append(f"来源:{link_source}")
|
||
if link_url:
|
||
parts.append(f"URL:{link_url}")
|
||
if content and content not in parts:
|
||
parts.append(content)
|
||
return append_quote_text(";".join(parts) if parts else content, m)
|
||
|
||
|
||
def _is_image_message(m: dict) -> bool:
|
||
return _msg_type(m) == 3
|
||
|
||
|
||
def _is_voice_message(m: dict) -> bool:
|
||
return _msg_type(m) == 34
|
||
|
||
|
||
def _is_video_message(m: dict) -> bool:
|
||
return _msg_type(m) == 43
|
||
|
||
|
||
def _is_file_or_link_message(m: dict) -> bool:
|
||
return _msg_type(m) == 49
|
||
|
||
|
||
def _is_file_message(m: dict) -> bool:
|
||
return _msg_type(m) == 49 and _msg_sub_type(m) == 6
|
||
|
||
|
||
def _media_kind(m: dict) -> str | None:
|
||
if _is_image_message(m):
|
||
return "image"
|
||
if _is_voice_message(m):
|
||
return "voice"
|
||
if _is_video_message(m):
|
||
return "video"
|
||
return None
|
||
|
||
|
||
def _media_parse_key(m: dict) -> str:
|
||
kind = _media_kind(m)
|
||
if kind == "voice":
|
||
return _voice_key(m)
|
||
if kind in {"image", "video"}:
|
||
return _media_key(m)
|
||
return ""
|
||
|
||
|
||
def _message_analysis_text(m: dict) -> str:
|
||
base = _base_msg_content(m).strip()
|
||
parsed = str(m.get("_ai_media_text") or "").strip()
|
||
parse_error = str(m.get("_ai_media_error") or "").strip()
|
||
parts = []
|
||
if base:
|
||
parts.append(base)
|
||
|
||
kind = _media_kind(m)
|
||
if parsed:
|
||
label = {"image": "图片描述", "voice": "语音转写", "video": "视频截图描述"}.get(kind or "", "媒体解析")
|
||
parts.append(f"[{label}] {parsed}")
|
||
elif parse_error:
|
||
parts.append(f"[媒体解析失败] {parse_error}")
|
||
elif kind:
|
||
parts.append(f"[{kind}消息] key={_media_parse_key(m) or '无'}")
|
||
|
||
if _is_file_or_link_message(m) and not parts:
|
||
parts.append("[文件/链接消息]")
|
||
|
||
neighbor_context = str(m.get("_neighbor_context") or "").strip()
|
||
if neighbor_context and (kind or _is_file_or_link_message(m) or parse_error):
|
||
parts.append(f"[附近上下文] {neighbor_context}")
|
||
|
||
text = ";".join(parts).strip()
|
||
if len(text) > CLASSIFY_CONTENT_LIMIT:
|
||
text = text[:CLASSIFY_CONTENT_LIMIT] + "..."
|
||
return text
|
||
|
||
|
||
def _prompt_item(m: dict, core: bool = True) -> dict:
|
||
return {
|
||
"seq": _msg_seq(m),
|
||
"time": _msg_time(m),
|
||
"sender": _msg_sender(m),
|
||
"core": core,
|
||
"type": _msg_type(m),
|
||
"content": _message_analysis_text(m),
|
||
}
|
||
|
||
|
||
def _extract_json_array(text: str) -> list:
|
||
array_start = text.find("[")
|
||
object_start = text.find("{")
|
||
if array_start >= 0 and (object_start < 0 or array_start < object_start):
|
||
end = text.rfind("]") + 1
|
||
if end <= 0:
|
||
raise ValueError(f"LLM 返回内容无法解析为 JSON 数组: {text[:200]}")
|
||
data = json.loads(text[array_start:end])
|
||
elif object_start >= 0:
|
||
end = text.rfind("}") + 1
|
||
if end <= 0:
|
||
raise ValueError(f"LLM 返回内容无法解析为 JSON 对象: {text[:200]}")
|
||
data = json.loads(text[object_start:end])
|
||
else:
|
||
raise ValueError(f"LLM 返回内容无法解析为 JSON: {text[:200]}")
|
||
if isinstance(data, dict):
|
||
data = data.get("topics") or data.get("items") or []
|
||
if not isinstance(data, list):
|
||
raise ValueError("LLM JSON 不是数组")
|
||
return data
|
||
|
||
|
||
def _parse_msg_time(raw: str) -> float | None:
|
||
"""把 chatlog 的 ISO 时间字符串解析为 unix 秒。失败返回 None。"""
|
||
if not raw:
|
||
return None
|
||
try:
|
||
s = raw.replace("/", "-")
|
||
if "T" not in s and " " in s:
|
||
s = s.replace(" ", "T", 1)
|
||
return datetime.fromisoformat(s).timestamp()
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def _looks_business_message(m: dict) -> bool:
|
||
text = _message_analysis_text(m)
|
||
if _is_file_or_link_message(m) or _is_voice_message(m):
|
||
return True
|
||
if (_is_image_message(m) or _is_video_message(m)) and str(m.get("_neighbor_context") or "").strip():
|
||
return True
|
||
if (_is_image_message(m) or _is_video_message(m)) and str(m.get("_ai_media_text") or "").strip():
|
||
return True
|
||
low = text.lower()
|
||
if any(k.lower() in low for k in BUSINESS_KEYWORDS):
|
||
return True
|
||
if re.search(r"\b(a20\d+|ap\d+|sp\d+|\d+p|[a-z]{2,}\d{2,})\b", low):
|
||
return True
|
||
if re.search(r"\.(dwg|xlsx?|pdf|docx?|zip|rar)\b", low):
|
||
return True
|
||
return False
|
||
|
||
|
||
def _has_text_signal(m: dict) -> bool:
|
||
text = _base_msg_content(m).strip()
|
||
if text:
|
||
return True
|
||
contents = _contents(m)
|
||
return bool(contents.get("desc") or contents.get("title") or contents.get("refer") or contents.get("recordInfo"))
|
||
|
||
|
||
def _attach_neighbor_context(messages: list[dict], radius: int = 4) -> None:
|
||
"""给图片/视频/语音/文件消息补充临近文本,供无法解析媒体时按上下文归类。"""
|
||
for idx, m in enumerate(messages):
|
||
if not (_media_kind(m) or _is_file_or_link_message(m)):
|
||
continue
|
||
lines: list[str] = []
|
||
sender = _msg_sender(m)
|
||
for j in range(max(0, idx - radius), min(len(messages), idx + radius + 1)):
|
||
if j == idx:
|
||
continue
|
||
n = messages[j]
|
||
if not _has_text_signal(n):
|
||
continue
|
||
# 同一发送者或时间邻近的文本最能解释媒体;不同发送者也保留少量上下文。
|
||
prefix = "前文" if j < idx else "后文"
|
||
relation = "同发送人" if sender and _msg_sender(n) == sender else "邻近"
|
||
text = _base_msg_content(n).strip()
|
||
if not text:
|
||
text = _message_analysis_text(n).strip()
|
||
if not text:
|
||
continue
|
||
if len(text) > 180:
|
||
text = text[:180] + "..."
|
||
lines.append(f"{prefix}/{relation} seq={_msg_seq(n)} {text}")
|
||
if lines:
|
||
m["_neighbor_context"] = " | ".join(lines[:6])
|
||
|
||
|
||
async def _fetch_window_messages(talker: str, start_ts: int, end_ts: int) -> list[dict]:
|
||
start_dt = datetime.fromtimestamp(start_ts)
|
||
end_dt = datetime.fromtimestamp(end_ts)
|
||
time_str = f"{start_dt.strftime('%Y-%m-%d')},{end_dt.strftime('%Y-%m-%d')}"
|
||
|
||
all_items: list[dict] = []
|
||
offset = 0
|
||
seen: set[int] = set()
|
||
while True:
|
||
data = await chatlog_client.get_messages(
|
||
talker,
|
||
time=time_str,
|
||
limit=FETCH_PAGE_SIZE,
|
||
offset=offset,
|
||
)
|
||
items = data.get("items", []) or []
|
||
total = int(data.get("total") or 0)
|
||
if not items:
|
||
break
|
||
for m in items:
|
||
seq = _msg_seq(m)
|
||
if seq and seq in seen:
|
||
continue
|
||
t = _msg_time(m)
|
||
ts = _parse_msg_time(t)
|
||
if ts is None or (start_ts <= ts <= end_ts):
|
||
all_items.append(m)
|
||
if seq:
|
||
seen.add(seq)
|
||
offset += len(items)
|
||
if total and offset >= total:
|
||
break
|
||
if len(items) < FETCH_PAGE_SIZE:
|
||
break
|
||
|
||
all_items.sort(key=lambda m: (_parse_msg_time(_msg_time(m)) or 0, _msg_seq(m)))
|
||
return all_items
|
||
|
||
|
||
def _validate_media_settings(messages: list[dict], ai_settings: dict) -> str | None:
|
||
has_visual = any(_media_kind(m) in {"image", "video"} for m in messages)
|
||
has_voice = any(_media_kind(m) == "voice" for m in messages)
|
||
if has_visual and not ai_settings.get("vision_model"):
|
||
return "所选时间段包含图片/视频消息,请先在「设置」页面填入视觉模型"
|
||
if has_voice and not ai_settings.get("voice_model"):
|
||
return "所选时间段包含语音消息,请先在「设置」页面填入语音模型"
|
||
return None
|
||
|
||
|
||
async def _parse_message_media(messages: list[dict], update_progress, base_processed: int, total: int) -> int:
|
||
semaphore = asyncio.Semaphore(MEDIA_PARSE_CONCURRENCY)
|
||
media_messages = [m for m in messages if _media_kind(m) and _media_parse_key(m)]
|
||
parsed_count = 0
|
||
cache: dict[tuple[str, str], str] = {}
|
||
|
||
async def parse_one(m: dict):
|
||
nonlocal parsed_count
|
||
kind = _media_kind(m)
|
||
key = _media_parse_key(m)
|
||
if not kind or not key:
|
||
return
|
||
async with semaphore:
|
||
try:
|
||
cache_key = (kind, key)
|
||
if cache_key not in cache:
|
||
parsed = await parse_media(kind, key)
|
||
cache[cache_key] = parsed.get("text") or ""
|
||
m["_ai_media_text"] = cache[cache_key]
|
||
except HTTPException as e:
|
||
m["_ai_media_error"] = str(e.detail)
|
||
except Exception as e:
|
||
log.warning(f"[classify] media parse failed seq={_msg_seq(m)}: {e}", exc_info=True)
|
||
m["_ai_media_error"] = str(e)
|
||
finally:
|
||
parsed_count += 1
|
||
await update_progress("running", base_processed + parsed_count, total)
|
||
|
||
await asyncio.gather(*(parse_one(m) for m in media_messages))
|
||
return len(media_messages)
|
||
|
||
|
||
async def _manual_assigned_seqs(db: aiosqlite.Connection, group_id: int) -> set[int]:
|
||
async with db.execute(
|
||
"""
|
||
SELECT DISTINCT tm.msg_seq
|
||
FROM topic_messages tm
|
||
JOIN topics t ON t.id = tm.topic_id
|
||
WHERE t.group_id = ? AND COALESCE(t.source, 'manual') = 'manual'
|
||
""",
|
||
(group_id,),
|
||
) as cur:
|
||
return {int(row["msg_seq"]) for row in await cur.fetchall()}
|
||
|
||
|
||
async def _delete_ai_topics(db: aiosqlite.Connection, group_id: int) -> None:
|
||
await db.execute(
|
||
"""
|
||
DELETE FROM knowledge_fts WHERE doc_id IN (
|
||
SELECT d.id FROM knowledge_docs d
|
||
JOIN topics t ON t.id = d.topic_id
|
||
WHERE t.group_id=? AND COALESCE(t.source, 'manual')='ai'
|
||
)
|
||
""",
|
||
(group_id,),
|
||
)
|
||
await db.execute(
|
||
"""
|
||
DELETE FROM knowledge_docs
|
||
WHERE topic_id IN (
|
||
SELECT id FROM topics WHERE group_id=? AND COALESCE(source, 'manual')='ai'
|
||
)
|
||
""",
|
||
(group_id,),
|
||
)
|
||
await db.execute(
|
||
"""
|
||
DELETE FROM topic_messages
|
||
WHERE topic_id IN (
|
||
SELECT id FROM topics WHERE group_id=? AND COALESCE(source, 'manual')='ai'
|
||
)
|
||
""",
|
||
(group_id,),
|
||
)
|
||
await db.execute("DELETE FROM topics WHERE group_id=? AND COALESCE(source, 'manual')='ai'", (group_id,))
|
||
await db.commit()
|
||
|
||
|
||
def _chunk_messages(messages: list[dict]) -> list[tuple[list[dict], list[dict]]]:
|
||
chunks = []
|
||
for start in range(0, len(messages), CLASSIFY_BATCH_SIZE):
|
||
end = min(len(messages), start + CLASSIFY_BATCH_SIZE)
|
||
ctx_start = max(0, start - CONTEXT_BEFORE)
|
||
ctx_end = min(len(messages), end + CONTEXT_AFTER)
|
||
chunks.append((messages[start:end], messages[ctx_start:ctx_end]))
|
||
return chunks
|
||
|
||
|
||
def _batch_classify_prompt(core_messages: list[dict], context_messages: list[dict], guidance: str = "") -> str:
|
||
core_seqs = {_msg_seq(m) for m in core_messages}
|
||
payload = [_prompt_item(m, _msg_seq(m) in core_seqs) for m in context_messages]
|
||
guidance_block = f"\n\n{guidance.strip()}\n\n" if guidance and guidance.strip() else ""
|
||
prompt = (
|
||
"你是广东大铁数控机械有限公司设备售后群话题分析助手。这个微信群里主要是大铁设备客户反馈问题,售后工程师对接处理。\n"
|
||
"请只为 core=true 的消息按【设备问题/故障现象】聚类,而不是按客户公司、地域、日期、工程师或单次对话聚类。\n"
|
||
"同类设备问题即使来自不同公司、不同地区,也要合并到同一话题;不同设备问题必须准确分开。\n"
|
||
f"常见问题口径包括:{', '.join(DEVICE_ISSUE_EXAMPLES)}。\n\n"
|
||
f"消息列表:\n{json.dumps(payload, ensure_ascii=False)}\n\n"
|
||
"规则:\n"
|
||
"1. 只输出 core=true 消息的归属,context 只用来理解上下文。\n"
|
||
"2. 标题必须体现设备部件和问题现象,建议如「主轴报警/故障代码问题」「刀库换刀卡顿」「回零/限位异常」。不要把客户公司名作为标题核心。\n"
|
||
"3. 不同公司出现相同设备问题时合并;同一公司出现不同问题时拆开。\n"
|
||
"4. [引用消息] 是当前回复的强上下文,必须用于理解当前消息含义,但归类对象仍然只能是当前 core seq。\n"
|
||
"5. 工程师回复、客户补充说明、图片、语音、视频、文件,必须跟随其对应的设备问题归类。\n"
|
||
"6. 图片/视频无法识别时,不要编造内容,但必须根据附近上下文判断所属设备问题。\n"
|
||
"7. 真正寒暄、入群通知、撤回、无设备售后含义短句才可归入「日常交流/待确认」;设备咨询和报修不得归入日常交流。\n"
|
||
"8. 每条 core 消息最多归入一个话题,尽量覆盖所有 core seq。\n\n"
|
||
"输出严格 JSON 数组,不要解释。格式:\n"
|
||
'[{"topic_key":"稳定短键","title":"具体话题标题","seqs":[1,2],"reason":"分类依据"}]'
|
||
)
|
||
return prompt + guidance_block
|
||
|
||
|
||
async def _classify_batches(
|
||
messages: list[dict],
|
||
update_progress,
|
||
base_processed: int,
|
||
total: int,
|
||
guidance: str = "",
|
||
) -> list[dict]:
|
||
chunks = _chunk_messages(messages)
|
||
results: list[dict] = []
|
||
_client, _ai = await _get_client()
|
||
|
||
for i, (core, context) in enumerate(chunks, start=1):
|
||
prompt = _batch_classify_prompt(core, context, guidance)
|
||
try:
|
||
resp = await _client.chat.completions.create(
|
||
model=_ai["ai_model"],
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": "你是广东大铁数控机械有限公司设备售后问题分类器。你只输出 JSON 数组,不输出解释、Markdown 或额外文字。",
|
||
},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature=0.05,
|
||
)
|
||
batch_items = _extract_json_array(resp.choices[0].message.content.strip())
|
||
for item in batch_items:
|
||
if isinstance(item, dict):
|
||
item["_batch"] = i
|
||
results.append(item)
|
||
except Exception as e:
|
||
log.warning(f"[classify] batch {i} classify failed: {e}", exc_info=True)
|
||
await update_progress("running", base_processed + i, total)
|
||
return results
|
||
|
||
|
||
def _sanitize_topic_items(items: list[dict], seq_to_msg: dict[int, dict], allowed_seqs: set[int]) -> list[dict]:
|
||
assigned: set[int] = set()
|
||
clean: list[dict] = []
|
||
for item in items:
|
||
title = str(item.get("title") or item.get("new_topic") or item.get("topic") or "").strip()
|
||
seqs = item.get("seqs") or item.get("message_seqs") or item.get("messages") or []
|
||
topic_key = str(item.get("topic_key") or title).strip()
|
||
reason = str(item.get("reason") or item.get("evidence") or "").strip()
|
||
clean_seqs: list[int] = []
|
||
for seq in seqs:
|
||
try:
|
||
n = int(seq)
|
||
except Exception:
|
||
continue
|
||
if n in allowed_seqs and n in seq_to_msg and n not in assigned:
|
||
clean_seqs.append(n)
|
||
if not title or not clean_seqs:
|
||
continue
|
||
if title == DAILY_TOPIC:
|
||
business = [n for n in clean_seqs if _looks_business_message(seq_to_msg[n])]
|
||
non_business = [n for n in clean_seqs if n not in business]
|
||
if business:
|
||
clean.append({
|
||
"topic_key": UNCATEGORIZED_BUSINESS,
|
||
"title": UNCATEGORIZED_BUSINESS,
|
||
"seqs": business,
|
||
"reason": "模型归入日常交流但消息含业务信号,转入人工复核。",
|
||
})
|
||
assigned.update(business)
|
||
clean_seqs = non_business
|
||
if not clean_seqs:
|
||
continue
|
||
if title != DAILY_TOPIC and len(title) > 80:
|
||
title = title[:80]
|
||
clean.append({"topic_key": topic_key, "title": title, "seqs": clean_seqs, "reason": reason})
|
||
assigned.update(clean_seqs)
|
||
return clean
|
||
|
||
|
||
def _merge_prompt(candidates: list[dict]) -> str:
|
||
compact = [
|
||
{
|
||
"id": i,
|
||
"topic_key": c.get("topic_key", ""),
|
||
"title": c.get("title", ""),
|
||
"seqs": c.get("seqs", []),
|
||
"reason": c.get("reason", ""),
|
||
}
|
||
for i, c in enumerate(candidates)
|
||
]
|
||
return (
|
||
"你是广东大铁数控机械有限公司设备售后话题合并审核员。请把分批识别出的候选话题做跨批次合并。\n"
|
||
"合并粒度是【设备问题/故障现象】。不要按客户公司、地域、日期、工程师或单次对话合并。\n"
|
||
"不同公司出现相同设备问题时必须合并,例如两个客户都反馈主轴报警,应归为同一类主轴报警问题。\n"
|
||
"不同设备部件、不同故障现象、不同处理路径必须分开,例如主轴报警、刀库卡刀、回零限位异常、加工精度异常不能互相合并。\n\n"
|
||
f"候选话题:\n{json.dumps(compact, ensure_ascii=False)}\n\n"
|
||
"输出严格 JSON 数组,不要解释。格式:\n"
|
||
'[{"title":"合并后的具体话题标题","candidate_ids":[0,1],"reason":"合并依据"}]'
|
||
)
|
||
|
||
|
||
async def _merge_candidates(candidates: list[dict], update_progress, processed: int, total: int) -> list[dict]:
|
||
if not candidates:
|
||
return []
|
||
if len(candidates) == 1:
|
||
await update_progress("running", processed + 1, total)
|
||
return candidates
|
||
|
||
_client, _ai = await _get_client()
|
||
resp = await _client.chat.completions.create(
|
||
model=_ai["ai_model"],
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": "你是广东大铁数控机械有限公司设备售后话题合并器。你只输出 JSON 数组,不输出解释、Markdown 或额外文字。",
|
||
},
|
||
{"role": "user", "content": _merge_prompt(candidates)},
|
||
],
|
||
temperature=0.05,
|
||
)
|
||
merged_raw = _extract_json_array(resp.choices[0].message.content.strip())
|
||
used: set[int] = set()
|
||
merged: list[dict] = []
|
||
for item in merged_raw:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
ids = item.get("candidate_ids") or item.get("ids") or []
|
||
valid_ids = []
|
||
for raw_id in ids:
|
||
try:
|
||
idx = int(raw_id)
|
||
except Exception:
|
||
continue
|
||
if 0 <= idx < len(candidates) and idx not in used:
|
||
valid_ids.append(idx)
|
||
if not valid_ids:
|
||
continue
|
||
title = str(item.get("title") or candidates[valid_ids[0]].get("title") or "").strip()
|
||
if not title:
|
||
continue
|
||
seqs: list[int] = []
|
||
for idx in valid_ids:
|
||
used.add(idx)
|
||
seqs.extend(candidates[idx].get("seqs", []))
|
||
merged.append({
|
||
"title": title[:80] if title != DAILY_TOPIC else title,
|
||
"seqs": seqs,
|
||
"reason": str(item.get("reason") or "").strip(),
|
||
})
|
||
|
||
for idx, candidate in enumerate(candidates):
|
||
if idx not in used:
|
||
merged.append(candidate)
|
||
|
||
await update_progress("running", processed + 1, total)
|
||
return merged
|
||
|
||
|
||
def _supplement_prompt(unassigned: list[dict], topics: list[dict]) -> str:
|
||
payload = [_prompt_item(m, True) for m in unassigned]
|
||
topic_payload = [{"index": i, "title": t["title"], "seqs": t.get("seqs", [])[:8]} for i, t in enumerate(topics)]
|
||
return (
|
||
"你是广东大铁数控机械有限公司设备售后消息补充分配审核员。请把未归属消息分配到已有话题,或新建设备问题话题。\n"
|
||
"分配粒度是设备问题/故障现象,不是客户公司、地域、日期、工程师或单次对话。\n"
|
||
"不同公司出现相同设备问题时分配到同一话题;不同设备部件、故障现象或处理路径要新建不同话题。\n"
|
||
"[引用消息] 是当前回复的强上下文,必须用于理解当前消息含义,但归类对象仍然只能是当前消息 seq。\n"
|
||
"图片/视频内容无法识别时,必须优先使用消息里的[附近上下文]归入最相关已有设备问题;只有完全孤立且无上下文线索才新建「未归类设备问题(需人工复核)」。\n"
|
||
"设备咨询、报修、售后处理消息不得归入「日常交流/待确认」。只有寒暄、通知、撤回等无售后含义消息才可归入日常交流。\n\n"
|
||
f"已有话题:\n{json.dumps(topic_payload, ensure_ascii=False)}\n\n"
|
||
f"未归属消息:\n{json.dumps(payload, ensure_ascii=False)}\n\n"
|
||
"输出严格 JSON 数组,不要解释。格式:\n"
|
||
'[{"seq":1,"topic_index":0,"new_topic":"","reason":"依据"}]\n'
|
||
'若需要新建话题,topic_index 用 null,并填写 new_topic。'
|
||
)
|
||
|
||
|
||
async def _supplement_assignments(
|
||
topics: list[dict],
|
||
messages: list[dict],
|
||
seq_to_msg: dict[int, dict],
|
||
allowed_seqs: set[int],
|
||
update_progress,
|
||
processed: int,
|
||
total: int,
|
||
) -> list[dict]:
|
||
assigned = {seq for t in topics for seq in t.get("seqs", [])}
|
||
missing = [seq_to_msg[s] for s in allowed_seqs if s not in assigned]
|
||
if not missing:
|
||
await update_progress("running", processed + 1, total)
|
||
return topics
|
||
|
||
try:
|
||
_client, _ai = await _get_client()
|
||
resp = await _client.chat.completions.create(
|
||
model=_ai["ai_model"],
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": "你是广东大铁数控机械有限公司设备售后消息补充分配器。你只输出 JSON 数组,不输出解释、Markdown 或额外文字。",
|
||
},
|
||
{"role": "user", "content": _supplement_prompt(missing, topics)},
|
||
],
|
||
temperature=0.05,
|
||
)
|
||
results = _extract_json_array(resp.choices[0].message.content.strip())
|
||
except Exception as e:
|
||
log.warning(f"[classify] supplement assignment failed: {e}", exc_info=True)
|
||
results = []
|
||
|
||
topic_by_title = {t["title"]: t for t in topics}
|
||
for item in results:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
try:
|
||
seq = int(item.get("seq"))
|
||
except Exception:
|
||
continue
|
||
if seq not in allowed_seqs or seq in assigned:
|
||
continue
|
||
title = ""
|
||
idx = item.get("topic_index")
|
||
try:
|
||
if idx is not None and 0 <= int(idx) < len(topics):
|
||
title = topics[int(idx)]["title"]
|
||
except Exception:
|
||
title = ""
|
||
if not title:
|
||
title = str(item.get("new_topic") or "").strip()
|
||
if not title or title == DAILY_TOPIC:
|
||
title = UNCATEGORIZED_BUSINESS if _looks_business_message(seq_to_msg[seq]) else DAILY_TOPIC
|
||
if title != DAILY_TOPIC and len(title) > 80:
|
||
title = title[:80]
|
||
if title not in topic_by_title:
|
||
topic_by_title[title] = {"title": title, "seqs": [], "reason": str(item.get("reason") or "").strip()}
|
||
topics.append(topic_by_title[title])
|
||
topic_by_title[title]["seqs"].append(seq)
|
||
assigned.add(seq)
|
||
|
||
for m in missing:
|
||
seq = _msg_seq(m)
|
||
if seq in assigned:
|
||
continue
|
||
title = UNCATEGORIZED_BUSINESS if _looks_business_message(m) else DAILY_TOPIC
|
||
if title not in topic_by_title:
|
||
topic_by_title[title] = {"title": title, "seqs": [], "reason": "补充分配后仍无法归入具体话题"}
|
||
topics.append(topic_by_title[title])
|
||
topic_by_title[title]["seqs"].append(seq)
|
||
assigned.add(seq)
|
||
|
||
await update_progress("running", processed + 1, total)
|
||
return topics
|
||
|
||
|
||
def _finalize_topics(topics: list[dict], seq_to_msg: dict[int, dict], allowed_seqs: set[int]) -> list[dict]:
|
||
assigned: set[int] = set()
|
||
grouped: dict[str, dict] = {}
|
||
for topic in topics:
|
||
title = str(topic.get("title") or "").strip()
|
||
if not title:
|
||
continue
|
||
seqs: list[int] = []
|
||
for seq in topic.get("seqs", []):
|
||
try:
|
||
n = int(seq)
|
||
except Exception:
|
||
continue
|
||
if n not in allowed_seqs or n not in seq_to_msg or n in assigned:
|
||
continue
|
||
if title == DAILY_TOPIC and _looks_business_message(seq_to_msg[n]):
|
||
title = UNCATEGORIZED_BUSINESS
|
||
seqs.append(n)
|
||
assigned.add(n)
|
||
if not seqs:
|
||
continue
|
||
title = title[:80] if title != DAILY_TOPIC else title
|
||
if title not in grouped:
|
||
grouped[title] = {"title": title, "seqs": [], "reason": str(topic.get("reason") or "")}
|
||
grouped[title]["seqs"].extend(seqs)
|
||
|
||
for seq in allowed_seqs:
|
||
if seq in assigned:
|
||
continue
|
||
title = UNCATEGORIZED_BUSINESS if _looks_business_message(seq_to_msg[seq]) else DAILY_TOPIC
|
||
if title not in grouped:
|
||
grouped[title] = {"title": title, "seqs": [], "reason": "最终兜底归属"}
|
||
grouped[title]["seqs"].append(seq)
|
||
|
||
result = list(grouped.values())
|
||
result.sort(key=lambda t: min(t["seqs"]) if t["seqs"] else 0)
|
||
return result
|
||
|
||
|
||
def _topic_text(topic: dict, seq_to_msg: dict[int, dict]) -> str:
|
||
parts = [str(topic.get("title") or ""), str(topic.get("reason") or "")]
|
||
for seq in topic.get("seqs", []):
|
||
try:
|
||
m = seq_to_msg[int(seq)]
|
||
except Exception:
|
||
continue
|
||
parts.append(_message_analysis_text(m))
|
||
return "\n".join(parts)
|
||
|
||
|
||
def _strip_customer_prefix(title: str) -> str:
|
||
title = title.strip()
|
||
patterns = (
|
||
r"^[\u4e00-\u9fa5A-Za-z0-9()()##_-]{2,30}(?:公司|工厂|厂|客户|现场|车间)[-—::/]+",
|
||
r"^(?:客户|厂家|现场|公司)[\u4e00-\u9fa5A-Za-z0-9()()##_-]{0,20}[-—::/]+",
|
||
)
|
||
for pattern in patterns:
|
||
title = re.sub(pattern, "", title).strip()
|
||
return title
|
||
|
||
|
||
def _infer_device_issue_title(topic: dict, seq_to_msg: dict[int, dict]) -> str:
|
||
original = str(topic.get("title") or "").strip()
|
||
if not original or original == DAILY_TOPIC:
|
||
return original
|
||
|
||
original = _strip_customer_prefix(original)
|
||
text = _topic_text(topic, seq_to_msg)
|
||
low = text.lower()
|
||
|
||
broad_titles = (
|
||
UNCATEGORIZED_BUSINESS,
|
||
"现场故障/售后处理",
|
||
"业务事项跟进",
|
||
"沟通响应跟进",
|
||
)
|
||
is_broad = original in broad_titles or any(original.endswith(f"-{title}") for title in broad_titles)
|
||
if not is_broad:
|
||
return original[:80] if original != DAILY_TOPIC else original
|
||
|
||
if any(k in text for k in ("主轴", "转速", "拉刀", "松刀")):
|
||
if any(k in text for k in ("报警", "警报", "报错", "故障码", "代码")):
|
||
return "主轴报警/故障代码问题"
|
||
return "主轴异常问题"
|
||
|
||
if any(k in text for k in ("刀库", "换刀", "刀臂", "刀盘", "刀号", "卡刀", "掉刀")):
|
||
return "刀库/换刀异常问题"
|
||
|
||
if any(k in text for k in ("回零", "原点", "限位", "行程")):
|
||
return "回零/限位异常问题"
|
||
|
||
if any(k in text for k in ("伺服", "驱动", "驱动器", "电机", "编码器", "变频器")):
|
||
return "伺服/驱动异常问题"
|
||
|
||
if any(k in text for k in ("精度", "尺寸", "跑偏", "震刀", "振刀", "圆度", "平面度")):
|
||
return "加工精度异常问题"
|
||
|
||
if any(k in text for k in ("气压", "液压", "油压", "漏油")):
|
||
return "气压/液压异常问题"
|
||
|
||
if any(k in text for k in ("润滑", "冷却", "水泵", "切削液", "漏水")):
|
||
return "润滑/冷却异常问题"
|
||
|
||
if any(k in text for k in ("电柜", "线路", "电源", "跳闸", "断电", "IO", "I/O", "PLC", "急停")):
|
||
return "电气线路/IO异常问题"
|
||
|
||
if any(k in text for k in ("系统", "参数", "程序", "G代码", "M代码", "面板", "手轮")):
|
||
return "系统参数/程序操作问题"
|
||
|
||
if "报警" in text or "警报" in text or "报错" in text or "故障码" in text or re.search(r"\b(error|alarm|err|e\d+)\b", low):
|
||
return "设备报警/故障代码问题"
|
||
|
||
if any(k in text for k in ("异响", "卡住", "排屑")):
|
||
return "设备机械异常问题"
|
||
|
||
return original
|
||
|
||
|
||
def _coalesce_device_issue_topics(topics: list[dict], seq_to_msg: dict[int, dict]) -> list[dict]:
|
||
grouped: dict[str, dict] = {}
|
||
for topic in topics:
|
||
title = _infer_device_issue_title(topic, seq_to_msg)
|
||
if not title:
|
||
continue
|
||
if title not in grouped:
|
||
grouped[title] = {"title": title, "seqs": [], "reason": "按设备问题/故障现象自动合并"}
|
||
grouped[title]["seqs"].extend(topic.get("seqs", []))
|
||
|
||
result: list[dict] = []
|
||
seen_global: set[int] = set()
|
||
for title, topic in grouped.items():
|
||
seqs: list[int] = []
|
||
for seq in topic["seqs"]:
|
||
try:
|
||
n = int(seq)
|
||
except Exception:
|
||
continue
|
||
if n in seen_global:
|
||
continue
|
||
seqs.append(n)
|
||
seen_global.add(n)
|
||
if seqs:
|
||
result.append({"title": title[:80] if title != DAILY_TOPIC else title, "seqs": seqs, "reason": topic["reason"]})
|
||
result.sort(key=lambda t: min(t["seqs"]) if t["seqs"] else 0)
|
||
return result
|
||
|
||
|
||
async def _save_topics(db: aiosqlite.Connection, group_id: int, talker: str, topics: list[dict], seq_to_msg: dict[int, dict]) -> None:
|
||
for topic in topics:
|
||
title = topic["title"]
|
||
seqs = topic.get("seqs", [])
|
||
if not seqs:
|
||
continue
|
||
await db.execute(
|
||
"INSERT INTO topics (group_id, title, source) VALUES (?, ?, 'ai')",
|
||
(group_id, title),
|
||
)
|
||
await db.commit()
|
||
async with db.execute("SELECT last_insert_rowid() AS id") as cur:
|
||
row = await cur.fetchone()
|
||
topic_id = row["id"]
|
||
for seq in seqs:
|
||
await db.execute(
|
||
"""
|
||
INSERT OR IGNORE INTO topic_messages
|
||
(topic_id, msg_seq, talker, added_by, message_json)
|
||
VALUES (?, ?, ?, 'ai', ?)
|
||
""",
|
||
(topic_id, seq, talker, _message_snapshot(seq_to_msg[seq])),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def run_classify_window(
|
||
group_id: int,
|
||
task_id: int,
|
||
group: dict,
|
||
start_ts: int,
|
||
end_ts: int,
|
||
):
|
||
"""
|
||
对指定时间区间内全部消息做 AI 话题分类。
|
||
|
||
- 串行执行:同一时间只允许一个分类任务。
|
||
- 重跑时删除旧 AI 话题,保留人工话题。
|
||
"""
|
||
global _classifying_group
|
||
|
||
path = get_active_db_path()
|
||
|
||
async def _update_task(status: str, processed: int = 0, total: int = 0, error: str = ""):
|
||
try:
|
||
async with aiosqlite.connect(path) as _db:
|
||
_db.row_factory = aiosqlite.Row
|
||
if error:
|
||
await _db.execute(
|
||
"UPDATE ai_tasks SET status=?, progress=?, error=?, updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||
(status, json.dumps({"processed": processed, "total": total}), error, task_id),
|
||
)
|
||
else:
|
||
await _db.execute(
|
||
"UPDATE ai_tasks SET status=?, progress=?, updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||
(status, json.dumps({"processed": processed, "total": total}), task_id),
|
||
)
|
||
await _db.commit()
|
||
except Exception as e:
|
||
log.warning(f"[classify] 更新 task {task_id} 失败: {e}")
|
||
|
||
_s = await get_ai_settings()
|
||
if not _s.get("ai_api_key") or not _s.get("ai_model"):
|
||
await _update_task("error", 0, 0, "AI 未配置,请在「设置」页面填入 API Key 和话题分析模型")
|
||
log.warning(f"[classify] group={group_id} aborted: AI not configured")
|
||
return
|
||
|
||
async with _classify_lock:
|
||
_classifying_group = group_id
|
||
try:
|
||
try:
|
||
messages = await _fetch_window_messages(group["talker"], start_ts, end_ts)
|
||
except Exception as e:
|
||
log.error(f"[classify] group={group_id} fetch error: {e}", exc_info=True)
|
||
await _update_task("error", 0, 0, f"拉取聊天记录失败:{e}")
|
||
return
|
||
|
||
total_messages = len(messages)
|
||
log.info(f"[classify] group={group_id} got {total_messages} msgs in selected window")
|
||
if total_messages == 0:
|
||
await _update_task("done", 0, 0)
|
||
return
|
||
_fill_message_sender_names(messages)
|
||
|
||
media_error = _validate_media_settings(messages, _s)
|
||
if media_error:
|
||
await _update_task("error", 0, total_messages, media_error)
|
||
return
|
||
|
||
_attach_neighbor_context(messages)
|
||
|
||
media_count = sum(1 for m in messages if _media_kind(m) and _media_parse_key(m))
|
||
chunks_count = len(_chunk_messages(messages))
|
||
total_units = total_messages + media_count + chunks_count + 2
|
||
await _update_task("running", total_messages, total_units)
|
||
|
||
async with aiosqlite.connect(path) as db:
|
||
db.row_factory = aiosqlite.Row
|
||
manual_seqs = await _manual_assigned_seqs(db, group_id)
|
||
await _delete_ai_topics(db, group_id)
|
||
|
||
parsed_units = await _parse_message_media(messages, _update_task, total_messages, total_units)
|
||
classifiable = [m for m in messages if _msg_seq(m) and _msg_seq(m) not in manual_seqs]
|
||
seq_to_msg = {_msg_seq(m): m for m in classifiable if _msg_seq(m)}
|
||
allowed_seqs = set(seq_to_msg)
|
||
if not classifiable:
|
||
await _update_task("done", total_units, total_units)
|
||
return
|
||
|
||
configured_prompt = (group.get("analysis_prompt") or _s.get("topic_analysis_prompt") or "").strip()
|
||
sample_query = "\n".join(_message_analysis_text(m) for m in classifiable[:80])
|
||
learning_context = await build_report_learning_context(
|
||
db,
|
||
group_id=group_id,
|
||
query=sample_query,
|
||
purpose="topic",
|
||
)
|
||
guidance_parts = []
|
||
if configured_prompt:
|
||
guidance_parts.append(f"客户自定义话题分析提示词:\n{configured_prompt}")
|
||
if learning_context:
|
||
guidance_parts.append(
|
||
"报告库学习参考:以下是人工修订过的历史报告,请学习其话题命名、分类颗粒度和关注点;不要复制历史事实。\n"
|
||
f"{learning_context}"
|
||
)
|
||
analysis_guidance = "\n\n".join(guidance_parts)
|
||
|
||
base = total_messages + parsed_units
|
||
candidates_raw = await _classify_batches(
|
||
classifiable,
|
||
_update_task,
|
||
base,
|
||
total_units,
|
||
analysis_guidance,
|
||
)
|
||
candidates = _sanitize_topic_items(candidates_raw, seq_to_msg, allowed_seqs)
|
||
processed_after_batches = base + chunks_count
|
||
merged = await _merge_candidates(candidates, _update_task, processed_after_batches, total_units)
|
||
merged = _finalize_topics(merged, seq_to_msg, allowed_seqs)
|
||
supplemented = await _supplement_assignments(
|
||
merged,
|
||
classifiable,
|
||
seq_to_msg,
|
||
allowed_seqs,
|
||
_update_task,
|
||
processed_after_batches + 1,
|
||
total_units,
|
||
)
|
||
final_topics = _finalize_topics(supplemented, seq_to_msg, allowed_seqs)
|
||
final_topics = _coalesce_device_issue_topics(final_topics, seq_to_msg)
|
||
final_topics = _finalize_topics(final_topics, seq_to_msg, allowed_seqs)
|
||
await _save_topics(db, group_id, group["talker"], final_topics, seq_to_msg)
|
||
|
||
await _update_task("done", total_units, total_units)
|
||
log.info(f"[classify] group={group_id} done, messages={total_messages}, topics={len(final_topics)}")
|
||
|
||
except Exception as e:
|
||
log.error(f"[classify] group={group_id} error: {e}", exc_info=True)
|
||
await _update_task("error", 0, 0, str(e))
|
||
finally:
|
||
_classifying_group = None
|