Update project and configurations

This commit is contained in:
Zou-Seay
2026-06-11 16:28:00 +08:00
parent 12d3922091
commit a29a91867d
237 changed files with 164880 additions and 90 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,467 @@
# 对齐车机方案的实施流程复审
## 1. 目标
本文档用于重新审视当前项目的实现方向,并将“对齐小鹏车机类方案”的核心流程、分支决策和典型场景演示明确下来。
目标不是做一个普通聊天机器人,而是做一个面向车机/客服/前台场景的低延迟执行型 Agent
- 简单请求本地快速闭环
- 复杂请求云端增强处理
- 多轮短句能够恢复上下文
- 多命令能够拆分为 workflow 执行
- 超出能力边界时明确拒绝或澄清
---
## 2. 对齐车机方案的核心原则
结合前面参考的小鹏公开专利,可以抽象出下面几个工程原则:
- 本地优先,云端增强,不是所有请求都直接走远端大模型
- 本地不是单模型,而是多支路并发
- 快反馈和最终反馈分离,首反馈必须短而稳
- 多轮短句优先靠上下文改写和缓存恢复,不是每轮重做完整规划
- 高风险动作必须确认
- 超能力边界时必须拒答,不能强行分类执行
---
## 3. 总体流程图
```mermaid
flowchart TD
A[用户语音输入] --> B[ASR 转文本]
B --> C[文本归一化]
C --> D[本地多支路并发]
D --> D1[keyword / rule / trie]
D --> D2[local bert classifier]
D --> D3[context rewrite / cache]
D --> D4[retrieval matcher]
D1 --> E[本地融合分级器]
D2 --> E
D3 --> E
D4 --> E
E -->|高置信| F[直接执行或直接生成 workflow]
E -->|中置信| G[等待 100~300ms 观察补充分支/云端结果]
E -->|低置信| H[云端 planner / LLM / RAG]
G --> H
H --> I[生成 intent / workflow / clarify / reject]
F --> J[插件执行]
I --> J
J --> K[首反馈 ack / progress]
J --> L[最终反馈 result / clarify / reject]
```
---
## 4. 本地与云端分工
### 4.1 本地快链路
本地负责:
- 高频固定控制类命令
- 已知业务集合内的快速意图识别
- 短句、省略句、连续调节的上下文恢复
- 简单任务的直接执行
- `<1s` 以内的首响应
本地组件:
- `keyword / rule / trie`
- `local bert classifier`
- `context rewrite / cache`
- `retrieval matcher`
- `fusion grader`
### 4.2 云端慢链路
云端负责:
- 多命令拆分
- 条件型请求理解
- 歧义消解
- 复杂问答
- planner 级 workflow 生成
---
## 5. 本地分级决策图
```mermaid
flowchart TD
A[本地多分支结果] --> B{融合分级}
B -->|high| C[直接执行]
B -->|medium| D[等待 100~300ms]
B -->|low| E[不直接执行]
D --> F{云端是否及时返回}
F -->|是| G[采用云端结果或覆盖本地]
F -->|否| H{本地是否达到最低执行阈值}
H -->|是| I[执行本地结果]
H -->|否| J[澄清或拒答]
E --> K[云端理解 / clarify / reject]
```
关键修正点:
- 不是“本地 BERT 有结果就执行”
- 不是“云端没返回就一定执行本地”
- 必须先判断本地结果是否达到最低可执行阈值
- 高风险动作即使高置信,也不能直接执行
---
## 6. 用户反馈状态图
```mermaid
stateDiagram-v2
[*] --> Received
Received --> Ack: 已接收请求
Ack --> ExecutingFast: 本地快执行
Ack --> WaitingCloud: 等待云端/复杂规划
Ack --> Clarify: 缺关键槽位
Ack --> Confirm: 高风险动作
Ack --> Reject: 超能力边界
ExecutingFast --> Result
WaitingCloud --> Result
Clarify --> Result
Confirm --> Result
Reject --> [*]
Result --> [*]
```
反馈规则:
- `ack`:收到,马上处理
- `progress`:正在为你处理
- `result`:执行完成或查询完成
- `clarify`:信息不足,补一个关键字段
- `confirm`:高风险动作确认
- `reject`:能力边界拒答
---
## 7. 反馈模板策略
### 7.1 快执行
适用:
- 打开车窗
- 调低空调
- 播放音乐
- 导航去公司
推荐反馈:
- 首反馈:`好的,正在打开车窗`
- 最终反馈:`车窗已打开`
如果设备动作极快,也可以直接播报最终反馈:
- `车窗已打开`
### 7.2 慢执行
适用:
- 查订单
- 查物流
- 多命令复杂规划
- 条件型任务
推荐反馈:
- 首反馈:`收到,我先帮你查一下`
- 最终反馈:`订单还没发货`
### 7.3 复合命令
例如:
- `打开车窗空调调低至20度`
推荐反馈:
- 首反馈:`好的,正在为你打开车窗并调低空调`
- 最终反馈:`车窗已打开空调已调到20度`
### 7.4 边界外请求
例如:
- `打开飞机门`
推荐反馈:
- `这个我暂时做不了,但我可以帮你导航、查订单、调空调或播放音乐`
---
## 8. 多轮上下文恢复流程
```mermaid
flowchart TD
A[当前轮输入: 再低一点] --> B[读取 session context]
B --> C{是否命中高频改写缓存}
C -->|是| D[改写为完整句]
C -->|否| E[轻量改写模型/规则补全]
D --> F[进入 router]
E --> F
F --> G[意图识别 + 槽位提取 + 执行]
```
说明:
- 这里的“上下文能力”不是 BERT 自己缓存的
- 而是 Agent 在 `session state` 中保存上轮任务和关键槽位
- 再由 `rewrite engine` 完成短句恢复
---
## 9. 典型场景流程演示
### 9.1 场景一:快执行单命令
用户输入:
- `打开车窗`
处理流程:
1. ASR 转文本:`打开车窗`
2. 文本归一化
3. 本地多支路并发
4. 若本地高置信命中 `cabin_open_window`
5. 直接执行车窗插件
6. 播报:`车窗已打开`
时序演示:
```mermaid
sequenceDiagram
participant U as 用户
participant A as ASR
participant R as 本地路由
participant P as 插件执行
participant T as TTS
U->>A: 打开车窗
A->>R: 打开车窗
R->>P: 执行 open_window
P-->>R: success
R->>T: 车窗已打开
T-->>U: 车窗已打开
```
### 9.2 场景二:复合命令快执行
用户输入:
- `打开车窗空调调低至20度`
处理流程:
1. 文本进入本地路由
2. 判断为多命令
3. planner 或本地 splitter 输出两个 step
4. 生成 sequence workflow
5. 顺序执行:
- 打开车窗
- 空调调到 20 度
6. 汇总反馈:`车窗已打开空调已调到20度`
时序演示:
```mermaid
sequenceDiagram
participant U as 用户
participant A as ASR
participant F as 融合分级器
participant W as Workflow
participant P as 插件层
participant T as TTS
U->>A: 打开车窗空调调低至20度
A->>F: 规范化文本
F->>W: 输出 sequence workflow
W->>P: step1 open_window
P-->>W: success
W->>P: step2 set_ac(20)
P-->>W: success
W->>T: 车窗已打开空调已调到20度
T-->>U: 车窗已打开空调已调到20度
```
### 9.3 场景三:慢执行查询
用户输入:
- `帮我查一下订单A123456`
处理流程:
1. 本地高频分支命中订单查询
2. 首反馈先给:
- `收到,我帮你查一下`
3. 调用订单查询插件
4. 最终反馈:
- `订单A123456当前待发货`
### 9.4 场景四:条件型请求
用户输入:
- `查一下订单A123456如果还没发货就取消`
处理流程:
1. 本地识别该请求复杂,进入云端 planner
2. planner 输出 conditional workflow
- step1: query_order
- step2: cancel_order
- condition: order_status == pending_shipment
3. 先执行 step1
4. 若满足条件,则进入确认
5. 用户回复确认后,再执行取消
6. 最终反馈:
- `订单A123456已取消`
时序演示:
```mermaid
sequenceDiagram
participant U as 用户
participant R as 本地融合器
participant C as 云端 Planner
participant W as Workflow
participant P as 插件层
participant T as TTS
U->>R: 查一下订单A123456如果还没发货就取消
R->>C: 复杂条件请求
C-->>W: conditional workflow
W->>P: step1 query_order
P-->>W: order_status=pending_shipment
W->>T: 即将取消订单,仅在订单未发货时取消。请回复确认或取消
T-->>U: 确认提示
U->>W: 确认
W->>P: step2 cancel_order
P-->>W: success
W->>T: 订单A123456已取消
T-->>U: 订单A123456已取消
```
### 9.5 场景五:多轮短句恢复
对话过程:
1. 用户:`把空调调到22度`
2. 系统:`空调已调到22度`
3. 用户:`再低一点`
4. 系统读取 `last_intent=cabin_set_ac``last_temperature=22`
5. rewrite engine 改写为:`把空调调到21度`
6. 再进入意图识别和执行
7. 系统反馈:`空调已调到21度`
### 9.6 场景六:边界外请求
用户输入:
- `打开飞机门`
正确处理:
1. 本地分支都无法稳定支持
2. 若低于执行阈值,不得直接执行已有意图
3. 进入:
- reject
- 或云端澄清
4. 反馈:
- `这个我暂时做不了,但我可以帮你导航、查订单、调空调或播放音乐`
这类场景必须通过 `unknown / out_of_scope` 机制处理,不能靠封闭集分类硬选。
---
## 10. 当前项目与目标方案的对应关系
### 10.1 已经具备的能力
- 本地 `keyword / classifier / retrieval / fusion`
- 本地 BERT 分类器
- `session state`
- `context rewrite`
- `planner`
- `sequence / conditional workflow`
- 高风险确认
- demo 调试面板
### 10.2 还需要补齐的关键能力
- `unknown / out_of_scope`
- 低分拒识策略
- 明确的 `execute / reject / route_to_cloud` 决策建议
- 更多真实车机意图
- 真实插件接入
- 语音前端与 ASR/TTS 完整接入
---
## 11. 当前阶段的正式实施结论
当前方向可以继续,但必须明确:
- 本地 BERT 是本地快分支之一,不是整个系统的唯一裁决者
- 最终执行依据应来自“本地融合分级器 + planner + 风险规则”
- 用户体验的关键不只是识别正确,还包括:
- 首反馈是否快
- 多轮是否顺
- 边界是否清楚
- 风险是否可控
因此,当前正式方案应定义为:
```text
车机型 Agent = 本地并发快链路
+ 上下文改写缓存
+ 分级融合决策
+ 云端 planner
+ workflow 执行
+ 风险确认
+ reject / clarify / fallback 策略
```
---
## 12. 下一步执行优先级
建议按以下顺序继续实现:
1.`unknown / out_of_scope` 和拒识阈值
2. 输出统一执行建议:
- `execute`
- `clarify`
- `reject`
- `route_to_cloud`
3. 扩真实车机场景意图:
- 车窗
- 车门
- 座椅
- 灯光
- 后视镜
- 除雾
4. 强化 rewrite/cache 高频模式
5. 接真实插件与真实语音链路

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,65 @@
# 本地 BERT 意图识别测试报告
## 概览
- 模型目录:`/Users/hwp/Documents/trae_projects/intelligent_cabin/models/local_bert_intent`
- 评测集:`/Users/hwp/Documents/trae_projects/intelligent_cabin/app/data/bert_intent_eval_independent.jsonl`
- 评测阈值:`0.0`
- 测试样本数:`42`
- 总体准确率:`0.9762`
## 训练摘要
- 基座模型:`hfl/chinese-macbert-base`
- 训练集 / 验证集:`1557 / 401`
- 最佳验证准确率:`0.9875`
- 训练设备:`mps`
## 分类别结果
- `business`: 33/34 = 0.9706
- `out_of_scope`: 4/4 = 1.0
- `social`: 4/4 = 1.0
## 分标签结果
- `__out_of_scope__` (out_of_scope): 4/4 = 1.0
- `__social__` (social): 4/4 = 1.0
- `cabin_ac_off` (business): 1/1 = 1.0
- `cabin_ac_on` (business): 1/1 = 1.0
- `cabin_defog_front_on` (business): 1/1 = 1.0
- `cabin_defog_rear_on` (business): 1/1 = 1.0
- `cabin_fan_down` (business): 1/1 = 1.0
- `cabin_fan_up` (business): 1/1 = 1.0
- `cabin_lights_off` (business): 1/1 = 1.0
- `cabin_lights_on` (business): 1/1 = 1.0
- `cabin_lock_doors` (business): 1/1 = 1.0
- `cabin_mirror_fold` (business): 1/1 = 1.0
- `cabin_mirror_unfold` (business): 1/1 = 1.0
- `cabin_nav_cancel` (business): 1/1 = 1.0
- `cabin_nav_to` (business): 1/1 = 1.0
- `cabin_next_track` (business): 1/1 = 1.0
- `cabin_pause_music` (business): 1/1 = 1.0
- `cabin_play_music` (business): 1/1 = 1.0
- `cabin_previous_track` (business): 1/1 = 1.0
- `cabin_seat_heat_off` (business): 1/1 = 1.0
- `cabin_seat_heat_on` (business): 1/1 = 1.0
- `cabin_set_ac` (business): 1/1 = 1.0
- `cabin_sunroof_close` (business): 1/1 = 1.0
- `cabin_sunroof_open` (business): 1/1 = 1.0
- `cabin_unlock_doors` (business): 1/1 = 1.0
- `cabin_volume_down` (business): 1/1 = 1.0
- `cabin_volume_mute` (business): 1/1 = 1.0
- `cabin_volume_up` (business): 1/1 = 1.0
- `cabin_window_close` (business): 1/1 = 1.0
- `cabin_window_open` (business): 0/1 = 0.0
- `cabin_wiper_off` (business): 1/1 = 1.0
- `cabin_wiper_on` (business): 1/1 = 1.0
- `cs_cancel_order` (business): 1/1 = 1.0
- `cs_query_logistics` (business): 1/1 = 1.0
- `cs_query_order` (business): 1/1 = 1.0
- `cs_transfer_human` (business): 1/1 = 1.0
## 错误样例
- 文本:`左前窗打开一点` | 类别:`business` | 期望:`cabin_window_open` | 预测:`cabin_defog_front_on` | 分数:`0.9951`
## 结论
- 当前本地 MacBERT 已具备较强的业务意图识别能力,可作为本地快链路分类器。
- 误判主要集中在方向相反或语义接近的控制指令,下一步应补充对抗样本和真实口语表达。
- 上线前建议继续补充 ASR 错字、多轮短句和多意图子句级样本。

View File

@@ -0,0 +1,426 @@
{
"model_dir": "/Users/hwp/Documents/trae_projects/intelligent_cabin/models/local_bert_intent",
"threshold": 0.0,
"test_path": "/Users/hwp/Documents/trae_projects/intelligent_cabin/app/data/bert_intent_eval_independent.jsonl",
"test_case_count": 42,
"accuracy": 0.9762,
"train_summary": {
"base_model": "hfl/chinese-macbert-base",
"epochs": 16,
"batch_size": 8,
"learning_rate": 2e-05,
"train_size": 1557,
"dev_size": 401,
"best_dev_accuracy": 0.9875,
"device": "mps"
},
"per_category": [
{
"category": "business",
"total": 34,
"correct": 33,
"accuracy": 0.9706
},
{
"category": "out_of_scope",
"total": 4,
"correct": 4,
"accuracy": 1.0
},
{
"category": "social",
"total": 4,
"correct": 4,
"accuracy": 1.0
}
],
"per_label": [
{
"label": "__out_of_scope__",
"category": "out_of_scope",
"total": 4,
"correct": 4,
"accuracy": 1.0
},
{
"label": "__social__",
"category": "social",
"total": 4,
"correct": 4,
"accuracy": 1.0
},
{
"label": "cabin_ac_off",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_ac_on",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_defog_front_on",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_defog_rear_on",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_fan_down",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_fan_up",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_lights_off",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_lights_on",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_lock_doors",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_mirror_fold",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_mirror_unfold",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_nav_cancel",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_nav_to",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_next_track",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_pause_music",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_play_music",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_previous_track",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_seat_heat_off",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_seat_heat_on",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_set_ac",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_sunroof_close",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_sunroof_open",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_unlock_doors",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_volume_down",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_volume_mute",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_volume_up",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_window_close",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_window_open",
"category": "business",
"total": 1,
"correct": 0,
"accuracy": 0.0
},
{
"label": "cabin_wiper_off",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cabin_wiper_on",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cs_cancel_order",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cs_query_logistics",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cs_query_order",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
},
{
"label": "cs_transfer_human",
"category": "business",
"total": 1,
"correct": 1,
"accuracy": 1.0
}
],
"errors": [
{
"text": "左前窗打开一点",
"category": "business",
"expected_label": "cabin_window_open",
"predicted_label": "cabin_defog_front_on",
"score": 0.9951,
"raw_label": "cabin_defog_front_on",
"ok": false,
"top_candidates": [
{
"intent_id": "cabin_defog_front_on",
"score": 0.9951
},
{
"intent_id": "cabin_sunroof_open",
"score": 0.0005
},
{
"intent_id": "cabin_lights_on",
"score": 0.0004
}
]
}
],
"confusion": {
"cabin_ac_off": {
"cabin_ac_off": 1
},
"cabin_ac_on": {
"cabin_ac_on": 1
},
"cabin_defog_front_on": {
"cabin_defog_front_on": 1
},
"cabin_defog_rear_on": {
"cabin_defog_rear_on": 1
},
"cabin_fan_down": {
"cabin_fan_down": 1
},
"cabin_fan_up": {
"cabin_fan_up": 1
},
"cabin_lights_off": {
"cabin_lights_off": 1
},
"cabin_lights_on": {
"cabin_lights_on": 1
},
"cabin_lock_doors": {
"cabin_lock_doors": 1
},
"cabin_mirror_fold": {
"cabin_mirror_fold": 1
},
"cabin_mirror_unfold": {
"cabin_mirror_unfold": 1
},
"cabin_nav_cancel": {
"cabin_nav_cancel": 1
},
"cabin_nav_to": {
"cabin_nav_to": 1
},
"cabin_next_track": {
"cabin_next_track": 1
},
"cabin_pause_music": {
"cabin_pause_music": 1
},
"cabin_play_music": {
"cabin_play_music": 1
},
"cabin_previous_track": {
"cabin_previous_track": 1
},
"cabin_seat_heat_off": {
"cabin_seat_heat_off": 1
},
"cabin_seat_heat_on": {
"cabin_seat_heat_on": 1
},
"cabin_set_ac": {
"cabin_set_ac": 1
},
"cabin_sunroof_close": {
"cabin_sunroof_close": 1
},
"cabin_sunroof_open": {
"cabin_sunroof_open": 1
},
"cabin_unlock_doors": {
"cabin_unlock_doors": 1
},
"cabin_volume_down": {
"cabin_volume_down": 1
},
"cabin_volume_mute": {
"cabin_volume_mute": 1
},
"cabin_volume_up": {
"cabin_volume_up": 1
},
"cabin_window_close": {
"cabin_window_close": 1
},
"cabin_window_open": {
"cabin_defog_front_on": 1
},
"cabin_wiper_off": {
"cabin_wiper_off": 1
},
"cabin_wiper_on": {
"cabin_wiper_on": 1
},
"cs_cancel_order": {
"cs_cancel_order": 1
},
"cs_query_logistics": {
"cs_query_logistics": 1
},
"cs_query_order": {
"cs_query_order": 1
},
"cs_transfer_human": {
"cs_transfer_human": 1
},
"__social__": {
"__social__": 4
},
"__out_of_scope__": {
"__out_of_scope__": 4
}
}
}

View File

@@ -0,0 +1,47 @@
# 本地多标签 Detector 独立评测报告
## 概览
- 模型目录:`/Users/hwp/Documents/trae_projects/intelligent_cabin/models/local_bert_multi_intent`
- 评测集:`/Users/hwp/Documents/trae_projects/intelligent_cabin/app/data/bert_intent_multilabel_eval_independent.jsonl`
- 样本数:`37`
- 阈值 / top_k / max_labels`0.45 / 8 / 4`
- `micro_precision``0.9362`
- `micro_recall``0.6377`
- `micro_f1``0.7586`
- `exact_match``0.5135`
- `multi_sentence_recall``0.4138`
- `single_guard_false_alarm_rate``0.0`
## 分类别结果
- `cabin_parallel`: count=15 micro_f1=0.807 exact_match=0.4667
- `cabin_sequence`: count=9 micro_f1=0.5385 exact_match=0.3333
- `cs_conditional`: count=3 micro_f1=0.9091 exact_match=0.6667
- `cs_sequence`: count=2 micro_f1=0.6667 exact_match=0.0
- `single_guard`: count=8 micro_f1=0.875 exact_match=0.875
## 主要混淆
- 漏掉 `cabin_sunroof_open`,同时误报 `cabin_window_open``1`
- 漏掉 `cabin_pause_music`,同时误报 `cabin_play_music``1`
- 漏掉 `cabin_window_open`,同时误报 `cabin_defog_front_on``1`
## 错误样例
- 文本:`锁车门,再把后视镜收起来` | 类别:`cabin_sequence` | 期望:`['cabin_lock_doors', 'cabin_mirror_fold']` | 预测:`[]`
- 文本:`把车门解锁,再把镜子展开` | 类别:`cabin_sequence` | 期望:`['cabin_mirror_unfold', 'cabin_unlock_doors']` | 预测:`[]`
- 文本:`路线别导了,音乐也停一下` | 类别:`cabin_parallel` | 期望:`['cabin_nav_cancel', 'cabin_pause_music']` | 预测:`[]`
- 文本:`雨停了,雨刮关掉,再把窗开一点` | 类别:`cabin_sequence` | 期望:`['cabin_window_open', 'cabin_wiper_off']` | 预测:`[]`
- 文本:`把天窗合上,然后把音乐暂停` | 类别:`cabin_sequence` | 期望:`['cabin_pause_music', 'cabin_sunroof_close']` | 预测:`[]`
- 文本:`先把音量调大,再切下一首` | 类别:`cabin_parallel` | 期望:`['cabin_next_track', 'cabin_volume_up']` | 预测:`[]`
- 文本:`静音之后切回上一首` | 类别:`cabin_sequence` | 期望:`['cabin_previous_track', 'cabin_volume_mute']` | 预测:`[]`
- 文本:`把天窗打开透口气,再开空调` | 类别:`cabin_parallel` | 期望:`['cabin_ac_on', 'cabin_sunroof_open']` | 预测:`['cabin_ac_on', 'cabin_window_open']`
- 文本:`音乐停一下,然后导航到公司` | 类别:`cabin_sequence` | 期望:`['cabin_nav_to', 'cabin_pause_music']` | 预测:`['cabin_nav_to', 'cabin_play_music']`
- 文本:`把左前窗降一点` | 类别:`single_guard` | 期望:`['cabin_window_open']` | 预测:`['cabin_defog_front_on']`
- 文本:`车里闷,给我透个气,再放点轻松的歌` | 类别:`cabin_parallel` | 期望:`['cabin_play_music', 'cabin_window_open']` | 预测:`['cabin_play_music']`
- 文本:`把空调开了,风别太小,再来首歌` | 类别:`cabin_parallel` | 期望:`['cabin_ac_on', 'cabin_fan_up', 'cabin_play_music']` | 预测:`['cabin_ac_on', 'cabin_play_music']`
- 文本:`开导航去徐家汇,顺便把风量调大` | 类别:`cabin_parallel` | 期望:`['cabin_fan_up', 'cabin_nav_to']` | 预测:`['cabin_nav_to']`
- 文本:`温度调到二十三度,风稍微小一点` | 类别:`cabin_parallel` | 期望:`['cabin_fan_down', 'cabin_set_ac']` | 预测:`['cabin_set_ac']`
- 文本:`帮我看A812302物流要是太慢就转人工` | 类别:`cs_conditional` | 期望:`['cs_query_logistics', 'cs_transfer_human']` | 预测:`['cs_query_logistics']`
## 结论建议
- 先看多意图句是否存在系统性漏召回,再看单意图是否被误报成多意图。
-`single_guard_false_alarm_rate` 偏高,需要先收紧 detector 阈值或补单意图负样本,再考虑进入 NER。
-`multi_sentence_recall` 不稳定,应继续补条件句、弱连接句和口语化多动作语料。

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,40 @@
# Joint NLU 独立评测报告
## 概览
- 模型目录:`/Users/hwp/Documents/trae_projects/intelligent_cabin/models/local_joint_bert_nlu`
- 评测集:`/Users/hwp/Documents/trae_projects/intelligent_cabin/app/data/joint_nlu_eval_independent.jsonl`
- 样本数:`43`
- `intent_accuracy``0.9302`
- `slot_exact_match``1.0`
- `joint_exact_match``0.9302`
- `slot_micro_precision``1.0`
- `slot_micro_recall``1.0`
- `slot_micro_f1``1.0`
## 训练摘要
- 训练集 / 评测集:`337 / 10`
- 训练阶段 `intent_accuracy``1.0`
- 训练阶段 `slot_exact_match``0.8`
## 分类别结果
- `failure_replay`: count=12 intent_acc=0.75 slot_exact=1.0 joint_exact=0.75
- `no_slot_control`: count=14 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
- `slot_destination`: count=4 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
- `slot_music`: count=5 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
- `slot_order`: count=4 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
- `slot_temperature`: count=4 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
## 主要意图混淆
- 期望 `cabin_window_open`,预测成 `None``1`
- 期望 `cabin_window_open`,预测成 `cabin_play_music``1`
- 期望 `cabin_fan_up`,预测成 `cabin_fan_down``1`
## 失败样例回放
- 文本:`把左前窗降一点` | 类别:`failure_replay` | 期望意图:`cabin_window_open` | 预测意图:`None` | 期望槽位:`{}` | 预测槽位:`{}` | 缺失槽位:`[]` | 多出槽位:`[]`
- 文本:`给我透个气` | 类别:`failure_replay` | 期望意图:`cabin_window_open` | 预测意图:`cabin_play_music` | 期望槽位:`{}` | 预测槽位:`{}` | 缺失槽位:`[]` | 多出槽位:`[]`
- 文本:`风别太小` | 类别:`failure_replay` | 期望意图:`cabin_fan_up` | 预测意图:`cabin_fan_down` | 期望槽位:`{}` | 预测槽位:`{}` | 缺失槽位:`[]` | 多出槽位:`[]`
## 结论
- 先看 `failure_replay` 是否仍然错,能直接判断先前多意图失败到底是联合模型本体问题还是上层组合问题。
-`slot_music``slot_destination` 仍不稳,优先补 span 标注,不要回退到规则抽槽。
-`no_slot_control` 很稳但 `failure_replay` 中仍有大量错误,下一步应补长尾控制语义数据,而不是急着上更复杂结构。

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
import argparse
import json
import sys
from collections import Counter
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.bootstrap import build_intent_registry
from app.services.joint_nlu import JointBertNLU
DEFAULT_TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_eval_independent.jsonl"
def load_cases(path: Path) -> list[dict[str, object]]:
rows: list[dict[str, object]] = []
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
rows.append(payload)
return rows
def main() -> None:
parser = argparse.ArgumentParser(description="Joint BERT 多意图独立评测")
parser.add_argument("--model-path", type=str, default="models/local_joint_bert_nlu")
parser.add_argument("--test-path", type=str, default=str(DEFAULT_TEST_PATH))
args = parser.parse_args()
registry = build_intent_registry()
nlu = JointBertNLU(model_path=args.model_path)
cases = load_cases(Path(args.test_path))
tp = 0
fp = 0
fn = 0
exact = 0
failures: list[dict[str, object]] = []
category_correct: Counter[str] = Counter()
category_total: Counter[str] = Counter()
for case in cases:
text = str(case["text"])
expected = sorted({str(item) for item in case.get("expected_intent_ids", [])})
predicted = sorted(item.intent_id for item in nlu.predict_multi_intents(text, registry.list(), top_k=8, max_labels=4))
expected_set = set(expected)
predicted_set = set(predicted)
tp += len(expected_set & predicted_set)
fp += len(predicted_set - expected_set)
fn += len(expected_set - predicted_set)
category = str(case.get("category") or "unknown")
category_total[category] += 1
if expected_set == predicted_set:
exact += 1
category_correct[category] += 1
else:
failures.append(
{
"text": text,
"expected_intent_ids": expected,
"predicted_intent_ids": predicted,
"category": category,
}
)
precision = tp / max(tp + fp, 1)
recall = tp / max(tp + fn, 1)
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
result = {
"sample_count": len(cases),
"micro_precision": round(precision, 4),
"micro_recall": round(recall, 4),
"micro_f1": round(f1, 4),
"exact_match": round(exact / max(len(cases), 1), 4),
"per_category_exact_match": {
category: round(category_correct[category] / max(total, 1), 4)
for category, total in sorted(category_total.items())
},
"failures": failures[:20],
}
print(json.dumps(result, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.bootstrap import build_intent_registry
from app.services.joint_nlu import JointBertNLU
def main() -> None:
parser = argparse.ArgumentParser(description="评测 Joint BERT NLU 单句意图与槽位输出")
parser.add_argument("--text", type=str, required=True, help="待评测文本")
parser.add_argument("--model-path", type=str, default="models/local_joint_bert_nlu", help="模型目录")
args = parser.parse_args()
registry = build_intent_registry()
nlu = JointBertNLU(model_path=args.model_path)
result = nlu.predict(args.text, registry.list())
print(
json.dumps(
{
"text": args.text,
"intent_id": result.intent_id,
"intent_score": round(result.intent_score, 4),
"candidates": [
{"intent_id": item.intent_id, "score": round(item.score, 4)}
for item in result.candidates
],
"multi_intent_candidates": [
{"intent_id": item.intent_id, "score": round(item.score, 4)}
for item in result.multi_intent_candidates
],
"slots": result.slots,
"slot_items": [
{
"slot_name": item.slot_name,
"value": item.value,
"start": item.start,
"end": item.end,
"score": item.score,
}
for item in result.slot_items
],
"error_message": result.error_message,
},
ensure_ascii=False,
indent=2,
)
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,275 @@
from __future__ import annotations
import argparse
import json
import sys
from collections import Counter, defaultdict
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.bootstrap import build_intent_registry
from app.services.joint_nlu import JointBertNLU
TEST_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval_independent.jsonl"
MODEL_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
REPORT_DIR = PROJECT_ROOT / "reports"
RESULT_PATH = REPORT_DIR / "joint_nlu_independent_result.json"
REPORT_PATH = REPORT_DIR / "joint_nlu_independent_report.md"
TRAIN_SUMMARY_PATH = MODEL_DIR / "train_summary.json"
def load_cases(file_path: Path) -> list[dict[str, object]]:
cases: list[dict[str, object]] = []
for line in file_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
expected_intent_id = str(payload.get("expected_intent_id") or payload.get("intent_id") or "").strip()
if not expected_intent_id:
continue
cases.append(
{
"text": str(payload["text"]),
"expected_intent_id": expected_intent_id,
"expected_slots": dict(payload.get("expected_slots") or {}),
"category": str(payload.get("category") or "unknown"),
}
)
return cases
def load_train_summary(file_path: Path) -> dict[str, object]:
if not file_path.exists():
return {}
return json.loads(file_path.read_text(encoding="utf-8"))
def compare_slots(expected: dict[str, object], predicted: dict[str, object]) -> dict[str, object]:
expected_keys = set(expected)
predicted_keys = set(predicted)
missing_keys = sorted(expected_keys - predicted_keys)
extra_keys = sorted(predicted_keys - expected_keys)
wrong_values: list[dict[str, object]] = []
matched_keys = 0
for key in sorted(expected_keys & predicted_keys):
if expected[key] == predicted[key]:
matched_keys += 1
else:
wrong_values.append(
{
"slot_name": key,
"expected": expected[key],
"predicted": predicted[key],
}
)
exact = not missing_keys and not extra_keys and not wrong_values
return {
"missing_keys": missing_keys,
"extra_keys": extra_keys,
"wrong_values": wrong_values,
"matched_keys": matched_keys,
"exact": exact,
}
def compute_metrics(results: list[dict[str, object]]) -> dict[str, float]:
total = len(results)
intent_correct = sum(1 for item in results if item["intent_ok"])
slot_exact = sum(1 for item in results if item["slot_exact"])
joint_exact = sum(1 for item in results if item["joint_ok"])
slot_tp = 0
slot_fp = 0
slot_fn = 0
for item in results:
expected = dict(item["expected_slots"])
predicted = dict(item["predicted_slots"])
expected_keys = set(expected)
predicted_keys = set(predicted)
slot_tp += sum(1 for key in expected_keys & predicted_keys if expected[key] == predicted[key])
slot_fp += len(predicted_keys - expected_keys)
slot_fn += len(expected_keys - predicted_keys)
slot_fp += sum(1 for key in expected_keys & predicted_keys if expected[key] != predicted[key])
slot_fn += sum(1 for key in expected_keys & predicted_keys if expected[key] != predicted[key])
precision = slot_tp / (slot_tp + slot_fp) if (slot_tp + slot_fp) else 0.0
recall = slot_tp / (slot_tp + slot_fn) if (slot_tp + slot_fn) else 0.0
slot_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
return {
"intent_accuracy": round(intent_correct / total, 4) if total else 0.0,
"slot_exact_match": round(slot_exact / total, 4) if total else 0.0,
"joint_exact_match": round(joint_exact / total, 4) if total else 0.0,
"slot_micro_precision": round(precision, 4),
"slot_micro_recall": round(recall, 4),
"slot_micro_f1": round(slot_f1, 4),
}
def summarize_by_category(results: list[dict[str, object]]) -> list[dict[str, object]]:
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
for item in results:
grouped[str(item["category"])].append(item)
summary: list[dict[str, object]] = []
for category, items in sorted(grouped.items()):
summary.append(
{
"category": category,
"sample_count": len(items),
"metrics": compute_metrics(items),
}
)
return summary
def collect_top_confusions(results: list[dict[str, object]], limit: int = 12) -> list[dict[str, object]]:
counter: Counter[tuple[str, str]] = Counter()
for item in results:
if item["intent_ok"]:
continue
counter[(str(item["expected_intent_id"]), str(item["predicted_intent_id"]))] += 1
return [
{"expected": expected, "predicted": predicted, "count": count}
for (expected, predicted), count in counter.most_common(limit)
]
def collect_failures(results: list[dict[str, object]], limit: int = 20) -> list[dict[str, object]]:
failures = [item for item in results if not item["joint_ok"]]
def sort_key(item: dict[str, object]) -> tuple[int, int, int]:
slot_errors = len(item["slot_diff"]["missing_keys"]) + len(item["slot_diff"]["extra_keys"]) + len(item["slot_diff"]["wrong_values"])
return (0 if item["intent_ok"] else 1, slot_errors, len(str(item["text"])))
return sorted(failures, key=sort_key, reverse=True)[:limit]
def main() -> None:
parser = argparse.ArgumentParser(description="Joint NLU 独立评测与失败样例回放")
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="评测集路径")
parser.add_argument("--model-path", type=str, default=str(MODEL_DIR), help="Joint NLU 模型路径")
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化结果输出路径")
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 报告输出路径")
args = parser.parse_args()
cases = load_cases(Path(args.test_path))
registry = build_intent_registry()
nlu = JointBertNLU(model_path=args.model_path)
results: list[dict[str, object]] = []
for case in cases:
prediction = nlu.predict(str(case["text"]), registry.list())
predicted_slots = dict(prediction.slots)
slot_diff = compare_slots(dict(case["expected_slots"]), predicted_slots)
predicted_intent_id = prediction.intent_id or "None"
intent_ok = predicted_intent_id == case["expected_intent_id"]
joint_ok = intent_ok and bool(slot_diff["exact"])
results.append(
{
"text": case["text"],
"category": case["category"],
"expected_intent_id": case["expected_intent_id"],
"predicted_intent_id": predicted_intent_id,
"expected_slots": case["expected_slots"],
"predicted_slots": predicted_slots,
"intent_score": round(prediction.intent_score, 4),
"intent_ok": intent_ok,
"slot_exact": bool(slot_diff["exact"]),
"joint_ok": joint_ok,
"slot_diff": slot_diff,
"top_candidates": [
{"intent_id": item.intent_id, "score": round(item.score, 4)}
for item in prediction.candidates
],
}
)
summary = {
"model_path": args.model_path,
"test_path": args.test_path,
"sample_count": len(results),
"metrics": compute_metrics(results),
"per_category": summarize_by_category(results),
"top_confusions": collect_top_confusions(results),
"failure_examples": collect_failures(results),
"train_summary": load_train_summary(TRAIN_SUMMARY_PATH),
"results": results,
}
REPORT_DIR.mkdir(parents=True, exist_ok=True)
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
print(json.dumps({"sample_count": summary["sample_count"], "metrics": summary["metrics"]}, ensure_ascii=False))
def render_report(summary: dict[str, object]) -> str:
metrics = summary["metrics"]
per_category = summary["per_category"]
confusions = summary["top_confusions"]
failures = summary["failure_examples"]
train_summary = summary.get("train_summary") or {}
lines = [
"# Joint NLU 独立评测报告",
"",
"## 概览",
f"- 模型目录:`{summary['model_path']}`",
f"- 评测集:`{summary['test_path']}`",
f"- 样本数:`{summary['sample_count']}`",
f"- `intent_accuracy``{metrics['intent_accuracy']}`",
f"- `slot_exact_match``{metrics['slot_exact_match']}`",
f"- `joint_exact_match``{metrics['joint_exact_match']}`",
f"- `slot_micro_precision``{metrics['slot_micro_precision']}`",
f"- `slot_micro_recall``{metrics['slot_micro_recall']}`",
f"- `slot_micro_f1``{metrics['slot_micro_f1']}`",
"",
"## 训练摘要",
]
if train_summary:
lines.extend(
[
f"- 训练集 / 评测集:`{train_summary.get('train_size', 'unknown')} / {train_summary.get('eval_size', 'unknown')}`",
f"- 训练阶段 `intent_accuracy``{train_summary.get('metrics', {}).get('intent_accuracy', 'unknown')}`",
f"- 训练阶段 `slot_exact_match``{train_summary.get('metrics', {}).get('slot_exact_match', 'unknown')}`",
"",
]
)
else:
lines.extend(["- 未找到训练摘要。", ""])
lines.extend(["## 分类别结果"])
for item in per_category:
category_metrics = item["metrics"]
lines.append(
f"- `{item['category']}`: count={item['sample_count']} intent_acc={category_metrics['intent_accuracy']} slot_exact={category_metrics['slot_exact_match']} joint_exact={category_metrics['joint_exact_match']}"
)
lines.extend(["", "## 主要意图混淆"])
if not confusions:
lines.append("- 未发现意图混淆。")
else:
for item in confusions:
lines.append(f"- 期望 `{item['expected']}`,预测成 `{item['predicted']}``{item['count']}` 次")
lines.extend(["", "## 失败样例回放"])
if not failures:
lines.append("- 无失败样例。")
else:
for item in failures:
slot_diff = item["slot_diff"]
lines.append(
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望意图:`{item['expected_intent_id']}` | 预测意图:`{item['predicted_intent_id']}` | 期望槽位:`{item['expected_slots']}` | 预测槽位:`{item['predicted_slots']}` | 缺失槽位:`{slot_diff['missing_keys']}` | 多出槽位:`{slot_diff['extra_keys']}`"
)
lines.extend(
[
"",
"## 结论",
"- 先看 `failure_replay` 是否仍然错,能直接判断先前多意图失败到底是联合模型本体问题还是上层组合问题。",
"- 若 `slot_music` 或 `slot_destination` 仍不稳,优先补 span 标注,不要回退到规则抽槽。",
"- 若 `no_slot_control` 很稳但 `failure_replay` 中仍有大量错误,下一步应补长尾控制语义数据,而不是急着上更复杂结构。",
"",
]
)
return "\n".join(lines)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,231 @@
from __future__ import annotations
import argparse
import json
from collections import Counter, defaultdict
from pathlib import Path
import sys
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.bootstrap import build_intent_registry
from app.services.classifier import BertIntentClassifier
TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_eval_independent.jsonl"
MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
REPORT_DIR = PROJECT_ROOT / "reports"
REPORT_PATH = REPORT_DIR / "bert_local_test_report.md"
RESULT_PATH = REPORT_DIR / "bert_local_test_result.json"
BERT_THRESHOLD = 0.0
TRAIN_SUMMARY_PATH = MODEL_DIR / "train_summary.json"
def load_cases(file_path: Path) -> list[dict[str, str]]:
cases: list[dict[str, str]] = []
for line in file_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
expected_label = str(payload.get("expected_label") or payload.get("intent_id") or "").strip()
if not expected_label:
continue
category = str(payload.get("category") or infer_category(expected_label)).strip()
cases.append(
{
"text": str(payload["text"]),
"expected_label": expected_label,
"category": category,
}
)
return cases
def load_train_summary(file_path: Path) -> dict[str, object]:
if not file_path.exists():
return {}
return json.loads(file_path.read_text(encoding="utf-8"))
def infer_category(label: str) -> str:
if label == "__social__":
return "social"
if label == "__out_of_scope__":
return "out_of_scope"
return "business"
def resolve_predicted_label(result) -> str:
if result.intent is not None:
return result.intent.intent_id
if result.raw_label:
return str(result.raw_label)
return "None"
def main() -> None:
parser = argparse.ArgumentParser(description="本地 BERT 独立评测脚本")
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="评测集路径")
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化评测结果输出路径")
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 评测报告输出路径")
args = parser.parse_args()
intent_registry = build_intent_registry()
intents = intent_registry.list()
classifier = BertIntentClassifier(
model_path=str(MODEL_DIR),
threshold=BERT_THRESHOLD,
label_map_path=str(MODEL_DIR / "label_map.json"),
fallback=None,
top_k=3,
)
cases = load_cases(Path(args.test_path))
results: list[dict[str, object]] = []
confusion: dict[str, Counter[str]] = defaultdict(Counter)
category_confusion: dict[str, Counter[str]] = defaultdict(Counter)
correct = 0
for case in cases:
result = classifier.predict(case["text"], intents)
predicted = resolve_predicted_label(result)
expected = case["expected_label"]
ok = predicted == expected
if ok:
correct += 1
confusion[expected][predicted] += 1
category_confusion[case["category"]]["correct" if ok else "wrong"] += 1
results.append(
{
"text": case["text"],
"category": case["category"],
"expected_label": expected,
"predicted_label": predicted,
"score": round(result.score, 4),
"raw_label": result.raw_label,
"ok": ok,
"top_candidates": [
{"intent_id": intent.intent_id, "score": round(score, 4)}
for intent, score in (result.candidates or [])
],
}
)
accuracy = correct / len(cases) if cases else 0.0
train_summary = load_train_summary(TRAIN_SUMMARY_PATH)
per_label_stats: list[dict[str, object]] = []
for label in sorted({case["expected_label"] for case in cases}):
label_cases = [item for item in results if item["expected_label"] == label]
label_correct = sum(1 for item in label_cases if item["ok"])
per_label_stats.append(
{
"label": label,
"category": infer_category(label),
"total": len(label_cases),
"correct": label_correct,
"accuracy": round(label_correct / len(label_cases), 4) if label_cases else 0.0,
}
)
per_category_stats: list[dict[str, object]] = []
for category in sorted({case["category"] for case in cases}):
category_cases = [item for item in results if item["category"] == category]
category_correct = sum(1 for item in category_cases if item["ok"])
per_category_stats.append(
{
"category": category,
"total": len(category_cases),
"correct": category_correct,
"accuracy": round(category_correct / len(category_cases), 4) if category_cases else 0.0,
}
)
errors = [item for item in results if not item["ok"]]
summary = {
"model_dir": str(MODEL_DIR),
"threshold": BERT_THRESHOLD,
"test_path": str(args.test_path),
"test_case_count": len(cases),
"accuracy": round(accuracy, 4),
"train_summary": train_summary,
"per_category": per_category_stats,
"per_label": per_label_stats,
"errors": errors,
"confusion": {key: dict(value) for key, value in confusion.items()},
}
REPORT_DIR.mkdir(parents=True, exist_ok=True)
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
print(json.dumps({"accuracy": summary["accuracy"], "test_case_count": len(cases), "error_count": len(errors)}, ensure_ascii=False))
def render_report(summary: dict[str, object]) -> str:
per_category = summary["per_category"]
per_label = summary["per_label"]
errors = summary["errors"]
train_summary = summary.get("train_summary") or {}
lines = [
"# 本地 BERT 意图识别测试报告",
"",
"## 概览",
f"- 模型目录:`{summary['model_dir']}`",
f"- 评测集:`{summary['test_path']}`",
f"- 评测阈值:`{summary['threshold']}`",
f"- 测试样本数:`{summary['test_case_count']}`",
f"- 总体准确率:`{summary['accuracy']}`",
"",
"## 训练摘要",
]
if train_summary:
lines.extend(
[
f"- 基座模型:`{train_summary.get('base_model', 'unknown')}`",
f"- 训练集 / 验证集:`{train_summary.get('train_size', 'unknown')} / {train_summary.get('dev_size', 'unknown')}`",
f"- 最佳验证准确率:`{train_summary.get('best_dev_accuracy', 'unknown')}`",
f"- 训练设备:`{train_summary.get('device', 'unknown')}`",
"",
]
)
else:
lines.extend(["- 未找到训练摘要。", ""])
lines.extend(
[
"## 分类别结果",
]
)
for item in per_category:
lines.append(
f"- `{item['category']}`: {item['correct']}/{item['total']} = {item['accuracy']}"
)
lines.extend(["", "## 分标签结果"])
for item in per_label:
lines.append(
f"- `{item['label']}` ({item['category']}): {item['correct']}/{item['total']} = {item['accuracy']}"
)
lines.extend(["", "## 错误样例"])
if not errors:
lines.append("- 无错误样例。")
else:
for item in errors[:10]:
lines.append(
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望:`{item['expected_label']}` | 预测:`{item['predicted_label']}` | 分数:`{item['score']}`"
)
lines.extend(
[
"",
"## 结论",
"- 当前本地 MacBERT 已具备较强的业务意图识别能力,可作为本地快链路分类器。",
"- 误判主要集中在方向相反或语义接近的控制指令,下一步应补充对抗样本和真实口语表达。",
"- 上线前建议继续补充 ASR 错字、多轮短句和多意图子句级样本。",
"",
]
)
return "\n".join(lines)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,123 @@
from __future__ import annotations
import argparse
import json
from pathlib import Path
import sys
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from scripts.train_local_bert_multi_intent import (
BATCH_SIZE,
OUTPUT_DIR,
TOP_K,
THRESHOLD,
MultiLabelIntentDataset,
load_all_samples,
split_samples,
set_seed,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate local BERT multi-intent detector.")
parser.add_argument("--model-path", default=str(OUTPUT_DIR), help="Path to trained multi-intent model.")
parser.add_argument("--threshold", type=float, default=THRESHOLD, help="Probability threshold.")
parser.add_argument("--top-k", type=int, default=TOP_K, help="Top-k for recall@k.")
parser.add_argument(
"--dataset",
choices=("dev", "all"),
default="dev",
help="Evaluate on the held-out dev split or all combined samples.",
)
return parser.parse_args()
def compute_metrics(
probabilities: list[list[float]],
targets: list[list[float]],
threshold: float,
top_k: int,
) -> dict[str, float]:
true_positive = 0
false_positive = 0
false_negative = 0
exact_match = 0
recall_at_k_total = 0.0
total = len(probabilities)
for scores, target in zip(probabilities, targets):
predicted = {index for index, score in enumerate(scores) if score >= threshold}
expected = {index for index, value in enumerate(target) if value >= 0.5}
if predicted == expected:
exact_match += 1
true_positive += len(predicted & expected)
false_positive += len(predicted - expected)
false_negative += len(expected - predicted)
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
if expected:
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
return {
"micro_precision": round(precision, 4),
"micro_recall": round(recall, 4),
"micro_f1": round(micro_f1, 4),
"exact_match": round(exact_match / total, 4) if total else 0.0,
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
}
def main() -> None:
args = parse_args()
set_seed(42)
samples = load_all_samples()
_, dev_samples = split_samples(samples)
eval_samples = samples if args.dataset == "all" else dev_samples
model_path = Path(args.model_path)
if not model_path.exists():
raise FileNotFoundError(f"model path not found: {model_path}")
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
label_to_id = {str(label): int(index) for label, index in (model.config.label2id or {}).items()}
if not label_to_id:
raise RuntimeError("label2id is missing from model config")
dataset = MultiLabelIntentDataset(eval_samples, tokenizer, label_to_id)
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
model.eval()
probabilities: list[list[float]] = []
targets: list[list[float]] = []
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
probabilities.extend(torch.sigmoid(outputs.logits).detach().cpu().tolist())
targets.extend(labels.detach().cpu().tolist())
metrics = compute_metrics(probabilities, targets, threshold=args.threshold, top_k=args.top_k)
result = {
"model_path": str(model_path),
"dataset": args.dataset,
"sample_size": len(eval_samples),
"threshold": args.threshold,
"top_k": args.top_k,
"metrics": metrics,
}
print(json.dumps(result, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,247 @@
from __future__ import annotations
import argparse
import json
from collections import Counter, defaultdict
from pathlib import Path
import sys
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.bootstrap import build_intent_registry
from app.services.multi_intent_detector import BertMultiIntentDetector
TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_eval_independent.jsonl"
MODEL_DIR = PROJECT_ROOT / "models/local_bert_multi_intent"
REPORT_DIR = PROJECT_ROOT / "reports"
RESULT_PATH = REPORT_DIR / "bert_multi_intent_independent_result.json"
REPORT_PATH = REPORT_DIR / "bert_multi_intent_independent_report.md"
THRESHOLD = 0.45
TOP_K = 8
MAX_LABELS = 4
def load_cases(file_path: Path) -> list[dict[str, object]]:
cases: list[dict[str, object]] = []
for line in file_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
expected = sorted({str(item).strip() for item in payload.get("expected_intent_ids") or [] if str(item).strip()})
if not expected:
continue
cases.append(
{
"text": str(payload["text"]),
"expected_intent_ids": expected,
"category": str(payload.get("category") or "unknown"),
}
)
return cases
def compute_set_metrics(results: list[dict[str, object]]) -> dict[str, float]:
true_positive = 0
false_positive = 0
false_negative = 0
exact_match = 0
multi_recall_hit = 0
single_false_alarm = 0
total = len(results)
single_guard_total = 0
for item in results:
expected = set(item["expected_intent_ids"])
predicted = set(item["predicted_intent_ids"])
if expected == predicted:
exact_match += 1
true_positive += len(expected & predicted)
false_positive += len(predicted - expected)
false_negative += len(expected - predicted)
if len(expected) >= 2 and expected.issubset(predicted):
multi_recall_hit += 1
if len(expected) == 1:
single_guard_total += 1
if len(predicted) > 1:
single_false_alarm += 1
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
multi_total = sum(1 for item in results if len(item["expected_intent_ids"]) >= 2)
return {
"micro_precision": round(precision, 4),
"micro_recall": round(recall, 4),
"micro_f1": round(micro_f1, 4),
"exact_match": round(exact_match / total, 4) if total else 0.0,
"multi_sentence_recall": round(multi_recall_hit / multi_total, 4) if multi_total else 0.0,
"single_guard_false_alarm_rate": round(single_false_alarm / single_guard_total, 4) if single_guard_total else 0.0,
}
def summarize_by_category(results: list[dict[str, object]]) -> list[dict[str, object]]:
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
for item in results:
grouped[str(item["category"])].append(item)
summary: list[dict[str, object]] = []
for category, items in sorted(grouped.items()):
summary.append(
{
"category": category,
"sample_count": len(items),
"metrics": compute_set_metrics(items),
}
)
return summary
def collect_error_examples(results: list[dict[str, object]], limit: int = 15) -> list[dict[str, object]]:
errors = [item for item in results if set(item["expected_intent_ids"]) != set(item["predicted_intent_ids"])]
def sort_key(item: dict[str, object]) -> tuple[int, int]:
expected = set(item["expected_intent_ids"])
predicted = set(item["predicted_intent_ids"])
miss = len(expected - predicted)
extra = len(predicted - expected)
return (miss + extra, miss)
return sorted(errors, key=sort_key, reverse=True)[:limit]
def top_confusions(results: list[dict[str, object]], limit: int = 12) -> list[dict[str, object]]:
counter: Counter[tuple[str, str]] = Counter()
for item in results:
expected = set(item["expected_intent_ids"])
predicted = set(item["predicted_intent_ids"])
for miss in sorted(expected - predicted):
for extra in sorted(predicted - expected):
counter[(miss, extra)] += 1
return [
{"expected_missing": pair[0], "wrong_extra": pair[1], "count": count}
for pair, count in counter.most_common(limit)
]
def main() -> None:
parser = argparse.ArgumentParser(description="本地多标签 detector 独立评测脚本")
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="独立评测集路径")
parser.add_argument("--model-path", type=str, default=str(MODEL_DIR), help="多标签模型路径")
parser.add_argument("--threshold", type=float, default=THRESHOLD, help="检测阈值")
parser.add_argument("--top-k", type=int, default=TOP_K, help="输出 top-k 原始分数")
parser.add_argument("--max-labels", type=int, default=MAX_LABELS, help="最多返回标签数")
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化结果输出路径")
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 报告输出路径")
args = parser.parse_args()
cases = load_cases(Path(args.test_path))
intents = build_intent_registry().list()
detector = BertMultiIntentDetector(
model_path=args.model_path,
threshold=args.threshold,
top_k=args.top_k,
max_labels=args.max_labels,
)
results: list[dict[str, object]] = []
for case in cases:
detection = detector.detect(str(case["text"]), intents)
predicted = [candidate.intent_id for candidate in detection.candidates]
raw_top = [
{
"intent_id": str(item.get("intent_id") or item.get("label") or ""),
"score": round(float(item.get("score", 0.0)), 4),
}
for item in detection.raw_scores
]
results.append(
{
"text": case["text"],
"category": case["category"],
"expected_intent_ids": case["expected_intent_ids"],
"predicted_intent_ids": predicted,
"detected": detection.detected,
"backend_name": detection.backend_name,
"reason": detection.reason,
"raw_top_scores": raw_top,
}
)
summary = {
"model_path": args.model_path,
"test_path": args.test_path,
"threshold": args.threshold,
"top_k": args.top_k,
"max_labels": args.max_labels,
"sample_count": len(results),
"metrics": compute_set_metrics(results),
"per_category": summarize_by_category(results),
"top_confusions": top_confusions(results),
"error_examples": collect_error_examples(results),
"results": results,
}
REPORT_DIR.mkdir(parents=True, exist_ok=True)
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
print(json.dumps({"sample_count": len(results), "metrics": summary["metrics"]}, ensure_ascii=False))
def render_report(summary: dict[str, object]) -> str:
metrics = summary["metrics"]
per_category = summary["per_category"]
confusions = summary["top_confusions"]
errors = summary["error_examples"]
lines = [
"# 本地多标签 Detector 独立评测报告",
"",
"## 概览",
f"- 模型目录:`{summary['model_path']}`",
f"- 评测集:`{summary['test_path']}`",
f"- 样本数:`{summary['sample_count']}`",
f"- 阈值 / top_k / max_labels`{summary['threshold']} / {summary['top_k']} / {summary['max_labels']}`",
f"- `micro_precision``{metrics['micro_precision']}`",
f"- `micro_recall``{metrics['micro_recall']}`",
f"- `micro_f1``{metrics['micro_f1']}`",
f"- `exact_match``{metrics['exact_match']}`",
f"- `multi_sentence_recall``{metrics['multi_sentence_recall']}`",
f"- `single_guard_false_alarm_rate``{metrics['single_guard_false_alarm_rate']}`",
"",
"## 分类别结果",
]
for item in per_category:
category_metrics = item["metrics"]
lines.append(
f"- `{item['category']}`: count={item['sample_count']} micro_f1={category_metrics['micro_f1']} exact_match={category_metrics['exact_match']}"
)
lines.extend(["", "## 主要混淆"])
if not confusions:
lines.append("- 未发现明显混淆对。")
else:
for item in confusions:
lines.append(
f"- 漏掉 `{item['expected_missing']}`,同时误报 `{item['wrong_extra']}``{item['count']}` 次"
)
lines.extend(["", "## 错误样例"])
if not errors:
lines.append("- 无错误样例。")
else:
for item in errors:
lines.append(
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望:`{item['expected_intent_ids']}` | 预测:`{item['predicted_intent_ids']}`"
)
lines.extend(
[
"",
"## 结论建议",
"- 先看多意图句是否存在系统性漏召回,再看单意图是否被误报成多意图。",
"- 若 `single_guard_false_alarm_rate` 偏高,需要先收紧 detector 阈值或补单意图负样本,再考虑进入 NER。",
"- 若 `multi_sentence_recall` 不稳定,应继续补条件句、弱连接句和口语化多动作语料。",
"",
]
)
return "\n".join(lines)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.config import settings
from app.core.bootstrap import build_intent_registry
from app.services.classifier import BertIntentClassifier
from app.services.router import build_matcher_pipeline
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
DEFAULT_LABEL_MAP = DEFAULT_MODEL_DIR / "label_map.json"
def build_classifier(threshold: float, top_k: int) -> BertIntentClassifier:
return BertIntentClassifier(
model_path=str(DEFAULT_MODEL_DIR),
threshold=threshold,
label_map_path=str(DEFAULT_LABEL_MAP),
fallback=None,
top_k=top_k,
)
def predict_once(text: str, threshold: float, top_k: int) -> dict[str, object]:
classifier = build_classifier(threshold=threshold, top_k=top_k)
registry = build_intent_registry()
intents = registry.list()
result = classifier.predict(text, intents)
matcher = build_matcher_pipeline(
registry,
["classifier"],
classifier=classifier,
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
clarify_margin_threshold=settings.local_clarify_margin_threshold,
)
route_result = matcher.match(text)
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
return {
"text": text,
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
"score": round(result.score, 4),
"model_name": result.model_name,
"backend": result.backend_name,
"raw_label": result.raw_label,
"fallback_reason": result.fallback_reason,
"error_message": result.error_message,
"decision": route_result.debug.decision,
"decision_reason": route_result.debug.decision_reason,
"confidence_grade": route_result.debug.confidence_grade,
"unknown_detected": route_result.debug.unknown_detected,
"fusion_top_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
"top_candidates": [
{"intent_id": intent.intent_id, "score": round(score, 4)}
for intent, score in (result.candidates or [])
],
}
def interactive_loop(threshold: float, top_k: int) -> None:
print("本地 BERT 意图测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
while True:
try:
text = input("\n请输入问题> ").strip()
except EOFError:
print()
break
if not text:
continue
if text.lower() in {"exit", "quit", "q"}:
break
result = predict_once(text, threshold=threshold, top_k=top_k)
print(json.dumps(result, ensure_ascii=False, indent=2))
def main() -> None:
parser = argparse.ArgumentParser(description="本地 BERT 意图识别测试脚本")
parser.add_argument("--text", type=str, default="", help="单次测试文本")
parser.add_argument("--threshold", type=float, default=0.0, help="BERT 置信度阈值")
parser.add_argument("--top-k", type=int, default=3, help="返回候选数量")
args = parser.parse_args()
if args.text.strip():
print(json.dumps(predict_once(args.text.strip(), threshold=args.threshold, top_k=args.top_k), ensure_ascii=False, indent=2))
return
interactive_loop(threshold=args.threshold, top_k=args.top_k)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,500 @@
from __future__ import annotations
import json
import random
import re
import sys
from dataclasses import dataclass
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import yaml
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.services.joint_nlu import JointBertForNLU
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
MULTI_TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
SEED_PATH = PROJECT_ROOT / "app/data/joint_nlu_seed.jsonl"
EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval.jsonl"
MULTI_EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_multilabel_eval.jsonl"
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
OUTPUT_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
MAX_LENGTH = 64
BATCH_SIZE = 8
EPOCHS = 8
LEARNING_RATE = 2e-5
SEED = 42
IGNORE_INDEX = -100
GENRE_KEYWORDS = ("轻音乐", "摇滚", "古典", "民谣", "爵士", "流行", "儿歌")
DEFAULT_INTENT_THRESHOLD = 0.3
MULTI_INTENT_REPEAT = 6
THRESHOLD_CANDIDATES = [0.1, 0.12, 0.15, 0.18, 0.2, 0.22, 0.25, 0.28, 0.3, 0.33, 0.35, 0.38, 0.4, 0.45]
@dataclass
class JointSample:
text: str
intent_ids: list[str]
slots: list[dict[str, object]]
class JointDataset(Dataset):
def __init__(
self,
samples: list[JointSample],
tokenizer,
intent_to_index: dict[str, int],
slot_to_index: dict[str, int],
) -> None:
self._samples = samples
self._tokenizer = tokenizer
self._intent_to_index = intent_to_index
self._slot_to_index = slot_to_index
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
sample = self._samples[index]
encoded = self._tokenizer(
sample.text,
truncation=True,
max_length=MAX_LENGTH,
padding="max_length",
return_offsets_mapping=True,
return_tensors="pt",
)
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
slot_labels = [IGNORE_INDEX] * len(offset_mapping)
char_labels = ["O"] * len(sample.text)
for slot in sample.slots:
start = int(slot["start"])
end = int(slot["end"])
slot_name = str(slot["slot_name"])
if start < 0 or end > len(sample.text) or start >= end:
continue
char_labels[start] = f"B-{slot_name}"
for pos in range(start + 1, end):
char_labels[pos] = f"I-{slot_name}"
for token_index, (start, end) in enumerate(offset_mapping):
if end <= start:
continue
label = char_labels[start]
slot_labels[token_index] = self._slot_to_index.get(label, self._slot_to_index["O"])
intent_vector = torch.zeros(len(self._intent_to_index), dtype=torch.float32)
for intent_id in sample.intent_ids:
if intent_id in self._intent_to_index:
intent_vector[self._intent_to_index[intent_id]] = 1.0
return {
"input_ids": encoded["input_ids"][0],
"attention_mask": encoded["attention_mask"][0],
"intent_labels": intent_vector,
"slot_labels": torch.tensor(slot_labels, dtype=torch.long),
}
def set_seed() -> None:
random.seed(SEED)
torch.manual_seed(SEED)
def load_jsonl(path: Path) -> list[dict[str, object]]:
rows: list[dict[str, object]] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def find_order_id_span(text: str) -> tuple[str, int, int] | None:
match = re.search(r"[A-Za-z]\d{5,}", text)
if not match:
return None
return match.group(0), match.start(), match.end()
def find_temperature_span(text: str) -> tuple[str, int, int] | None:
match = re.search(r"(\d{2}\s*度)", text)
if not match:
return None
return match.group(1), match.start(), match.end()
def find_destination_span(text: str) -> tuple[str, int, int] | None:
for pattern in (
r"导航去(?P<destination>.+)",
r"导航到(?P<destination>.+)",
r"带我去(?P<destination>.+)",
r"送我去(?P<destination>.+)",
r"去(?P<destination>.+)",
):
match = re.search(pattern, text)
if not match:
continue
destination = re.split(r"(?:然后|并且|同时|再|,|||;)", match.group("destination"), maxsplit=1)[0].strip(" ,。")
if not destination:
continue
start = text.find(destination)
if start >= 0:
return destination, start, start + len(destination)
return None
def find_music_span(text: str) -> tuple[str, str, int, int] | None:
for genre in GENRE_KEYWORDS:
start = text.find(genre)
if start >= 0:
return "genre", genre, start, start + len(genre)
for trigger in ("播放", "来点", "放点", "", "来首", "来一首", "放一首"):
if trigger not in text:
continue
target = text.split(trigger, maxsplit=1)[-1]
target = re.split(r"(?:然后|并且|同时|再|,|||;)", target, maxsplit=1)[0].strip(" 的一首首个歌曲音乐吧呀啊,。")
if not target or target in {"", "音乐"}:
continue
for genre in GENRE_KEYWORDS:
if genre in target:
start = text.find(genre)
return "genre", genre, start, start + len(genre)
start = text.find(target)
if start >= 0:
return "song", target, start, start + len(target)
return None
def annotate_slots(text: str, intent_id: str) -> list[dict[str, object]]:
slots: list[dict[str, object]] = []
if intent_id in {"cs_query_order", "cs_query_logistics", "cs_cancel_order"}:
matched = find_order_id_span(text)
if matched is not None:
value, start, end = matched
slots.append({"slot_name": "order_id", "value": value, "start": start, "end": end})
elif intent_id == "cabin_set_ac":
matched = find_temperature_span(text)
if matched is not None:
value, start, end = matched
slots.append({"slot_name": "temperature", "value": value, "start": start, "end": end})
elif intent_id == "cabin_nav_to":
matched = find_destination_span(text)
if matched is not None:
value, start, end = matched
slots.append({"slot_name": "destination", "value": value, "start": start, "end": end})
elif intent_id == "cabin_play_music":
matched = find_music_span(text)
if matched is not None:
slot_name, value, start, end = matched
slots.append({"slot_name": slot_name, "value": value, "start": start, "end": end})
return slots
def annotate_slots_for_intents(text: str, intent_ids: list[str]) -> list[dict[str, object]]:
merged: list[dict[str, object]] = []
seen: set[tuple[str, int, int]] = set()
for intent_id in intent_ids:
for slot in annotate_slots(text, intent_id):
key = (str(slot["slot_name"]), int(slot["start"]), int(slot["end"]))
if key in seen:
continue
seen.add(key)
merged.append(slot)
merged.sort(key=lambda item: (int(item["start"]), int(item["end"])))
return merged
def build_train_samples() -> list[JointSample]:
samples: list[JointSample] = []
seen: set[tuple[str, tuple[str, ...]]] = set()
domain_data = yaml.safe_load(DOMAIN_PATH.read_text(encoding="utf-8")) or {}
for intent in domain_data.get("intents", []):
intent_id = str(intent.get("intent_id", "")).strip()
if not intent_id:
continue
for text in list(intent.get("examples", [])) + list(intent.get("keywords", [])):
text = str(text).strip()
if not text:
continue
key = (text, (intent_id,))
if key in seen:
continue
seen.add(key)
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
for row in load_jsonl(TRAIN_PATH):
text = str(row["text"])
intent_id = str(row["intent_id"])
key = (text, (intent_id,))
if key in seen:
continue
seen.add(key)
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
for row in load_jsonl(SEED_PATH):
text = str(row["text"])
intent_id = str(row["intent_id"])
key = (text, (intent_id,))
if key in seen:
continue
seen.add(key)
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=list(row.get("slots", []))))
for row in load_jsonl(MULTI_TRAIN_PATH):
text = str(row["text"]).strip()
intent_ids = sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()})
if not text or not intent_ids:
continue
key = (text, tuple(intent_ids))
if key in seen:
continue
seen.add(key)
slots = list(row.get("slots") or annotate_slots_for_intents(text, intent_ids))
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=slots))
if len(intent_ids) >= 2:
for _ in range(MULTI_INTENT_REPEAT - 1):
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=list(slots)))
random.shuffle(samples)
return samples
def build_eval_samples() -> list[JointSample]:
rows = load_jsonl(EVAL_PATH)
samples = [
JointSample(
text=str(row["text"]),
intent_ids=[str(row["intent_id"])],
slots=list(row.get("slots", [])),
)
for row in rows
]
if MULTI_EVAL_PATH.exists():
for row in load_jsonl(MULTI_EVAL_PATH):
samples.append(
JointSample(
text=str(row["text"]),
intent_ids=sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()}),
slots=list(row.get("slots") or annotate_slots_for_intents(str(row["text"]), list(row.get("intent_ids", [])))),
)
)
return samples
def build_slot_labels(samples: list[JointSample]) -> list[str]:
slot_names = sorted({str(slot["slot_name"]) for sample in samples for slot in sample.slots})
labels = ["O"]
for name in slot_names:
labels.append(f"B-{name}")
labels.append(f"I-{name}")
return labels
def compute_metrics(
model: JointBertForNLU,
dataloader: DataLoader,
device: torch.device,
intent_labels: list[str],
slot_labels: list[str],
threshold: float,
) -> dict[str, float]:
model.eval()
intent_tp = 0
intent_fp = 0
intent_fn = 0
single_intent_correct = 0
single_intent_total = 0
intent_exact_match = 0
correct_slot_tokens = 0
total_slot_tokens = 0
exact_slot_samples = 0
total_samples = 0
with torch.no_grad():
for batch in dataloader:
batch = {key: value.to(device) for key, value in batch.items()}
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
predicted_probs = torch.sigmoid(intent_logits)
predicted_multi = predicted_probs >= threshold
gold_multi = batch["intent_labels"] > 0.5
intent_tp += int((predicted_multi & gold_multi).sum().item())
intent_fp += int((predicted_multi & ~gold_multi).sum().item())
intent_fn += int((~predicted_multi & gold_multi).sum().item())
intent_exact_match += int((predicted_multi == gold_multi).all(dim=1).sum().item())
top_predicted = torch.argmax(predicted_probs, dim=-1)
gold_counts = gold_multi.sum(dim=-1)
single_mask = gold_counts == 1
if int(single_mask.sum().item()) > 0:
gold_top = torch.argmax(gold_multi.float(), dim=-1)
single_intent_correct += int((top_predicted[single_mask] == gold_top[single_mask]).sum().item())
single_intent_total += int(single_mask.sum().item())
predicted_slots = torch.argmax(slot_logits, dim=-1)
mask = batch["slot_labels"] != IGNORE_INDEX
correct_slot_tokens += int(((predicted_slots == batch["slot_labels"]) & mask).sum().item())
total_slot_tokens += int(mask.sum().item())
for index in range(batch["slot_labels"].size(0)):
gold = batch["slot_labels"][index][mask[index]]
pred = predicted_slots[index][mask[index]]
exact_slot_samples += int(torch.equal(gold, pred))
total_samples += 1
precision = intent_tp / max(intent_tp + intent_fp, 1)
recall = intent_tp / max(intent_tp + intent_fn, 1)
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
return {
"intent_threshold": round(threshold, 4),
"intent_micro_precision": round(precision, 4),
"intent_micro_recall": round(recall, 4),
"intent_micro_f1": round(f1, 4),
"intent_exact_match": round(intent_exact_match / max(total_samples, 1), 4),
"single_intent_top1_accuracy": round(single_intent_correct / max(single_intent_total, 1), 4),
"slot_token_accuracy": round(correct_slot_tokens / max(total_slot_tokens, 1), 4),
"slot_exact_match": round(exact_slot_samples / max(total_samples, 1), 4),
"intent_label_count": float(len(intent_labels)),
"slot_label_count": float(len(slot_labels)),
}
def search_best_threshold(
model: JointBertForNLU,
dataloader: DataLoader,
device: torch.device,
intent_labels: list[str],
slot_labels: list[str],
) -> dict[str, float]:
best_metrics: dict[str, float] | None = None
for threshold in THRESHOLD_CANDIDATES:
metrics = compute_metrics(
model,
dataloader,
device,
intent_labels,
slot_labels,
threshold=threshold,
)
if best_metrics is None:
best_metrics = metrics
continue
current_score = (metrics["intent_micro_f1"], metrics["intent_exact_match"], metrics["slot_exact_match"])
best_score = (
best_metrics["intent_micro_f1"],
best_metrics["intent_exact_match"],
best_metrics["slot_exact_match"],
)
if current_score > best_score:
best_metrics = metrics
assert best_metrics is not None
return best_metrics
def build_pos_weight(samples: list[JointSample], intent_labels: list[str]) -> torch.Tensor:
positive_counts = {label: 0 for label in intent_labels}
for sample in samples:
sample_intents = set(sample.intent_ids)
for label in intent_labels:
if label in sample_intents:
positive_counts[label] += 1
total = max(len(samples), 1)
weights: list[float] = []
for label in intent_labels:
positives = max(positive_counts[label], 1)
negatives = max(total - positives, 1)
weight = negatives / positives
weights.append(min(max(weight, 1.0), 12.0))
return torch.tensor(weights, dtype=torch.float32)
def main() -> None:
set_seed()
train_samples = build_train_samples()
eval_samples = build_eval_samples()
intent_labels = sorted({intent_id for sample in train_samples + eval_samples for intent_id in sample.intent_ids})
slot_labels = build_slot_labels(train_samples + eval_samples)
intent_to_index = {label: index for index, label in enumerate(intent_labels)}
slot_to_index = {label: index for index, label in enumerate(slot_labels)}
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_BASE_MODEL)
train_dataset = JointDataset(train_samples, tokenizer, intent_to_index, slot_to_index)
eval_dataset = JointDataset(eval_samples, tokenizer, intent_to_index, slot_to_index)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = JointBertForNLU(
base_model_name=DEFAULT_BASE_MODEL,
num_intents=len(intent_labels),
num_slot_labels=len(slot_labels),
)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
pos_weight = build_pos_weight(train_samples, intent_labels).to(device)
best_metrics: dict[str, float] | None = None
best_state: dict[str, torch.Tensor] | None = None
for epoch in range(EPOCHS):
model.train()
epoch_loss = 0.0
for batch in train_loader:
batch = {key: value.to(device) for key, value in batch.items()}
optimizer.zero_grad()
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
intent_loss = torch.nn.functional.binary_cross_entropy_with_logits(
intent_logits,
batch["intent_labels"],
pos_weight=pos_weight,
)
slot_loss = torch.nn.functional.cross_entropy(
slot_logits.view(-1, slot_logits.size(-1)),
batch["slot_labels"].view(-1),
ignore_index=IGNORE_INDEX,
)
loss = intent_loss + slot_loss
loss.backward()
optimizer.step()
epoch_loss += float(loss.item())
metrics = search_best_threshold(model, eval_loader, device, intent_labels, slot_labels)
metrics["train_loss"] = round(epoch_loss / max(len(train_loader), 1), 4)
print(json.dumps({"epoch": epoch + 1, **metrics}, ensure_ascii=False))
if best_metrics is None or metrics["intent_micro_f1"] > best_metrics["intent_micro_f1"]:
best_metrics = metrics
best_state = {key: value.detach().cpu() for key, value in model.state_dict().items()}
if best_state is None or best_metrics is None:
raise RuntimeError("joint nlu training did not produce a best checkpoint")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
tokenizer.save_pretrained(OUTPUT_DIR)
torch.save(best_state, OUTPUT_DIR / "model_state.pt")
config = {
"base_model_name": DEFAULT_BASE_MODEL,
"intent_task": "multi_label",
"intent_labels": intent_labels,
"slot_labels": slot_labels,
"max_length": MAX_LENGTH,
"intent_threshold": float(best_metrics["intent_threshold"]),
"multi_intent_threshold": float(best_metrics["intent_threshold"]),
"max_multi_intents": 4,
}
(OUTPUT_DIR / "joint_nlu_config.json").write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")
(OUTPUT_DIR / "train_summary.json").write_text(
json.dumps(
{
"train_size": len(train_samples),
"eval_size": len(eval_samples),
"metrics": best_metrics,
},
ensure_ascii=False,
indent=2,
),
encoding="utf-8",
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,684 @@
from __future__ import annotations
import json
import os
import random
from dataclasses import dataclass
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import yaml
PROJECT_ROOT = Path(__file__).resolve().parents[1]
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_intent"
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
MAX_LENGTH = 48
BATCH_SIZE = 8
EPOCHS = 16
LEARNING_RATE = 2e-5
SEED = 42
ORDER_IDS = ["A123456", "A700001", "A800002", "A900005", "A202501", "A808001"]
DESTINATIONS = ["公司停车场", "浦东机场", "徐家汇", "虹桥机场", "最近的充电站", "南京东路"]
TEMPERATURES = [18, 20, 21, 22, 23, 24, 26]
SONGS = ["夜曲", "稻香", "青花瓷", "晴天", "告白气球"]
GENRES = ["轻音乐", "摇滚", "古典音乐", "民谣", "爵士"]
SOCIAL_LABEL = "__social__"
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
TEMPLATES: dict[str, list[str]] = {
"cs_query_order": [
"查一下订单{order_id}现在什么状态",
"我的订单{order_id}到哪一步了",
"帮我看看{order_id}这个订单",
"确认下{order_id}订单状态",
"订单{order_id}现在处理到哪里",
"看下{order_id}这单进度",
"订单号{order_id}目前怎么样",
"帮忙确认订单{order_id}",
"订单{order_id}有结果了吗",
"帮我追一下订单{order_id}",
"订单{order_id}现在受理了吗",
"看看{order_id}这单现在啥情况",
"帮我查查{order_id}订单进展",
"{order_id}这个订单处理好了没",
"{order_id}这笔订单现在进展到哪了",
],
"cs_query_logistics": [
"快递{order_id}到哪儿了",
"帮我查{order_id}物流进度",
"看看{order_id}配送状态",
"订单{order_id}物流更新了吗",
"查询{order_id}的快递信息",
"我的{order_id}现在派送到哪了",
"查一下{order_id}这单物流",
"配送单{order_id}走到哪里了",
"帮我看下{order_id}快递到没到",
"物流单号{order_id}现在在哪",
"订单{order_id}物流到哪一步了",
"{order_id}这单现在派件了吗",
"帮我追踪{order_id}运输轨迹",
"{order_id}快件现在运到哪里了",
"我想看{order_id}的配送更新",
],
"cs_cancel_order": [
"帮我取消{order_id}这个订单",
"{order_id}别要了给我撤销",
"把订单{order_id}取消掉",
"我不要{order_id}",
"撤销一下{order_id}订单",
"订单{order_id}不要发了",
"帮我把{order_id}退掉并取消",
"{order_id}这一单停掉",
"{order_id}这单直接取消",
"订单号{order_id}撤回一下",
"订单{order_id}我不想要了",
"{order_id}这笔订单先别发了",
"{order_id}这单给我撤单",
"订单{order_id}停掉吧",
"{order_id}这个快给我取消了",
],
"cs_transfer_human": [
"我要找人工客服处理",
"现在转人工",
"麻烦给我接人工服务",
"帮我呼叫真人客服",
"别机器人了我要人工",
"转真人客服",
"我要人工坐席",
"帮我接人工处理",
"叫人工客服来",
"直接给我转人工",
"这个问题给我人工跟进",
"安排真人客服接手",
"机器人处理不了,转人工",
"帮我叫个客服专员",
"我要人工来处理这事",
],
"cabin_nav_to": [
"导航到{destination}",
"带我去{destination}",
"我要去{destination}",
"{destination}",
"开导航去{destination}",
"帮我导航到{destination}",
"送我去{destination}",
"现在去{destination}",
"带路到{destination}",
"去一下{destination}",
"规划路线去{destination}",
"直接开去{destination}",
"给我导到{destination}",
"{destination}怎么走,导航一下",
"出发去{destination}",
],
"cabin_set_ac": [
"把空调设到{temperature}",
"车里温度调成{temperature}",
"冷气开到{temperature}",
"空调给我调到{temperature}",
"温度改成{temperature}",
"车内设成{temperature}",
"把温度打到{temperature}",
"空调调为{temperature}",
"帮我把车里调成{temperature}",
"冷风调到{temperature}",
"把车内温度设为{temperature}",
"空调温度改到{temperature}",
"冷气帮我调到{temperature}",
"舱内调成{temperature}",
"给我把温度定在{temperature}",
"把车里弄凉快点",
"车里太热了,降一点",
"把里面调凉快一点",
"有点热,降温",
"空调再冷一点",
"车内温度低一点",
"把里面弄暖和点",
"车里太冷了,升一点温度",
],
"cabin_ac_on": [
"把空调打开",
"开一下冷气",
"把冷风开起来",
"车里热,空调开开",
"打开制冷",
"空调启动一下",
],
"cabin_window_open": [
"把车窗打开",
"开下窗",
"窗户开一点",
"帮我透透气",
"车里太闷了,开下窗",
"顺便开下车窗",
"把窗户降一点",
"把玻璃打开一点",
],
"cabin_window_close": [
"把车窗关上",
"窗户关一下",
"把窗升起来",
"外面太吵了,把窗关了",
"把窗户关严",
],
"cabin_fan_down": [
"风别这么大",
"风小一点",
"别吹这么猛",
"把风量调小一点",
"出风弱一点",
],
"cabin_fan_up": [
"风再大一点",
"把风量开大点",
"出风强一点",
"风不够,调大些",
],
"cabin_defog_front_on": [
"前挡起雾了,除一下",
"把前挡风玻璃雾气清掉",
"前窗看不清了,开除雾",
],
"cabin_defog_rear_on": [
"后挡有雾,开下除雾",
"后玻璃起雾了,清一下",
"后窗看不清了,除雾",
],
"cabin_play_music": [
"播放一首{genre}",
"来点{genre}",
"我想听{genre}",
"给我播点{genre}",
"放一首{song}",
"来一首{song}",
"播放{song}",
"放点音乐,来个{genre}",
"我想听首{song}",
"给我来点歌,放{song}",
"随机放点{genre}",
"帮我播首{song}",
"来点适合开车听的{genre}",
"打开音乐,放{song}",
"给我放一些{genre}",
"放点歌",
"来首歌",
"整点音乐",
"车里放点歌",
"来点能听的",
],
SOCIAL_LABEL: [
"你好",
"",
"哈喽",
"早上好",
"晚上好",
"你叫什么名字",
"你是谁",
"你能做什么",
"今天天气不错",
"陪我聊聊天",
],
OUT_OF_SCOPE_LABEL: [
"帮我点个外卖",
"订一张去北京的机票",
"帮我买杯咖啡",
"给我订一家酒店",
"人类诞生的意义是什么",
"帮我写一份年终总结",
"推荐一部电影",
"讲个笑话",
"帮我做一道数学题",
"去美团叫个外卖",
],
}
INTENT_REPLACEMENTS: dict[str, list[tuple[str, str]]] = {
"cs_query_order": [
("订单", "这单"),
("查一下", "看一下"),
("帮我", "麻烦帮我"),
("现在什么状态", "现在啥状态"),
("处理到哪里", "进展到哪里"),
],
"cs_query_logistics": [
("物流", "快递"),
("快递", "配送"),
("配送", "派送"),
("帮我", "麻烦帮我"),
("现在在哪", "现在到哪了"),
],
"cs_cancel_order": [
("取消", "撤销"),
("撤销", "撤单"),
("订单", "这单"),
("帮我", "麻烦帮我"),
("不要发了", "别发了"),
],
"cs_transfer_human": [
("人工客服", "真人客服"),
("人工", "人工坐席"),
("帮我", "麻烦帮我"),
],
"cabin_nav_to": [
("导航", "带路"),
("带我", "送我"),
("", "前往"),
],
"cabin_set_ac": [
("空调", "车里温度"),
("调到", "设到"),
("温度", "车内温度"),
("凉快点", "冷一点"),
("暖和点", "热一点"),
],
"cabin_ac_on": [
("空调", "冷气"),
("打开", ""),
("冷风", "制冷"),
],
"cabin_window_open": [
("车窗", "窗户"),
("打开", ""),
("透透气", "通通风"),
],
"cabin_window_close": [
("关上", "关掉"),
("车窗", "窗户"),
("关严", "关好"),
],
"cabin_fan_down": [
("风量", ""),
("调小", "调低"),
("别吹这么猛", "风小一点"),
],
"cabin_fan_up": [
("风量", ""),
("调大", "调高"),
],
"cabin_defog_front_on": [
("前挡", "前窗"),
("除雾", "除一下"),
],
"cabin_defog_rear_on": [
("后挡", "后窗"),
("除雾", "清一下雾"),
],
"cabin_play_music": [
("播放", ""),
("来点", "播点"),
("我想听", "给我来点"),
("放点歌", "来首歌"),
],
SOCIAL_LABEL: [
("你好", "您好"),
("哈喽", "hello"),
("你叫什么名字", "怎么称呼你"),
],
OUT_OF_SCOPE_LABEL: [
("点个外卖", "叫个外卖"),
("订一家酒店", "订个酒店"),
("讲个笑话", "说个笑话"),
],
}
@dataclass
class Sample:
text: str
intent_id: str
HARD_NEGATIVE_RAW_SAMPLES: list[tuple[str, str]] = [
("订单A700101物流到哪了", "cs_query_logistics"),
("帮我看下订单A700102配送到哪里了", "cs_query_logistics"),
("订单A700103现在派件了吗", "cs_query_logistics"),
("A700104这单物流有没有更新", "cs_query_logistics"),
("查一下订单A700105运输轨迹", "cs_query_logistics"),
("订单A700106不要了帮我撤单", "cs_cancel_order"),
("A700107这单别发了直接取消", "cs_cancel_order"),
("把订单A700108停掉吧", "cs_cancel_order"),
("A700109这个订单我不想要了", "cs_cancel_order"),
("订单A700110给我撤回", "cs_cancel_order"),
("订单A700111现在受理了吗", "cs_query_order"),
("帮我看看A700112这单处理得怎么样了", "cs_query_order"),
("A700113订单目前进展如何", "cs_query_order"),
("查下订单A700114现在什么情况", "cs_query_order"),
("帮我确认订单A700115是否已经处理", "cs_query_order"),
("你好呀", SOCIAL_LABEL),
("嗨,在吗", SOCIAL_LABEL),
("今天天气真不错", SOCIAL_LABEL),
("你叫什么名字呀", SOCIAL_LABEL),
("你是做什么的", SOCIAL_LABEL),
("陪我随便聊聊", SOCIAL_LABEL),
("帮我透透气", "cabin_window_open"),
("车里太闷了,开下窗", "cabin_window_open"),
("把窗户降一点", "cabin_window_open"),
("前挡起雾了,除一下", "cabin_defog_front_on"),
("后挡有雾,开下除雾", "cabin_defog_rear_on"),
("把车里弄凉快点", "cabin_set_ac"),
("车里太热了,降一点", "cabin_set_ac"),
("空调再冷一点", "cabin_set_ac"),
("把里面弄暖和点", "cabin_set_ac"),
("风别这么大", "cabin_fan_down"),
("别吹这么猛", "cabin_fan_down"),
("风再大一点", "cabin_fan_up"),
("放点歌", "cabin_play_music"),
("来首歌", "cabin_play_music"),
("整点音乐", "cabin_play_music"),
("透透气", "cabin_window_open"),
("通通风", "cabin_window_open"),
("车里太闷了", "cabin_window_open"),
("凉快点", "cabin_set_ac"),
("暖和点", "cabin_set_ac"),
("帮我点一份麻辣烫", OUT_OF_SCOPE_LABEL),
("给我订今晚的酒店", OUT_OF_SCOPE_LABEL),
("帮我买张电影票", OUT_OF_SCOPE_LABEL),
("人为什么会做梦", OUT_OF_SCOPE_LABEL),
("帮我做个旅游攻略", OUT_OF_SCOPE_LABEL),
("帮我点肯德基外卖", OUT_OF_SCOPE_LABEL),
("透透气,别给我除雾", "cabin_window_open"),
("后挡有雾,不是开窗,是除雾", "cabin_defog_rear_on"),
("前挡看不清了,除雾不要开窗", "cabin_defog_front_on"),
("凉快点,不是把风量调小", "cabin_set_ac"),
("别吹这么猛,不是降温", "cabin_fan_down"),
("来首歌,不是切下一首", "cabin_play_music"),
("放点歌,不是暂停音乐", "cabin_play_music"),
]
HARD_NEGATIVE_SAMPLES: list[Sample] = [
Sample(text=text, intent_id=intent_id) for text, intent_id in HARD_NEGATIVE_RAW_SAMPLES
]
class IntentDataset(Dataset):
def __init__(self, samples: list[Sample], tokenizer, label_to_id: dict[str, int]) -> None:
self._samples = samples
self._tokenizer = tokenizer
self._label_to_id = label_to_id
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
sample = self._samples[index]
encoded = self._tokenizer(
sample.text,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt",
)
return {
"input_ids": encoded["input_ids"].squeeze(0),
"attention_mask": encoded["attention_mask"].squeeze(0),
"labels": torch.tensor(self._label_to_id[sample.intent_id], dtype=torch.long),
}
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def load_samples(file_path: Path) -> list[Sample]:
samples: list[Sample] = []
if not file_path.exists():
return samples
for line in file_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
samples.append(Sample(text=str(payload["text"]), intent_id=str(payload["intent_id"])))
return samples
def load_domain_samples(file_path: Path) -> list[Sample]:
if not file_path.exists():
return []
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
intents = payload.get("intents", [])
samples: list[Sample] = []
seen: set[tuple[str, str]] = set()
for item in intents:
intent_id = str(item.get("intent_id") or "").strip()
if not intent_id:
continue
seed_texts = list(item.get("examples") or [])
seed_texts.extend(item.get("keywords") or [])
label = str(item.get("label") or "").strip()
if label:
seed_texts.append(label)
for text in seed_texts:
normalized = str(text).strip()
if not normalized:
continue
for variant in expand_seed_variants(normalized):
key = (variant, intent_id)
if key in seen:
continue
seen.add(key)
samples.append(Sample(text=variant, intent_id=intent_id))
return samples
def expand_seed_variants(text: str) -> list[str]:
normalized = text.strip().strip(",。!?;; ")
if not normalized:
return []
variants = {
normalized,
normalized.replace("一下", "").strip(),
normalized.replace("帮我", "").strip(),
normalized.replace("", "").strip(),
f"帮我{normalized}",
f"{normalized}",
f"{normalized}一下",
}
cleaned: list[str] = []
for item in variants:
compact = " ".join(item.split()).strip(",。!?;; ")
if compact:
cleaned.append(compact)
return cleaned
def load_training_samples() -> list[Sample]:
samples = load_samples(TRAIN_PATH)
samples.extend(load_domain_samples(DOMAIN_PATH))
deduped: list[Sample] = []
seen: set[tuple[str, str]] = set()
for sample in samples:
key = (sample.text, sample.intent_id)
if key in seen:
continue
seen.add(key)
deduped.append(sample)
return deduped
def augment_samples(samples: list[Sample]) -> list[Sample]:
augmented = list(samples)
seen = {(sample.text, sample.intent_id) for sample in augmented}
for intent_id, templates in TEMPLATES.items():
for index, template in enumerate(templates):
sample = render_template(intent_id, template, index)
key = (sample.text, sample.intent_id)
if key not in seen:
augmented.append(sample)
seen.add(key)
for sample in HARD_NEGATIVE_SAMPLES:
key = (sample.text, sample.intent_id)
if key not in seen:
augmented.append(sample)
seen.add(key)
for sample in list(augmented):
text = sample.text
for source, target in INTENT_REPLACEMENTS.get(sample.intent_id, []):
if source in text:
variant = text.replace(source, target, 1)
key = (variant, sample.intent_id)
if key not in seen:
augmented.append(Sample(text=variant, intent_id=sample.intent_id))
seen.add(key)
compact = text
for token in ("帮我", "麻烦", "", "一下"):
if token in compact:
compact = compact.replace(token, "", 1)
compact = compact.strip(" ,。!?")
if compact and compact != text:
key = (compact, sample.intent_id)
if key not in seen:
augmented.append(Sample(text=compact, intent_id=sample.intent_id))
seen.add(key)
random.shuffle(augmented)
return augmented
def render_template(intent_id: str, template: str, index: int) -> Sample:
order_id = ORDER_IDS[index % len(ORDER_IDS)]
destination = DESTINATIONS[index % len(DESTINATIONS)]
temperature = TEMPERATURES[index % len(TEMPERATURES)]
song = SONGS[index % len(SONGS)]
genre = GENRES[index % len(GENRES)]
text = template.format(
order_id=order_id,
destination=destination,
temperature=temperature,
song=song,
genre=genre,
)
return Sample(text=text, intent_id=intent_id)
def split_samples(samples: list[Sample]) -> tuple[list[Sample], list[Sample]]:
grouped: dict[str, list[Sample]] = {}
for sample in samples:
grouped.setdefault(sample.intent_id, []).append(sample)
train_samples: list[Sample] = []
dev_samples: list[Sample] = []
for items in grouped.values():
random.shuffle(items)
cut = max(1, int(len(items) * 0.8))
train_samples.extend(items[:cut])
dev_samples.extend(items[cut:])
random.shuffle(train_samples)
random.shuffle(dev_samples)
return train_samples, dev_samples
def accuracy(model, loader: DataLoader, device: torch.device) -> float:
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
preds = outputs.logits.argmax(dim=-1)
correct += int((preds == labels).sum().item())
total += int(labels.numel())
return correct / total if total else 0.0
def resolve_base_model() -> str:
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
if configured:
return configured
return DEFAULT_BASE_MODEL
def main() -> None:
set_seed(SEED)
samples = augment_samples(load_training_samples())
intents = sorted({sample.intent_id for sample in samples})
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
train_samples, dev_samples = split_samples(samples)
base_model = resolve_base_model()
tokenizer = AutoTokenizer.from_pretrained(base_model)
train_dataset = IntentDataset(train_samples, tokenizer, label_to_id)
dev_dataset = IntentDataset(dev_samples, tokenizer, label_to_id)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
model = AutoModelForSequenceClassification.from_pretrained(
base_model,
num_labels=len(intents),
id2label=id_to_label,
label2id=label_to_id,
)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
best_dev_acc = 0.0
best_state = None
for epoch in range(1, EPOCHS + 1):
model.train()
total_loss = 0.0
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
total_loss += float(loss.item())
dev_acc = accuracy(model, dev_loader, device)
avg_loss = total_loss / max(len(train_loader), 1)
print(f"epoch={epoch} loss={avg_loss:.4f} dev_acc={dev_acc:.4f}")
if dev_acc >= best_dev_acc:
best_dev_acc = dev_acc
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
if best_state is not None:
model.load_state_dict(best_state)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
(OUTPUT_DIR / "label_map.json").write_text(
json.dumps(label_map, ensure_ascii=False, indent=2),
encoding="utf-8",
)
train_summary = {
"base_model": base_model,
"epochs": EPOCHS,
"batch_size": BATCH_SIZE,
"learning_rate": LEARNING_RATE,
"train_size": len(train_samples),
"dev_size": len(dev_samples),
"best_dev_accuracy": round(best_dev_acc, 4),
"device": str(device),
}
(OUTPUT_DIR / "train_summary.json").write_text(
json.dumps(train_summary, ensure_ascii=False, indent=2),
encoding="utf-8",
)
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,415 @@
from __future__ import annotations
import json
import os
import random
from dataclasses import dataclass
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import yaml
PROJECT_ROOT = Path(__file__).resolve().parents[1]
SINGLE_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
MULTI_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_multi_intent"
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
SOCIAL_LABEL = "__social__"
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
BLOCKED_LABELS = {SOCIAL_LABEL, OUT_OF_SCOPE_LABEL}
MAX_LENGTH = 48
BATCH_SIZE = 8
EPOCHS = 12
LEARNING_RATE = 2e-5
THRESHOLD = 0.5
TOP_K = 4
SEED = 42
CONNECTOR_VARIANTS: tuple[tuple[str, str], ...] = (
("", "然后"),
("然后", ""),
("顺便", ""),
("", "顺便"),
)
@dataclass(frozen=True)
class MultiLabelSample:
text: str
intent_ids: tuple[str, ...]
class MultiLabelIntentDataset(Dataset):
def __init__(
self,
samples: list[MultiLabelSample],
tokenizer,
label_to_id: dict[str, int],
) -> None:
self._samples = samples
self._tokenizer = tokenizer
self._label_to_id = label_to_id
self._label_size = len(label_to_id)
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
sample = self._samples[index]
encoded = self._tokenizer(
sample.text,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt",
)
labels = torch.zeros(self._label_size, dtype=torch.float32)
for intent_id in sample.intent_ids:
labels[self._label_to_id[intent_id]] = 1.0
return {
"input_ids": encoded["input_ids"].squeeze(0),
"attention_mask": encoded["attention_mask"].squeeze(0),
"labels": labels,
}
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def resolve_base_model() -> str:
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
if configured:
return configured
return DEFAULT_BASE_MODEL
def normalize_text(text: str) -> str:
return " ".join(str(text).strip().split())
def normalize_intent_ids(intent_ids: list[str] | tuple[str, ...]) -> tuple[str, ...]:
cleaned = sorted(
{
str(intent_id).strip()
for intent_id in intent_ids
if str(intent_id).strip() and str(intent_id).strip() not in BLOCKED_LABELS
}
)
return tuple(cleaned)
def expand_single_label_variants(text: str) -> list[str]:
normalized = text.strip().strip(",。!?;; ")
if not normalized:
return []
variants = {
normalized,
normalized.replace("一下", "").strip(),
normalized.replace("帮我", "").strip(),
normalized.replace("", "").strip(),
f"帮我{normalized}",
f"{normalized}",
f"{normalized}一下",
}
cleaned: list[str] = []
for item in variants:
compact = " ".join(item.split()).strip(",。!?;; ")
if compact:
cleaned.append(compact)
return cleaned
def load_single_label_samples(file_path: Path) -> list[MultiLabelSample]:
samples: list[MultiLabelSample] = []
if not file_path.exists():
return samples
for line in file_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
intent_ids = normalize_intent_ids([str(payload.get("intent_id") or "")])
if not intent_ids:
continue
text = normalize_text(str(payload.get("text") or ""))
if not text:
continue
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
return samples
def load_domain_samples(file_path: Path) -> list[MultiLabelSample]:
if not file_path.exists():
return []
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
intents = payload.get("intents", [])
samples: list[MultiLabelSample] = []
seen: set[tuple[str, tuple[str, ...]]] = set()
for item in intents:
intent_ids = normalize_intent_ids([str(item.get("intent_id") or "")])
if not intent_ids:
continue
seed_texts = list(item.get("examples") or [])
seed_texts.extend(item.get("keywords") or [])
label = str(item.get("label") or "").strip()
if label:
seed_texts.append(label)
for text in seed_texts:
normalized = normalize_text(text)
if not normalized:
continue
for variant in expand_single_label_variants(normalized):
key = (variant, intent_ids)
if key in seen:
continue
seen.add(key)
samples.append(MultiLabelSample(text=variant, intent_ids=intent_ids))
return samples
def load_multilabel_samples(file_path: Path) -> list[MultiLabelSample]:
samples: list[MultiLabelSample] = []
if not file_path.exists():
return samples
for line in file_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
intent_ids = normalize_intent_ids(list(payload.get("intent_ids") or []))
if len(intent_ids) < 2:
continue
text = normalize_text(str(payload.get("text") or ""))
if not text:
continue
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
return samples
def augment_multilabel_samples(samples: list[MultiLabelSample]) -> list[MultiLabelSample]:
augmented = list(samples)
seen = {(sample.text, sample.intent_ids) for sample in augmented}
for sample in list(samples):
variants = {
sample.text,
f"帮我{sample.text}",
f"{sample.text}",
sample.text.replace("", ", "),
sample.text.replace("", ""),
}
for source, target in CONNECTOR_VARIANTS:
if source in sample.text:
variants.add(sample.text.replace(source, target, 1))
for variant in variants:
normalized = normalize_text(variant).strip(",。!?;; ")
key = (normalized, sample.intent_ids)
if normalized and key not in seen:
augmented.append(MultiLabelSample(text=normalized, intent_ids=sample.intent_ids))
seen.add(key)
return augmented
def load_all_samples() -> list[MultiLabelSample]:
samples = load_single_label_samples(SINGLE_LABEL_PATH)
samples.extend(load_domain_samples(DOMAIN_PATH))
samples.extend(augment_multilabel_samples(load_multilabel_samples(MULTI_LABEL_PATH)))
deduped: list[MultiLabelSample] = []
seen: set[tuple[str, tuple[str, ...]]] = set()
for sample in samples:
key = (sample.text, sample.intent_ids)
if key in seen:
continue
seen.add(key)
deduped.append(sample)
random.shuffle(deduped)
return deduped
def split_samples(samples: list[MultiLabelSample]) -> tuple[list[MultiLabelSample], list[MultiLabelSample]]:
grouped: dict[tuple[str, ...], list[MultiLabelSample]] = {}
for sample in samples:
grouped.setdefault(sample.intent_ids, []).append(sample)
train_samples: list[MultiLabelSample] = []
dev_samples: list[MultiLabelSample] = []
for items in grouped.values():
random.shuffle(items)
if len(items) == 1:
train_samples.extend(items)
continue
cut = max(1, int(len(items) * 0.8))
if cut >= len(items):
cut = len(items) - 1
train_samples.extend(items[:cut])
dev_samples.extend(items[cut:])
if not dev_samples:
dev_samples = train_samples[-max(1, min(32, len(train_samples) // 5 or 1)) :]
train_samples = train_samples[: len(train_samples) - len(dev_samples)]
random.shuffle(train_samples)
random.shuffle(dev_samples)
return train_samples, dev_samples
def logits_to_probabilities(logits: torch.Tensor) -> list[list[float]]:
return torch.sigmoid(logits).detach().cpu().tolist()
def compute_metrics(
probabilities: list[list[float]],
targets: list[list[float]],
threshold: float,
top_k: int,
) -> dict[str, float]:
true_positive = 0
false_positive = 0
false_negative = 0
exact_match = 0
recall_at_k_total = 0.0
total = len(probabilities)
for scores, target in zip(probabilities, targets):
predicted = {index for index, score in enumerate(scores) if score >= threshold}
expected = {index for index, value in enumerate(target) if value >= 0.5}
if predicted == expected:
exact_match += 1
true_positive += len(predicted & expected)
false_positive += len(predicted - expected)
false_negative += len(expected - predicted)
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
if expected:
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
return {
"micro_precision": round(precision, 4),
"micro_recall": round(recall, 4),
"micro_f1": round(micro_f1, 4),
"exact_match": round(exact_match / total, 4) if total else 0.0,
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
}
def evaluate(model, loader: DataLoader, device: torch.device, threshold: float, top_k: int) -> tuple[float, dict[str, float]]:
model.eval()
total_loss = 0.0
probabilities: list[list[float]] = []
targets: list[list[float]] = []
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
total_loss += float(outputs.loss.item())
probabilities.extend(logits_to_probabilities(outputs.logits))
targets.extend(labels.detach().cpu().tolist())
avg_loss = total_loss / max(len(loader), 1)
return avg_loss, compute_metrics(probabilities, targets, threshold=threshold, top_k=top_k)
def main() -> None:
set_seed(SEED)
samples = load_all_samples()
intents = sorted({intent_id for sample in samples for intent_id in sample.intent_ids})
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
train_samples, dev_samples = split_samples(samples)
base_model = resolve_base_model()
tokenizer = AutoTokenizer.from_pretrained(base_model)
train_dataset = MultiLabelIntentDataset(train_samples, tokenizer, label_to_id)
dev_dataset = MultiLabelIntentDataset(dev_samples, tokenizer, label_to_id)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
model = AutoModelForSequenceClassification.from_pretrained(
base_model,
num_labels=len(intents),
id2label=id_to_label,
label2id=label_to_id,
problem_type="multi_label_classification",
)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
best_dev_f1 = 0.0
best_state = None
best_metrics: dict[str, float] = {}
for epoch in range(1, EPOCHS + 1):
model.train()
total_loss = 0.0
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
total_loss += float(loss.item())
dev_loss, dev_metrics = evaluate(model, dev_loader, device, threshold=THRESHOLD, top_k=TOP_K)
avg_loss = total_loss / max(len(train_loader), 1)
print(
" ".join(
[
f"epoch={epoch}",
f"train_loss={avg_loss:.4f}",
f"dev_loss={dev_loss:.4f}",
f"dev_micro_f1={dev_metrics['micro_f1']:.4f}",
f"dev_exact_match={dev_metrics['exact_match']:.4f}",
]
)
)
if dev_metrics["micro_f1"] >= best_dev_f1:
best_dev_f1 = dev_metrics["micro_f1"]
best_metrics = dict(dev_metrics)
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
if best_state is not None:
model.load_state_dict(best_state)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
(OUTPUT_DIR / "label_map.json").write_text(
json.dumps(label_map, ensure_ascii=False, indent=2),
encoding="utf-8",
)
train_summary = {
"task_type": "multi_label_intent_detection",
"base_model": base_model,
"epochs": EPOCHS,
"batch_size": BATCH_SIZE,
"learning_rate": LEARNING_RATE,
"threshold": THRESHOLD,
"top_k": TOP_K,
"train_size": len(train_samples),
"dev_size": len(dev_samples),
"label_count": len(intents),
"labels": intents,
"best_dev_metrics": best_metrics,
"device": str(device),
}
(OUTPUT_DIR / "train_summary.json").write_text(
json.dumps(train_summary, ensure_ascii=False, indent=2),
encoding="utf-8",
)
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import unittest
from app.plugins.base import PluginRegistry
from app.schemas.chat import ChatRequest
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
from app.schemas.intent import IntentDefinition
from app.services.agent_service import AgentService
from app.services.intent_registry import IntentRegistry
from app.services.planner import PlanningResult
from app.services.session_store import InMemorySessionStore
class _RouteToCloudRouter:
def route(self, text: str):
_ = text
return type(
"RouteResult",
(),
{
"intent": None,
"debug": RoutingDebug(
selected_intent="cabin_nav_to",
matched_stage="fusion",
decision="route_to_cloud",
decision_reason="local signal is not stable enough, routing to cloud planner",
confidence_grade="low",
stages=[
MatcherStageDebug(
stage="fusion",
accepted=False,
selected_intent="cabin_nav_to",
score=0.88,
reason="route to cloud",
candidates=[
IntentCandidate(intent_id="cabin_nav_to", score=0.88, reason="fusion", model_name="fusion"),
IntentCandidate(intent_id="cabin_play_music", score=0.75, reason="fusion", model_name="fusion"),
],
)
],
),
},
)()
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, object]:
_ = (text, intent)
return {}
class _PlannerRejects:
def plan(self, text: str, intents: list[IntentDefinition], context: dict[str, object] | None = None) -> PlanningResult:
_ = (text, intents, context)
return PlanningResult(
accepted=False,
workflow_type="single",
model_name="qwen3.5-plus",
backend="dashscope",
reason="cloud planner could not produce a stable executable step",
)
class _PlannerOutOfScope:
def plan(self, text: str, intents: list[IntentDefinition], context: dict[str, object] | None = None) -> PlanningResult:
_ = (text, intents, context)
return PlanningResult(
accepted=False,
workflow_type="single",
model_name="qwen3.5-plus",
backend="dashscope",
reason="The provided intent catalog only contains cabin and service actions. There is no matching intent for ordering food via a third-party app action.",
)
def _intent(intent_id: str) -> IntentDefinition:
return IntentDefinition(
intent_id=intent_id,
plugin_id=f"mock.{intent_id}",
domain="cabin",
keywords=[],
examples=[],
)
class AgentCloudRouteTests(unittest.TestCase):
def test_route_to_cloud_returns_explicit_clarify_feedback_when_planner_does_not_accept(self) -> None:
service = AgentService(
intent_registry=IntentRegistry([_intent("cabin_nav_to"), _intent("cabin_play_music")]),
router=_RouteToCloudRouter(),
plugins=PluginRegistry(),
session_store=InMemorySessionStore(),
planner=_PlannerRejects(),
)
response = service.handle_chat(
ChatRequest(
session_id="sess_cloud_route",
user_id="user_1",
input_text="带我过去",
)
)
self.assertEqual(response.decision, "route_to_cloud")
self.assertEqual(response.reply_type, "clarify")
self.assertEqual(response.status, "route_to_cloud")
self.assertIn("请确认一下", response.reply_text)
def test_route_to_cloud_rejects_when_planner_marks_request_out_of_scope(self) -> None:
service = AgentService(
intent_registry=IntentRegistry([_intent("cabin_nav_to"), _intent("cabin_play_music")]),
router=_RouteToCloudRouter(),
plugins=PluginRegistry(),
session_store=InMemorySessionStore(),
planner=_PlannerOutOfScope(),
)
response = service.handle_chat(
ChatRequest(
session_id="sess_cloud_route_reject",
user_id="user_1",
input_text="去美团叫个外卖",
)
)
self.assertEqual(response.reply_type, "reject")
self.assertEqual(response.decision, "reject")
self.assertEqual(response.status, "rejected")
self.assertIn("做不了", response.reply_text)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,235 @@
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.config import settings
from app.core.bootstrap import build_intent_registry
from app.services.classifier import BertIntentClassifier
from app.services.router import build_matcher_pipeline
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "local_bert_intent"
NON_BUSINESS_LABELS = {"__social__", "__out_of_scope__"}
def resolve_model_path(model_path: str) -> Path:
configured = (model_path or settings.classifier_model_path).strip()
if configured:
return Path(configured)
return DEFAULT_MODEL_DIR
def resolve_label_map_path(label_map_path: str, model_path: Path) -> Path:
configured = (label_map_path or settings.classifier_label_map_path).strip()
if configured:
return Path(configured)
return model_path / "label_map.json"
def build_classifier(
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
) -> BertIntentClassifier:
return BertIntentClassifier(
model_path=str(model_path),
threshold=threshold,
label_map_path=str(label_map_path),
fallback=None,
top_k=top_k,
)
def predict_once(
text: str,
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
warmup: bool,
) -> dict[str, object]:
registry = build_intent_registry()
classifier = build_classifier(
model_path=model_path,
label_map_path=label_map_path,
threshold=threshold,
top_k=top_k,
)
warmup_ok = None
if warmup:
warmup_ok = classifier.warmup(settings.classifier_warmup_text)
if not warmup_ok:
error_message = getattr(classifier, "_warmup_error_message", None) or "BERT warmup failed"
raise RuntimeError(error_message)
intents = registry.list()
result = classifier.predict(text, intents)
matcher = build_matcher_pipeline(
registry,
["classifier"],
classifier=classifier,
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
clarify_margin_threshold=settings.local_clarify_margin_threshold,
classifier_execute_score_threshold=settings.local_classifier_execute_score_threshold,
classifier_execute_margin_threshold=settings.local_classifier_execute_margin_threshold,
)
route_result = matcher.match(text)
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
classifier_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "classifier"), None)
return {
"text": text,
"config": {
"model_path": str(model_path),
"label_map_path": str(label_map_path),
"threshold": threshold,
"top_k": top_k,
"warmup_requested": warmup,
"warmup_ok": warmup_ok,
"warmup_elapsed_ms": getattr(classifier, "_warmup_elapsed_ms", None),
"warmup_error_message": getattr(classifier, "_warmup_error_message", None),
},
"classifier_result": {
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
"score": round(result.score, 4),
"model_name": result.model_name,
"backend_name": result.backend_name,
"used_fallback": result.used_fallback,
"fallback_reason": result.fallback_reason,
"error_message": result.error_message,
"raw_label": result.raw_label,
"raw_candidates": result.raw_candidates or [],
"known_candidates": [
{"intent_id": intent.intent_id, "score": round(score, 4)}
for intent, score in (result.candidates or [])
],
},
"route_result": {
"decision": route_result.debug.decision,
"decision_reason": route_result.debug.decision_reason,
"matched_stage": route_result.debug.matched_stage,
"selected_intent": route_result.debug.selected_intent,
"confidence_grade": route_result.debug.confidence_grade,
"unknown_detected": route_result.debug.unknown_detected,
"classifier_score": round(classifier_stage.score, 4) if classifier_stage is not None else None,
"fusion_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
},
}
def summarize_business_view(result: dict[str, object]) -> dict[str, object]:
classifier_result = dict(result.get("classifier_result") or {})
route_result = dict(result.get("route_result") or {})
predicted_intent = classifier_result.get("predicted_intent")
raw_label = classifier_result.get("raw_label")
effective_label = raw_label if raw_label in NON_BUSINESS_LABELS else predicted_intent
if effective_label in NON_BUSINESS_LABELS:
classifier_result["predicted_intent"] = None
classifier_result["non_business_label"] = effective_label
classifier_result["business_interpretation"] = "non_business_label_detected"
route_result["selected_intent"] = None
route_result["decision"] = "reject"
route_result["decision_reason"] = "classifier detected a non-business label"
route_result["unknown_detected"] = True
else:
classifier_result["non_business_label"] = None
classifier_result["business_interpretation"] = "known_business_intent_or_uncertain"
return {
"text": result.get("text"),
"config": result.get("config"),
"classifier_result": classifier_result,
"route_result": route_result,
}
def interactive_loop(
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
warmup: bool,
mode: str,
) -> None:
print("当前 BERT 测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
print(f"model_path={model_path}")
print(f"label_map_path={label_map_path}")
print(f"threshold={threshold} top_k={top_k} warmup={warmup} mode={mode}")
while True:
try:
text = input("\n请输入问题> ").strip()
except EOFError:
print()
break
if not text:
continue
if text.lower() in {"exit", "quit", "q"}:
break
result = predict_once(
text,
model_path=model_path,
label_map_path=label_map_path,
threshold=threshold,
top_k=top_k,
warmup=warmup,
)
if mode == "business":
result = summarize_business_view(result)
print(json.dumps(result, ensure_ascii=False, indent=2))
def main() -> None:
parser = argparse.ArgumentParser(description="当前项目 BERT 意图识别测试脚本")
parser.add_argument("--text", type=str, default="", help="单次测试文本")
parser.add_argument("--threshold", type=float, default=settings.classifier_bert_threshold, help="BERT 置信度阈值")
parser.add_argument("--top-k", type=int, default=settings.classifier_top_k, help="返回候选数量")
parser.add_argument("--model-path", type=str, default="", help="模型目录,默认取 .env 或 models/local_bert_intent")
parser.add_argument("--label-map-path", type=str, default="", help="标签映射文件,默认取 .env 或 model_path/label_map.json")
parser.add_argument("--warmup", action="store_true", help="先执行一次 warmup 再预测")
parser.add_argument(
"--mode",
choices=("classifier", "business"),
default="classifier",
help="classifier 显示原始分类结果business 会把非业务标签折叠成未命中业务意图",
)
args = parser.parse_args()
model_path = resolve_model_path(args.model_path)
label_map_path = resolve_label_map_path(args.label_map_path, model_path)
if args.text.strip():
result = predict_once(
args.text.strip(),
model_path=model_path,
label_map_path=label_map_path,
threshold=args.threshold,
top_k=args.top_k,
warmup=args.warmup,
)
if args.mode == "business":
result = summarize_business_view(result)
print(json.dumps(result, ensure_ascii=False, indent=2))
return
interactive_loop(
model_path=model_path,
label_map_path=label_map_path,
threshold=args.threshold,
top_k=args.top_k,
warmup=args.warmup,
mode=args.mode,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,109 @@
from __future__ import annotations
import json
import os
import time
import unittest
from unittest.mock import patch
from fastapi.testclient import TestClient
os.environ["AGENT_CLASSIFIER_BACKEND"] = "mock"
os.environ["AGENT_CLASSIFIER_WARMUP_ENABLED"] = "false"
from app.main import app
from app.schemas.chat import ChatResponse
def _fake_response() -> ChatResponse:
return ChatResponse(
session_id="sess_stream_1",
reply_type="workflow_result",
reply_text="好,空调已经打开了。",
intent="cabin_ac_on",
status="completed",
trace_id="trace_stream_1",
)
class ChatStreamTests(unittest.TestCase):
def test_chat_stream_returns_final_only_when_fast(self) -> None:
client = TestClient(app)
with patch("app.main.agent_service.handle_chat", return_value=_fake_response()):
response = client.post(
"/api/v1/agent/chat-stream",
json={
"session_id": "sess_stream_1",
"user_id": "user_stream_1",
"channel": "test",
"input_text": "打开车窗",
"input_type": "text",
},
)
self.assertEqual(response.status_code, 200)
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
self.assertEqual(len(lines), 1)
final_event = json.loads(lines[0])
self.assertEqual(final_event.get("type"), "final")
def test_chat_stream_returns_ack_then_final_when_slow_request(self) -> None:
client = TestClient(app)
def _slow_handle_chat(_request):
time.sleep(1.2)
return _fake_response()
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
response = client.post(
"/api/v1/agent/chat-stream",
json={
"session_id": "sess_stream_1",
"user_id": "user_stream_1",
"channel": "test",
"input_text": "打开车窗",
"input_type": "text",
},
)
self.assertEqual(response.status_code, 200)
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
self.assertGreaterEqual(len(lines), 2)
ack_event = json.loads(lines[0])
final_event = json.loads(lines[-1])
self.assertEqual(ack_event.get("type"), "ack")
self.assertEqual(final_event.get("type"), "final")
self.assertIn("data", final_event)
self.assertIn("reply_text", final_event["data"])
def test_chat_stream_returns_ack_then_final_when_slow_social_request(self) -> None:
client = TestClient(app)
def _slow_handle_chat(_request):
time.sleep(1.2)
return _fake_response()
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
response = client.post(
"/api/v1/agent/chat-stream",
json={
"session_id": "sess_stream_1",
"user_id": "user_stream_1",
"channel": "test",
"input_text": "今天天气如何",
"input_type": "text",
},
)
self.assertEqual(response.status_code, 200)
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
self.assertGreaterEqual(len(lines), 2)
ack_event = json.loads(lines[0])
final_event = json.loads(lines[-1])
self.assertEqual(ack_event.get("type"), "ack")
self.assertEqual(final_event.get("type"), "final")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,90 @@
from __future__ import annotations
import unittest
from unittest.mock import patch
from app.core.bootstrap import build_planner, load_runtime_bundle
from app.core.config import settings
from app.services.planner import CompositeWorkflowPlanner
from app.services.config_loader import ConfigLoader
from app.services.dialog_rules import DialogRuleEngine
from app.services.response_policy import ResponsePolicy
class ConfigLoaderTests(unittest.TestCase):
def test_loader_reads_domain_actions_and_responses(self) -> None:
bundle = ConfigLoader(
domain_path="config/domain.yml",
action_path="config/actions.yml",
response_path="config/responses.yml",
form_path="config/forms.yml",
rule_path="config/rules.yml",
dialog_act_path="config/dialog_acts.yml",
workflow_path="config/workflows.yml",
legacy_intent_path="app/data/intents.json",
).load()
self.assertGreaterEqual(len(bundle.intent_registry.list()), 30)
self.assertEqual(bundle.intent_registry.get("cabin_window_open").plugin_id, "plugin.cabin.window.open")
self.assertEqual(bundle.intent_hints.get("cabin_window_open"), "打开车窗")
self.assertEqual(bundle.response_templates.get("task_stopped"), "好的,已停止当前任务。")
self.assertEqual(bundle.intent_registry.get("cabin_set_ac").required_slots, ["temperature"])
self.assertTrue(bundle.dialog_rules.is_stop_request("先不要了"))
self.assertEqual(bundle.dialog_rules.parse_confirmation_decision("确认"), True)
self.assertEqual(bundle.dialog_act_engine.detect("你好"), "chitchat")
self.assertGreaterEqual(len(bundle.workflow_templates.templates), 2)
def test_bootstrap_runtime_bundle_is_available(self) -> None:
bundle = load_runtime_bundle()
self.assertGreaterEqual(len(bundle.intent_registry.list()), 30)
self.assertIn("fallback", bundle.response_templates)
self.assertEqual(bundle.dialog_act_engine.detect("确认"), "affirm")
def test_response_policy_can_be_driven_by_config_templates(self) -> None:
policy = ResponsePolicy(
templates={"reject": "这个能力暂未开通。"},
intent_hints={"cabin_window_open": "开车窗"},
)
self.assertEqual(policy.reject(), "这个能力暂未开通。")
self.assertEqual(policy.clarify(["cabin_window_open"]), "请确认一下,你是想开车窗吗?")
def test_response_policy_formats_multi_step_summary_naturally(self) -> None:
policy = ResponsePolicy()
summary = policy.workflow_summary(["好的,已打开空调。", "已将空调调到 20 度。"])
self.assertEqual(summary, "好,空调已经打开了,也调到 20 度了。")
def test_response_policy_formats_multi_step_summary_in_vehicle_style(self) -> None:
policy = ResponsePolicy()
summary = policy.workflow_summary(["好的,已打开空调。", "好的,已关闭车窗。"])
self.assertEqual(summary, "好,空调已经打开了,车窗也帮你关上了。")
def test_build_planner_prefers_local_planners_before_cloud(self) -> None:
with patch.object(settings, "planner_backend", "dashscope"):
planner = build_planner()
self.assertIsInstance(planner, CompositeWorkflowPlanner)
self.assertIsInstance(planner._planners[0], CompositeWorkflowPlanner)
def test_dialog_rule_engine_supports_configured_confirmation_and_stop(self) -> None:
rules = DialogRuleEngine(
stop_phrases=("先不用了",),
positive_confirmation_tokens=("好,继续",),
negative_confirmation_tokens=("取消吧",),
confirmation_required_intents=("foo",),
confirmation_required_risk_levels=("high",),
)
self.assertTrue(rules.is_stop_request("先不用了"))
self.assertTrue(rules.parse_confirmation_decision("好,继续"))
self.assertFalse(rules.parse_confirmation_decision("取消吧"))
self.assertTrue(rules.requires_confirmation("foo", "low"))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,202 @@
from __future__ import annotations
import unittest
from app.plugins.base import PluginRegistry
from app.plugins.mock import MockPluginExecutor
from app.schemas.chat import ChatRequest
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
from app.schemas.workflow import Workflow, WorkflowStep
from app.services.agent_service import AgentService
from app.services.intent_registry import IntentRegistry
from app.services.planner import HeuristicWorkflowPlanner
from app.services.response_policy import ResponsePolicy
from app.services.rewrite_engine import ContextRewriteEngine
from app.services.router import RouteMatchResult
from app.services.session_store import InMemorySessionStore
class _FailIfCalledPlanner:
def plan(self, text, intents, context=None):
_ = (intents, context)
raise AssertionError(f"planner should not be called for single intent request: {text}")
class _ScriptedRouter:
def __init__(self, registry: IntentRegistry) -> None:
self._registry = registry
self._route_map = {
"来点music": self._route_result("cabin_play_music", ["cabin_play_music"]),
"打开车窗和空调": self._route_result("cabin_window_open", ["cabin_window_open", "cabin_ac_on"]),
"关闭车窗": self._route_result("cabin_window_close", ["cabin_window_close", "cabin_window_open"]),
}
self._slot_map = {
("播放黄昏", "cabin_play_music"): {"song": "黄昏"},
("来一首黄昏", "cabin_play_music"): {"song": "黄昏"},
("来点黄昏", "cabin_play_music"): {"song": "黄昏"},
}
def route(self, text: str) -> RouteMatchResult:
if text not in self._route_map:
raise AssertionError(f"unexpected route request: {text}")
return self._route_map[text]
def extract_slots(self, text: str, intent) -> dict[str, object]:
return dict(self._slot_map.get((text, intent.intent_id), {}))
def _route_result(self, selected_intent: str, candidates: list[str]) -> RouteMatchResult:
intent = self._registry.get(selected_intent)
stage = MatcherStageDebug(
stage="fusion",
accepted=True,
selected_intent=selected_intent,
score=1.0,
candidates=[
IntentCandidate(intent_id=intent_id, score=max(0.5, 1.0 - index * 0.1))
for index, intent_id in enumerate(candidates)
],
)
return RouteMatchResult(
intent=intent,
debug=RoutingDebug(
selected_intent=selected_intent,
matched_stage="fusion",
decision="execute",
stages=[stage],
),
)
class DialogContinuationAndMultiIntentTests(unittest.TestCase):
def setUp(self) -> None:
self.registry = IntentRegistry.from_json("app/data/intents.json")
self.plugins = PluginRegistry()
MockPluginExecutor().register(self.plugins)
self.service = AgentService(
intent_registry=self.registry,
router=_ScriptedRouter(self.registry),
plugins=self.plugins,
session_store=InMemorySessionStore(),
rewrite_engine=ContextRewriteEngine(),
response_policy=ResponsePolicy(),
planner=HeuristicWorkflowPlanner(),
)
def test_music_followup_in_chat_continues_waiting_slot(self) -> None:
first = self.service.handle_chat(
ChatRequest(
session_id="sess_music_followup",
user_id="user_1",
input_text="来点music",
)
)
self.assertEqual(first.reply_type, "ask_slot")
self.assertEqual(first.pending_slots, ["media_query"])
second = self.service.handle_chat(
ChatRequest(
session_id="sess_music_followup",
user_id="user_1",
input_text="黄昏",
)
)
self.assertEqual(second.reply_type, "workflow_result")
self.assertEqual(second.intent, "cabin_play_music")
self.assertEqual(second.filled_slots.get("song"), "黄昏")
self.assertIn("黄昏", second.reply_text)
def test_parallel_compound_request_enters_planner(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_parallel_compound",
user_id="user_1",
input_text="打开车窗和空调",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.workflow.workflow_type, "sequence")
step_intents = [step.intent_id for step in response.workflow.steps]
self.assertEqual(step_intents, ["cabin_window_open", "cabin_ac_on"])
self.assertIn("车窗", response.reply_text)
self.assertIn("空调", response.reply_text)
def test_single_cabin_intent_does_not_enter_planner_from_top2_domain_candidates(self) -> None:
service = AgentService(
intent_registry=self.registry,
router=_ScriptedRouter(self.registry),
plugins=self.plugins,
session_store=InMemorySessionStore(),
rewrite_engine=ContextRewriteEngine(),
response_policy=ResponsePolicy(),
planner=_FailIfCalledPlanner(),
)
response = service.handle_chat(
ChatRequest(
session_id="sess_single_cabin_no_planner",
user_id="user_1",
input_text="关闭车窗",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_window_close")
self.assertEqual(response.routing_debug.decision, "execute")
self.assertFalse(any(stage.stage == "planner" for stage in response.routing_debug.stages))
def test_waiting_confirmation_can_continue_via_chat(self) -> None:
session = self.service.session_store.get_or_create("sess_confirm_chat", "user_1")
session.current_intent = "cs_cancel_order"
session.status = "waiting_confirmation"
session.pending_slots = ["confirmation"]
session.slots = {"order_id": "A123456"}
session.workflow = Workflow(
workflow_id="wf_confirm_chat",
workflow_type="conditional",
domain="customer_service",
intent_id="cs_cancel_order",
status="waiting_confirmation",
risk_level="high",
slots={"order_id": "A123456"},
steps=[
WorkflowStep(
step=1,
step_id="step_cancel",
intent_id="cs_cancel_order",
plugin_id="plugin.order.cancel",
action="cancel_order",
status="waiting_confirmation",
slots={"order_id": "A123456"},
requires_confirmation=True,
)
],
meta={
"pending_confirmation": {
"step_id": "step_cancel",
"intent_id": "cs_cancel_order",
"detail": "确认取消订单 A123456",
},
"step_results": {},
"confirmed_steps": [],
},
).model_dump()
self.service.session_store.save(session)
response = self.service.handle_chat(
ChatRequest(
session_id="sess_confirm_chat",
user_id="user_1",
input_text="确认",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.status, "completed")
self.assertIn("A123456", response.reply_text)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,149 @@
from __future__ import annotations
import unittest
from app.plugins.base import PluginRegistry
from app.plugins.mock import MockPluginExecutor
from app.services.classifier import MockIntentClassifier
from app.services.agent_service import AgentService
from app.services.intent_registry import IntentRegistry
from app.services.response_policy import ResponsePolicy
from app.services.rewrite_engine import ContextRewriteEngine
from app.services.router import HeuristicSlotExtractor, IntentRouter, build_matcher_pipeline
from app.services.session_store import InMemorySessionStore
from app.schemas.chat import ChatRequest, FillSlotsRequest
class _BertLikeMockClassifier(MockIntentClassifier):
def predict(self, text, intents):
result = super().predict(text, intents)
result.model_name = "bert-local"
result.backend_name = "bert-local"
return result
def _build_service() -> AgentService:
registry = IntentRegistry.from_json("app/data/intents.json")
plugins = PluginRegistry()
MockPluginExecutor().register(plugins)
router = IntentRouter(
matcher=build_matcher_pipeline(
registry,
["classifier"],
classifier=_BertLikeMockClassifier(threshold=0.0, top_k=3),
),
slot_extractor=HeuristicSlotExtractor(),
)
return AgentService(
intent_registry=registry,
router=router,
plugins=plugins,
session_store=InMemorySessionStore(),
rewrite_engine=ContextRewriteEngine(),
response_policy=ResponsePolicy(),
planner=None,
)
class IntentCoverageAndStopTests(unittest.TestCase):
def setUp(self) -> None:
self.service = _build_service()
self.registry = IntentRegistry.from_json("app/data/intents.json")
def test_intent_catalog_has_at_least_30_items(self) -> None:
self.assertGreaterEqual(len(self.registry.list()), 30)
def test_close_ac_routes_to_power_off(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_close_ac",
user_id="user_1",
input_text="关闭空调",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_ac_off")
self.assertIn("已关闭空调", response.reply_text)
def test_open_window_is_covered(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_window_open",
user_id="user_1",
input_text="打开车窗",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_window_open")
self.assertIn("已打开车窗", response.reply_text)
def test_stop_current_task_while_waiting_for_slot(self) -> None:
first = self.service.handle_chat(
ChatRequest(
session_id="sess_stop_task",
user_id="user_1",
input_text="空调调到",
)
)
self.assertEqual(first.reply_type, "ask_slot")
self.assertEqual(first.pending_slots, ["temperature"])
stopped = self.service.handle_fill_slots(
FillSlotsRequest(
session_id="sess_stop_task",
user_id="user_1",
input_text="不用了",
)
)
self.assertEqual(stopped.reply_type, "text")
self.assertEqual(stopped.status, "stopped")
self.assertEqual(stopped.pending_slots, [])
self.assertIn("已停止当前任务", stopped.reply_text)
def test_relative_ac_adjustment_uses_two_degree_step(self) -> None:
session = self.service.session_store.get_or_create("sess_ac_lower", "user_1")
session.current_intent = "cabin_set_ac"
session.context_memory["last_temperature"] = 24
self.service.session_store.save(session)
response = self.service.handle_chat(
ChatRequest(
session_id="sess_ac_lower",
user_id="user_1",
input_text="调低一点",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_set_ac")
self.assertEqual(response.filled_slots.get("temperature"), 22)
self.assertIn("22", response.reply_text)
def test_relative_ac_adjustment_without_history_uses_default_baseline(self) -> None:
session = self.service.session_store.get_or_create("sess_ac_lower_default", "user_1")
session.current_intent = "cabin_ac_on"
self.service.session_store.save(session)
response = self.service.handle_chat(
ChatRequest(
session_id="sess_ac_lower_default",
user_id="user_1",
input_text="调低一点",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_set_ac")
self.assertEqual(response.filled_slots.get("temperature"), 22)
self.assertIn("22", response.reply_text)
def test_temperature_is_clamped_before_execution(self) -> None:
self.assertEqual(self.service._normalize_temperature_value(-1), 16)
self.assertEqual(self.service._normalize_temperature_value(40), 30)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
import unittest
from app.schemas.intent import IntentDefinition
from app.services.classifier import JointBertIntentClassifier
from app.services.planner import HeuristicWorkflowPlanner
from app.services.router import JointBertSlotExtractor
class FakeJointNLU:
def __init__(self) -> None:
self._predictions = {
"把空调调到22度": {
"intent_id": "cabin_set_ac",
"intent_score": 0.93,
"candidates": [("cabin_set_ac", 0.93), ("cabin_ac_on", 0.04)],
"slots": {"temperature": 22},
},
"导航去公司然后把空调调到22度": {
"intent_id": "cabin_nav_to",
"intent_score": 0.88,
"candidates": [("cabin_nav_to", 0.88), ("cabin_set_ac", 0.72)],
"slots": {"destination": "公司"},
},
}
self._slot_predictions = {
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
("导航去公司", "cabin_nav_to"): {"destination": "公司"},
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
}
def warmup(self, sample_text: str = "打开车窗") -> bool:
_ = sample_text
return True
def predict(self, text: str, intents: list[IntentDefinition]):
from app.services.joint_nlu import JointCandidate, JointNluResult
raw = self._predictions[text]
candidates = [JointCandidate(intent_id=intent_id, score=score) for intent_id, score in raw["candidates"]]
return JointNluResult(
intent_id=raw["intent_id"],
intent_score=raw["intent_score"],
candidates=candidates,
slots=dict(raw["slots"]),
)
def extract_slots(self, text: str, intent: IntentDefinition):
return dict(self._slot_predictions.get((text, intent.intent_id), {}))
def extract_slots_by_intent_id(self, text: str, intent_id: str, required_slots=None):
_ = required_slots
return dict(self._slot_predictions.get((text, intent_id), {}))
class JointNLUIntegrationTests(unittest.TestCase):
def setUp(self) -> None:
self.intents = [
IntentDefinition(intent_id="cabin_set_ac", plugin_id="x", domain="cabin", required_slots=["temperature"]),
IntentDefinition(intent_id="cabin_nav_to", plugin_id="x", domain="cabin", required_slots=["destination"]),
IntentDefinition(intent_id="cabin_ac_on", plugin_id="x", domain="cabin"),
]
self.fake_nlu = FakeJointNLU()
def test_joint_classifier_uses_joint_nlu_intent_head(self) -> None:
classifier = JointBertIntentClassifier(self.fake_nlu, threshold=0.3, top_k=2)
result = classifier.predict("把空调调到22度", self.intents)
self.assertIsNotNone(result.intent)
self.assertEqual(result.intent.intent_id, "cabin_set_ac")
self.assertEqual(result.raw_candidates[0]["intent_id"], "cabin_set_ac")
def test_joint_slot_extractor_uses_joint_nlu_slots(self) -> None:
extractor = JointBertSlotExtractor(self.fake_nlu)
slots = extractor.extract("把空调调到22度", self.intents[0])
self.assertEqual(slots, {"temperature": 22})
def test_planner_prefers_joint_nlu_slots_for_each_clause(self) -> None:
planner = HeuristicWorkflowPlanner(joint_nlu=self.fake_nlu)
result = planner.plan("导航去公司然后把空调调到22度", self.intents)
self.assertTrue(result.accepted)
self.assertEqual(result.steps[0].slots, {"destination": "公司"})
self.assertEqual(result.steps[1].slots, {"temperature": 22})
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,144 @@
from __future__ import annotations
import unittest
import torch
from app.schemas.intent import IntentDefinition
from app.services.multi_intent_detector import BertMultiIntentDetector, JointBertMultiIntentDetector
class FakeTokenizer:
def __call__(self, text, truncation=True, padding=False, return_tensors="pt"):
_ = (text, truncation, padding, return_tensors)
return {
"input_ids": torch.tensor([[101, 102]], dtype=torch.long),
"attention_mask": torch.tensor([[1, 1]], dtype=torch.long),
}
class FakeModel:
def __init__(self, logits: list[float], id2label: dict[int, str]) -> None:
self.config = type("Config", (), {"id2label": id2label})()
self._logits = torch.tensor([logits], dtype=torch.float32)
def eval(self) -> None:
return None
def __call__(self, **kwargs):
_ = kwargs
return type("Output", (), {"logits": self._logits})()
class RuntimeBackedDetector(BertMultiIntentDetector):
def __init__(
self,
logits: list[float],
id2label: dict[int, str],
threshold: float = 0.45,
top_k: int = 8,
max_labels: int = 4,
) -> None:
super().__init__(
model_path="unused",
threshold=threshold,
top_k=top_k,
max_labels=max_labels,
)
self._runtime = (torch, FakeTokenizer(), FakeModel(logits, id2label))
def _get_runtime(self):
return self._runtime
class MultiIntentDetectorTests(unittest.TestCase):
def test_detector_filters_blocked_and_unknown_labels(self) -> None:
detector = RuntimeBackedDetector(
logits=[2.4, 2.0, 3.2, 2.6],
id2label={
0: "cabin_window_open",
1: "cabin_play_music",
2: "__social__",
3: "unknown_intent",
},
threshold=0.8,
top_k=4,
)
intents = [
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
]
result = detector.detect("打开车窗并播放音乐", intents)
self.assertTrue(result.detected)
self.assertEqual(result.backend_name, "bert-multi-label")
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
def test_detector_respects_threshold_and_max_labels(self) -> None:
detector = RuntimeBackedDetector(
logits=[2.8, 2.5, 2.2],
id2label={
0: "cabin_window_open",
1: "cabin_play_music",
2: "cabin_nav_to",
},
threshold=0.89,
top_k=3,
max_labels=2,
)
intents = [
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
IntentDefinition(intent_id="cabin_nav_to", plugin_id="plugin.nav", domain="cabin"),
]
result = detector.detect("开窗放歌去公司", intents)
self.assertTrue(result.detected)
self.assertEqual(len(result.candidates), 2)
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
def test_joint_bert_detector_wraps_shared_runtime(self) -> None:
intents = [
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
]
class FakeJointNlu:
def __init__(self) -> None:
self.calls: list[dict[str, object]] = []
def predict_multi_intents(self, text, known_intents, threshold=0.45, max_labels=4, top_k=8):
self.calls.append(
{
"text": text,
"threshold": threshold,
"max_labels": max_labels,
"top_k": top_k,
"known_count": len(known_intents),
}
)
return [
type("Candidate", (), {"intent_id": "cabin_window_open", "score": 0.93})(),
type("Candidate", (), {"intent_id": "cabin_play_music", "score": 0.88})(),
]
def warmup(self, sample_text="") -> bool:
_ = sample_text
return True
fake_nlu = FakeJointNlu()
detector = JointBertMultiIntentDetector(fake_nlu, threshold=0.5, top_k=6, max_labels=3)
result = detector.detect("打开车窗并播放音乐", intents)
self.assertTrue(result.detected)
self.assertEqual(result.backend_name, "joint-bert-multi-label")
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
self.assertEqual(fake_nlu.calls[0]["threshold"], 0.5)
self.assertEqual(fake_nlu.calls[0]["top_k"], 6)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,195 @@
from __future__ import annotations
import unittest
from app.schemas.debug import IntentCandidate, MatcherStageDebug
from app.schemas.intent import IntentDefinition
from app.services.intent_registry import IntentRegistry
from app.services.router import IntentMatchResult, MultiStageIntentMatcher
class _FakeMatcher:
def __init__(self, stage_debug: MatcherStageDebug) -> None:
self._stage_debug = stage_debug
def match(self, text: str) -> IntentMatchResult:
_ = text
return IntentMatchResult(intent=None, stage_debug=self._stage_debug)
def _intent(intent_id: str) -> IntentDefinition:
return IntentDefinition(
intent_id=intent_id,
plugin_id=f"mock.{intent_id}",
domain="test",
keywords=[],
examples=[],
)
class RouterDecisionTests(unittest.TestCase):
def setUp(self) -> None:
self.registry = IntentRegistry([_intent("alpha"), _intent("beta"), _intent("gamma")])
def test_execute_when_bert_classifier_is_clear(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=True,
selected_intent="alpha",
score=0.92,
reason="classifier selected best candidate",
backend="joint-bert-local",
candidates=[
IntentCandidate(intent_id="alpha", score=0.92, reason="classifier", model_name="joint-bert-local"),
IntentCandidate(intent_id="beta", score=0.21, reason="classifier", model_name="joint-bert-local"),
],
)
),
],
)
result = matcher.match("alpha")
self.assertEqual(result.debug.decision, "execute")
self.assertEqual(result.intent.intent_id if result.intent else None, "alpha")
def test_clarify_when_bert_top_candidates_are_too_close(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=True,
selected_intent="alpha",
score=0.22,
reason="classifier selected best candidate",
backend="bert-local",
metadata={"threshold": 0.2},
candidates=[
IntentCandidate(intent_id="alpha", score=0.31, reason="classifier", model_name="bert-local"),
IntentCandidate(intent_id="beta", score=0.28, reason="classifier", model_name="bert-local"),
],
)
),
],
route_to_cloud_threshold=0.2,
)
result = matcher.match("ambiguous request")
self.assertEqual(result.debug.decision, "clarify")
self.assertIsNone(result.intent)
self.assertEqual(result.debug.confidence_grade, "medium")
def test_route_to_cloud_when_bert_signal_is_weak_but_known(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=False,
selected_intent="alpha",
score=0.29,
reason="classifier below execute threshold",
backend="joint-bert-local",
candidates=[
IntentCandidate(intent_id="alpha", score=0.29, reason="classifier", model_name="joint-bert-local"),
IntentCandidate(intent_id="beta", score=0.14, reason="classifier", model_name="joint-bert-local"),
],
)
),
],
)
result = matcher.match("weak symbolic request")
self.assertEqual(result.debug.decision, "route_to_cloud")
self.assertIsNone(result.intent)
def test_reject_when_no_branch_has_usable_signal(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=False,
score=0.12,
reason="classifier below threshold",
backend="bert-local",
metadata={"threshold": 0.2},
candidates=[],
)
),
],
)
result = matcher.match("unknown request")
self.assertEqual(result.debug.decision, "reject")
self.assertTrue(result.debug.unknown_detected)
self.assertIsNone(result.intent)
def test_route_to_cloud_for_low_confidence_classifier_only_bert_signal(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=True,
selected_intent="alpha",
score=0.31,
reason="classifier selected best candidate",
backend="bert-local",
metadata={"threshold": 0.0, "top_margin": 0.04},
candidates=[
IntentCandidate(intent_id="alpha", score=0.31, reason="classifier", model_name="bert-local"),
IntentCandidate(intent_id="beta", score=0.27, reason="classifier", model_name="bert-local"),
],
)
),
],
)
result = matcher.match("bert only weak request")
self.assertEqual(result.debug.decision, "route_to_cloud")
self.assertIsNone(result.intent)
def test_execute_for_high_confidence_classifier_only_bert_signal(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=True,
selected_intent="alpha",
score=0.92,
reason="classifier selected best candidate",
backend="bert-local",
metadata={"threshold": 0.0, "top_margin": 0.63},
candidates=[
IntentCandidate(intent_id="alpha", score=0.92, reason="classifier", model_name="bert-local"),
IntentCandidate(intent_id="beta", score=0.29, reason="classifier", model_name="bert-local"),
],
)
),
],
)
result = matcher.match("bert only strong request")
self.assertEqual(result.debug.decision, "execute")
self.assertEqual(result.intent.intent_id if result.intent else None, "alpha")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,180 @@
from __future__ import annotations
import unittest
from app.plugins.base import PluginRegistry
from app.schemas.chat import ChatRequest, FillSlotsRequest
from app.schemas.debug import RoutingDebug
from app.schemas.intent import IntentDefinition
from app.schemas.workflow import Workflow, WorkflowStep
from app.services.agent_service import AgentService
from app.services.intent_registry import IntentRegistry
from app.services.session_store import InMemorySessionStore
from app.services.social import SocialReplyResult, SocialRouter
class _FailingRouter:
def route(self, text: str): # pragma: no cover - should not be called in these tests
raise AssertionError(f"router should not be called for social input: {text}")
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, object]:
_ = (text, intent)
return {}
class _FakeSocialResponder:
def reply(self, text: str, session) -> SocialReplyResult:
_ = (text, session)
normalized = text.strip()
if "你好" in normalized:
text = "你好呀,我在,想聊什么都可以。"
elif "名字" in normalized or "你是谁" in normalized:
text = "我是一名智能座舱助手,你可以直接叫我座舱助手。"
elif "天气" in normalized:
text = "是啊,今天确实挺舒服的。"
else:
text = "我在,咱们继续聊。"
return SocialReplyResult(
text=text,
backend="fake-cloud",
model_name="fake-social",
)
def _intent(intent_id: str, plugin_id: str) -> IntentDefinition:
return IntentDefinition(
intent_id=intent_id,
plugin_id=plugin_id,
domain="service" if intent_id.startswith("cs_") else "cabin",
risk_level="high" if intent_id == "cs_cancel_order" else "low",
required_slots=["order_id"] if intent_id == "cs_cancel_order" else [],
ask_templates={"order_id": "请告诉我订单号。"} if intent_id == "cs_cancel_order" else {},
keywords=[],
examples=[],
)
class SocialChatTests(unittest.TestCase):
def setUp(self) -> None:
self.session_store = InMemorySessionStore()
self.plugins = PluginRegistry()
self.plugins.register(
"mock.cancel_order",
lambda slots: {"success": True, "message": f"已取消订单 {slots.get('order_id', '')}"},
)
self.service = AgentService(
intent_registry=IntentRegistry([_intent("cs_cancel_order", "mock.cancel_order")]),
router=_FailingRouter(),
plugins=self.plugins,
session_store=self.session_store,
social_router=SocialRouter(),
social_responder=_FakeSocialResponder(),
)
def test_greeting_social_reply_uses_social_responder(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_social_hi",
user_id="user_1",
input_text="你好",
)
)
self.assertEqual(response.decision, "open_social")
self.assertEqual(response.status, "social")
self.assertIn("你好呀", response.reply_text)
def test_capability_social_question_does_not_fall_into_business_intent(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_social_name",
user_id="user_1",
input_text="你叫什么名字",
)
)
self.assertEqual(response.decision, "open_social")
self.assertEqual(response.status, "social")
self.assertNotEqual(response.reply_type, "ask_slot")
self.assertIn("智能座舱助手", response.reply_text)
def test_open_social_reply_uses_social_responder(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_social_open",
user_id="user_1",
input_text="今天天气真不错啊",
)
)
self.assertEqual(response.decision, "open_social")
self.assertEqual(response.status, "social")
self.assertIn("挺舒服", response.reply_text)
def test_social_turn_does_not_break_waiting_confirmation(self) -> None:
session = self.session_store.get_or_create("sess_confirm", "user_1")
session.current_intent = "cs_cancel_order"
session.status = "waiting_confirmation"
session.pending_slots = ["confirmation"]
session.slots = {"order_id": "A123456"}
session.routing_debug = RoutingDebug(selected_intent="cs_cancel_order", decision="execute").model_dump()
session.workflow = Workflow(
workflow_id="wf_sess_confirm",
workflow_type="conditional",
domain="service",
intent_id="cs_cancel_order",
status="waiting_confirmation",
risk_level="high",
slots={"order_id": "A123456"},
steps=[
WorkflowStep(
step=1,
step_id="step_confirm",
intent_id="cs_cancel_order",
plugin_id="mock.cancel_order",
action="cancel_order",
status="waiting_confirmation",
slots={"order_id": "A123456"},
requires_confirmation=True,
)
],
meta={
"pending_confirmation": {
"step_id": "step_confirm",
"intent_id": "cs_cancel_order",
"detail": "确认取消订单 A123456",
},
"confirmed_steps": [],
"step_results": {},
},
).model_dump()
self.session_store.save(session)
social_response = self.service.handle_fill_slots(
FillSlotsRequest(
session_id="sess_confirm",
user_id="user_1",
input_text="今天天气真不错啊",
)
)
self.assertEqual(social_response.decision, "open_social")
self.assertEqual(social_response.status, "waiting_confirmation")
self.assertEqual(social_response.pending_slots, ["confirmation"])
self.assertIn("回复“确认”或“取消”即可", social_response.reply_text)
confirm_response = self.service.handle_fill_slots(
FillSlotsRequest(
session_id="sess_confirm",
user_id="user_1",
input_text="确认",
)
)
self.assertEqual(confirm_response.reply_type, "workflow_result")
self.assertEqual(confirm_response.status, "completed")
self.assertIn("已取消订单", confirm_response.reply_text)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,184 @@
from __future__ import annotations
import unittest
from pathlib import Path
from app.schemas.configuration import WorkflowTemplatesConfig
from app.services.classifier import ClassificationResult
from app.services.multi_intent_detector import MultiIntentCandidate, MultiIntentDetectionResult
from app.services.planner import CompositeWorkflowPlanner, HeuristicWorkflowPlanner, TemplateWorkflowPlanner
from app.services.intent_registry import IntentRegistry
class FakeClauseClassifier:
def __init__(self, predictions: dict[str, list[dict[str, float | str]]]) -> None:
self._predictions = predictions
def predict(self, text, intents):
_ = intents
return ClassificationResult(
intent=None,
score=0.0,
model_name="fake-bert-clause",
backend_name="fake-bert-clause",
raw_candidates=self._predictions.get(text, []),
)
class FakeMultiIntentDetector:
def __init__(self, predictions: dict[str, list[tuple[str, float]]]) -> None:
self._predictions = predictions
def detect(self, text, intents):
_ = intents
candidates = [
MultiIntentCandidate(intent_id=intent_id, score=score, label=intent_id)
for intent_id, score in self._predictions.get(text, [])
]
return MultiIntentDetectionResult(
detected=len(candidates) >= 2,
candidates=candidates,
reason="fake detector",
backend_name="fake-multi",
)
class WorkflowTemplateTests(unittest.TestCase):
def setUp(self) -> None:
self.registry = IntentRegistry.from_json("app/data/intents.json")
self.templates = WorkflowTemplatesConfig.model_validate_json(
Path("config/workflows.yml").read_text(encoding="utf-8")
)
def test_template_planner_matches_sequence_template(self) -> None:
planner = TemplateWorkflowPlanner(self.templates)
result = planner.plan("打开车窗然后把空调调到20度", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.backend, "local-template")
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_set_ac"])
self.assertEqual(result.steps[1].slots.get("temperature"), 20)
def test_template_planner_matches_conditional_template(self) -> None:
planner = TemplateWorkflowPlanner(self.templates)
result = planner.plan("查一下订单A123456如果还没发货就取消", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "conditional")
self.assertEqual([step.intent_id for step in result.steps], ["cs_query_order", "cs_cancel_order"])
self.assertEqual(result.steps[1].depends_on, [1])
self.assertTrue(result.steps[1].requires_confirmation)
def test_composite_planner_falls_back_to_heuristic_when_template_misses(self) -> None:
planner = CompositeWorkflowPlanner([TemplateWorkflowPlanner(self.templates), HeuristicWorkflowPlanner()])
result = planner.plan("打开车窗,并且播放轻音乐", self.registry.list())
self.assertTrue(result.accepted)
self.assertIn(result.backend, {"local-template", "local-heuristic"})
self.assertEqual(result.workflow_type, "sequence")
def test_heuristic_planner_parses_ac_then_window_close_sequence(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("打开空调,再把窗户降下来", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.backend, "local-heuristic")
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_close"])
def test_planner_metadata_contains_clause_analysis(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("打开空调,然后打开车窗", self.registry.list())
self.assertTrue(result.accepted)
self.assertTrue(result.metadata.get("multi_intent_detected"))
clause_analysis = result.metadata.get("clause_analysis", [])
self.assertEqual(len(clause_analysis), 2)
self.assertEqual(clause_analysis[0].get("selected_intent_id"), "cabin_ac_on")
self.assertEqual(clause_analysis[1].get("selected_intent_id"), "cabin_window_open")
def test_heuristic_planner_supports_shared_action_parallel_objects(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("打开空调和车窗", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_open"])
def test_heuristic_planner_supports_parallel_objects_with_suffix_action(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("把车窗和天窗打开", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_sunroof_open"])
def test_heuristic_planner_supports_parallel_clause_with_bing_connector(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("打开车窗并播放轻音乐", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
def test_heuristic_planner_can_use_clause_classifier_to_rescue_semantic_clause(self) -> None:
planner = HeuristicWorkflowPlanner(
clause_classifier=FakeClauseClassifier(
{
"车里太闷了": [
{"label": "cabin_window_open", "intent_id": "cabin_window_open", "score": 0.83},
],
"来点轻音乐": [
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
],
}
)
)
result = planner.plan("车里太闷了,然后来点轻音乐", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
clause_analysis = result.metadata.get("clause_analysis", [])
self.assertGreater(clause_analysis[0].get("candidates", [])[0].get("model_score", 0.0), 0.8)
def test_heuristic_planner_can_use_multi_intent_detector_prior(self) -> None:
planner = HeuristicWorkflowPlanner(
clause_classifier=FakeClauseClassifier(
{
"来点轻音乐": [
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
],
}
),
multi_intent_detector=FakeMultiIntentDetector(
{
"顺便开下车窗,再来点轻音乐": [
("cabin_window_open", 0.87),
("cabin_play_music", 0.82),
]
}
),
)
result = planner.plan("顺便开下车窗,再来点轻音乐", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
detector_meta = result.metadata.get("multi_intent_detector") or {}
self.assertTrue(detector_meta.get("detected"))
self.assertEqual(len(detector_meta.get("candidates", [])), 2)
if __name__ == "__main__":
unittest.main()