Update project and configurations
This commit is contained in:
2254
intelligent_cabin/archive/demo/demo.html
Normal file
2254
intelligent_cabin/archive/demo/demo.html
Normal file
File diff suppressed because it is too large
Load Diff
1722
intelligent_cabin/archive/docs/current_system_flow.md
Normal file
1722
intelligent_cabin/archive/docs/current_system_flow.md
Normal file
File diff suppressed because it is too large
Load Diff
2104
intelligent_cabin/archive/docs/design.md
Normal file
2104
intelligent_cabin/archive/docs/design.md
Normal file
File diff suppressed because it is too large
Load Diff
467
intelligent_cabin/archive/docs/solution_review.md
Normal file
467
intelligent_cabin/archive/docs/solution_review.md
Normal file
@@ -0,0 +1,467 @@
|
||||
# 对齐车机方案的实施流程复审
|
||||
|
||||
## 1. 目标
|
||||
|
||||
本文档用于重新审视当前项目的实现方向,并将“对齐小鹏车机类方案”的核心流程、分支决策和典型场景演示明确下来。
|
||||
|
||||
目标不是做一个普通聊天机器人,而是做一个面向车机/客服/前台场景的低延迟执行型 Agent:
|
||||
|
||||
- 简单请求本地快速闭环
|
||||
- 复杂请求云端增强处理
|
||||
- 多轮短句能够恢复上下文
|
||||
- 多命令能够拆分为 workflow 执行
|
||||
- 超出能力边界时明确拒绝或澄清
|
||||
|
||||
---
|
||||
|
||||
## 2. 对齐车机方案的核心原则
|
||||
|
||||
结合前面参考的小鹏公开专利,可以抽象出下面几个工程原则:
|
||||
|
||||
- 本地优先,云端增强,不是所有请求都直接走远端大模型
|
||||
- 本地不是单模型,而是多支路并发
|
||||
- 快反馈和最终反馈分离,首反馈必须短而稳
|
||||
- 多轮短句优先靠上下文改写和缓存恢复,不是每轮重做完整规划
|
||||
- 高风险动作必须确认
|
||||
- 超能力边界时必须拒答,不能强行分类执行
|
||||
|
||||
---
|
||||
|
||||
## 3. 总体流程图
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[用户语音输入] --> B[ASR 转文本]
|
||||
B --> C[文本归一化]
|
||||
C --> D[本地多支路并发]
|
||||
|
||||
D --> D1[keyword / rule / trie]
|
||||
D --> D2[local bert classifier]
|
||||
D --> D3[context rewrite / cache]
|
||||
D --> D4[retrieval matcher]
|
||||
|
||||
D1 --> E[本地融合分级器]
|
||||
D2 --> E
|
||||
D3 --> E
|
||||
D4 --> E
|
||||
|
||||
E -->|高置信| F[直接执行或直接生成 workflow]
|
||||
E -->|中置信| G[等待 100~300ms 观察补充分支/云端结果]
|
||||
E -->|低置信| H[云端 planner / LLM / RAG]
|
||||
|
||||
G --> H
|
||||
H --> I[生成 intent / workflow / clarify / reject]
|
||||
|
||||
F --> J[插件执行]
|
||||
I --> J
|
||||
|
||||
J --> K[首反馈 ack / progress]
|
||||
J --> L[最终反馈 result / clarify / reject]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 本地与云端分工
|
||||
|
||||
### 4.1 本地快链路
|
||||
|
||||
本地负责:
|
||||
|
||||
- 高频固定控制类命令
|
||||
- 已知业务集合内的快速意图识别
|
||||
- 短句、省略句、连续调节的上下文恢复
|
||||
- 简单任务的直接执行
|
||||
- `<1s` 以内的首响应
|
||||
|
||||
本地组件:
|
||||
|
||||
- `keyword / rule / trie`
|
||||
- `local bert classifier`
|
||||
- `context rewrite / cache`
|
||||
- `retrieval matcher`
|
||||
- `fusion grader`
|
||||
|
||||
### 4.2 云端慢链路
|
||||
|
||||
云端负责:
|
||||
|
||||
- 多命令拆分
|
||||
- 条件型请求理解
|
||||
- 歧义消解
|
||||
- 复杂问答
|
||||
- planner 级 workflow 生成
|
||||
|
||||
---
|
||||
|
||||
## 5. 本地分级决策图
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[本地多分支结果] --> B{融合分级}
|
||||
B -->|high| C[直接执行]
|
||||
B -->|medium| D[等待 100~300ms]
|
||||
B -->|low| E[不直接执行]
|
||||
|
||||
D --> F{云端是否及时返回}
|
||||
F -->|是| G[采用云端结果或覆盖本地]
|
||||
F -->|否| H{本地是否达到最低执行阈值}
|
||||
|
||||
H -->|是| I[执行本地结果]
|
||||
H -->|否| J[澄清或拒答]
|
||||
|
||||
E --> K[云端理解 / clarify / reject]
|
||||
```
|
||||
|
||||
关键修正点:
|
||||
|
||||
- 不是“本地 BERT 有结果就执行”
|
||||
- 不是“云端没返回就一定执行本地”
|
||||
- 必须先判断本地结果是否达到最低可执行阈值
|
||||
- 高风险动作即使高置信,也不能直接执行
|
||||
|
||||
---
|
||||
|
||||
## 6. 用户反馈状态图
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> Received
|
||||
Received --> Ack: 已接收请求
|
||||
Ack --> ExecutingFast: 本地快执行
|
||||
Ack --> WaitingCloud: 等待云端/复杂规划
|
||||
Ack --> Clarify: 缺关键槽位
|
||||
Ack --> Confirm: 高风险动作
|
||||
Ack --> Reject: 超能力边界
|
||||
|
||||
ExecutingFast --> Result
|
||||
WaitingCloud --> Result
|
||||
Clarify --> Result
|
||||
Confirm --> Result
|
||||
Reject --> [*]
|
||||
Result --> [*]
|
||||
```
|
||||
|
||||
反馈规则:
|
||||
|
||||
- `ack`:收到,马上处理
|
||||
- `progress`:正在为你处理
|
||||
- `result`:执行完成或查询完成
|
||||
- `clarify`:信息不足,补一个关键字段
|
||||
- `confirm`:高风险动作确认
|
||||
- `reject`:能力边界拒答
|
||||
|
||||
---
|
||||
|
||||
## 7. 反馈模板策略
|
||||
|
||||
### 7.1 快执行
|
||||
|
||||
适用:
|
||||
|
||||
- 打开车窗
|
||||
- 调低空调
|
||||
- 播放音乐
|
||||
- 导航去公司
|
||||
|
||||
推荐反馈:
|
||||
|
||||
- 首反馈:`好的,正在打开车窗`
|
||||
- 最终反馈:`车窗已打开`
|
||||
|
||||
如果设备动作极快,也可以直接播报最终反馈:
|
||||
|
||||
- `车窗已打开`
|
||||
|
||||
### 7.2 慢执行
|
||||
|
||||
适用:
|
||||
|
||||
- 查订单
|
||||
- 查物流
|
||||
- 多命令复杂规划
|
||||
- 条件型任务
|
||||
|
||||
推荐反馈:
|
||||
|
||||
- 首反馈:`收到,我先帮你查一下`
|
||||
- 最终反馈:`订单还没发货`
|
||||
|
||||
### 7.3 复合命令
|
||||
|
||||
例如:
|
||||
|
||||
- `打开车窗,空调调低至20度`
|
||||
|
||||
推荐反馈:
|
||||
|
||||
- 首反馈:`好的,正在为你打开车窗并调低空调`
|
||||
- 最终反馈:`车窗已打开,空调已调到20度`
|
||||
|
||||
### 7.4 边界外请求
|
||||
|
||||
例如:
|
||||
|
||||
- `打开飞机门`
|
||||
|
||||
推荐反馈:
|
||||
|
||||
- `这个我暂时做不了,但我可以帮你导航、查订单、调空调或播放音乐`
|
||||
|
||||
---
|
||||
|
||||
## 8. 多轮上下文恢复流程
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[当前轮输入: 再低一点] --> B[读取 session context]
|
||||
B --> C{是否命中高频改写缓存}
|
||||
C -->|是| D[改写为完整句]
|
||||
C -->|否| E[轻量改写模型/规则补全]
|
||||
D --> F[进入 router]
|
||||
E --> F
|
||||
F --> G[意图识别 + 槽位提取 + 执行]
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- 这里的“上下文能力”不是 BERT 自己缓存的
|
||||
- 而是 Agent 在 `session state` 中保存上轮任务和关键槽位
|
||||
- 再由 `rewrite engine` 完成短句恢复
|
||||
|
||||
---
|
||||
|
||||
## 9. 典型场景流程演示
|
||||
|
||||
### 9.1 场景一:快执行单命令
|
||||
|
||||
用户输入:
|
||||
|
||||
- `打开车窗`
|
||||
|
||||
处理流程:
|
||||
|
||||
1. ASR 转文本:`打开车窗`
|
||||
2. 文本归一化
|
||||
3. 本地多支路并发
|
||||
4. 若本地高置信命中 `cabin_open_window`
|
||||
5. 直接执行车窗插件
|
||||
6. 播报:`车窗已打开`
|
||||
|
||||
时序演示:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant U as 用户
|
||||
participant A as ASR
|
||||
participant R as 本地路由
|
||||
participant P as 插件执行
|
||||
participant T as TTS
|
||||
|
||||
U->>A: 打开车窗
|
||||
A->>R: 打开车窗
|
||||
R->>P: 执行 open_window
|
||||
P-->>R: success
|
||||
R->>T: 车窗已打开
|
||||
T-->>U: 车窗已打开
|
||||
```
|
||||
|
||||
### 9.2 场景二:复合命令快执行
|
||||
|
||||
用户输入:
|
||||
|
||||
- `打开车窗,空调调低至20度`
|
||||
|
||||
处理流程:
|
||||
|
||||
1. 文本进入本地路由
|
||||
2. 判断为多命令
|
||||
3. planner 或本地 splitter 输出两个 step
|
||||
4. 生成 sequence workflow
|
||||
5. 顺序执行:
|
||||
- 打开车窗
|
||||
- 空调调到 20 度
|
||||
6. 汇总反馈:`车窗已打开,空调已调到20度`
|
||||
|
||||
时序演示:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant U as 用户
|
||||
participant A as ASR
|
||||
participant F as 融合分级器
|
||||
participant W as Workflow
|
||||
participant P as 插件层
|
||||
participant T as TTS
|
||||
|
||||
U->>A: 打开车窗,空调调低至20度
|
||||
A->>F: 规范化文本
|
||||
F->>W: 输出 sequence workflow
|
||||
W->>P: step1 open_window
|
||||
P-->>W: success
|
||||
W->>P: step2 set_ac(20)
|
||||
P-->>W: success
|
||||
W->>T: 车窗已打开,空调已调到20度
|
||||
T-->>U: 车窗已打开,空调已调到20度
|
||||
```
|
||||
|
||||
### 9.3 场景三:慢执行查询
|
||||
|
||||
用户输入:
|
||||
|
||||
- `帮我查一下订单A123456`
|
||||
|
||||
处理流程:
|
||||
|
||||
1. 本地高频分支命中订单查询
|
||||
2. 首反馈先给:
|
||||
- `收到,我帮你查一下`
|
||||
3. 调用订单查询插件
|
||||
4. 最终反馈:
|
||||
- `订单A123456当前待发货`
|
||||
|
||||
### 9.4 场景四:条件型请求
|
||||
|
||||
用户输入:
|
||||
|
||||
- `查一下订单A123456,如果还没发货就取消`
|
||||
|
||||
处理流程:
|
||||
|
||||
1. 本地识别该请求复杂,进入云端 planner
|
||||
2. planner 输出 conditional workflow:
|
||||
- step1: query_order
|
||||
- step2: cancel_order
|
||||
- condition: order_status == pending_shipment
|
||||
3. 先执行 step1
|
||||
4. 若满足条件,则进入确认
|
||||
5. 用户回复确认后,再执行取消
|
||||
6. 最终反馈:
|
||||
- `订单A123456已取消`
|
||||
|
||||
时序演示:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant U as 用户
|
||||
participant R as 本地融合器
|
||||
participant C as 云端 Planner
|
||||
participant W as Workflow
|
||||
participant P as 插件层
|
||||
participant T as TTS
|
||||
|
||||
U->>R: 查一下订单A123456,如果还没发货就取消
|
||||
R->>C: 复杂条件请求
|
||||
C-->>W: conditional workflow
|
||||
W->>P: step1 query_order
|
||||
P-->>W: order_status=pending_shipment
|
||||
W->>T: 即将取消订单,仅在订单未发货时取消。请回复确认或取消
|
||||
T-->>U: 确认提示
|
||||
U->>W: 确认
|
||||
W->>P: step2 cancel_order
|
||||
P-->>W: success
|
||||
W->>T: 订单A123456已取消
|
||||
T-->>U: 订单A123456已取消
|
||||
```
|
||||
|
||||
### 9.5 场景五:多轮短句恢复
|
||||
|
||||
对话过程:
|
||||
|
||||
1. 用户:`把空调调到22度`
|
||||
2. 系统:`空调已调到22度`
|
||||
3. 用户:`再低一点`
|
||||
4. 系统读取 `last_intent=cabin_set_ac` 和 `last_temperature=22`
|
||||
5. rewrite engine 改写为:`把空调调到21度`
|
||||
6. 再进入意图识别和执行
|
||||
7. 系统反馈:`空调已调到21度`
|
||||
|
||||
### 9.6 场景六:边界外请求
|
||||
|
||||
用户输入:
|
||||
|
||||
- `打开飞机门`
|
||||
|
||||
正确处理:
|
||||
|
||||
1. 本地分支都无法稳定支持
|
||||
2. 若低于执行阈值,不得直接执行已有意图
|
||||
3. 进入:
|
||||
- reject
|
||||
- 或云端澄清
|
||||
4. 反馈:
|
||||
- `这个我暂时做不了,但我可以帮你导航、查订单、调空调或播放音乐`
|
||||
|
||||
这类场景必须通过 `unknown / out_of_scope` 机制处理,不能靠封闭集分类硬选。
|
||||
|
||||
---
|
||||
|
||||
## 10. 当前项目与目标方案的对应关系
|
||||
|
||||
### 10.1 已经具备的能力
|
||||
|
||||
- 本地 `keyword / classifier / retrieval / fusion`
|
||||
- 本地 BERT 分类器
|
||||
- `session state`
|
||||
- `context rewrite`
|
||||
- `planner`
|
||||
- `sequence / conditional workflow`
|
||||
- 高风险确认
|
||||
- demo 调试面板
|
||||
|
||||
### 10.2 还需要补齐的关键能力
|
||||
|
||||
- `unknown / out_of_scope`
|
||||
- 低分拒识策略
|
||||
- 明确的 `execute / reject / route_to_cloud` 决策建议
|
||||
- 更多真实车机意图
|
||||
- 真实插件接入
|
||||
- 语音前端与 ASR/TTS 完整接入
|
||||
|
||||
---
|
||||
|
||||
## 11. 当前阶段的正式实施结论
|
||||
|
||||
当前方向可以继续,但必须明确:
|
||||
|
||||
- 本地 BERT 是本地快分支之一,不是整个系统的唯一裁决者
|
||||
- 最终执行依据应来自“本地融合分级器 + planner + 风险规则”
|
||||
- 用户体验的关键不只是识别正确,还包括:
|
||||
- 首反馈是否快
|
||||
- 多轮是否顺
|
||||
- 边界是否清楚
|
||||
- 风险是否可控
|
||||
|
||||
因此,当前正式方案应定义为:
|
||||
|
||||
```text
|
||||
车机型 Agent = 本地并发快链路
|
||||
+ 上下文改写缓存
|
||||
+ 分级融合决策
|
||||
+ 云端 planner
|
||||
+ workflow 执行
|
||||
+ 风险确认
|
||||
+ reject / clarify / fallback 策略
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 12. 下一步执行优先级
|
||||
|
||||
建议按以下顺序继续实现:
|
||||
|
||||
1. 补 `unknown / out_of_scope` 和拒识阈值
|
||||
2. 输出统一执行建议:
|
||||
- `execute`
|
||||
- `clarify`
|
||||
- `reject`
|
||||
- `route_to_cloud`
|
||||
3. 扩真实车机场景意图:
|
||||
- 车窗
|
||||
- 车门
|
||||
- 座椅
|
||||
- 灯光
|
||||
- 后视镜
|
||||
- 除雾
|
||||
4. 强化 rewrite/cache 高频模式
|
||||
5. 接真实插件与真实语音链路
|
||||
|
||||
BIN
intelligent_cabin/archive/patents/202510261979_CN120089140A.pdf
Normal file
BIN
intelligent_cabin/archive/patents/202510261979_CN120089140A.pdf
Normal file
Binary file not shown.
BIN
intelligent_cabin/archive/patents/CN114299931B.pdf
Normal file
BIN
intelligent_cabin/archive/patents/CN114299931B.pdf
Normal file
Binary file not shown.
BIN
intelligent_cabin/archive/patents/CN115394300A.pdf
Normal file
BIN
intelligent_cabin/archive/patents/CN115394300A.pdf
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
1095
intelligent_cabin/archive/patents/texts/CN114299931B.txt
Normal file
1095
intelligent_cabin/archive/patents/texts/CN114299931B.txt
Normal file
File diff suppressed because it is too large
Load Diff
1814
intelligent_cabin/archive/patents/texts/CN115394300A.txt
Normal file
1814
intelligent_cabin/archive/patents/texts/CN115394300A.txt
Normal file
File diff suppressed because it is too large
Load Diff
65
intelligent_cabin/archive/reports/bert_local_test_report.md
Normal file
65
intelligent_cabin/archive/reports/bert_local_test_report.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# 本地 BERT 意图识别测试报告
|
||||
|
||||
## 概览
|
||||
- 模型目录:`/Users/hwp/Documents/trae_projects/intelligent_cabin/models/local_bert_intent`
|
||||
- 评测集:`/Users/hwp/Documents/trae_projects/intelligent_cabin/app/data/bert_intent_eval_independent.jsonl`
|
||||
- 评测阈值:`0.0`
|
||||
- 测试样本数:`42`
|
||||
- 总体准确率:`0.9762`
|
||||
|
||||
## 训练摘要
|
||||
- 基座模型:`hfl/chinese-macbert-base`
|
||||
- 训练集 / 验证集:`1557 / 401`
|
||||
- 最佳验证准确率:`0.9875`
|
||||
- 训练设备:`mps`
|
||||
|
||||
## 分类别结果
|
||||
- `business`: 33/34 = 0.9706
|
||||
- `out_of_scope`: 4/4 = 1.0
|
||||
- `social`: 4/4 = 1.0
|
||||
|
||||
## 分标签结果
|
||||
- `__out_of_scope__` (out_of_scope): 4/4 = 1.0
|
||||
- `__social__` (social): 4/4 = 1.0
|
||||
- `cabin_ac_off` (business): 1/1 = 1.0
|
||||
- `cabin_ac_on` (business): 1/1 = 1.0
|
||||
- `cabin_defog_front_on` (business): 1/1 = 1.0
|
||||
- `cabin_defog_rear_on` (business): 1/1 = 1.0
|
||||
- `cabin_fan_down` (business): 1/1 = 1.0
|
||||
- `cabin_fan_up` (business): 1/1 = 1.0
|
||||
- `cabin_lights_off` (business): 1/1 = 1.0
|
||||
- `cabin_lights_on` (business): 1/1 = 1.0
|
||||
- `cabin_lock_doors` (business): 1/1 = 1.0
|
||||
- `cabin_mirror_fold` (business): 1/1 = 1.0
|
||||
- `cabin_mirror_unfold` (business): 1/1 = 1.0
|
||||
- `cabin_nav_cancel` (business): 1/1 = 1.0
|
||||
- `cabin_nav_to` (business): 1/1 = 1.0
|
||||
- `cabin_next_track` (business): 1/1 = 1.0
|
||||
- `cabin_pause_music` (business): 1/1 = 1.0
|
||||
- `cabin_play_music` (business): 1/1 = 1.0
|
||||
- `cabin_previous_track` (business): 1/1 = 1.0
|
||||
- `cabin_seat_heat_off` (business): 1/1 = 1.0
|
||||
- `cabin_seat_heat_on` (business): 1/1 = 1.0
|
||||
- `cabin_set_ac` (business): 1/1 = 1.0
|
||||
- `cabin_sunroof_close` (business): 1/1 = 1.0
|
||||
- `cabin_sunroof_open` (business): 1/1 = 1.0
|
||||
- `cabin_unlock_doors` (business): 1/1 = 1.0
|
||||
- `cabin_volume_down` (business): 1/1 = 1.0
|
||||
- `cabin_volume_mute` (business): 1/1 = 1.0
|
||||
- `cabin_volume_up` (business): 1/1 = 1.0
|
||||
- `cabin_window_close` (business): 1/1 = 1.0
|
||||
- `cabin_window_open` (business): 0/1 = 0.0
|
||||
- `cabin_wiper_off` (business): 1/1 = 1.0
|
||||
- `cabin_wiper_on` (business): 1/1 = 1.0
|
||||
- `cs_cancel_order` (business): 1/1 = 1.0
|
||||
- `cs_query_logistics` (business): 1/1 = 1.0
|
||||
- `cs_query_order` (business): 1/1 = 1.0
|
||||
- `cs_transfer_human` (business): 1/1 = 1.0
|
||||
|
||||
## 错误样例
|
||||
- 文本:`左前窗打开一点` | 类别:`business` | 期望:`cabin_window_open` | 预测:`cabin_defog_front_on` | 分数:`0.9951`
|
||||
|
||||
## 结论
|
||||
- 当前本地 MacBERT 已具备较强的业务意图识别能力,可作为本地快链路分类器。
|
||||
- 误判主要集中在方向相反或语义接近的控制指令,下一步应补充对抗样本和真实口语表达。
|
||||
- 上线前建议继续补充 ASR 错字、多轮短句和多意图子句级样本。
|
||||
426
intelligent_cabin/archive/reports/bert_local_test_result.json
Normal file
426
intelligent_cabin/archive/reports/bert_local_test_result.json
Normal file
@@ -0,0 +1,426 @@
|
||||
{
|
||||
"model_dir": "/Users/hwp/Documents/trae_projects/intelligent_cabin/models/local_bert_intent",
|
||||
"threshold": 0.0,
|
||||
"test_path": "/Users/hwp/Documents/trae_projects/intelligent_cabin/app/data/bert_intent_eval_independent.jsonl",
|
||||
"test_case_count": 42,
|
||||
"accuracy": 0.9762,
|
||||
"train_summary": {
|
||||
"base_model": "hfl/chinese-macbert-base",
|
||||
"epochs": 16,
|
||||
"batch_size": 8,
|
||||
"learning_rate": 2e-05,
|
||||
"train_size": 1557,
|
||||
"dev_size": 401,
|
||||
"best_dev_accuracy": 0.9875,
|
||||
"device": "mps"
|
||||
},
|
||||
"per_category": [
|
||||
{
|
||||
"category": "business",
|
||||
"total": 34,
|
||||
"correct": 33,
|
||||
"accuracy": 0.9706
|
||||
},
|
||||
{
|
||||
"category": "out_of_scope",
|
||||
"total": 4,
|
||||
"correct": 4,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"category": "social",
|
||||
"total": 4,
|
||||
"correct": 4,
|
||||
"accuracy": 1.0
|
||||
}
|
||||
],
|
||||
"per_label": [
|
||||
{
|
||||
"label": "__out_of_scope__",
|
||||
"category": "out_of_scope",
|
||||
"total": 4,
|
||||
"correct": 4,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "__social__",
|
||||
"category": "social",
|
||||
"total": 4,
|
||||
"correct": 4,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_ac_off",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_ac_on",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_defog_front_on",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_defog_rear_on",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_fan_down",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_fan_up",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_lights_off",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_lights_on",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_lock_doors",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_mirror_fold",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_mirror_unfold",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_nav_cancel",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_nav_to",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_next_track",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_pause_music",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_play_music",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_previous_track",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_seat_heat_off",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_seat_heat_on",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_set_ac",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_sunroof_close",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_sunroof_open",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_unlock_doors",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_volume_down",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_volume_mute",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_volume_up",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_window_close",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_window_open",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 0,
|
||||
"accuracy": 0.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_wiper_off",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cabin_wiper_on",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cs_cancel_order",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cs_query_logistics",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cs_query_order",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
},
|
||||
{
|
||||
"label": "cs_transfer_human",
|
||||
"category": "business",
|
||||
"total": 1,
|
||||
"correct": 1,
|
||||
"accuracy": 1.0
|
||||
}
|
||||
],
|
||||
"errors": [
|
||||
{
|
||||
"text": "左前窗打开一点",
|
||||
"category": "business",
|
||||
"expected_label": "cabin_window_open",
|
||||
"predicted_label": "cabin_defog_front_on",
|
||||
"score": 0.9951,
|
||||
"raw_label": "cabin_defog_front_on",
|
||||
"ok": false,
|
||||
"top_candidates": [
|
||||
{
|
||||
"intent_id": "cabin_defog_front_on",
|
||||
"score": 0.9951
|
||||
},
|
||||
{
|
||||
"intent_id": "cabin_sunroof_open",
|
||||
"score": 0.0005
|
||||
},
|
||||
{
|
||||
"intent_id": "cabin_lights_on",
|
||||
"score": 0.0004
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"confusion": {
|
||||
"cabin_ac_off": {
|
||||
"cabin_ac_off": 1
|
||||
},
|
||||
"cabin_ac_on": {
|
||||
"cabin_ac_on": 1
|
||||
},
|
||||
"cabin_defog_front_on": {
|
||||
"cabin_defog_front_on": 1
|
||||
},
|
||||
"cabin_defog_rear_on": {
|
||||
"cabin_defog_rear_on": 1
|
||||
},
|
||||
"cabin_fan_down": {
|
||||
"cabin_fan_down": 1
|
||||
},
|
||||
"cabin_fan_up": {
|
||||
"cabin_fan_up": 1
|
||||
},
|
||||
"cabin_lights_off": {
|
||||
"cabin_lights_off": 1
|
||||
},
|
||||
"cabin_lights_on": {
|
||||
"cabin_lights_on": 1
|
||||
},
|
||||
"cabin_lock_doors": {
|
||||
"cabin_lock_doors": 1
|
||||
},
|
||||
"cabin_mirror_fold": {
|
||||
"cabin_mirror_fold": 1
|
||||
},
|
||||
"cabin_mirror_unfold": {
|
||||
"cabin_mirror_unfold": 1
|
||||
},
|
||||
"cabin_nav_cancel": {
|
||||
"cabin_nav_cancel": 1
|
||||
},
|
||||
"cabin_nav_to": {
|
||||
"cabin_nav_to": 1
|
||||
},
|
||||
"cabin_next_track": {
|
||||
"cabin_next_track": 1
|
||||
},
|
||||
"cabin_pause_music": {
|
||||
"cabin_pause_music": 1
|
||||
},
|
||||
"cabin_play_music": {
|
||||
"cabin_play_music": 1
|
||||
},
|
||||
"cabin_previous_track": {
|
||||
"cabin_previous_track": 1
|
||||
},
|
||||
"cabin_seat_heat_off": {
|
||||
"cabin_seat_heat_off": 1
|
||||
},
|
||||
"cabin_seat_heat_on": {
|
||||
"cabin_seat_heat_on": 1
|
||||
},
|
||||
"cabin_set_ac": {
|
||||
"cabin_set_ac": 1
|
||||
},
|
||||
"cabin_sunroof_close": {
|
||||
"cabin_sunroof_close": 1
|
||||
},
|
||||
"cabin_sunroof_open": {
|
||||
"cabin_sunroof_open": 1
|
||||
},
|
||||
"cabin_unlock_doors": {
|
||||
"cabin_unlock_doors": 1
|
||||
},
|
||||
"cabin_volume_down": {
|
||||
"cabin_volume_down": 1
|
||||
},
|
||||
"cabin_volume_mute": {
|
||||
"cabin_volume_mute": 1
|
||||
},
|
||||
"cabin_volume_up": {
|
||||
"cabin_volume_up": 1
|
||||
},
|
||||
"cabin_window_close": {
|
||||
"cabin_window_close": 1
|
||||
},
|
||||
"cabin_window_open": {
|
||||
"cabin_defog_front_on": 1
|
||||
},
|
||||
"cabin_wiper_off": {
|
||||
"cabin_wiper_off": 1
|
||||
},
|
||||
"cabin_wiper_on": {
|
||||
"cabin_wiper_on": 1
|
||||
},
|
||||
"cs_cancel_order": {
|
||||
"cs_cancel_order": 1
|
||||
},
|
||||
"cs_query_logistics": {
|
||||
"cs_query_logistics": 1
|
||||
},
|
||||
"cs_query_order": {
|
||||
"cs_query_order": 1
|
||||
},
|
||||
"cs_transfer_human": {
|
||||
"cs_transfer_human": 1
|
||||
},
|
||||
"__social__": {
|
||||
"__social__": 4
|
||||
},
|
||||
"__out_of_scope__": {
|
||||
"__out_of_scope__": 4
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
# 本地多标签 Detector 独立评测报告
|
||||
|
||||
## 概览
|
||||
- 模型目录:`/Users/hwp/Documents/trae_projects/intelligent_cabin/models/local_bert_multi_intent`
|
||||
- 评测集:`/Users/hwp/Documents/trae_projects/intelligent_cabin/app/data/bert_intent_multilabel_eval_independent.jsonl`
|
||||
- 样本数:`37`
|
||||
- 阈值 / top_k / max_labels:`0.45 / 8 / 4`
|
||||
- `micro_precision`:`0.9362`
|
||||
- `micro_recall`:`0.6377`
|
||||
- `micro_f1`:`0.7586`
|
||||
- `exact_match`:`0.5135`
|
||||
- `multi_sentence_recall`:`0.4138`
|
||||
- `single_guard_false_alarm_rate`:`0.0`
|
||||
|
||||
## 分类别结果
|
||||
- `cabin_parallel`: count=15 micro_f1=0.807 exact_match=0.4667
|
||||
- `cabin_sequence`: count=9 micro_f1=0.5385 exact_match=0.3333
|
||||
- `cs_conditional`: count=3 micro_f1=0.9091 exact_match=0.6667
|
||||
- `cs_sequence`: count=2 micro_f1=0.6667 exact_match=0.0
|
||||
- `single_guard`: count=8 micro_f1=0.875 exact_match=0.875
|
||||
|
||||
## 主要混淆
|
||||
- 漏掉 `cabin_sunroof_open`,同时误报 `cabin_window_open`:`1` 次
|
||||
- 漏掉 `cabin_pause_music`,同时误报 `cabin_play_music`:`1` 次
|
||||
- 漏掉 `cabin_window_open`,同时误报 `cabin_defog_front_on`:`1` 次
|
||||
|
||||
## 错误样例
|
||||
- 文本:`锁车门,再把后视镜收起来` | 类别:`cabin_sequence` | 期望:`['cabin_lock_doors', 'cabin_mirror_fold']` | 预测:`[]`
|
||||
- 文本:`把车门解锁,再把镜子展开` | 类别:`cabin_sequence` | 期望:`['cabin_mirror_unfold', 'cabin_unlock_doors']` | 预测:`[]`
|
||||
- 文本:`路线别导了,音乐也停一下` | 类别:`cabin_parallel` | 期望:`['cabin_nav_cancel', 'cabin_pause_music']` | 预测:`[]`
|
||||
- 文本:`雨停了,雨刮关掉,再把窗开一点` | 类别:`cabin_sequence` | 期望:`['cabin_window_open', 'cabin_wiper_off']` | 预测:`[]`
|
||||
- 文本:`把天窗合上,然后把音乐暂停` | 类别:`cabin_sequence` | 期望:`['cabin_pause_music', 'cabin_sunroof_close']` | 预测:`[]`
|
||||
- 文本:`先把音量调大,再切下一首` | 类别:`cabin_parallel` | 期望:`['cabin_next_track', 'cabin_volume_up']` | 预测:`[]`
|
||||
- 文本:`静音之后切回上一首` | 类别:`cabin_sequence` | 期望:`['cabin_previous_track', 'cabin_volume_mute']` | 预测:`[]`
|
||||
- 文本:`把天窗打开透口气,再开空调` | 类别:`cabin_parallel` | 期望:`['cabin_ac_on', 'cabin_sunroof_open']` | 预测:`['cabin_ac_on', 'cabin_window_open']`
|
||||
- 文本:`音乐停一下,然后导航到公司` | 类别:`cabin_sequence` | 期望:`['cabin_nav_to', 'cabin_pause_music']` | 预测:`['cabin_nav_to', 'cabin_play_music']`
|
||||
- 文本:`把左前窗降一点` | 类别:`single_guard` | 期望:`['cabin_window_open']` | 预测:`['cabin_defog_front_on']`
|
||||
- 文本:`车里闷,给我透个气,再放点轻松的歌` | 类别:`cabin_parallel` | 期望:`['cabin_play_music', 'cabin_window_open']` | 预测:`['cabin_play_music']`
|
||||
- 文本:`把空调开了,风别太小,再来首歌` | 类别:`cabin_parallel` | 期望:`['cabin_ac_on', 'cabin_fan_up', 'cabin_play_music']` | 预测:`['cabin_ac_on', 'cabin_play_music']`
|
||||
- 文本:`开导航去徐家汇,顺便把风量调大` | 类别:`cabin_parallel` | 期望:`['cabin_fan_up', 'cabin_nav_to']` | 预测:`['cabin_nav_to']`
|
||||
- 文本:`温度调到二十三度,风稍微小一点` | 类别:`cabin_parallel` | 期望:`['cabin_fan_down', 'cabin_set_ac']` | 预测:`['cabin_set_ac']`
|
||||
- 文本:`帮我看A812302物流,要是太慢就转人工` | 类别:`cs_conditional` | 期望:`['cs_query_logistics', 'cs_transfer_human']` | 预测:`['cs_query_logistics']`
|
||||
|
||||
## 结论建议
|
||||
- 先看多意图句是否存在系统性漏召回,再看单意图是否被误报成多意图。
|
||||
- 若 `single_guard_false_alarm_rate` 偏高,需要先收紧 detector 阈值或补单意图负样本,再考虑进入 NER。
|
||||
- 若 `multi_sentence_recall` 不稳定,应继续补条件句、弱连接句和口语化多动作语料。
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,40 @@
|
||||
# Joint NLU 独立评测报告
|
||||
|
||||
## 概览
|
||||
- 模型目录:`/Users/hwp/Documents/trae_projects/intelligent_cabin/models/local_joint_bert_nlu`
|
||||
- 评测集:`/Users/hwp/Documents/trae_projects/intelligent_cabin/app/data/joint_nlu_eval_independent.jsonl`
|
||||
- 样本数:`43`
|
||||
- `intent_accuracy`:`0.9302`
|
||||
- `slot_exact_match`:`1.0`
|
||||
- `joint_exact_match`:`0.9302`
|
||||
- `slot_micro_precision`:`1.0`
|
||||
- `slot_micro_recall`:`1.0`
|
||||
- `slot_micro_f1`:`1.0`
|
||||
|
||||
## 训练摘要
|
||||
- 训练集 / 评测集:`337 / 10`
|
||||
- 训练阶段 `intent_accuracy`:`1.0`
|
||||
- 训练阶段 `slot_exact_match`:`0.8`
|
||||
|
||||
## 分类别结果
|
||||
- `failure_replay`: count=12 intent_acc=0.75 slot_exact=1.0 joint_exact=0.75
|
||||
- `no_slot_control`: count=14 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
|
||||
- `slot_destination`: count=4 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
|
||||
- `slot_music`: count=5 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
|
||||
- `slot_order`: count=4 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
|
||||
- `slot_temperature`: count=4 intent_acc=1.0 slot_exact=1.0 joint_exact=1.0
|
||||
|
||||
## 主要意图混淆
|
||||
- 期望 `cabin_window_open`,预测成 `None`:`1` 次
|
||||
- 期望 `cabin_window_open`,预测成 `cabin_play_music`:`1` 次
|
||||
- 期望 `cabin_fan_up`,预测成 `cabin_fan_down`:`1` 次
|
||||
|
||||
## 失败样例回放
|
||||
- 文本:`把左前窗降一点` | 类别:`failure_replay` | 期望意图:`cabin_window_open` | 预测意图:`None` | 期望槽位:`{}` | 预测槽位:`{}` | 缺失槽位:`[]` | 多出槽位:`[]`
|
||||
- 文本:`给我透个气` | 类别:`failure_replay` | 期望意图:`cabin_window_open` | 预测意图:`cabin_play_music` | 期望槽位:`{}` | 预测槽位:`{}` | 缺失槽位:`[]` | 多出槽位:`[]`
|
||||
- 文本:`风别太小` | 类别:`failure_replay` | 期望意图:`cabin_fan_up` | 预测意图:`cabin_fan_down` | 期望槽位:`{}` | 预测槽位:`{}` | 缺失槽位:`[]` | 多出槽位:`[]`
|
||||
|
||||
## 结论
|
||||
- 先看 `failure_replay` 是否仍然错,能直接判断先前多意图失败到底是联合模型本体问题还是上层组合问题。
|
||||
- 若 `slot_music` 或 `slot_destination` 仍不稳,优先补 span 标注,不要回退到规则抽槽。
|
||||
- 若 `no_slot_control` 很稳但 `failure_replay` 中仍有大量错误,下一步应补长尾控制语义数据,而不是急着上更复杂结构。
|
||||
1706
intelligent_cabin/archive/reports/joint_nlu_independent_result.json
Normal file
1706
intelligent_cabin/archive/reports/joint_nlu_independent_result.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
DEFAULT_TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_eval_independent.jsonl"
|
||||
|
||||
|
||||
def load_cases(path: Path) -> list[dict[str, object]]:
|
||||
rows: list[dict[str, object]] = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
rows.append(payload)
|
||||
return rows
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Joint BERT 多意图独立评测")
|
||||
parser.add_argument("--model-path", type=str, default="models/local_joint_bert_nlu")
|
||||
parser.add_argument("--test-path", type=str, default=str(DEFAULT_TEST_PATH))
|
||||
args = parser.parse_args()
|
||||
|
||||
registry = build_intent_registry()
|
||||
nlu = JointBertNLU(model_path=args.model_path)
|
||||
cases = load_cases(Path(args.test_path))
|
||||
|
||||
tp = 0
|
||||
fp = 0
|
||||
fn = 0
|
||||
exact = 0
|
||||
failures: list[dict[str, object]] = []
|
||||
category_correct: Counter[str] = Counter()
|
||||
category_total: Counter[str] = Counter()
|
||||
|
||||
for case in cases:
|
||||
text = str(case["text"])
|
||||
expected = sorted({str(item) for item in case.get("expected_intent_ids", [])})
|
||||
predicted = sorted(item.intent_id for item in nlu.predict_multi_intents(text, registry.list(), top_k=8, max_labels=4))
|
||||
expected_set = set(expected)
|
||||
predicted_set = set(predicted)
|
||||
tp += len(expected_set & predicted_set)
|
||||
fp += len(predicted_set - expected_set)
|
||||
fn += len(expected_set - predicted_set)
|
||||
category = str(case.get("category") or "unknown")
|
||||
category_total[category] += 1
|
||||
if expected_set == predicted_set:
|
||||
exact += 1
|
||||
category_correct[category] += 1
|
||||
else:
|
||||
failures.append(
|
||||
{
|
||||
"text": text,
|
||||
"expected_intent_ids": expected,
|
||||
"predicted_intent_ids": predicted,
|
||||
"category": category,
|
||||
}
|
||||
)
|
||||
|
||||
precision = tp / max(tp + fp, 1)
|
||||
recall = tp / max(tp + fn, 1)
|
||||
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
result = {
|
||||
"sample_count": len(cases),
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(f1, 4),
|
||||
"exact_match": round(exact / max(len(cases), 1), 4),
|
||||
"per_category_exact_match": {
|
||||
category: round(category_correct[category] / max(total, 1), 4)
|
||||
for category, total in sorted(category_total.items())
|
||||
},
|
||||
"failures": failures[:20],
|
||||
}
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
59
intelligent_cabin/archive/scripts/eval_joint_bert_nlu.py
Normal file
59
intelligent_cabin/archive/scripts/eval_joint_bert_nlu.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="评测 Joint BERT NLU 单句意图与槽位输出")
|
||||
parser.add_argument("--text", type=str, required=True, help="待评测文本")
|
||||
parser.add_argument("--model-path", type=str, default="models/local_joint_bert_nlu", help="模型目录")
|
||||
args = parser.parse_args()
|
||||
|
||||
registry = build_intent_registry()
|
||||
nlu = JointBertNLU(model_path=args.model_path)
|
||||
result = nlu.predict(args.text, registry.list())
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"text": args.text,
|
||||
"intent_id": result.intent_id,
|
||||
"intent_score": round(result.intent_score, 4),
|
||||
"candidates": [
|
||||
{"intent_id": item.intent_id, "score": round(item.score, 4)}
|
||||
for item in result.candidates
|
||||
],
|
||||
"multi_intent_candidates": [
|
||||
{"intent_id": item.intent_id, "score": round(item.score, 4)}
|
||||
for item in result.multi_intent_candidates
|
||||
],
|
||||
"slots": result.slots,
|
||||
"slot_items": [
|
||||
{
|
||||
"slot_name": item.slot_name,
|
||||
"value": item.value,
|
||||
"start": item.start,
|
||||
"end": item.end,
|
||||
"score": item.score,
|
||||
}
|
||||
for item in result.slot_items
|
||||
],
|
||||
"error_message": result.error_message,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
275
intelligent_cabin/archive/scripts/eval_joint_nlu_independent.py
Normal file
275
intelligent_cabin/archive/scripts/eval_joint_nlu_independent.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
TEST_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval_independent.jsonl"
|
||||
MODEL_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
|
||||
REPORT_DIR = PROJECT_ROOT / "reports"
|
||||
RESULT_PATH = REPORT_DIR / "joint_nlu_independent_result.json"
|
||||
REPORT_PATH = REPORT_DIR / "joint_nlu_independent_report.md"
|
||||
TRAIN_SUMMARY_PATH = MODEL_DIR / "train_summary.json"
|
||||
|
||||
|
||||
def load_cases(file_path: Path) -> list[dict[str, object]]:
|
||||
cases: list[dict[str, object]] = []
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
expected_intent_id = str(payload.get("expected_intent_id") or payload.get("intent_id") or "").strip()
|
||||
if not expected_intent_id:
|
||||
continue
|
||||
cases.append(
|
||||
{
|
||||
"text": str(payload["text"]),
|
||||
"expected_intent_id": expected_intent_id,
|
||||
"expected_slots": dict(payload.get("expected_slots") or {}),
|
||||
"category": str(payload.get("category") or "unknown"),
|
||||
}
|
||||
)
|
||||
return cases
|
||||
|
||||
|
||||
def load_train_summary(file_path: Path) -> dict[str, object]:
|
||||
if not file_path.exists():
|
||||
return {}
|
||||
return json.loads(file_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def compare_slots(expected: dict[str, object], predicted: dict[str, object]) -> dict[str, object]:
|
||||
expected_keys = set(expected)
|
||||
predicted_keys = set(predicted)
|
||||
missing_keys = sorted(expected_keys - predicted_keys)
|
||||
extra_keys = sorted(predicted_keys - expected_keys)
|
||||
wrong_values: list[dict[str, object]] = []
|
||||
matched_keys = 0
|
||||
for key in sorted(expected_keys & predicted_keys):
|
||||
if expected[key] == predicted[key]:
|
||||
matched_keys += 1
|
||||
else:
|
||||
wrong_values.append(
|
||||
{
|
||||
"slot_name": key,
|
||||
"expected": expected[key],
|
||||
"predicted": predicted[key],
|
||||
}
|
||||
)
|
||||
exact = not missing_keys and not extra_keys and not wrong_values
|
||||
return {
|
||||
"missing_keys": missing_keys,
|
||||
"extra_keys": extra_keys,
|
||||
"wrong_values": wrong_values,
|
||||
"matched_keys": matched_keys,
|
||||
"exact": exact,
|
||||
}
|
||||
|
||||
|
||||
def compute_metrics(results: list[dict[str, object]]) -> dict[str, float]:
|
||||
total = len(results)
|
||||
intent_correct = sum(1 for item in results if item["intent_ok"])
|
||||
slot_exact = sum(1 for item in results if item["slot_exact"])
|
||||
joint_exact = sum(1 for item in results if item["joint_ok"])
|
||||
|
||||
slot_tp = 0
|
||||
slot_fp = 0
|
||||
slot_fn = 0
|
||||
for item in results:
|
||||
expected = dict(item["expected_slots"])
|
||||
predicted = dict(item["predicted_slots"])
|
||||
expected_keys = set(expected)
|
||||
predicted_keys = set(predicted)
|
||||
slot_tp += sum(1 for key in expected_keys & predicted_keys if expected[key] == predicted[key])
|
||||
slot_fp += len(predicted_keys - expected_keys)
|
||||
slot_fn += len(expected_keys - predicted_keys)
|
||||
slot_fp += sum(1 for key in expected_keys & predicted_keys if expected[key] != predicted[key])
|
||||
slot_fn += sum(1 for key in expected_keys & predicted_keys if expected[key] != predicted[key])
|
||||
|
||||
precision = slot_tp / (slot_tp + slot_fp) if (slot_tp + slot_fp) else 0.0
|
||||
recall = slot_tp / (slot_tp + slot_fn) if (slot_tp + slot_fn) else 0.0
|
||||
slot_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"intent_accuracy": round(intent_correct / total, 4) if total else 0.0,
|
||||
"slot_exact_match": round(slot_exact / total, 4) if total else 0.0,
|
||||
"joint_exact_match": round(joint_exact / total, 4) if total else 0.0,
|
||||
"slot_micro_precision": round(precision, 4),
|
||||
"slot_micro_recall": round(recall, 4),
|
||||
"slot_micro_f1": round(slot_f1, 4),
|
||||
}
|
||||
|
||||
|
||||
def summarize_by_category(results: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
|
||||
for item in results:
|
||||
grouped[str(item["category"])].append(item)
|
||||
summary: list[dict[str, object]] = []
|
||||
for category, items in sorted(grouped.items()):
|
||||
summary.append(
|
||||
{
|
||||
"category": category,
|
||||
"sample_count": len(items),
|
||||
"metrics": compute_metrics(items),
|
||||
}
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def collect_top_confusions(results: list[dict[str, object]], limit: int = 12) -> list[dict[str, object]]:
|
||||
counter: Counter[tuple[str, str]] = Counter()
|
||||
for item in results:
|
||||
if item["intent_ok"]:
|
||||
continue
|
||||
counter[(str(item["expected_intent_id"]), str(item["predicted_intent_id"]))] += 1
|
||||
return [
|
||||
{"expected": expected, "predicted": predicted, "count": count}
|
||||
for (expected, predicted), count in counter.most_common(limit)
|
||||
]
|
||||
|
||||
|
||||
def collect_failures(results: list[dict[str, object]], limit: int = 20) -> list[dict[str, object]]:
|
||||
failures = [item for item in results if not item["joint_ok"]]
|
||||
|
||||
def sort_key(item: dict[str, object]) -> tuple[int, int, int]:
|
||||
slot_errors = len(item["slot_diff"]["missing_keys"]) + len(item["slot_diff"]["extra_keys"]) + len(item["slot_diff"]["wrong_values"])
|
||||
return (0 if item["intent_ok"] else 1, slot_errors, len(str(item["text"])))
|
||||
|
||||
return sorted(failures, key=sort_key, reverse=True)[:limit]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Joint NLU 独立评测与失败样例回放")
|
||||
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="评测集路径")
|
||||
parser.add_argument("--model-path", type=str, default=str(MODEL_DIR), help="Joint NLU 模型路径")
|
||||
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化结果输出路径")
|
||||
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 报告输出路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
cases = load_cases(Path(args.test_path))
|
||||
registry = build_intent_registry()
|
||||
nlu = JointBertNLU(model_path=args.model_path)
|
||||
results: list[dict[str, object]] = []
|
||||
for case in cases:
|
||||
prediction = nlu.predict(str(case["text"]), registry.list())
|
||||
predicted_slots = dict(prediction.slots)
|
||||
slot_diff = compare_slots(dict(case["expected_slots"]), predicted_slots)
|
||||
predicted_intent_id = prediction.intent_id or "None"
|
||||
intent_ok = predicted_intent_id == case["expected_intent_id"]
|
||||
joint_ok = intent_ok and bool(slot_diff["exact"])
|
||||
results.append(
|
||||
{
|
||||
"text": case["text"],
|
||||
"category": case["category"],
|
||||
"expected_intent_id": case["expected_intent_id"],
|
||||
"predicted_intent_id": predicted_intent_id,
|
||||
"expected_slots": case["expected_slots"],
|
||||
"predicted_slots": predicted_slots,
|
||||
"intent_score": round(prediction.intent_score, 4),
|
||||
"intent_ok": intent_ok,
|
||||
"slot_exact": bool(slot_diff["exact"]),
|
||||
"joint_ok": joint_ok,
|
||||
"slot_diff": slot_diff,
|
||||
"top_candidates": [
|
||||
{"intent_id": item.intent_id, "score": round(item.score, 4)}
|
||||
for item in prediction.candidates
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
summary = {
|
||||
"model_path": args.model_path,
|
||||
"test_path": args.test_path,
|
||||
"sample_count": len(results),
|
||||
"metrics": compute_metrics(results),
|
||||
"per_category": summarize_by_category(results),
|
||||
"top_confusions": collect_top_confusions(results),
|
||||
"failure_examples": collect_failures(results),
|
||||
"train_summary": load_train_summary(TRAIN_SUMMARY_PATH),
|
||||
"results": results,
|
||||
}
|
||||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
|
||||
print(json.dumps({"sample_count": summary["sample_count"], "metrics": summary["metrics"]}, ensure_ascii=False))
|
||||
|
||||
|
||||
def render_report(summary: dict[str, object]) -> str:
|
||||
metrics = summary["metrics"]
|
||||
per_category = summary["per_category"]
|
||||
confusions = summary["top_confusions"]
|
||||
failures = summary["failure_examples"]
|
||||
train_summary = summary.get("train_summary") or {}
|
||||
lines = [
|
||||
"# Joint NLU 独立评测报告",
|
||||
"",
|
||||
"## 概览",
|
||||
f"- 模型目录:`{summary['model_path']}`",
|
||||
f"- 评测集:`{summary['test_path']}`",
|
||||
f"- 样本数:`{summary['sample_count']}`",
|
||||
f"- `intent_accuracy`:`{metrics['intent_accuracy']}`",
|
||||
f"- `slot_exact_match`:`{metrics['slot_exact_match']}`",
|
||||
f"- `joint_exact_match`:`{metrics['joint_exact_match']}`",
|
||||
f"- `slot_micro_precision`:`{metrics['slot_micro_precision']}`",
|
||||
f"- `slot_micro_recall`:`{metrics['slot_micro_recall']}`",
|
||||
f"- `slot_micro_f1`:`{metrics['slot_micro_f1']}`",
|
||||
"",
|
||||
"## 训练摘要",
|
||||
]
|
||||
if train_summary:
|
||||
lines.extend(
|
||||
[
|
||||
f"- 训练集 / 评测集:`{train_summary.get('train_size', 'unknown')} / {train_summary.get('eval_size', 'unknown')}`",
|
||||
f"- 训练阶段 `intent_accuracy`:`{train_summary.get('metrics', {}).get('intent_accuracy', 'unknown')}`",
|
||||
f"- 训练阶段 `slot_exact_match`:`{train_summary.get('metrics', {}).get('slot_exact_match', 'unknown')}`",
|
||||
"",
|
||||
]
|
||||
)
|
||||
else:
|
||||
lines.extend(["- 未找到训练摘要。", ""])
|
||||
lines.extend(["## 分类别结果"])
|
||||
for item in per_category:
|
||||
category_metrics = item["metrics"]
|
||||
lines.append(
|
||||
f"- `{item['category']}`: count={item['sample_count']} intent_acc={category_metrics['intent_accuracy']} slot_exact={category_metrics['slot_exact_match']} joint_exact={category_metrics['joint_exact_match']}"
|
||||
)
|
||||
lines.extend(["", "## 主要意图混淆"])
|
||||
if not confusions:
|
||||
lines.append("- 未发现意图混淆。")
|
||||
else:
|
||||
for item in confusions:
|
||||
lines.append(f"- 期望 `{item['expected']}`,预测成 `{item['predicted']}`:`{item['count']}` 次")
|
||||
lines.extend(["", "## 失败样例回放"])
|
||||
if not failures:
|
||||
lines.append("- 无失败样例。")
|
||||
else:
|
||||
for item in failures:
|
||||
slot_diff = item["slot_diff"]
|
||||
lines.append(
|
||||
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望意图:`{item['expected_intent_id']}` | 预测意图:`{item['predicted_intent_id']}` | 期望槽位:`{item['expected_slots']}` | 预测槽位:`{item['predicted_slots']}` | 缺失槽位:`{slot_diff['missing_keys']}` | 多出槽位:`{slot_diff['extra_keys']}`"
|
||||
)
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## 结论",
|
||||
"- 先看 `failure_replay` 是否仍然错,能直接判断先前多意图失败到底是联合模型本体问题还是上层组合问题。",
|
||||
"- 若 `slot_music` 或 `slot_destination` 仍不稳,优先补 span 标注,不要回退到规则抽槽。",
|
||||
"- 若 `no_slot_control` 很稳但 `failure_replay` 中仍有大量错误,下一步应补长尾控制语义数据,而不是急着上更复杂结构。",
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
231
intelligent_cabin/archive/scripts/eval_local_bert_intent.py
Normal file
231
intelligent_cabin/archive/scripts/eval_local_bert_intent.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.classifier import BertIntentClassifier
|
||||
|
||||
|
||||
TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_eval_independent.jsonl"
|
||||
MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
||||
REPORT_DIR = PROJECT_ROOT / "reports"
|
||||
REPORT_PATH = REPORT_DIR / "bert_local_test_report.md"
|
||||
RESULT_PATH = REPORT_DIR / "bert_local_test_result.json"
|
||||
BERT_THRESHOLD = 0.0
|
||||
TRAIN_SUMMARY_PATH = MODEL_DIR / "train_summary.json"
|
||||
|
||||
|
||||
def load_cases(file_path: Path) -> list[dict[str, str]]:
|
||||
cases: list[dict[str, str]] = []
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
expected_label = str(payload.get("expected_label") or payload.get("intent_id") or "").strip()
|
||||
if not expected_label:
|
||||
continue
|
||||
category = str(payload.get("category") or infer_category(expected_label)).strip()
|
||||
cases.append(
|
||||
{
|
||||
"text": str(payload["text"]),
|
||||
"expected_label": expected_label,
|
||||
"category": category,
|
||||
}
|
||||
)
|
||||
return cases
|
||||
|
||||
|
||||
def load_train_summary(file_path: Path) -> dict[str, object]:
|
||||
if not file_path.exists():
|
||||
return {}
|
||||
return json.loads(file_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def infer_category(label: str) -> str:
|
||||
if label == "__social__":
|
||||
return "social"
|
||||
if label == "__out_of_scope__":
|
||||
return "out_of_scope"
|
||||
return "business"
|
||||
|
||||
|
||||
def resolve_predicted_label(result) -> str:
|
||||
if result.intent is not None:
|
||||
return result.intent.intent_id
|
||||
if result.raw_label:
|
||||
return str(result.raw_label)
|
||||
return "None"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="本地 BERT 独立评测脚本")
|
||||
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="评测集路径")
|
||||
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化评测结果输出路径")
|
||||
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 评测报告输出路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
intent_registry = build_intent_registry()
|
||||
intents = intent_registry.list()
|
||||
classifier = BertIntentClassifier(
|
||||
model_path=str(MODEL_DIR),
|
||||
threshold=BERT_THRESHOLD,
|
||||
label_map_path=str(MODEL_DIR / "label_map.json"),
|
||||
fallback=None,
|
||||
top_k=3,
|
||||
)
|
||||
cases = load_cases(Path(args.test_path))
|
||||
|
||||
results: list[dict[str, object]] = []
|
||||
confusion: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
category_confusion: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
correct = 0
|
||||
|
||||
for case in cases:
|
||||
result = classifier.predict(case["text"], intents)
|
||||
predicted = resolve_predicted_label(result)
|
||||
expected = case["expected_label"]
|
||||
ok = predicted == expected
|
||||
if ok:
|
||||
correct += 1
|
||||
confusion[expected][predicted] += 1
|
||||
category_confusion[case["category"]]["correct" if ok else "wrong"] += 1
|
||||
results.append(
|
||||
{
|
||||
"text": case["text"],
|
||||
"category": case["category"],
|
||||
"expected_label": expected,
|
||||
"predicted_label": predicted,
|
||||
"score": round(result.score, 4),
|
||||
"raw_label": result.raw_label,
|
||||
"ok": ok,
|
||||
"top_candidates": [
|
||||
{"intent_id": intent.intent_id, "score": round(score, 4)}
|
||||
for intent, score in (result.candidates or [])
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
accuracy = correct / len(cases) if cases else 0.0
|
||||
train_summary = load_train_summary(TRAIN_SUMMARY_PATH)
|
||||
per_label_stats: list[dict[str, object]] = []
|
||||
for label in sorted({case["expected_label"] for case in cases}):
|
||||
label_cases = [item for item in results if item["expected_label"] == label]
|
||||
label_correct = sum(1 for item in label_cases if item["ok"])
|
||||
per_label_stats.append(
|
||||
{
|
||||
"label": label,
|
||||
"category": infer_category(label),
|
||||
"total": len(label_cases),
|
||||
"correct": label_correct,
|
||||
"accuracy": round(label_correct / len(label_cases), 4) if label_cases else 0.0,
|
||||
}
|
||||
)
|
||||
per_category_stats: list[dict[str, object]] = []
|
||||
for category in sorted({case["category"] for case in cases}):
|
||||
category_cases = [item for item in results if item["category"] == category]
|
||||
category_correct = sum(1 for item in category_cases if item["ok"])
|
||||
per_category_stats.append(
|
||||
{
|
||||
"category": category,
|
||||
"total": len(category_cases),
|
||||
"correct": category_correct,
|
||||
"accuracy": round(category_correct / len(category_cases), 4) if category_cases else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
errors = [item for item in results if not item["ok"]]
|
||||
summary = {
|
||||
"model_dir": str(MODEL_DIR),
|
||||
"threshold": BERT_THRESHOLD,
|
||||
"test_path": str(args.test_path),
|
||||
"test_case_count": len(cases),
|
||||
"accuracy": round(accuracy, 4),
|
||||
"train_summary": train_summary,
|
||||
"per_category": per_category_stats,
|
||||
"per_label": per_label_stats,
|
||||
"errors": errors,
|
||||
"confusion": {key: dict(value) for key, value in confusion.items()},
|
||||
}
|
||||
|
||||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
|
||||
print(json.dumps({"accuracy": summary["accuracy"], "test_case_count": len(cases), "error_count": len(errors)}, ensure_ascii=False))
|
||||
|
||||
|
||||
def render_report(summary: dict[str, object]) -> str:
|
||||
per_category = summary["per_category"]
|
||||
per_label = summary["per_label"]
|
||||
errors = summary["errors"]
|
||||
train_summary = summary.get("train_summary") or {}
|
||||
lines = [
|
||||
"# 本地 BERT 意图识别测试报告",
|
||||
"",
|
||||
"## 概览",
|
||||
f"- 模型目录:`{summary['model_dir']}`",
|
||||
f"- 评测集:`{summary['test_path']}`",
|
||||
f"- 评测阈值:`{summary['threshold']}`",
|
||||
f"- 测试样本数:`{summary['test_case_count']}`",
|
||||
f"- 总体准确率:`{summary['accuracy']}`",
|
||||
"",
|
||||
"## 训练摘要",
|
||||
]
|
||||
if train_summary:
|
||||
lines.extend(
|
||||
[
|
||||
f"- 基座模型:`{train_summary.get('base_model', 'unknown')}`",
|
||||
f"- 训练集 / 验证集:`{train_summary.get('train_size', 'unknown')} / {train_summary.get('dev_size', 'unknown')}`",
|
||||
f"- 最佳验证准确率:`{train_summary.get('best_dev_accuracy', 'unknown')}`",
|
||||
f"- 训练设备:`{train_summary.get('device', 'unknown')}`",
|
||||
"",
|
||||
]
|
||||
)
|
||||
else:
|
||||
lines.extend(["- 未找到训练摘要。", ""])
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"## 分类别结果",
|
||||
]
|
||||
)
|
||||
for item in per_category:
|
||||
lines.append(
|
||||
f"- `{item['category']}`: {item['correct']}/{item['total']} = {item['accuracy']}"
|
||||
)
|
||||
lines.extend(["", "## 分标签结果"])
|
||||
for item in per_label:
|
||||
lines.append(
|
||||
f"- `{item['label']}` ({item['category']}): {item['correct']}/{item['total']} = {item['accuracy']}"
|
||||
)
|
||||
lines.extend(["", "## 错误样例"])
|
||||
if not errors:
|
||||
lines.append("- 无错误样例。")
|
||||
else:
|
||||
for item in errors[:10]:
|
||||
lines.append(
|
||||
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望:`{item['expected_label']}` | 预测:`{item['predicted_label']}` | 分数:`{item['score']}`"
|
||||
)
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## 结论",
|
||||
"- 当前本地 MacBERT 已具备较强的业务意图识别能力,可作为本地快链路分类器。",
|
||||
"- 误判主要集中在方向相反或语义接近的控制指令,下一步应补充对抗样本和真实口语表达。",
|
||||
"- 上线前建议继续补充 ASR 错字、多轮短句和多意图子句级样本。",
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from scripts.train_local_bert_multi_intent import (
|
||||
BATCH_SIZE,
|
||||
OUTPUT_DIR,
|
||||
TOP_K,
|
||||
THRESHOLD,
|
||||
MultiLabelIntentDataset,
|
||||
load_all_samples,
|
||||
split_samples,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Evaluate local BERT multi-intent detector.")
|
||||
parser.add_argument("--model-path", default=str(OUTPUT_DIR), help="Path to trained multi-intent model.")
|
||||
parser.add_argument("--threshold", type=float, default=THRESHOLD, help="Probability threshold.")
|
||||
parser.add_argument("--top-k", type=int, default=TOP_K, help="Top-k for recall@k.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
choices=("dev", "all"),
|
||||
default="dev",
|
||||
help="Evaluate on the held-out dev split or all combined samples.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
probabilities: list[list[float]],
|
||||
targets: list[list[float]],
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
) -> dict[str, float]:
|
||||
true_positive = 0
|
||||
false_positive = 0
|
||||
false_negative = 0
|
||||
exact_match = 0
|
||||
recall_at_k_total = 0.0
|
||||
total = len(probabilities)
|
||||
for scores, target in zip(probabilities, targets):
|
||||
predicted = {index for index, score in enumerate(scores) if score >= threshold}
|
||||
expected = {index for index, value in enumerate(target) if value >= 0.5}
|
||||
if predicted == expected:
|
||||
exact_match += 1
|
||||
true_positive += len(predicted & expected)
|
||||
false_positive += len(predicted - expected)
|
||||
false_negative += len(expected - predicted)
|
||||
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
|
||||
if expected:
|
||||
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
|
||||
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
||||
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
||||
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(micro_f1, 4),
|
||||
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
||||
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
set_seed(42)
|
||||
samples = load_all_samples()
|
||||
_, dev_samples = split_samples(samples)
|
||||
eval_samples = samples if args.dataset == "all" else dev_samples
|
||||
model_path = Path(args.model_path)
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"model path not found: {model_path}")
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
label_to_id = {str(label): int(index) for label, index in (model.config.label2id or {}).items()}
|
||||
if not label_to_id:
|
||||
raise RuntimeError("label2id is missing from model config")
|
||||
|
||||
dataset = MultiLabelIntentDataset(eval_samples, tokenizer, label_to_id)
|
||||
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
probabilities: list[list[float]] = []
|
||||
targets: list[list[float]] = []
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
probabilities.extend(torch.sigmoid(outputs.logits).detach().cpu().tolist())
|
||||
targets.extend(labels.detach().cpu().tolist())
|
||||
|
||||
metrics = compute_metrics(probabilities, targets, threshold=args.threshold, top_k=args.top_k)
|
||||
result = {
|
||||
"model_path": str(model_path),
|
||||
"dataset": args.dataset,
|
||||
"sample_size": len(eval_samples),
|
||||
"threshold": args.threshold,
|
||||
"top_k": args.top_k,
|
||||
"metrics": metrics,
|
||||
}
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.multi_intent_detector import BertMultiIntentDetector
|
||||
|
||||
|
||||
TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_eval_independent.jsonl"
|
||||
MODEL_DIR = PROJECT_ROOT / "models/local_bert_multi_intent"
|
||||
REPORT_DIR = PROJECT_ROOT / "reports"
|
||||
RESULT_PATH = REPORT_DIR / "bert_multi_intent_independent_result.json"
|
||||
REPORT_PATH = REPORT_DIR / "bert_multi_intent_independent_report.md"
|
||||
THRESHOLD = 0.45
|
||||
TOP_K = 8
|
||||
MAX_LABELS = 4
|
||||
|
||||
|
||||
def load_cases(file_path: Path) -> list[dict[str, object]]:
|
||||
cases: list[dict[str, object]] = []
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
expected = sorted({str(item).strip() for item in payload.get("expected_intent_ids") or [] if str(item).strip()})
|
||||
if not expected:
|
||||
continue
|
||||
cases.append(
|
||||
{
|
||||
"text": str(payload["text"]),
|
||||
"expected_intent_ids": expected,
|
||||
"category": str(payload.get("category") or "unknown"),
|
||||
}
|
||||
)
|
||||
return cases
|
||||
|
||||
|
||||
def compute_set_metrics(results: list[dict[str, object]]) -> dict[str, float]:
|
||||
true_positive = 0
|
||||
false_positive = 0
|
||||
false_negative = 0
|
||||
exact_match = 0
|
||||
multi_recall_hit = 0
|
||||
single_false_alarm = 0
|
||||
total = len(results)
|
||||
single_guard_total = 0
|
||||
for item in results:
|
||||
expected = set(item["expected_intent_ids"])
|
||||
predicted = set(item["predicted_intent_ids"])
|
||||
if expected == predicted:
|
||||
exact_match += 1
|
||||
true_positive += len(expected & predicted)
|
||||
false_positive += len(predicted - expected)
|
||||
false_negative += len(expected - predicted)
|
||||
if len(expected) >= 2 and expected.issubset(predicted):
|
||||
multi_recall_hit += 1
|
||||
if len(expected) == 1:
|
||||
single_guard_total += 1
|
||||
if len(predicted) > 1:
|
||||
single_false_alarm += 1
|
||||
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
||||
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
||||
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
multi_total = sum(1 for item in results if len(item["expected_intent_ids"]) >= 2)
|
||||
return {
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(micro_f1, 4),
|
||||
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
||||
"multi_sentence_recall": round(multi_recall_hit / multi_total, 4) if multi_total else 0.0,
|
||||
"single_guard_false_alarm_rate": round(single_false_alarm / single_guard_total, 4) if single_guard_total else 0.0,
|
||||
}
|
||||
|
||||
|
||||
def summarize_by_category(results: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
|
||||
for item in results:
|
||||
grouped[str(item["category"])].append(item)
|
||||
summary: list[dict[str, object]] = []
|
||||
for category, items in sorted(grouped.items()):
|
||||
summary.append(
|
||||
{
|
||||
"category": category,
|
||||
"sample_count": len(items),
|
||||
"metrics": compute_set_metrics(items),
|
||||
}
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def collect_error_examples(results: list[dict[str, object]], limit: int = 15) -> list[dict[str, object]]:
|
||||
errors = [item for item in results if set(item["expected_intent_ids"]) != set(item["predicted_intent_ids"])]
|
||||
def sort_key(item: dict[str, object]) -> tuple[int, int]:
|
||||
expected = set(item["expected_intent_ids"])
|
||||
predicted = set(item["predicted_intent_ids"])
|
||||
miss = len(expected - predicted)
|
||||
extra = len(predicted - expected)
|
||||
return (miss + extra, miss)
|
||||
return sorted(errors, key=sort_key, reverse=True)[:limit]
|
||||
|
||||
|
||||
def top_confusions(results: list[dict[str, object]], limit: int = 12) -> list[dict[str, object]]:
|
||||
counter: Counter[tuple[str, str]] = Counter()
|
||||
for item in results:
|
||||
expected = set(item["expected_intent_ids"])
|
||||
predicted = set(item["predicted_intent_ids"])
|
||||
for miss in sorted(expected - predicted):
|
||||
for extra in sorted(predicted - expected):
|
||||
counter[(miss, extra)] += 1
|
||||
return [
|
||||
{"expected_missing": pair[0], "wrong_extra": pair[1], "count": count}
|
||||
for pair, count in counter.most_common(limit)
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="本地多标签 detector 独立评测脚本")
|
||||
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="独立评测集路径")
|
||||
parser.add_argument("--model-path", type=str, default=str(MODEL_DIR), help="多标签模型路径")
|
||||
parser.add_argument("--threshold", type=float, default=THRESHOLD, help="检测阈值")
|
||||
parser.add_argument("--top-k", type=int, default=TOP_K, help="输出 top-k 原始分数")
|
||||
parser.add_argument("--max-labels", type=int, default=MAX_LABELS, help="最多返回标签数")
|
||||
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化结果输出路径")
|
||||
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 报告输出路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
cases = load_cases(Path(args.test_path))
|
||||
intents = build_intent_registry().list()
|
||||
detector = BertMultiIntentDetector(
|
||||
model_path=args.model_path,
|
||||
threshold=args.threshold,
|
||||
top_k=args.top_k,
|
||||
max_labels=args.max_labels,
|
||||
)
|
||||
|
||||
results: list[dict[str, object]] = []
|
||||
for case in cases:
|
||||
detection = detector.detect(str(case["text"]), intents)
|
||||
predicted = [candidate.intent_id for candidate in detection.candidates]
|
||||
raw_top = [
|
||||
{
|
||||
"intent_id": str(item.get("intent_id") or item.get("label") or ""),
|
||||
"score": round(float(item.get("score", 0.0)), 4),
|
||||
}
|
||||
for item in detection.raw_scores
|
||||
]
|
||||
results.append(
|
||||
{
|
||||
"text": case["text"],
|
||||
"category": case["category"],
|
||||
"expected_intent_ids": case["expected_intent_ids"],
|
||||
"predicted_intent_ids": predicted,
|
||||
"detected": detection.detected,
|
||||
"backend_name": detection.backend_name,
|
||||
"reason": detection.reason,
|
||||
"raw_top_scores": raw_top,
|
||||
}
|
||||
)
|
||||
|
||||
summary = {
|
||||
"model_path": args.model_path,
|
||||
"test_path": args.test_path,
|
||||
"threshold": args.threshold,
|
||||
"top_k": args.top_k,
|
||||
"max_labels": args.max_labels,
|
||||
"sample_count": len(results),
|
||||
"metrics": compute_set_metrics(results),
|
||||
"per_category": summarize_by_category(results),
|
||||
"top_confusions": top_confusions(results),
|
||||
"error_examples": collect_error_examples(results),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
|
||||
print(json.dumps({"sample_count": len(results), "metrics": summary["metrics"]}, ensure_ascii=False))
|
||||
|
||||
|
||||
def render_report(summary: dict[str, object]) -> str:
|
||||
metrics = summary["metrics"]
|
||||
per_category = summary["per_category"]
|
||||
confusions = summary["top_confusions"]
|
||||
errors = summary["error_examples"]
|
||||
lines = [
|
||||
"# 本地多标签 Detector 独立评测报告",
|
||||
"",
|
||||
"## 概览",
|
||||
f"- 模型目录:`{summary['model_path']}`",
|
||||
f"- 评测集:`{summary['test_path']}`",
|
||||
f"- 样本数:`{summary['sample_count']}`",
|
||||
f"- 阈值 / top_k / max_labels:`{summary['threshold']} / {summary['top_k']} / {summary['max_labels']}`",
|
||||
f"- `micro_precision`:`{metrics['micro_precision']}`",
|
||||
f"- `micro_recall`:`{metrics['micro_recall']}`",
|
||||
f"- `micro_f1`:`{metrics['micro_f1']}`",
|
||||
f"- `exact_match`:`{metrics['exact_match']}`",
|
||||
f"- `multi_sentence_recall`:`{metrics['multi_sentence_recall']}`",
|
||||
f"- `single_guard_false_alarm_rate`:`{metrics['single_guard_false_alarm_rate']}`",
|
||||
"",
|
||||
"## 分类别结果",
|
||||
]
|
||||
for item in per_category:
|
||||
category_metrics = item["metrics"]
|
||||
lines.append(
|
||||
f"- `{item['category']}`: count={item['sample_count']} micro_f1={category_metrics['micro_f1']} exact_match={category_metrics['exact_match']}"
|
||||
)
|
||||
lines.extend(["", "## 主要混淆"])
|
||||
if not confusions:
|
||||
lines.append("- 未发现明显混淆对。")
|
||||
else:
|
||||
for item in confusions:
|
||||
lines.append(
|
||||
f"- 漏掉 `{item['expected_missing']}`,同时误报 `{item['wrong_extra']}`:`{item['count']}` 次"
|
||||
)
|
||||
lines.extend(["", "## 错误样例"])
|
||||
if not errors:
|
||||
lines.append("- 无错误样例。")
|
||||
else:
|
||||
for item in errors:
|
||||
lines.append(
|
||||
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望:`{item['expected_intent_ids']}` | 预测:`{item['predicted_intent_ids']}`"
|
||||
)
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## 结论建议",
|
||||
"- 先看多意图句是否存在系统性漏召回,再看单意图是否被误报成多意图。",
|
||||
"- 若 `single_guard_false_alarm_rate` 偏高,需要先收紧 detector 阈值或补单意图负样本,再考虑进入 NER。",
|
||||
"- 若 `multi_sentence_recall` 不稳定,应继续补条件句、弱连接句和口语化多动作语料。",
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
97
intelligent_cabin/archive/scripts/test_local_bert_intent.py
Normal file
97
intelligent_cabin/archive/scripts/test_local_bert_intent.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.classifier import BertIntentClassifier
|
||||
from app.services.router import build_matcher_pipeline
|
||||
|
||||
|
||||
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
||||
DEFAULT_LABEL_MAP = DEFAULT_MODEL_DIR / "label_map.json"
|
||||
|
||||
|
||||
def build_classifier(threshold: float, top_k: int) -> BertIntentClassifier:
|
||||
return BertIntentClassifier(
|
||||
model_path=str(DEFAULT_MODEL_DIR),
|
||||
threshold=threshold,
|
||||
label_map_path=str(DEFAULT_LABEL_MAP),
|
||||
fallback=None,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def predict_once(text: str, threshold: float, top_k: int) -> dict[str, object]:
|
||||
classifier = build_classifier(threshold=threshold, top_k=top_k)
|
||||
registry = build_intent_registry()
|
||||
intents = registry.list()
|
||||
result = classifier.predict(text, intents)
|
||||
matcher = build_matcher_pipeline(
|
||||
registry,
|
||||
["classifier"],
|
||||
classifier=classifier,
|
||||
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
|
||||
clarify_margin_threshold=settings.local_clarify_margin_threshold,
|
||||
)
|
||||
route_result = matcher.match(text)
|
||||
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
|
||||
return {
|
||||
"text": text,
|
||||
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
|
||||
"score": round(result.score, 4),
|
||||
"model_name": result.model_name,
|
||||
"backend": result.backend_name,
|
||||
"raw_label": result.raw_label,
|
||||
"fallback_reason": result.fallback_reason,
|
||||
"error_message": result.error_message,
|
||||
"decision": route_result.debug.decision,
|
||||
"decision_reason": route_result.debug.decision_reason,
|
||||
"confidence_grade": route_result.debug.confidence_grade,
|
||||
"unknown_detected": route_result.debug.unknown_detected,
|
||||
"fusion_top_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
|
||||
"top_candidates": [
|
||||
{"intent_id": intent.intent_id, "score": round(score, 4)}
|
||||
for intent, score in (result.candidates or [])
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def interactive_loop(threshold: float, top_k: int) -> None:
|
||||
print("本地 BERT 意图测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
|
||||
while True:
|
||||
try:
|
||||
text = input("\n请输入问题> ").strip()
|
||||
except EOFError:
|
||||
print()
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
if text.lower() in {"exit", "quit", "q"}:
|
||||
break
|
||||
result = predict_once(text, threshold=threshold, top_k=top_k)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="本地 BERT 意图识别测试脚本")
|
||||
parser.add_argument("--text", type=str, default="", help="单次测试文本")
|
||||
parser.add_argument("--threshold", type=float, default=0.0, help="BERT 置信度阈值")
|
||||
parser.add_argument("--top-k", type=int, default=3, help="返回候选数量")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.text.strip():
|
||||
print(json.dumps(predict_once(args.text.strip(), threshold=args.threshold, top_k=args.top_k), ensure_ascii=False, indent=2))
|
||||
return
|
||||
interactive_loop(threshold=args.threshold, top_k=args.top_k)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
500
intelligent_cabin/archive/scripts/train_joint_bert_nlu.py
Normal file
500
intelligent_cabin/archive/scripts/train_joint_bert_nlu.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoTokenizer
|
||||
import yaml
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.services.joint_nlu import JointBertForNLU
|
||||
|
||||
|
||||
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||||
MULTI_TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
|
||||
SEED_PATH = PROJECT_ROOT / "app/data/joint_nlu_seed.jsonl"
|
||||
EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval.jsonl"
|
||||
MULTI_EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_multilabel_eval.jsonl"
|
||||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
|
||||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||||
MAX_LENGTH = 64
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 8
|
||||
LEARNING_RATE = 2e-5
|
||||
SEED = 42
|
||||
IGNORE_INDEX = -100
|
||||
GENRE_KEYWORDS = ("轻音乐", "摇滚", "古典", "民谣", "爵士", "流行", "儿歌")
|
||||
DEFAULT_INTENT_THRESHOLD = 0.3
|
||||
MULTI_INTENT_REPEAT = 6
|
||||
THRESHOLD_CANDIDATES = [0.1, 0.12, 0.15, 0.18, 0.2, 0.22, 0.25, 0.28, 0.3, 0.33, 0.35, 0.38, 0.4, 0.45]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointSample:
|
||||
text: str
|
||||
intent_ids: list[str]
|
||||
slots: list[dict[str, object]]
|
||||
|
||||
|
||||
class JointDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
samples: list[JointSample],
|
||||
tokenizer,
|
||||
intent_to_index: dict[str, int],
|
||||
slot_to_index: dict[str, int],
|
||||
) -> None:
|
||||
self._samples = samples
|
||||
self._tokenizer = tokenizer
|
||||
self._intent_to_index = intent_to_index
|
||||
self._slot_to_index = slot_to_index
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||||
sample = self._samples[index]
|
||||
encoded = self._tokenizer(
|
||||
sample.text,
|
||||
truncation=True,
|
||||
max_length=MAX_LENGTH,
|
||||
padding="max_length",
|
||||
return_offsets_mapping=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
|
||||
slot_labels = [IGNORE_INDEX] * len(offset_mapping)
|
||||
char_labels = ["O"] * len(sample.text)
|
||||
for slot in sample.slots:
|
||||
start = int(slot["start"])
|
||||
end = int(slot["end"])
|
||||
slot_name = str(slot["slot_name"])
|
||||
if start < 0 or end > len(sample.text) or start >= end:
|
||||
continue
|
||||
char_labels[start] = f"B-{slot_name}"
|
||||
for pos in range(start + 1, end):
|
||||
char_labels[pos] = f"I-{slot_name}"
|
||||
|
||||
for token_index, (start, end) in enumerate(offset_mapping):
|
||||
if end <= start:
|
||||
continue
|
||||
label = char_labels[start]
|
||||
slot_labels[token_index] = self._slot_to_index.get(label, self._slot_to_index["O"])
|
||||
intent_vector = torch.zeros(len(self._intent_to_index), dtype=torch.float32)
|
||||
for intent_id in sample.intent_ids:
|
||||
if intent_id in self._intent_to_index:
|
||||
intent_vector[self._intent_to_index[intent_id]] = 1.0
|
||||
|
||||
return {
|
||||
"input_ids": encoded["input_ids"][0],
|
||||
"attention_mask": encoded["attention_mask"][0],
|
||||
"intent_labels": intent_vector,
|
||||
"slot_labels": torch.tensor(slot_labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def set_seed() -> None:
|
||||
random.seed(SEED)
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
|
||||
def load_jsonl(path: Path) -> list[dict[str, object]]:
|
||||
rows: list[dict[str, object]] = []
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def find_order_id_span(text: str) -> tuple[str, int, int] | None:
|
||||
match = re.search(r"[A-Za-z]\d{5,}", text)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(0), match.start(), match.end()
|
||||
|
||||
|
||||
def find_temperature_span(text: str) -> tuple[str, int, int] | None:
|
||||
match = re.search(r"(\d{2}\s*度)", text)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(1), match.start(), match.end()
|
||||
|
||||
|
||||
def find_destination_span(text: str) -> tuple[str, int, int] | None:
|
||||
for pattern in (
|
||||
r"导航去(?P<destination>.+)",
|
||||
r"导航到(?P<destination>.+)",
|
||||
r"带我去(?P<destination>.+)",
|
||||
r"送我去(?P<destination>.+)",
|
||||
r"去(?P<destination>.+)",
|
||||
):
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
continue
|
||||
destination = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", match.group("destination"), maxsplit=1)[0].strip(" ,。")
|
||||
if not destination:
|
||||
continue
|
||||
start = text.find(destination)
|
||||
if start >= 0:
|
||||
return destination, start, start + len(destination)
|
||||
return None
|
||||
|
||||
|
||||
def find_music_span(text: str) -> tuple[str, str, int, int] | None:
|
||||
for genre in GENRE_KEYWORDS:
|
||||
start = text.find(genre)
|
||||
if start >= 0:
|
||||
return "genre", genre, start, start + len(genre)
|
||||
for trigger in ("播放", "来点", "放点", "听", "来首", "来一首", "放一首"):
|
||||
if trigger not in text:
|
||||
continue
|
||||
target = text.split(trigger, maxsplit=1)[-1]
|
||||
target = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", target, maxsplit=1)[0].strip(" 的一首首个歌曲音乐吧呀啊,。")
|
||||
if not target or target in {"歌", "音乐"}:
|
||||
continue
|
||||
for genre in GENRE_KEYWORDS:
|
||||
if genre in target:
|
||||
start = text.find(genre)
|
||||
return "genre", genre, start, start + len(genre)
|
||||
start = text.find(target)
|
||||
if start >= 0:
|
||||
return "song", target, start, start + len(target)
|
||||
return None
|
||||
|
||||
|
||||
def annotate_slots(text: str, intent_id: str) -> list[dict[str, object]]:
|
||||
slots: list[dict[str, object]] = []
|
||||
if intent_id in {"cs_query_order", "cs_query_logistics", "cs_cancel_order"}:
|
||||
matched = find_order_id_span(text)
|
||||
if matched is not None:
|
||||
value, start, end = matched
|
||||
slots.append({"slot_name": "order_id", "value": value, "start": start, "end": end})
|
||||
elif intent_id == "cabin_set_ac":
|
||||
matched = find_temperature_span(text)
|
||||
if matched is not None:
|
||||
value, start, end = matched
|
||||
slots.append({"slot_name": "temperature", "value": value, "start": start, "end": end})
|
||||
elif intent_id == "cabin_nav_to":
|
||||
matched = find_destination_span(text)
|
||||
if matched is not None:
|
||||
value, start, end = matched
|
||||
slots.append({"slot_name": "destination", "value": value, "start": start, "end": end})
|
||||
elif intent_id == "cabin_play_music":
|
||||
matched = find_music_span(text)
|
||||
if matched is not None:
|
||||
slot_name, value, start, end = matched
|
||||
slots.append({"slot_name": slot_name, "value": value, "start": start, "end": end})
|
||||
return slots
|
||||
|
||||
|
||||
def annotate_slots_for_intents(text: str, intent_ids: list[str]) -> list[dict[str, object]]:
|
||||
merged: list[dict[str, object]] = []
|
||||
seen: set[tuple[str, int, int]] = set()
|
||||
for intent_id in intent_ids:
|
||||
for slot in annotate_slots(text, intent_id):
|
||||
key = (str(slot["slot_name"]), int(slot["start"]), int(slot["end"]))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
merged.append(slot)
|
||||
merged.sort(key=lambda item: (int(item["start"]), int(item["end"])))
|
||||
return merged
|
||||
|
||||
|
||||
def build_train_samples() -> list[JointSample]:
|
||||
samples: list[JointSample] = []
|
||||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||||
domain_data = yaml.safe_load(DOMAIN_PATH.read_text(encoding="utf-8")) or {}
|
||||
for intent in domain_data.get("intents", []):
|
||||
intent_id = str(intent.get("intent_id", "")).strip()
|
||||
if not intent_id:
|
||||
continue
|
||||
for text in list(intent.get("examples", [])) + list(intent.get("keywords", [])):
|
||||
text = str(text).strip()
|
||||
if not text:
|
||||
continue
|
||||
key = (text, (intent_id,))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
|
||||
for row in load_jsonl(TRAIN_PATH):
|
||||
text = str(row["text"])
|
||||
intent_id = str(row["intent_id"])
|
||||
key = (text, (intent_id,))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
|
||||
for row in load_jsonl(SEED_PATH):
|
||||
text = str(row["text"])
|
||||
intent_id = str(row["intent_id"])
|
||||
key = (text, (intent_id,))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=list(row.get("slots", []))))
|
||||
for row in load_jsonl(MULTI_TRAIN_PATH):
|
||||
text = str(row["text"]).strip()
|
||||
intent_ids = sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()})
|
||||
if not text or not intent_ids:
|
||||
continue
|
||||
key = (text, tuple(intent_ids))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
slots = list(row.get("slots") or annotate_slots_for_intents(text, intent_ids))
|
||||
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=slots))
|
||||
if len(intent_ids) >= 2:
|
||||
for _ in range(MULTI_INTENT_REPEAT - 1):
|
||||
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=list(slots)))
|
||||
random.shuffle(samples)
|
||||
return samples
|
||||
|
||||
|
||||
def build_eval_samples() -> list[JointSample]:
|
||||
rows = load_jsonl(EVAL_PATH)
|
||||
samples = [
|
||||
JointSample(
|
||||
text=str(row["text"]),
|
||||
intent_ids=[str(row["intent_id"])],
|
||||
slots=list(row.get("slots", [])),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
if MULTI_EVAL_PATH.exists():
|
||||
for row in load_jsonl(MULTI_EVAL_PATH):
|
||||
samples.append(
|
||||
JointSample(
|
||||
text=str(row["text"]),
|
||||
intent_ids=sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()}),
|
||||
slots=list(row.get("slots") or annotate_slots_for_intents(str(row["text"]), list(row.get("intent_ids", [])))),
|
||||
)
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def build_slot_labels(samples: list[JointSample]) -> list[str]:
|
||||
slot_names = sorted({str(slot["slot_name"]) for sample in samples for slot in sample.slots})
|
||||
labels = ["O"]
|
||||
for name in slot_names:
|
||||
labels.append(f"B-{name}")
|
||||
labels.append(f"I-{name}")
|
||||
return labels
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
model: JointBertForNLU,
|
||||
dataloader: DataLoader,
|
||||
device: torch.device,
|
||||
intent_labels: list[str],
|
||||
slot_labels: list[str],
|
||||
threshold: float,
|
||||
) -> dict[str, float]:
|
||||
model.eval()
|
||||
intent_tp = 0
|
||||
intent_fp = 0
|
||||
intent_fn = 0
|
||||
single_intent_correct = 0
|
||||
single_intent_total = 0
|
||||
intent_exact_match = 0
|
||||
correct_slot_tokens = 0
|
||||
total_slot_tokens = 0
|
||||
exact_slot_samples = 0
|
||||
total_samples = 0
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
batch = {key: value.to(device) for key, value in batch.items()}
|
||||
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
|
||||
predicted_probs = torch.sigmoid(intent_logits)
|
||||
predicted_multi = predicted_probs >= threshold
|
||||
gold_multi = batch["intent_labels"] > 0.5
|
||||
intent_tp += int((predicted_multi & gold_multi).sum().item())
|
||||
intent_fp += int((predicted_multi & ~gold_multi).sum().item())
|
||||
intent_fn += int((~predicted_multi & gold_multi).sum().item())
|
||||
intent_exact_match += int((predicted_multi == gold_multi).all(dim=1).sum().item())
|
||||
top_predicted = torch.argmax(predicted_probs, dim=-1)
|
||||
gold_counts = gold_multi.sum(dim=-1)
|
||||
single_mask = gold_counts == 1
|
||||
if int(single_mask.sum().item()) > 0:
|
||||
gold_top = torch.argmax(gold_multi.float(), dim=-1)
|
||||
single_intent_correct += int((top_predicted[single_mask] == gold_top[single_mask]).sum().item())
|
||||
single_intent_total += int(single_mask.sum().item())
|
||||
|
||||
predicted_slots = torch.argmax(slot_logits, dim=-1)
|
||||
mask = batch["slot_labels"] != IGNORE_INDEX
|
||||
correct_slot_tokens += int(((predicted_slots == batch["slot_labels"]) & mask).sum().item())
|
||||
total_slot_tokens += int(mask.sum().item())
|
||||
|
||||
for index in range(batch["slot_labels"].size(0)):
|
||||
gold = batch["slot_labels"][index][mask[index]]
|
||||
pred = predicted_slots[index][mask[index]]
|
||||
exact_slot_samples += int(torch.equal(gold, pred))
|
||||
total_samples += 1
|
||||
precision = intent_tp / max(intent_tp + intent_fp, 1)
|
||||
recall = intent_tp / max(intent_tp + intent_fn, 1)
|
||||
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"intent_threshold": round(threshold, 4),
|
||||
"intent_micro_precision": round(precision, 4),
|
||||
"intent_micro_recall": round(recall, 4),
|
||||
"intent_micro_f1": round(f1, 4),
|
||||
"intent_exact_match": round(intent_exact_match / max(total_samples, 1), 4),
|
||||
"single_intent_top1_accuracy": round(single_intent_correct / max(single_intent_total, 1), 4),
|
||||
"slot_token_accuracy": round(correct_slot_tokens / max(total_slot_tokens, 1), 4),
|
||||
"slot_exact_match": round(exact_slot_samples / max(total_samples, 1), 4),
|
||||
"intent_label_count": float(len(intent_labels)),
|
||||
"slot_label_count": float(len(slot_labels)),
|
||||
}
|
||||
|
||||
|
||||
def search_best_threshold(
|
||||
model: JointBertForNLU,
|
||||
dataloader: DataLoader,
|
||||
device: torch.device,
|
||||
intent_labels: list[str],
|
||||
slot_labels: list[str],
|
||||
) -> dict[str, float]:
|
||||
best_metrics: dict[str, float] | None = None
|
||||
for threshold in THRESHOLD_CANDIDATES:
|
||||
metrics = compute_metrics(
|
||||
model,
|
||||
dataloader,
|
||||
device,
|
||||
intent_labels,
|
||||
slot_labels,
|
||||
threshold=threshold,
|
||||
)
|
||||
if best_metrics is None:
|
||||
best_metrics = metrics
|
||||
continue
|
||||
current_score = (metrics["intent_micro_f1"], metrics["intent_exact_match"], metrics["slot_exact_match"])
|
||||
best_score = (
|
||||
best_metrics["intent_micro_f1"],
|
||||
best_metrics["intent_exact_match"],
|
||||
best_metrics["slot_exact_match"],
|
||||
)
|
||||
if current_score > best_score:
|
||||
best_metrics = metrics
|
||||
assert best_metrics is not None
|
||||
return best_metrics
|
||||
|
||||
|
||||
def build_pos_weight(samples: list[JointSample], intent_labels: list[str]) -> torch.Tensor:
|
||||
positive_counts = {label: 0 for label in intent_labels}
|
||||
for sample in samples:
|
||||
sample_intents = set(sample.intent_ids)
|
||||
for label in intent_labels:
|
||||
if label in sample_intents:
|
||||
positive_counts[label] += 1
|
||||
total = max(len(samples), 1)
|
||||
weights: list[float] = []
|
||||
for label in intent_labels:
|
||||
positives = max(positive_counts[label], 1)
|
||||
negatives = max(total - positives, 1)
|
||||
weight = negatives / positives
|
||||
weights.append(min(max(weight, 1.0), 12.0))
|
||||
return torch.tensor(weights, dtype=torch.float32)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
set_seed()
|
||||
train_samples = build_train_samples()
|
||||
eval_samples = build_eval_samples()
|
||||
intent_labels = sorted({intent_id for sample in train_samples + eval_samples for intent_id in sample.intent_ids})
|
||||
slot_labels = build_slot_labels(train_samples + eval_samples)
|
||||
intent_to_index = {label: index for index, label in enumerate(intent_labels)}
|
||||
slot_to_index = {label: index for index, label in enumerate(slot_labels)}
|
||||
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_BASE_MODEL)
|
||||
train_dataset = JointDataset(train_samples, tokenizer, intent_to_index, slot_to_index)
|
||||
eval_dataset = JointDataset(eval_samples, tokenizer, intent_to_index, slot_to_index)
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model = JointBertForNLU(
|
||||
base_model_name=DEFAULT_BASE_MODEL,
|
||||
num_intents=len(intent_labels),
|
||||
num_slot_labels=len(slot_labels),
|
||||
)
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
pos_weight = build_pos_weight(train_samples, intent_labels).to(device)
|
||||
best_metrics: dict[str, float] | None = None
|
||||
best_state: dict[str, torch.Tensor] | None = None
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
model.train()
|
||||
epoch_loss = 0.0
|
||||
for batch in train_loader:
|
||||
batch = {key: value.to(device) for key, value in batch.items()}
|
||||
optimizer.zero_grad()
|
||||
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
|
||||
intent_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
intent_logits,
|
||||
batch["intent_labels"],
|
||||
pos_weight=pos_weight,
|
||||
)
|
||||
slot_loss = torch.nn.functional.cross_entropy(
|
||||
slot_logits.view(-1, slot_logits.size(-1)),
|
||||
batch["slot_labels"].view(-1),
|
||||
ignore_index=IGNORE_INDEX,
|
||||
)
|
||||
loss = intent_loss + slot_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
epoch_loss += float(loss.item())
|
||||
|
||||
metrics = search_best_threshold(model, eval_loader, device, intent_labels, slot_labels)
|
||||
metrics["train_loss"] = round(epoch_loss / max(len(train_loader), 1), 4)
|
||||
print(json.dumps({"epoch": epoch + 1, **metrics}, ensure_ascii=False))
|
||||
if best_metrics is None or metrics["intent_micro_f1"] > best_metrics["intent_micro_f1"]:
|
||||
best_metrics = metrics
|
||||
best_state = {key: value.detach().cpu() for key, value in model.state_dict().items()}
|
||||
|
||||
if best_state is None or best_metrics is None:
|
||||
raise RuntimeError("joint nlu training did not produce a best checkpoint")
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
torch.save(best_state, OUTPUT_DIR / "model_state.pt")
|
||||
config = {
|
||||
"base_model_name": DEFAULT_BASE_MODEL,
|
||||
"intent_task": "multi_label",
|
||||
"intent_labels": intent_labels,
|
||||
"slot_labels": slot_labels,
|
||||
"max_length": MAX_LENGTH,
|
||||
"intent_threshold": float(best_metrics["intent_threshold"]),
|
||||
"multi_intent_threshold": float(best_metrics["intent_threshold"]),
|
||||
"max_multi_intents": 4,
|
||||
}
|
||||
(OUTPUT_DIR / "joint_nlu_config.json").write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"train_size": len(train_samples),
|
||||
"eval_size": len(eval_samples),
|
||||
"metrics": best_metrics,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
684
intelligent_cabin/archive/scripts/train_local_bert_intent.py
Normal file
684
intelligent_cabin/archive/scripts/train_local_bert_intent.py
Normal file
@@ -0,0 +1,684 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
import yaml
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
||||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||||
MAX_LENGTH = 48
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 16
|
||||
LEARNING_RATE = 2e-5
|
||||
SEED = 42
|
||||
|
||||
ORDER_IDS = ["A123456", "A700001", "A800002", "A900005", "A202501", "A808001"]
|
||||
DESTINATIONS = ["公司停车场", "浦东机场", "徐家汇", "虹桥机场", "最近的充电站", "南京东路"]
|
||||
TEMPERATURES = [18, 20, 21, 22, 23, 24, 26]
|
||||
SONGS = ["夜曲", "稻香", "青花瓷", "晴天", "告白气球"]
|
||||
GENRES = ["轻音乐", "摇滚", "古典音乐", "民谣", "爵士"]
|
||||
SOCIAL_LABEL = "__social__"
|
||||
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
|
||||
|
||||
TEMPLATES: dict[str, list[str]] = {
|
||||
"cs_query_order": [
|
||||
"查一下订单{order_id}现在什么状态",
|
||||
"我的订单{order_id}到哪一步了",
|
||||
"帮我看看{order_id}这个订单",
|
||||
"确认下{order_id}订单状态",
|
||||
"订单{order_id}现在处理到哪里",
|
||||
"看下{order_id}这单进度",
|
||||
"订单号{order_id}目前怎么样",
|
||||
"帮忙确认订单{order_id}",
|
||||
"订单{order_id}有结果了吗",
|
||||
"帮我追一下订单{order_id}",
|
||||
"订单{order_id}现在受理了吗",
|
||||
"看看{order_id}这单现在啥情况",
|
||||
"帮我查查{order_id}订单进展",
|
||||
"{order_id}这个订单处理好了没",
|
||||
"{order_id}这笔订单现在进展到哪了",
|
||||
],
|
||||
"cs_query_logistics": [
|
||||
"快递{order_id}到哪儿了",
|
||||
"帮我查{order_id}物流进度",
|
||||
"看看{order_id}配送状态",
|
||||
"订单{order_id}物流更新了吗",
|
||||
"查询{order_id}的快递信息",
|
||||
"我的{order_id}现在派送到哪了",
|
||||
"查一下{order_id}这单物流",
|
||||
"配送单{order_id}走到哪里了",
|
||||
"帮我看下{order_id}快递到没到",
|
||||
"物流单号{order_id}现在在哪",
|
||||
"订单{order_id}物流到哪一步了",
|
||||
"{order_id}这单现在派件了吗",
|
||||
"帮我追踪{order_id}运输轨迹",
|
||||
"{order_id}快件现在运到哪里了",
|
||||
"我想看{order_id}的配送更新",
|
||||
],
|
||||
"cs_cancel_order": [
|
||||
"帮我取消{order_id}这个订单",
|
||||
"{order_id}别要了给我撤销",
|
||||
"把订单{order_id}取消掉",
|
||||
"我不要{order_id}了",
|
||||
"撤销一下{order_id}订单",
|
||||
"订单{order_id}不要发了",
|
||||
"帮我把{order_id}退掉并取消",
|
||||
"把{order_id}这一单停掉",
|
||||
"{order_id}这单直接取消",
|
||||
"订单号{order_id}撤回一下",
|
||||
"订单{order_id}我不想要了",
|
||||
"{order_id}这笔订单先别发了",
|
||||
"把{order_id}这单给我撤单",
|
||||
"订单{order_id}停掉吧",
|
||||
"{order_id}这个快给我取消了",
|
||||
],
|
||||
"cs_transfer_human": [
|
||||
"我要找人工客服处理",
|
||||
"现在转人工",
|
||||
"麻烦给我接人工服务",
|
||||
"帮我呼叫真人客服",
|
||||
"别机器人了我要人工",
|
||||
"转真人客服",
|
||||
"我要人工坐席",
|
||||
"帮我接人工处理",
|
||||
"叫人工客服来",
|
||||
"直接给我转人工",
|
||||
"这个问题给我人工跟进",
|
||||
"安排真人客服接手",
|
||||
"机器人处理不了,转人工",
|
||||
"帮我叫个客服专员",
|
||||
"我要人工来处理这事",
|
||||
],
|
||||
"cabin_nav_to": [
|
||||
"导航到{destination}",
|
||||
"带我去{destination}",
|
||||
"我要去{destination}",
|
||||
"去{destination}",
|
||||
"开导航去{destination}",
|
||||
"帮我导航到{destination}",
|
||||
"送我去{destination}",
|
||||
"现在去{destination}",
|
||||
"带路到{destination}",
|
||||
"去一下{destination}",
|
||||
"规划路线去{destination}",
|
||||
"直接开去{destination}",
|
||||
"给我导到{destination}",
|
||||
"{destination}怎么走,导航一下",
|
||||
"出发去{destination}",
|
||||
],
|
||||
"cabin_set_ac": [
|
||||
"把空调设到{temperature}度",
|
||||
"车里温度调成{temperature}度",
|
||||
"冷气开到{temperature}度",
|
||||
"空调给我调到{temperature}度",
|
||||
"温度改成{temperature}度",
|
||||
"车内设成{temperature}度",
|
||||
"把温度打到{temperature}度",
|
||||
"空调调为{temperature}度",
|
||||
"帮我把车里调成{temperature}度",
|
||||
"冷风调到{temperature}度",
|
||||
"把车内温度设为{temperature}度",
|
||||
"空调温度改到{temperature}度",
|
||||
"冷气帮我调到{temperature}度",
|
||||
"舱内调成{temperature}度",
|
||||
"给我把温度定在{temperature}度",
|
||||
"把车里弄凉快点",
|
||||
"车里太热了,降一点",
|
||||
"把里面调凉快一点",
|
||||
"有点热,降温",
|
||||
"空调再冷一点",
|
||||
"车内温度低一点",
|
||||
"把里面弄暖和点",
|
||||
"车里太冷了,升一点温度",
|
||||
],
|
||||
"cabin_ac_on": [
|
||||
"把空调打开",
|
||||
"开一下冷气",
|
||||
"把冷风开起来",
|
||||
"车里热,空调开开",
|
||||
"打开制冷",
|
||||
"空调启动一下",
|
||||
],
|
||||
"cabin_window_open": [
|
||||
"把车窗打开",
|
||||
"开下窗",
|
||||
"窗户开一点",
|
||||
"帮我透透气",
|
||||
"车里太闷了,开下窗",
|
||||
"顺便开下车窗",
|
||||
"把窗户降一点",
|
||||
"把玻璃打开一点",
|
||||
],
|
||||
"cabin_window_close": [
|
||||
"把车窗关上",
|
||||
"窗户关一下",
|
||||
"把窗升起来",
|
||||
"外面太吵了,把窗关了",
|
||||
"把窗户关严",
|
||||
],
|
||||
"cabin_fan_down": [
|
||||
"风别这么大",
|
||||
"风小一点",
|
||||
"别吹这么猛",
|
||||
"把风量调小一点",
|
||||
"出风弱一点",
|
||||
],
|
||||
"cabin_fan_up": [
|
||||
"风再大一点",
|
||||
"把风量开大点",
|
||||
"出风强一点",
|
||||
"风不够,调大些",
|
||||
],
|
||||
"cabin_defog_front_on": [
|
||||
"前挡起雾了,除一下",
|
||||
"把前挡风玻璃雾气清掉",
|
||||
"前窗看不清了,开除雾",
|
||||
],
|
||||
"cabin_defog_rear_on": [
|
||||
"后挡有雾,开下除雾",
|
||||
"后玻璃起雾了,清一下",
|
||||
"后窗看不清了,除雾",
|
||||
],
|
||||
"cabin_play_music": [
|
||||
"播放一首{genre}",
|
||||
"来点{genre}",
|
||||
"我想听{genre}",
|
||||
"给我播点{genre}",
|
||||
"放一首{song}",
|
||||
"来一首{song}",
|
||||
"播放{song}",
|
||||
"放点音乐,来个{genre}",
|
||||
"我想听首{song}",
|
||||
"给我来点歌,放{song}",
|
||||
"随机放点{genre}",
|
||||
"帮我播首{song}",
|
||||
"来点适合开车听的{genre}",
|
||||
"打开音乐,放{song}",
|
||||
"给我放一些{genre}",
|
||||
"放点歌",
|
||||
"来首歌",
|
||||
"整点音乐",
|
||||
"车里放点歌",
|
||||
"来点能听的",
|
||||
],
|
||||
SOCIAL_LABEL: [
|
||||
"你好",
|
||||
"嗨",
|
||||
"哈喽",
|
||||
"早上好",
|
||||
"晚上好",
|
||||
"你叫什么名字",
|
||||
"你是谁",
|
||||
"你能做什么",
|
||||
"今天天气不错",
|
||||
"陪我聊聊天",
|
||||
],
|
||||
OUT_OF_SCOPE_LABEL: [
|
||||
"帮我点个外卖",
|
||||
"订一张去北京的机票",
|
||||
"帮我买杯咖啡",
|
||||
"给我订一家酒店",
|
||||
"人类诞生的意义是什么",
|
||||
"帮我写一份年终总结",
|
||||
"推荐一部电影",
|
||||
"讲个笑话",
|
||||
"帮我做一道数学题",
|
||||
"去美团叫个外卖",
|
||||
],
|
||||
}
|
||||
|
||||
INTENT_REPLACEMENTS: dict[str, list[tuple[str, str]]] = {
|
||||
"cs_query_order": [
|
||||
("订单", "这单"),
|
||||
("查一下", "看一下"),
|
||||
("帮我", "麻烦帮我"),
|
||||
("现在什么状态", "现在啥状态"),
|
||||
("处理到哪里", "进展到哪里"),
|
||||
],
|
||||
"cs_query_logistics": [
|
||||
("物流", "快递"),
|
||||
("快递", "配送"),
|
||||
("配送", "派送"),
|
||||
("帮我", "麻烦帮我"),
|
||||
("现在在哪", "现在到哪了"),
|
||||
],
|
||||
"cs_cancel_order": [
|
||||
("取消", "撤销"),
|
||||
("撤销", "撤单"),
|
||||
("订单", "这单"),
|
||||
("帮我", "麻烦帮我"),
|
||||
("不要发了", "别发了"),
|
||||
],
|
||||
"cs_transfer_human": [
|
||||
("人工客服", "真人客服"),
|
||||
("人工", "人工坐席"),
|
||||
("帮我", "麻烦帮我"),
|
||||
],
|
||||
"cabin_nav_to": [
|
||||
("导航", "带路"),
|
||||
("带我", "送我"),
|
||||
("去", "前往"),
|
||||
],
|
||||
"cabin_set_ac": [
|
||||
("空调", "车里温度"),
|
||||
("调到", "设到"),
|
||||
("温度", "车内温度"),
|
||||
("凉快点", "冷一点"),
|
||||
("暖和点", "热一点"),
|
||||
],
|
||||
"cabin_ac_on": [
|
||||
("空调", "冷气"),
|
||||
("打开", "开"),
|
||||
("冷风", "制冷"),
|
||||
],
|
||||
"cabin_window_open": [
|
||||
("车窗", "窗户"),
|
||||
("打开", "开"),
|
||||
("透透气", "通通风"),
|
||||
],
|
||||
"cabin_window_close": [
|
||||
("关上", "关掉"),
|
||||
("车窗", "窗户"),
|
||||
("关严", "关好"),
|
||||
],
|
||||
"cabin_fan_down": [
|
||||
("风量", "风"),
|
||||
("调小", "调低"),
|
||||
("别吹这么猛", "风小一点"),
|
||||
],
|
||||
"cabin_fan_up": [
|
||||
("风量", "风"),
|
||||
("调大", "调高"),
|
||||
],
|
||||
"cabin_defog_front_on": [
|
||||
("前挡", "前窗"),
|
||||
("除雾", "除一下"),
|
||||
],
|
||||
"cabin_defog_rear_on": [
|
||||
("后挡", "后窗"),
|
||||
("除雾", "清一下雾"),
|
||||
],
|
||||
"cabin_play_music": [
|
||||
("播放", "放"),
|
||||
("来点", "播点"),
|
||||
("我想听", "给我来点"),
|
||||
("放点歌", "来首歌"),
|
||||
],
|
||||
SOCIAL_LABEL: [
|
||||
("你好", "您好"),
|
||||
("哈喽", "hello"),
|
||||
("你叫什么名字", "怎么称呼你"),
|
||||
],
|
||||
OUT_OF_SCOPE_LABEL: [
|
||||
("点个外卖", "叫个外卖"),
|
||||
("订一家酒店", "订个酒店"),
|
||||
("讲个笑话", "说个笑话"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sample:
|
||||
text: str
|
||||
intent_id: str
|
||||
|
||||
|
||||
HARD_NEGATIVE_RAW_SAMPLES: list[tuple[str, str]] = [
|
||||
("订单A700101物流到哪了", "cs_query_logistics"),
|
||||
("帮我看下订单A700102配送到哪里了", "cs_query_logistics"),
|
||||
("订单A700103现在派件了吗", "cs_query_logistics"),
|
||||
("A700104这单物流有没有更新", "cs_query_logistics"),
|
||||
("查一下订单A700105运输轨迹", "cs_query_logistics"),
|
||||
("订单A700106不要了,帮我撤单", "cs_cancel_order"),
|
||||
("A700107这单别发了,直接取消", "cs_cancel_order"),
|
||||
("把订单A700108停掉吧", "cs_cancel_order"),
|
||||
("A700109这个订单我不想要了", "cs_cancel_order"),
|
||||
("订单A700110给我撤回", "cs_cancel_order"),
|
||||
("订单A700111现在受理了吗", "cs_query_order"),
|
||||
("帮我看看A700112这单处理得怎么样了", "cs_query_order"),
|
||||
("A700113订单目前进展如何", "cs_query_order"),
|
||||
("查下订单A700114现在什么情况", "cs_query_order"),
|
||||
("帮我确认订单A700115是否已经处理", "cs_query_order"),
|
||||
("你好呀", SOCIAL_LABEL),
|
||||
("嗨,在吗", SOCIAL_LABEL),
|
||||
("今天天气真不错", SOCIAL_LABEL),
|
||||
("你叫什么名字呀", SOCIAL_LABEL),
|
||||
("你是做什么的", SOCIAL_LABEL),
|
||||
("陪我随便聊聊", SOCIAL_LABEL),
|
||||
("帮我透透气", "cabin_window_open"),
|
||||
("车里太闷了,开下窗", "cabin_window_open"),
|
||||
("把窗户降一点", "cabin_window_open"),
|
||||
("前挡起雾了,除一下", "cabin_defog_front_on"),
|
||||
("后挡有雾,开下除雾", "cabin_defog_rear_on"),
|
||||
("把车里弄凉快点", "cabin_set_ac"),
|
||||
("车里太热了,降一点", "cabin_set_ac"),
|
||||
("空调再冷一点", "cabin_set_ac"),
|
||||
("把里面弄暖和点", "cabin_set_ac"),
|
||||
("风别这么大", "cabin_fan_down"),
|
||||
("别吹这么猛", "cabin_fan_down"),
|
||||
("风再大一点", "cabin_fan_up"),
|
||||
("放点歌", "cabin_play_music"),
|
||||
("来首歌", "cabin_play_music"),
|
||||
("整点音乐", "cabin_play_music"),
|
||||
("透透气", "cabin_window_open"),
|
||||
("通通风", "cabin_window_open"),
|
||||
("车里太闷了", "cabin_window_open"),
|
||||
("凉快点", "cabin_set_ac"),
|
||||
("暖和点", "cabin_set_ac"),
|
||||
("帮我点一份麻辣烫", OUT_OF_SCOPE_LABEL),
|
||||
("给我订今晚的酒店", OUT_OF_SCOPE_LABEL),
|
||||
("帮我买张电影票", OUT_OF_SCOPE_LABEL),
|
||||
("人为什么会做梦", OUT_OF_SCOPE_LABEL),
|
||||
("帮我做个旅游攻略", OUT_OF_SCOPE_LABEL),
|
||||
("帮我点肯德基外卖", OUT_OF_SCOPE_LABEL),
|
||||
("透透气,别给我除雾", "cabin_window_open"),
|
||||
("后挡有雾,不是开窗,是除雾", "cabin_defog_rear_on"),
|
||||
("前挡看不清了,除雾不要开窗", "cabin_defog_front_on"),
|
||||
("凉快点,不是把风量调小", "cabin_set_ac"),
|
||||
("别吹这么猛,不是降温", "cabin_fan_down"),
|
||||
("来首歌,不是切下一首", "cabin_play_music"),
|
||||
("放点歌,不是暂停音乐", "cabin_play_music"),
|
||||
]
|
||||
|
||||
HARD_NEGATIVE_SAMPLES: list[Sample] = [
|
||||
Sample(text=text, intent_id=intent_id) for text, intent_id in HARD_NEGATIVE_RAW_SAMPLES
|
||||
]
|
||||
|
||||
|
||||
class IntentDataset(Dataset):
|
||||
def __init__(self, samples: list[Sample], tokenizer, label_to_id: dict[str, int]) -> None:
|
||||
self._samples = samples
|
||||
self._tokenizer = tokenizer
|
||||
self._label_to_id = label_to_id
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||||
sample = self._samples[index]
|
||||
encoded = self._tokenizer(
|
||||
sample.text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=MAX_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return {
|
||||
"input_ids": encoded["input_ids"].squeeze(0),
|
||||
"attention_mask": encoded["attention_mask"].squeeze(0),
|
||||
"labels": torch.tensor(self._label_to_id[sample.intent_id], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
|
||||
def load_samples(file_path: Path) -> list[Sample]:
|
||||
samples: list[Sample] = []
|
||||
if not file_path.exists():
|
||||
return samples
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
samples.append(Sample(text=str(payload["text"]), intent_id=str(payload["intent_id"])))
|
||||
return samples
|
||||
|
||||
|
||||
def load_domain_samples(file_path: Path) -> list[Sample]:
|
||||
if not file_path.exists():
|
||||
return []
|
||||
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
|
||||
intents = payload.get("intents", [])
|
||||
samples: list[Sample] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for item in intents:
|
||||
intent_id = str(item.get("intent_id") or "").strip()
|
||||
if not intent_id:
|
||||
continue
|
||||
seed_texts = list(item.get("examples") or [])
|
||||
seed_texts.extend(item.get("keywords") or [])
|
||||
label = str(item.get("label") or "").strip()
|
||||
if label:
|
||||
seed_texts.append(label)
|
||||
for text in seed_texts:
|
||||
normalized = str(text).strip()
|
||||
if not normalized:
|
||||
continue
|
||||
for variant in expand_seed_variants(normalized):
|
||||
key = (variant, intent_id)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(Sample(text=variant, intent_id=intent_id))
|
||||
return samples
|
||||
|
||||
|
||||
def expand_seed_variants(text: str) -> list[str]:
|
||||
normalized = text.strip().strip(",。!?;; ")
|
||||
if not normalized:
|
||||
return []
|
||||
variants = {
|
||||
normalized,
|
||||
normalized.replace("一下", "").strip(),
|
||||
normalized.replace("帮我", "").strip(),
|
||||
normalized.replace("请", "").strip(),
|
||||
f"帮我{normalized}",
|
||||
f"请{normalized}",
|
||||
f"{normalized}一下",
|
||||
}
|
||||
cleaned: list[str] = []
|
||||
for item in variants:
|
||||
compact = " ".join(item.split()).strip(",。!?;; ")
|
||||
if compact:
|
||||
cleaned.append(compact)
|
||||
return cleaned
|
||||
|
||||
|
||||
def load_training_samples() -> list[Sample]:
|
||||
samples = load_samples(TRAIN_PATH)
|
||||
samples.extend(load_domain_samples(DOMAIN_PATH))
|
||||
deduped: list[Sample] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for sample in samples:
|
||||
key = (sample.text, sample.intent_id)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(sample)
|
||||
return deduped
|
||||
|
||||
|
||||
def augment_samples(samples: list[Sample]) -> list[Sample]:
|
||||
augmented = list(samples)
|
||||
seen = {(sample.text, sample.intent_id) for sample in augmented}
|
||||
for intent_id, templates in TEMPLATES.items():
|
||||
for index, template in enumerate(templates):
|
||||
sample = render_template(intent_id, template, index)
|
||||
key = (sample.text, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(sample)
|
||||
seen.add(key)
|
||||
|
||||
for sample in HARD_NEGATIVE_SAMPLES:
|
||||
key = (sample.text, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(sample)
|
||||
seen.add(key)
|
||||
|
||||
for sample in list(augmented):
|
||||
text = sample.text
|
||||
for source, target in INTENT_REPLACEMENTS.get(sample.intent_id, []):
|
||||
if source in text:
|
||||
variant = text.replace(source, target, 1)
|
||||
key = (variant, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(Sample(text=variant, intent_id=sample.intent_id))
|
||||
seen.add(key)
|
||||
|
||||
compact = text
|
||||
for token in ("帮我", "麻烦", "请", "一下"):
|
||||
if token in compact:
|
||||
compact = compact.replace(token, "", 1)
|
||||
compact = compact.strip(" ,。!?")
|
||||
if compact and compact != text:
|
||||
key = (compact, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(Sample(text=compact, intent_id=sample.intent_id))
|
||||
seen.add(key)
|
||||
random.shuffle(augmented)
|
||||
return augmented
|
||||
|
||||
|
||||
def render_template(intent_id: str, template: str, index: int) -> Sample:
|
||||
order_id = ORDER_IDS[index % len(ORDER_IDS)]
|
||||
destination = DESTINATIONS[index % len(DESTINATIONS)]
|
||||
temperature = TEMPERATURES[index % len(TEMPERATURES)]
|
||||
song = SONGS[index % len(SONGS)]
|
||||
genre = GENRES[index % len(GENRES)]
|
||||
text = template.format(
|
||||
order_id=order_id,
|
||||
destination=destination,
|
||||
temperature=temperature,
|
||||
song=song,
|
||||
genre=genre,
|
||||
)
|
||||
return Sample(text=text, intent_id=intent_id)
|
||||
|
||||
|
||||
def split_samples(samples: list[Sample]) -> tuple[list[Sample], list[Sample]]:
|
||||
grouped: dict[str, list[Sample]] = {}
|
||||
for sample in samples:
|
||||
grouped.setdefault(sample.intent_id, []).append(sample)
|
||||
train_samples: list[Sample] = []
|
||||
dev_samples: list[Sample] = []
|
||||
for items in grouped.values():
|
||||
random.shuffle(items)
|
||||
cut = max(1, int(len(items) * 0.8))
|
||||
train_samples.extend(items[:cut])
|
||||
dev_samples.extend(items[cut:])
|
||||
random.shuffle(train_samples)
|
||||
random.shuffle(dev_samples)
|
||||
return train_samples, dev_samples
|
||||
|
||||
|
||||
def accuracy(model, loader: DataLoader, device: torch.device) -> float:
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
preds = outputs.logits.argmax(dim=-1)
|
||||
correct += int((preds == labels).sum().item())
|
||||
total += int(labels.numel())
|
||||
return correct / total if total else 0.0
|
||||
|
||||
|
||||
def resolve_base_model() -> str:
|
||||
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
return DEFAULT_BASE_MODEL
|
||||
|
||||
|
||||
def main() -> None:
|
||||
set_seed(SEED)
|
||||
samples = augment_samples(load_training_samples())
|
||||
intents = sorted({sample.intent_id for sample in samples})
|
||||
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
|
||||
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
|
||||
|
||||
train_samples, dev_samples = split_samples(samples)
|
||||
base_model = resolve_base_model()
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
train_dataset = IntentDataset(train_samples, tokenizer, label_to_id)
|
||||
dev_dataset = IntentDataset(dev_samples, tokenizer, label_to_id)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model,
|
||||
num_labels=len(intents),
|
||||
id2label=id_to_label,
|
||||
label2id=label_to_id,
|
||||
)
|
||||
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
best_dev_acc = 0.0
|
||||
best_state = None
|
||||
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += float(loss.item())
|
||||
|
||||
dev_acc = accuracy(model, dev_loader, device)
|
||||
avg_loss = total_loss / max(len(train_loader), 1)
|
||||
print(f"epoch={epoch} loss={avg_loss:.4f} dev_acc={dev_acc:.4f}")
|
||||
if dev_acc >= best_dev_acc:
|
||||
best_dev_acc = dev_acc
|
||||
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
|
||||
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(OUTPUT_DIR)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
|
||||
(OUTPUT_DIR / "label_map.json").write_text(
|
||||
json.dumps(label_map, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
train_summary = {
|
||||
"base_model": base_model,
|
||||
"epochs": EPOCHS,
|
||||
"batch_size": BATCH_SIZE,
|
||||
"learning_rate": LEARNING_RATE,
|
||||
"train_size": len(train_samples),
|
||||
"dev_size": len(dev_samples),
|
||||
"best_dev_accuracy": round(best_dev_acc, 4),
|
||||
"device": str(device),
|
||||
}
|
||||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||||
json.dumps(train_summary, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,415 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
import yaml
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
SINGLE_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||||
MULTI_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
|
||||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_multi_intent"
|
||||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||||
SOCIAL_LABEL = "__social__"
|
||||
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
|
||||
BLOCKED_LABELS = {SOCIAL_LABEL, OUT_OF_SCOPE_LABEL}
|
||||
MAX_LENGTH = 48
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 12
|
||||
LEARNING_RATE = 2e-5
|
||||
THRESHOLD = 0.5
|
||||
TOP_K = 4
|
||||
SEED = 42
|
||||
|
||||
CONNECTOR_VARIANTS: tuple[tuple[str, str], ...] = (
|
||||
("并", "然后"),
|
||||
("然后", "并"),
|
||||
("顺便", "再"),
|
||||
("再", "顺便"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiLabelSample:
|
||||
text: str
|
||||
intent_ids: tuple[str, ...]
|
||||
|
||||
|
||||
class MultiLabelIntentDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
samples: list[MultiLabelSample],
|
||||
tokenizer,
|
||||
label_to_id: dict[str, int],
|
||||
) -> None:
|
||||
self._samples = samples
|
||||
self._tokenizer = tokenizer
|
||||
self._label_to_id = label_to_id
|
||||
self._label_size = len(label_to_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||||
sample = self._samples[index]
|
||||
encoded = self._tokenizer(
|
||||
sample.text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=MAX_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = torch.zeros(self._label_size, dtype=torch.float32)
|
||||
for intent_id in sample.intent_ids:
|
||||
labels[self._label_to_id[intent_id]] = 1.0
|
||||
return {
|
||||
"input_ids": encoded["input_ids"].squeeze(0),
|
||||
"attention_mask": encoded["attention_mask"].squeeze(0),
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
|
||||
def resolve_base_model() -> str:
|
||||
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
return DEFAULT_BASE_MODEL
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
return " ".join(str(text).strip().split())
|
||||
|
||||
|
||||
def normalize_intent_ids(intent_ids: list[str] | tuple[str, ...]) -> tuple[str, ...]:
|
||||
cleaned = sorted(
|
||||
{
|
||||
str(intent_id).strip()
|
||||
for intent_id in intent_ids
|
||||
if str(intent_id).strip() and str(intent_id).strip() not in BLOCKED_LABELS
|
||||
}
|
||||
)
|
||||
return tuple(cleaned)
|
||||
|
||||
|
||||
def expand_single_label_variants(text: str) -> list[str]:
|
||||
normalized = text.strip().strip(",。!?;; ")
|
||||
if not normalized:
|
||||
return []
|
||||
variants = {
|
||||
normalized,
|
||||
normalized.replace("一下", "").strip(),
|
||||
normalized.replace("帮我", "").strip(),
|
||||
normalized.replace("请", "").strip(),
|
||||
f"帮我{normalized}",
|
||||
f"请{normalized}",
|
||||
f"{normalized}一下",
|
||||
}
|
||||
cleaned: list[str] = []
|
||||
for item in variants:
|
||||
compact = " ".join(item.split()).strip(",。!?;; ")
|
||||
if compact:
|
||||
cleaned.append(compact)
|
||||
return cleaned
|
||||
|
||||
|
||||
def load_single_label_samples(file_path: Path) -> list[MultiLabelSample]:
|
||||
samples: list[MultiLabelSample] = []
|
||||
if not file_path.exists():
|
||||
return samples
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
intent_ids = normalize_intent_ids([str(payload.get("intent_id") or "")])
|
||||
if not intent_ids:
|
||||
continue
|
||||
text = normalize_text(str(payload.get("text") or ""))
|
||||
if not text:
|
||||
continue
|
||||
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
|
||||
return samples
|
||||
|
||||
|
||||
def load_domain_samples(file_path: Path) -> list[MultiLabelSample]:
|
||||
if not file_path.exists():
|
||||
return []
|
||||
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
|
||||
intents = payload.get("intents", [])
|
||||
samples: list[MultiLabelSample] = []
|
||||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||||
for item in intents:
|
||||
intent_ids = normalize_intent_ids([str(item.get("intent_id") or "")])
|
||||
if not intent_ids:
|
||||
continue
|
||||
seed_texts = list(item.get("examples") or [])
|
||||
seed_texts.extend(item.get("keywords") or [])
|
||||
label = str(item.get("label") or "").strip()
|
||||
if label:
|
||||
seed_texts.append(label)
|
||||
for text in seed_texts:
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
continue
|
||||
for variant in expand_single_label_variants(normalized):
|
||||
key = (variant, intent_ids)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(MultiLabelSample(text=variant, intent_ids=intent_ids))
|
||||
return samples
|
||||
|
||||
|
||||
def load_multilabel_samples(file_path: Path) -> list[MultiLabelSample]:
|
||||
samples: list[MultiLabelSample] = []
|
||||
if not file_path.exists():
|
||||
return samples
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
intent_ids = normalize_intent_ids(list(payload.get("intent_ids") or []))
|
||||
if len(intent_ids) < 2:
|
||||
continue
|
||||
text = normalize_text(str(payload.get("text") or ""))
|
||||
if not text:
|
||||
continue
|
||||
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
|
||||
return samples
|
||||
|
||||
|
||||
def augment_multilabel_samples(samples: list[MultiLabelSample]) -> list[MultiLabelSample]:
|
||||
augmented = list(samples)
|
||||
seen = {(sample.text, sample.intent_ids) for sample in augmented}
|
||||
for sample in list(samples):
|
||||
variants = {
|
||||
sample.text,
|
||||
f"帮我{sample.text}",
|
||||
f"请{sample.text}",
|
||||
sample.text.replace(",", ", "),
|
||||
sample.text.replace(",", ""),
|
||||
}
|
||||
for source, target in CONNECTOR_VARIANTS:
|
||||
if source in sample.text:
|
||||
variants.add(sample.text.replace(source, target, 1))
|
||||
for variant in variants:
|
||||
normalized = normalize_text(variant).strip(",。!?;; ")
|
||||
key = (normalized, sample.intent_ids)
|
||||
if normalized and key not in seen:
|
||||
augmented.append(MultiLabelSample(text=normalized, intent_ids=sample.intent_ids))
|
||||
seen.add(key)
|
||||
return augmented
|
||||
|
||||
|
||||
def load_all_samples() -> list[MultiLabelSample]:
|
||||
samples = load_single_label_samples(SINGLE_LABEL_PATH)
|
||||
samples.extend(load_domain_samples(DOMAIN_PATH))
|
||||
samples.extend(augment_multilabel_samples(load_multilabel_samples(MULTI_LABEL_PATH)))
|
||||
deduped: list[MultiLabelSample] = []
|
||||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||||
for sample in samples:
|
||||
key = (sample.text, sample.intent_ids)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(sample)
|
||||
random.shuffle(deduped)
|
||||
return deduped
|
||||
|
||||
|
||||
def split_samples(samples: list[MultiLabelSample]) -> tuple[list[MultiLabelSample], list[MultiLabelSample]]:
|
||||
grouped: dict[tuple[str, ...], list[MultiLabelSample]] = {}
|
||||
for sample in samples:
|
||||
grouped.setdefault(sample.intent_ids, []).append(sample)
|
||||
train_samples: list[MultiLabelSample] = []
|
||||
dev_samples: list[MultiLabelSample] = []
|
||||
for items in grouped.values():
|
||||
random.shuffle(items)
|
||||
if len(items) == 1:
|
||||
train_samples.extend(items)
|
||||
continue
|
||||
cut = max(1, int(len(items) * 0.8))
|
||||
if cut >= len(items):
|
||||
cut = len(items) - 1
|
||||
train_samples.extend(items[:cut])
|
||||
dev_samples.extend(items[cut:])
|
||||
if not dev_samples:
|
||||
dev_samples = train_samples[-max(1, min(32, len(train_samples) // 5 or 1)) :]
|
||||
train_samples = train_samples[: len(train_samples) - len(dev_samples)]
|
||||
random.shuffle(train_samples)
|
||||
random.shuffle(dev_samples)
|
||||
return train_samples, dev_samples
|
||||
|
||||
|
||||
def logits_to_probabilities(logits: torch.Tensor) -> list[list[float]]:
|
||||
return torch.sigmoid(logits).detach().cpu().tolist()
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
probabilities: list[list[float]],
|
||||
targets: list[list[float]],
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
) -> dict[str, float]:
|
||||
true_positive = 0
|
||||
false_positive = 0
|
||||
false_negative = 0
|
||||
exact_match = 0
|
||||
recall_at_k_total = 0.0
|
||||
total = len(probabilities)
|
||||
for scores, target in zip(probabilities, targets):
|
||||
predicted = {index for index, score in enumerate(scores) if score >= threshold}
|
||||
expected = {index for index, value in enumerate(target) if value >= 0.5}
|
||||
if predicted == expected:
|
||||
exact_match += 1
|
||||
true_positive += len(predicted & expected)
|
||||
false_positive += len(predicted - expected)
|
||||
false_negative += len(expected - predicted)
|
||||
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
|
||||
if expected:
|
||||
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
|
||||
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
||||
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
||||
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(micro_f1, 4),
|
||||
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
||||
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
|
||||
def evaluate(model, loader: DataLoader, device: torch.device, threshold: float, top_k: int) -> tuple[float, dict[str, float]]:
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
probabilities: list[list[float]] = []
|
||||
targets: list[list[float]] = []
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
total_loss += float(outputs.loss.item())
|
||||
probabilities.extend(logits_to_probabilities(outputs.logits))
|
||||
targets.extend(labels.detach().cpu().tolist())
|
||||
avg_loss = total_loss / max(len(loader), 1)
|
||||
return avg_loss, compute_metrics(probabilities, targets, threshold=threshold, top_k=top_k)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
set_seed(SEED)
|
||||
samples = load_all_samples()
|
||||
intents = sorted({intent_id for sample in samples for intent_id in sample.intent_ids})
|
||||
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
|
||||
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
|
||||
|
||||
train_samples, dev_samples = split_samples(samples)
|
||||
base_model = resolve_base_model()
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
train_dataset = MultiLabelIntentDataset(train_samples, tokenizer, label_to_id)
|
||||
dev_dataset = MultiLabelIntentDataset(dev_samples, tokenizer, label_to_id)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model,
|
||||
num_labels=len(intents),
|
||||
id2label=id_to_label,
|
||||
label2id=label_to_id,
|
||||
problem_type="multi_label_classification",
|
||||
)
|
||||
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
best_dev_f1 = 0.0
|
||||
best_state = None
|
||||
best_metrics: dict[str, float] = {}
|
||||
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += float(loss.item())
|
||||
|
||||
dev_loss, dev_metrics = evaluate(model, dev_loader, device, threshold=THRESHOLD, top_k=TOP_K)
|
||||
avg_loss = total_loss / max(len(train_loader), 1)
|
||||
print(
|
||||
" ".join(
|
||||
[
|
||||
f"epoch={epoch}",
|
||||
f"train_loss={avg_loss:.4f}",
|
||||
f"dev_loss={dev_loss:.4f}",
|
||||
f"dev_micro_f1={dev_metrics['micro_f1']:.4f}",
|
||||
f"dev_exact_match={dev_metrics['exact_match']:.4f}",
|
||||
]
|
||||
)
|
||||
)
|
||||
if dev_metrics["micro_f1"] >= best_dev_f1:
|
||||
best_dev_f1 = dev_metrics["micro_f1"]
|
||||
best_metrics = dict(dev_metrics)
|
||||
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
|
||||
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(OUTPUT_DIR)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
|
||||
(OUTPUT_DIR / "label_map.json").write_text(
|
||||
json.dumps(label_map, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
train_summary = {
|
||||
"task_type": "multi_label_intent_detection",
|
||||
"base_model": base_model,
|
||||
"epochs": EPOCHS,
|
||||
"batch_size": BATCH_SIZE,
|
||||
"learning_rate": LEARNING_RATE,
|
||||
"threshold": THRESHOLD,
|
||||
"top_k": TOP_K,
|
||||
"train_size": len(train_samples),
|
||||
"dev_size": len(dev_samples),
|
||||
"label_count": len(intents),
|
||||
"labels": intents,
|
||||
"best_dev_metrics": best_metrics,
|
||||
"device": str(device),
|
||||
}
|
||||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||||
json.dumps(train_summary, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
132
intelligent_cabin/archive/tests/test_agent_cloud_route.py
Normal file
132
intelligent_cabin/archive/tests/test_agent_cloud_route.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.plugins.base import PluginRegistry
|
||||
from app.schemas.chat import ChatRequest
|
||||
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.planner import PlanningResult
|
||||
from app.services.session_store import InMemorySessionStore
|
||||
|
||||
|
||||
class _RouteToCloudRouter:
|
||||
def route(self, text: str):
|
||||
_ = text
|
||||
return type(
|
||||
"RouteResult",
|
||||
(),
|
||||
{
|
||||
"intent": None,
|
||||
"debug": RoutingDebug(
|
||||
selected_intent="cabin_nav_to",
|
||||
matched_stage="fusion",
|
||||
decision="route_to_cloud",
|
||||
decision_reason="local signal is not stable enough, routing to cloud planner",
|
||||
confidence_grade="low",
|
||||
stages=[
|
||||
MatcherStageDebug(
|
||||
stage="fusion",
|
||||
accepted=False,
|
||||
selected_intent="cabin_nav_to",
|
||||
score=0.88,
|
||||
reason="route to cloud",
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="cabin_nav_to", score=0.88, reason="fusion", model_name="fusion"),
|
||||
IntentCandidate(intent_id="cabin_play_music", score=0.75, reason="fusion", model_name="fusion"),
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
},
|
||||
)()
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, object]:
|
||||
_ = (text, intent)
|
||||
return {}
|
||||
|
||||
|
||||
class _PlannerRejects:
|
||||
def plan(self, text: str, intents: list[IntentDefinition], context: dict[str, object] | None = None) -> PlanningResult:
|
||||
_ = (text, intents, context)
|
||||
return PlanningResult(
|
||||
accepted=False,
|
||||
workflow_type="single",
|
||||
model_name="qwen3.5-plus",
|
||||
backend="dashscope",
|
||||
reason="cloud planner could not produce a stable executable step",
|
||||
)
|
||||
|
||||
|
||||
class _PlannerOutOfScope:
|
||||
def plan(self, text: str, intents: list[IntentDefinition], context: dict[str, object] | None = None) -> PlanningResult:
|
||||
_ = (text, intents, context)
|
||||
return PlanningResult(
|
||||
accepted=False,
|
||||
workflow_type="single",
|
||||
model_name="qwen3.5-plus",
|
||||
backend="dashscope",
|
||||
reason="The provided intent catalog only contains cabin and service actions. There is no matching intent for ordering food via a third-party app action.",
|
||||
)
|
||||
|
||||
|
||||
def _intent(intent_id: str) -> IntentDefinition:
|
||||
return IntentDefinition(
|
||||
intent_id=intent_id,
|
||||
plugin_id=f"mock.{intent_id}",
|
||||
domain="cabin",
|
||||
keywords=[],
|
||||
examples=[],
|
||||
)
|
||||
|
||||
|
||||
class AgentCloudRouteTests(unittest.TestCase):
|
||||
def test_route_to_cloud_returns_explicit_clarify_feedback_when_planner_does_not_accept(self) -> None:
|
||||
service = AgentService(
|
||||
intent_registry=IntentRegistry([_intent("cabin_nav_to"), _intent("cabin_play_music")]),
|
||||
router=_RouteToCloudRouter(),
|
||||
plugins=PluginRegistry(),
|
||||
session_store=InMemorySessionStore(),
|
||||
planner=_PlannerRejects(),
|
||||
)
|
||||
|
||||
response = service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_cloud_route",
|
||||
user_id="user_1",
|
||||
input_text="带我过去",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.decision, "route_to_cloud")
|
||||
self.assertEqual(response.reply_type, "clarify")
|
||||
self.assertEqual(response.status, "route_to_cloud")
|
||||
self.assertIn("请确认一下", response.reply_text)
|
||||
|
||||
def test_route_to_cloud_rejects_when_planner_marks_request_out_of_scope(self) -> None:
|
||||
service = AgentService(
|
||||
intent_registry=IntentRegistry([_intent("cabin_nav_to"), _intent("cabin_play_music")]),
|
||||
router=_RouteToCloudRouter(),
|
||||
plugins=PluginRegistry(),
|
||||
session_store=InMemorySessionStore(),
|
||||
planner=_PlannerOutOfScope(),
|
||||
)
|
||||
|
||||
response = service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_cloud_route_reject",
|
||||
user_id="user_1",
|
||||
input_text="去美团叫个外卖",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "reject")
|
||||
self.assertEqual(response.decision, "reject")
|
||||
self.assertEqual(response.status, "rejected")
|
||||
self.assertIn("做不了", response.reply_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
235
intelligent_cabin/archive/tests/test_bert.py
Normal file
235
intelligent_cabin/archive/tests/test_bert.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.classifier import BertIntentClassifier
|
||||
from app.services.router import build_matcher_pipeline
|
||||
|
||||
|
||||
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "local_bert_intent"
|
||||
NON_BUSINESS_LABELS = {"__social__", "__out_of_scope__"}
|
||||
|
||||
|
||||
def resolve_model_path(model_path: str) -> Path:
|
||||
configured = (model_path or settings.classifier_model_path).strip()
|
||||
if configured:
|
||||
return Path(configured)
|
||||
return DEFAULT_MODEL_DIR
|
||||
|
||||
|
||||
def resolve_label_map_path(label_map_path: str, model_path: Path) -> Path:
|
||||
configured = (label_map_path or settings.classifier_label_map_path).strip()
|
||||
if configured:
|
||||
return Path(configured)
|
||||
return model_path / "label_map.json"
|
||||
|
||||
|
||||
def build_classifier(
|
||||
*,
|
||||
model_path: Path,
|
||||
label_map_path: Path,
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
) -> BertIntentClassifier:
|
||||
return BertIntentClassifier(
|
||||
model_path=str(model_path),
|
||||
threshold=threshold,
|
||||
label_map_path=str(label_map_path),
|
||||
fallback=None,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def predict_once(
|
||||
text: str,
|
||||
*,
|
||||
model_path: Path,
|
||||
label_map_path: Path,
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
warmup: bool,
|
||||
) -> dict[str, object]:
|
||||
registry = build_intent_registry()
|
||||
classifier = build_classifier(
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=threshold,
|
||||
top_k=top_k,
|
||||
)
|
||||
warmup_ok = None
|
||||
if warmup:
|
||||
warmup_ok = classifier.warmup(settings.classifier_warmup_text)
|
||||
if not warmup_ok:
|
||||
error_message = getattr(classifier, "_warmup_error_message", None) or "BERT warmup failed"
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
intents = registry.list()
|
||||
result = classifier.predict(text, intents)
|
||||
matcher = build_matcher_pipeline(
|
||||
registry,
|
||||
["classifier"],
|
||||
classifier=classifier,
|
||||
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
|
||||
clarify_margin_threshold=settings.local_clarify_margin_threshold,
|
||||
classifier_execute_score_threshold=settings.local_classifier_execute_score_threshold,
|
||||
classifier_execute_margin_threshold=settings.local_classifier_execute_margin_threshold,
|
||||
)
|
||||
route_result = matcher.match(text)
|
||||
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
|
||||
classifier_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "classifier"), None)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"config": {
|
||||
"model_path": str(model_path),
|
||||
"label_map_path": str(label_map_path),
|
||||
"threshold": threshold,
|
||||
"top_k": top_k,
|
||||
"warmup_requested": warmup,
|
||||
"warmup_ok": warmup_ok,
|
||||
"warmup_elapsed_ms": getattr(classifier, "_warmup_elapsed_ms", None),
|
||||
"warmup_error_message": getattr(classifier, "_warmup_error_message", None),
|
||||
},
|
||||
"classifier_result": {
|
||||
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
|
||||
"score": round(result.score, 4),
|
||||
"model_name": result.model_name,
|
||||
"backend_name": result.backend_name,
|
||||
"used_fallback": result.used_fallback,
|
||||
"fallback_reason": result.fallback_reason,
|
||||
"error_message": result.error_message,
|
||||
"raw_label": result.raw_label,
|
||||
"raw_candidates": result.raw_candidates or [],
|
||||
"known_candidates": [
|
||||
{"intent_id": intent.intent_id, "score": round(score, 4)}
|
||||
for intent, score in (result.candidates or [])
|
||||
],
|
||||
},
|
||||
"route_result": {
|
||||
"decision": route_result.debug.decision,
|
||||
"decision_reason": route_result.debug.decision_reason,
|
||||
"matched_stage": route_result.debug.matched_stage,
|
||||
"selected_intent": route_result.debug.selected_intent,
|
||||
"confidence_grade": route_result.debug.confidence_grade,
|
||||
"unknown_detected": route_result.debug.unknown_detected,
|
||||
"classifier_score": round(classifier_stage.score, 4) if classifier_stage is not None else None,
|
||||
"fusion_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def summarize_business_view(result: dict[str, object]) -> dict[str, object]:
|
||||
classifier_result = dict(result.get("classifier_result") or {})
|
||||
route_result = dict(result.get("route_result") or {})
|
||||
predicted_intent = classifier_result.get("predicted_intent")
|
||||
raw_label = classifier_result.get("raw_label")
|
||||
effective_label = raw_label if raw_label in NON_BUSINESS_LABELS else predicted_intent
|
||||
if effective_label in NON_BUSINESS_LABELS:
|
||||
classifier_result["predicted_intent"] = None
|
||||
classifier_result["non_business_label"] = effective_label
|
||||
classifier_result["business_interpretation"] = "non_business_label_detected"
|
||||
route_result["selected_intent"] = None
|
||||
route_result["decision"] = "reject"
|
||||
route_result["decision_reason"] = "classifier detected a non-business label"
|
||||
route_result["unknown_detected"] = True
|
||||
else:
|
||||
classifier_result["non_business_label"] = None
|
||||
classifier_result["business_interpretation"] = "known_business_intent_or_uncertain"
|
||||
return {
|
||||
"text": result.get("text"),
|
||||
"config": result.get("config"),
|
||||
"classifier_result": classifier_result,
|
||||
"route_result": route_result,
|
||||
}
|
||||
|
||||
|
||||
def interactive_loop(
|
||||
*,
|
||||
model_path: Path,
|
||||
label_map_path: Path,
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
warmup: bool,
|
||||
mode: str,
|
||||
) -> None:
|
||||
print("当前 BERT 测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
|
||||
print(f"model_path={model_path}")
|
||||
print(f"label_map_path={label_map_path}")
|
||||
print(f"threshold={threshold} top_k={top_k} warmup={warmup} mode={mode}")
|
||||
while True:
|
||||
try:
|
||||
text = input("\n请输入问题> ").strip()
|
||||
except EOFError:
|
||||
print()
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
if text.lower() in {"exit", "quit", "q"}:
|
||||
break
|
||||
result = predict_once(
|
||||
text,
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=threshold,
|
||||
top_k=top_k,
|
||||
warmup=warmup,
|
||||
)
|
||||
if mode == "business":
|
||||
result = summarize_business_view(result)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="当前项目 BERT 意图识别测试脚本")
|
||||
parser.add_argument("--text", type=str, default="", help="单次测试文本")
|
||||
parser.add_argument("--threshold", type=float, default=settings.classifier_bert_threshold, help="BERT 置信度阈值")
|
||||
parser.add_argument("--top-k", type=int, default=settings.classifier_top_k, help="返回候选数量")
|
||||
parser.add_argument("--model-path", type=str, default="", help="模型目录,默认取 .env 或 models/local_bert_intent")
|
||||
parser.add_argument("--label-map-path", type=str, default="", help="标签映射文件,默认取 .env 或 model_path/label_map.json")
|
||||
parser.add_argument("--warmup", action="store_true", help="先执行一次 warmup 再预测")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=("classifier", "business"),
|
||||
default="classifier",
|
||||
help="classifier 显示原始分类结果;business 会把非业务标签折叠成未命中业务意图",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = resolve_model_path(args.model_path)
|
||||
label_map_path = resolve_label_map_path(args.label_map_path, model_path)
|
||||
|
||||
if args.text.strip():
|
||||
result = predict_once(
|
||||
args.text.strip(),
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=args.threshold,
|
||||
top_k=args.top_k,
|
||||
warmup=args.warmup,
|
||||
)
|
||||
if args.mode == "business":
|
||||
result = summarize_business_view(result)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
return
|
||||
|
||||
interactive_loop(
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=args.threshold,
|
||||
top_k=args.top_k,
|
||||
warmup=args.warmup,
|
||||
mode=args.mode,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
109
intelligent_cabin/archive/tests/test_chat_stream.py
Normal file
109
intelligent_cabin/archive/tests/test_chat_stream.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
os.environ["AGENT_CLASSIFIER_BACKEND"] = "mock"
|
||||
os.environ["AGENT_CLASSIFIER_WARMUP_ENABLED"] = "false"
|
||||
|
||||
from app.main import app
|
||||
from app.schemas.chat import ChatResponse
|
||||
|
||||
|
||||
def _fake_response() -> ChatResponse:
|
||||
return ChatResponse(
|
||||
session_id="sess_stream_1",
|
||||
reply_type="workflow_result",
|
||||
reply_text="好,空调已经打开了。",
|
||||
intent="cabin_ac_on",
|
||||
status="completed",
|
||||
trace_id="trace_stream_1",
|
||||
)
|
||||
|
||||
|
||||
class ChatStreamTests(unittest.TestCase):
|
||||
def test_chat_stream_returns_final_only_when_fast(self) -> None:
|
||||
client = TestClient(app)
|
||||
with patch("app.main.agent_service.handle_chat", return_value=_fake_response()):
|
||||
response = client.post(
|
||||
"/api/v1/agent/chat-stream",
|
||||
json={
|
||||
"session_id": "sess_stream_1",
|
||||
"user_id": "user_stream_1",
|
||||
"channel": "test",
|
||||
"input_text": "打开车窗",
|
||||
"input_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
|
||||
self.assertEqual(len(lines), 1)
|
||||
final_event = json.loads(lines[0])
|
||||
self.assertEqual(final_event.get("type"), "final")
|
||||
|
||||
def test_chat_stream_returns_ack_then_final_when_slow_request(self) -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
def _slow_handle_chat(_request):
|
||||
time.sleep(1.2)
|
||||
return _fake_response()
|
||||
|
||||
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
|
||||
response = client.post(
|
||||
"/api/v1/agent/chat-stream",
|
||||
json={
|
||||
"session_id": "sess_stream_1",
|
||||
"user_id": "user_stream_1",
|
||||
"channel": "test",
|
||||
"input_text": "打开车窗",
|
||||
"input_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
|
||||
self.assertGreaterEqual(len(lines), 2)
|
||||
|
||||
ack_event = json.loads(lines[0])
|
||||
final_event = json.loads(lines[-1])
|
||||
self.assertEqual(ack_event.get("type"), "ack")
|
||||
self.assertEqual(final_event.get("type"), "final")
|
||||
self.assertIn("data", final_event)
|
||||
self.assertIn("reply_text", final_event["data"])
|
||||
|
||||
def test_chat_stream_returns_ack_then_final_when_slow_social_request(self) -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
def _slow_handle_chat(_request):
|
||||
time.sleep(1.2)
|
||||
return _fake_response()
|
||||
|
||||
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
|
||||
response = client.post(
|
||||
"/api/v1/agent/chat-stream",
|
||||
json={
|
||||
"session_id": "sess_stream_1",
|
||||
"user_id": "user_stream_1",
|
||||
"channel": "test",
|
||||
"input_text": "今天天气如何",
|
||||
"input_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
|
||||
self.assertGreaterEqual(len(lines), 2)
|
||||
ack_event = json.loads(lines[0])
|
||||
final_event = json.loads(lines[-1])
|
||||
self.assertEqual(ack_event.get("type"), "ack")
|
||||
self.assertEqual(final_event.get("type"), "final")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
90
intelligent_cabin/archive/tests/test_config_loader.py
Normal file
90
intelligent_cabin/archive/tests/test_config_loader.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.core.bootstrap import build_planner, load_runtime_bundle
|
||||
from app.core.config import settings
|
||||
from app.services.planner import CompositeWorkflowPlanner
|
||||
from app.services.config_loader import ConfigLoader
|
||||
from app.services.dialog_rules import DialogRuleEngine
|
||||
from app.services.response_policy import ResponsePolicy
|
||||
|
||||
|
||||
class ConfigLoaderTests(unittest.TestCase):
|
||||
def test_loader_reads_domain_actions_and_responses(self) -> None:
|
||||
bundle = ConfigLoader(
|
||||
domain_path="config/domain.yml",
|
||||
action_path="config/actions.yml",
|
||||
response_path="config/responses.yml",
|
||||
form_path="config/forms.yml",
|
||||
rule_path="config/rules.yml",
|
||||
dialog_act_path="config/dialog_acts.yml",
|
||||
workflow_path="config/workflows.yml",
|
||||
legacy_intent_path="app/data/intents.json",
|
||||
).load()
|
||||
|
||||
self.assertGreaterEqual(len(bundle.intent_registry.list()), 30)
|
||||
self.assertEqual(bundle.intent_registry.get("cabin_window_open").plugin_id, "plugin.cabin.window.open")
|
||||
self.assertEqual(bundle.intent_hints.get("cabin_window_open"), "打开车窗")
|
||||
self.assertEqual(bundle.response_templates.get("task_stopped"), "好的,已停止当前任务。")
|
||||
self.assertEqual(bundle.intent_registry.get("cabin_set_ac").required_slots, ["temperature"])
|
||||
self.assertTrue(bundle.dialog_rules.is_stop_request("先不要了"))
|
||||
self.assertEqual(bundle.dialog_rules.parse_confirmation_decision("确认"), True)
|
||||
self.assertEqual(bundle.dialog_act_engine.detect("你好"), "chitchat")
|
||||
self.assertGreaterEqual(len(bundle.workflow_templates.templates), 2)
|
||||
|
||||
def test_bootstrap_runtime_bundle_is_available(self) -> None:
|
||||
bundle = load_runtime_bundle()
|
||||
|
||||
self.assertGreaterEqual(len(bundle.intent_registry.list()), 30)
|
||||
self.assertIn("fallback", bundle.response_templates)
|
||||
self.assertEqual(bundle.dialog_act_engine.detect("确认"), "affirm")
|
||||
|
||||
def test_response_policy_can_be_driven_by_config_templates(self) -> None:
|
||||
policy = ResponsePolicy(
|
||||
templates={"reject": "这个能力暂未开通。"},
|
||||
intent_hints={"cabin_window_open": "开车窗"},
|
||||
)
|
||||
|
||||
self.assertEqual(policy.reject(), "这个能力暂未开通。")
|
||||
self.assertEqual(policy.clarify(["cabin_window_open"]), "请确认一下,你是想开车窗吗?")
|
||||
|
||||
def test_response_policy_formats_multi_step_summary_naturally(self) -> None:
|
||||
policy = ResponsePolicy()
|
||||
|
||||
summary = policy.workflow_summary(["好的,已打开空调。", "已将空调调到 20 度。"])
|
||||
|
||||
self.assertEqual(summary, "好,空调已经打开了,也调到 20 度了。")
|
||||
|
||||
def test_response_policy_formats_multi_step_summary_in_vehicle_style(self) -> None:
|
||||
policy = ResponsePolicy()
|
||||
|
||||
summary = policy.workflow_summary(["好的,已打开空调。", "好的,已关闭车窗。"])
|
||||
|
||||
self.assertEqual(summary, "好,空调已经打开了,车窗也帮你关上了。")
|
||||
|
||||
def test_build_planner_prefers_local_planners_before_cloud(self) -> None:
|
||||
with patch.object(settings, "planner_backend", "dashscope"):
|
||||
planner = build_planner()
|
||||
|
||||
self.assertIsInstance(planner, CompositeWorkflowPlanner)
|
||||
self.assertIsInstance(planner._planners[0], CompositeWorkflowPlanner)
|
||||
|
||||
def test_dialog_rule_engine_supports_configured_confirmation_and_stop(self) -> None:
|
||||
rules = DialogRuleEngine(
|
||||
stop_phrases=("先不用了",),
|
||||
positive_confirmation_tokens=("好,继续",),
|
||||
negative_confirmation_tokens=("取消吧",),
|
||||
confirmation_required_intents=("foo",),
|
||||
confirmation_required_risk_levels=("high",),
|
||||
)
|
||||
|
||||
self.assertTrue(rules.is_stop_request("先不用了"))
|
||||
self.assertTrue(rules.parse_confirmation_decision("好,继续"))
|
||||
self.assertFalse(rules.parse_confirmation_decision("取消吧"))
|
||||
self.assertTrue(rules.requires_confirmation("foo", "low"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,202 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.plugins.base import PluginRegistry
|
||||
from app.plugins.mock import MockPluginExecutor
|
||||
from app.schemas.chat import ChatRequest
|
||||
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
|
||||
from app.schemas.workflow import Workflow, WorkflowStep
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.planner import HeuristicWorkflowPlanner
|
||||
from app.services.response_policy import ResponsePolicy
|
||||
from app.services.rewrite_engine import ContextRewriteEngine
|
||||
from app.services.router import RouteMatchResult
|
||||
from app.services.session_store import InMemorySessionStore
|
||||
|
||||
|
||||
class _FailIfCalledPlanner:
|
||||
def plan(self, text, intents, context=None):
|
||||
_ = (intents, context)
|
||||
raise AssertionError(f"planner should not be called for single intent request: {text}")
|
||||
|
||||
|
||||
class _ScriptedRouter:
|
||||
def __init__(self, registry: IntentRegistry) -> None:
|
||||
self._registry = registry
|
||||
self._route_map = {
|
||||
"来点music": self._route_result("cabin_play_music", ["cabin_play_music"]),
|
||||
"打开车窗和空调": self._route_result("cabin_window_open", ["cabin_window_open", "cabin_ac_on"]),
|
||||
"关闭车窗": self._route_result("cabin_window_close", ["cabin_window_close", "cabin_window_open"]),
|
||||
}
|
||||
self._slot_map = {
|
||||
("播放黄昏", "cabin_play_music"): {"song": "黄昏"},
|
||||
("来一首黄昏", "cabin_play_music"): {"song": "黄昏"},
|
||||
("来点黄昏", "cabin_play_music"): {"song": "黄昏"},
|
||||
}
|
||||
|
||||
def route(self, text: str) -> RouteMatchResult:
|
||||
if text not in self._route_map:
|
||||
raise AssertionError(f"unexpected route request: {text}")
|
||||
return self._route_map[text]
|
||||
|
||||
def extract_slots(self, text: str, intent) -> dict[str, object]:
|
||||
return dict(self._slot_map.get((text, intent.intent_id), {}))
|
||||
|
||||
def _route_result(self, selected_intent: str, candidates: list[str]) -> RouteMatchResult:
|
||||
intent = self._registry.get(selected_intent)
|
||||
stage = MatcherStageDebug(
|
||||
stage="fusion",
|
||||
accepted=True,
|
||||
selected_intent=selected_intent,
|
||||
score=1.0,
|
||||
candidates=[
|
||||
IntentCandidate(intent_id=intent_id, score=max(0.5, 1.0 - index * 0.1))
|
||||
for index, intent_id in enumerate(candidates)
|
||||
],
|
||||
)
|
||||
return RouteMatchResult(
|
||||
intent=intent,
|
||||
debug=RoutingDebug(
|
||||
selected_intent=selected_intent,
|
||||
matched_stage="fusion",
|
||||
decision="execute",
|
||||
stages=[stage],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DialogContinuationAndMultiIntentTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.registry = IntentRegistry.from_json("app/data/intents.json")
|
||||
self.plugins = PluginRegistry()
|
||||
MockPluginExecutor().register(self.plugins)
|
||||
self.service = AgentService(
|
||||
intent_registry=self.registry,
|
||||
router=_ScriptedRouter(self.registry),
|
||||
plugins=self.plugins,
|
||||
session_store=InMemorySessionStore(),
|
||||
rewrite_engine=ContextRewriteEngine(),
|
||||
response_policy=ResponsePolicy(),
|
||||
planner=HeuristicWorkflowPlanner(),
|
||||
)
|
||||
|
||||
def test_music_followup_in_chat_continues_waiting_slot(self) -> None:
|
||||
first = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_music_followup",
|
||||
user_id="user_1",
|
||||
input_text="来点music",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(first.reply_type, "ask_slot")
|
||||
self.assertEqual(first.pending_slots, ["media_query"])
|
||||
|
||||
second = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_music_followup",
|
||||
user_id="user_1",
|
||||
input_text="黄昏",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(second.reply_type, "workflow_result")
|
||||
self.assertEqual(second.intent, "cabin_play_music")
|
||||
self.assertEqual(second.filled_slots.get("song"), "黄昏")
|
||||
self.assertIn("黄昏", second.reply_text)
|
||||
|
||||
def test_parallel_compound_request_enters_planner(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_parallel_compound",
|
||||
user_id="user_1",
|
||||
input_text="打开车窗和空调",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.workflow.workflow_type, "sequence")
|
||||
step_intents = [step.intent_id for step in response.workflow.steps]
|
||||
self.assertEqual(step_intents, ["cabin_window_open", "cabin_ac_on"])
|
||||
self.assertIn("车窗", response.reply_text)
|
||||
self.assertIn("空调", response.reply_text)
|
||||
|
||||
def test_single_cabin_intent_does_not_enter_planner_from_top2_domain_candidates(self) -> None:
|
||||
service = AgentService(
|
||||
intent_registry=self.registry,
|
||||
router=_ScriptedRouter(self.registry),
|
||||
plugins=self.plugins,
|
||||
session_store=InMemorySessionStore(),
|
||||
rewrite_engine=ContextRewriteEngine(),
|
||||
response_policy=ResponsePolicy(),
|
||||
planner=_FailIfCalledPlanner(),
|
||||
)
|
||||
|
||||
response = service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_single_cabin_no_planner",
|
||||
user_id="user_1",
|
||||
input_text="关闭车窗",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_window_close")
|
||||
self.assertEqual(response.routing_debug.decision, "execute")
|
||||
self.assertFalse(any(stage.stage == "planner" for stage in response.routing_debug.stages))
|
||||
|
||||
def test_waiting_confirmation_can_continue_via_chat(self) -> None:
|
||||
session = self.service.session_store.get_or_create("sess_confirm_chat", "user_1")
|
||||
session.current_intent = "cs_cancel_order"
|
||||
session.status = "waiting_confirmation"
|
||||
session.pending_slots = ["confirmation"]
|
||||
session.slots = {"order_id": "A123456"}
|
||||
session.workflow = Workflow(
|
||||
workflow_id="wf_confirm_chat",
|
||||
workflow_type="conditional",
|
||||
domain="customer_service",
|
||||
intent_id="cs_cancel_order",
|
||||
status="waiting_confirmation",
|
||||
risk_level="high",
|
||||
slots={"order_id": "A123456"},
|
||||
steps=[
|
||||
WorkflowStep(
|
||||
step=1,
|
||||
step_id="step_cancel",
|
||||
intent_id="cs_cancel_order",
|
||||
plugin_id="plugin.order.cancel",
|
||||
action="cancel_order",
|
||||
status="waiting_confirmation",
|
||||
slots={"order_id": "A123456"},
|
||||
requires_confirmation=True,
|
||||
)
|
||||
],
|
||||
meta={
|
||||
"pending_confirmation": {
|
||||
"step_id": "step_cancel",
|
||||
"intent_id": "cs_cancel_order",
|
||||
"detail": "确认取消订单 A123456",
|
||||
},
|
||||
"step_results": {},
|
||||
"confirmed_steps": [],
|
||||
},
|
||||
).model_dump()
|
||||
self.service.session_store.save(session)
|
||||
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_confirm_chat",
|
||||
user_id="user_1",
|
||||
input_text="确认",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.status, "completed")
|
||||
self.assertIn("A123456", response.reply_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
149
intelligent_cabin/archive/tests/test_intent_coverage_and_stop.py
Normal file
149
intelligent_cabin/archive/tests/test_intent_coverage_and_stop.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.plugins.base import PluginRegistry
|
||||
from app.plugins.mock import MockPluginExecutor
|
||||
from app.services.classifier import MockIntentClassifier
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.response_policy import ResponsePolicy
|
||||
from app.services.rewrite_engine import ContextRewriteEngine
|
||||
from app.services.router import HeuristicSlotExtractor, IntentRouter, build_matcher_pipeline
|
||||
from app.services.session_store import InMemorySessionStore
|
||||
from app.schemas.chat import ChatRequest, FillSlotsRequest
|
||||
|
||||
|
||||
class _BertLikeMockClassifier(MockIntentClassifier):
|
||||
def predict(self, text, intents):
|
||||
result = super().predict(text, intents)
|
||||
result.model_name = "bert-local"
|
||||
result.backend_name = "bert-local"
|
||||
return result
|
||||
|
||||
|
||||
def _build_service() -> AgentService:
|
||||
registry = IntentRegistry.from_json("app/data/intents.json")
|
||||
plugins = PluginRegistry()
|
||||
MockPluginExecutor().register(plugins)
|
||||
router = IntentRouter(
|
||||
matcher=build_matcher_pipeline(
|
||||
registry,
|
||||
["classifier"],
|
||||
classifier=_BertLikeMockClassifier(threshold=0.0, top_k=3),
|
||||
),
|
||||
slot_extractor=HeuristicSlotExtractor(),
|
||||
)
|
||||
return AgentService(
|
||||
intent_registry=registry,
|
||||
router=router,
|
||||
plugins=plugins,
|
||||
session_store=InMemorySessionStore(),
|
||||
rewrite_engine=ContextRewriteEngine(),
|
||||
response_policy=ResponsePolicy(),
|
||||
planner=None,
|
||||
)
|
||||
|
||||
|
||||
class IntentCoverageAndStopTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.service = _build_service()
|
||||
self.registry = IntentRegistry.from_json("app/data/intents.json")
|
||||
|
||||
def test_intent_catalog_has_at_least_30_items(self) -> None:
|
||||
self.assertGreaterEqual(len(self.registry.list()), 30)
|
||||
|
||||
def test_close_ac_routes_to_power_off(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_close_ac",
|
||||
user_id="user_1",
|
||||
input_text="关闭空调",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_ac_off")
|
||||
self.assertIn("已关闭空调", response.reply_text)
|
||||
|
||||
def test_open_window_is_covered(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_window_open",
|
||||
user_id="user_1",
|
||||
input_text="打开车窗",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_window_open")
|
||||
self.assertIn("已打开车窗", response.reply_text)
|
||||
|
||||
def test_stop_current_task_while_waiting_for_slot(self) -> None:
|
||||
first = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_stop_task",
|
||||
user_id="user_1",
|
||||
input_text="空调调到",
|
||||
)
|
||||
)
|
||||
self.assertEqual(first.reply_type, "ask_slot")
|
||||
self.assertEqual(first.pending_slots, ["temperature"])
|
||||
|
||||
stopped = self.service.handle_fill_slots(
|
||||
FillSlotsRequest(
|
||||
session_id="sess_stop_task",
|
||||
user_id="user_1",
|
||||
input_text="不用了",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(stopped.reply_type, "text")
|
||||
self.assertEqual(stopped.status, "stopped")
|
||||
self.assertEqual(stopped.pending_slots, [])
|
||||
self.assertIn("已停止当前任务", stopped.reply_text)
|
||||
|
||||
def test_relative_ac_adjustment_uses_two_degree_step(self) -> None:
|
||||
session = self.service.session_store.get_or_create("sess_ac_lower", "user_1")
|
||||
session.current_intent = "cabin_set_ac"
|
||||
session.context_memory["last_temperature"] = 24
|
||||
self.service.session_store.save(session)
|
||||
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_ac_lower",
|
||||
user_id="user_1",
|
||||
input_text="调低一点",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_set_ac")
|
||||
self.assertEqual(response.filled_slots.get("temperature"), 22)
|
||||
self.assertIn("22", response.reply_text)
|
||||
|
||||
def test_relative_ac_adjustment_without_history_uses_default_baseline(self) -> None:
|
||||
session = self.service.session_store.get_or_create("sess_ac_lower_default", "user_1")
|
||||
session.current_intent = "cabin_ac_on"
|
||||
self.service.session_store.save(session)
|
||||
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_ac_lower_default",
|
||||
user_id="user_1",
|
||||
input_text="调低一点",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_set_ac")
|
||||
self.assertEqual(response.filled_slots.get("temperature"), 22)
|
||||
self.assertIn("22", response.reply_text)
|
||||
|
||||
def test_temperature_is_clamped_before_execution(self) -> None:
|
||||
self.assertEqual(self.service._normalize_temperature_value(-1), 16)
|
||||
self.assertEqual(self.service._normalize_temperature_value(40), 30)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.classifier import JointBertIntentClassifier
|
||||
from app.services.planner import HeuristicWorkflowPlanner
|
||||
from app.services.router import JointBertSlotExtractor
|
||||
|
||||
|
||||
class FakeJointNLU:
|
||||
def __init__(self) -> None:
|
||||
self._predictions = {
|
||||
"把空调调到22度": {
|
||||
"intent_id": "cabin_set_ac",
|
||||
"intent_score": 0.93,
|
||||
"candidates": [("cabin_set_ac", 0.93), ("cabin_ac_on", 0.04)],
|
||||
"slots": {"temperature": 22},
|
||||
},
|
||||
"导航去公司,然后把空调调到22度": {
|
||||
"intent_id": "cabin_nav_to",
|
||||
"intent_score": 0.88,
|
||||
"candidates": [("cabin_nav_to", 0.88), ("cabin_set_ac", 0.72)],
|
||||
"slots": {"destination": "公司"},
|
||||
},
|
||||
}
|
||||
self._slot_predictions = {
|
||||
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
|
||||
("导航去公司", "cabin_nav_to"): {"destination": "公司"},
|
||||
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
|
||||
}
|
||||
|
||||
def warmup(self, sample_text: str = "打开车窗") -> bool:
|
||||
_ = sample_text
|
||||
return True
|
||||
|
||||
def predict(self, text: str, intents: list[IntentDefinition]):
|
||||
from app.services.joint_nlu import JointCandidate, JointNluResult
|
||||
|
||||
raw = self._predictions[text]
|
||||
candidates = [JointCandidate(intent_id=intent_id, score=score) for intent_id, score in raw["candidates"]]
|
||||
return JointNluResult(
|
||||
intent_id=raw["intent_id"],
|
||||
intent_score=raw["intent_score"],
|
||||
candidates=candidates,
|
||||
slots=dict(raw["slots"]),
|
||||
)
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition):
|
||||
return dict(self._slot_predictions.get((text, intent.intent_id), {}))
|
||||
|
||||
def extract_slots_by_intent_id(self, text: str, intent_id: str, required_slots=None):
|
||||
_ = required_slots
|
||||
return dict(self._slot_predictions.get((text, intent_id), {}))
|
||||
|
||||
|
||||
class JointNLUIntegrationTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.intents = [
|
||||
IntentDefinition(intent_id="cabin_set_ac", plugin_id="x", domain="cabin", required_slots=["temperature"]),
|
||||
IntentDefinition(intent_id="cabin_nav_to", plugin_id="x", domain="cabin", required_slots=["destination"]),
|
||||
IntentDefinition(intent_id="cabin_ac_on", plugin_id="x", domain="cabin"),
|
||||
]
|
||||
self.fake_nlu = FakeJointNLU()
|
||||
|
||||
def test_joint_classifier_uses_joint_nlu_intent_head(self) -> None:
|
||||
classifier = JointBertIntentClassifier(self.fake_nlu, threshold=0.3, top_k=2)
|
||||
|
||||
result = classifier.predict("把空调调到22度", self.intents)
|
||||
|
||||
self.assertIsNotNone(result.intent)
|
||||
self.assertEqual(result.intent.intent_id, "cabin_set_ac")
|
||||
self.assertEqual(result.raw_candidates[0]["intent_id"], "cabin_set_ac")
|
||||
|
||||
def test_joint_slot_extractor_uses_joint_nlu_slots(self) -> None:
|
||||
extractor = JointBertSlotExtractor(self.fake_nlu)
|
||||
|
||||
slots = extractor.extract("把空调调到22度", self.intents[0])
|
||||
|
||||
self.assertEqual(slots, {"temperature": 22})
|
||||
|
||||
def test_planner_prefers_joint_nlu_slots_for_each_clause(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner(joint_nlu=self.fake_nlu)
|
||||
|
||||
result = planner.plan("导航去公司,然后把空调调到22度", self.intents)
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.steps[0].slots, {"destination": "公司"})
|
||||
self.assertEqual(result.steps[1].slots, {"temperature": 22})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
144
intelligent_cabin/archive/tests/test_multi_intent_detector.py
Normal file
144
intelligent_cabin/archive/tests/test_multi_intent_detector.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.multi_intent_detector import BertMultiIntentDetector, JointBertMultiIntentDetector
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def __call__(self, text, truncation=True, padding=False, return_tensors="pt"):
|
||||
_ = (text, truncation, padding, return_tensors)
|
||||
return {
|
||||
"input_ids": torch.tensor([[101, 102]], dtype=torch.long),
|
||||
"attention_mask": torch.tensor([[1, 1]], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
class FakeModel:
|
||||
def __init__(self, logits: list[float], id2label: dict[int, str]) -> None:
|
||||
self.config = type("Config", (), {"id2label": id2label})()
|
||||
self._logits = torch.tensor([logits], dtype=torch.float32)
|
||||
|
||||
def eval(self) -> None:
|
||||
return None
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
_ = kwargs
|
||||
return type("Output", (), {"logits": self._logits})()
|
||||
|
||||
|
||||
class RuntimeBackedDetector(BertMultiIntentDetector):
|
||||
def __init__(
|
||||
self,
|
||||
logits: list[float],
|
||||
id2label: dict[int, str],
|
||||
threshold: float = 0.45,
|
||||
top_k: int = 8,
|
||||
max_labels: int = 4,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model_path="unused",
|
||||
threshold=threshold,
|
||||
top_k=top_k,
|
||||
max_labels=max_labels,
|
||||
)
|
||||
self._runtime = (torch, FakeTokenizer(), FakeModel(logits, id2label))
|
||||
|
||||
def _get_runtime(self):
|
||||
return self._runtime
|
||||
|
||||
|
||||
class MultiIntentDetectorTests(unittest.TestCase):
|
||||
def test_detector_filters_blocked_and_unknown_labels(self) -> None:
|
||||
detector = RuntimeBackedDetector(
|
||||
logits=[2.4, 2.0, 3.2, 2.6],
|
||||
id2label={
|
||||
0: "cabin_window_open",
|
||||
1: "cabin_play_music",
|
||||
2: "__social__",
|
||||
3: "unknown_intent",
|
||||
},
|
||||
threshold=0.8,
|
||||
top_k=4,
|
||||
)
|
||||
intents = [
|
||||
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
||||
]
|
||||
|
||||
result = detector.detect("打开车窗并播放音乐", intents)
|
||||
|
||||
self.assertTrue(result.detected)
|
||||
self.assertEqual(result.backend_name, "bert-multi-label")
|
||||
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
||||
|
||||
def test_detector_respects_threshold_and_max_labels(self) -> None:
|
||||
detector = RuntimeBackedDetector(
|
||||
logits=[2.8, 2.5, 2.2],
|
||||
id2label={
|
||||
0: "cabin_window_open",
|
||||
1: "cabin_play_music",
|
||||
2: "cabin_nav_to",
|
||||
},
|
||||
threshold=0.89,
|
||||
top_k=3,
|
||||
max_labels=2,
|
||||
)
|
||||
intents = [
|
||||
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_nav_to", plugin_id="plugin.nav", domain="cabin"),
|
||||
]
|
||||
|
||||
result = detector.detect("开窗放歌去公司", intents)
|
||||
|
||||
self.assertTrue(result.detected)
|
||||
self.assertEqual(len(result.candidates), 2)
|
||||
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
||||
|
||||
def test_joint_bert_detector_wraps_shared_runtime(self) -> None:
|
||||
intents = [
|
||||
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
||||
]
|
||||
|
||||
class FakeJointNlu:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, object]] = []
|
||||
|
||||
def predict_multi_intents(self, text, known_intents, threshold=0.45, max_labels=4, top_k=8):
|
||||
self.calls.append(
|
||||
{
|
||||
"text": text,
|
||||
"threshold": threshold,
|
||||
"max_labels": max_labels,
|
||||
"top_k": top_k,
|
||||
"known_count": len(known_intents),
|
||||
}
|
||||
)
|
||||
return [
|
||||
type("Candidate", (), {"intent_id": "cabin_window_open", "score": 0.93})(),
|
||||
type("Candidate", (), {"intent_id": "cabin_play_music", "score": 0.88})(),
|
||||
]
|
||||
|
||||
def warmup(self, sample_text="") -> bool:
|
||||
_ = sample_text
|
||||
return True
|
||||
|
||||
fake_nlu = FakeJointNlu()
|
||||
detector = JointBertMultiIntentDetector(fake_nlu, threshold=0.5, top_k=6, max_labels=3)
|
||||
|
||||
result = detector.detect("打开车窗并播放音乐", intents)
|
||||
|
||||
self.assertTrue(result.detected)
|
||||
self.assertEqual(result.backend_name, "joint-bert-multi-label")
|
||||
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
||||
self.assertEqual(fake_nlu.calls[0]["threshold"], 0.5)
|
||||
self.assertEqual(fake_nlu.calls[0]["top_k"], 6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
195
intelligent_cabin/archive/tests/test_router_decisions.py
Normal file
195
intelligent_cabin/archive/tests/test_router_decisions.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.schemas.debug import IntentCandidate, MatcherStageDebug
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.router import IntentMatchResult, MultiStageIntentMatcher
|
||||
|
||||
|
||||
class _FakeMatcher:
|
||||
def __init__(self, stage_debug: MatcherStageDebug) -> None:
|
||||
self._stage_debug = stage_debug
|
||||
|
||||
def match(self, text: str) -> IntentMatchResult:
|
||||
_ = text
|
||||
return IntentMatchResult(intent=None, stage_debug=self._stage_debug)
|
||||
|
||||
|
||||
def _intent(intent_id: str) -> IntentDefinition:
|
||||
return IntentDefinition(
|
||||
intent_id=intent_id,
|
||||
plugin_id=f"mock.{intent_id}",
|
||||
domain="test",
|
||||
keywords=[],
|
||||
examples=[],
|
||||
)
|
||||
|
||||
|
||||
class RouterDecisionTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.registry = IntentRegistry([_intent("alpha"), _intent("beta"), _intent("gamma")])
|
||||
|
||||
def test_execute_when_bert_classifier_is_clear(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=True,
|
||||
selected_intent="alpha",
|
||||
score=0.92,
|
||||
reason="classifier selected best candidate",
|
||||
backend="joint-bert-local",
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.92, reason="classifier", model_name="joint-bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.21, reason="classifier", model_name="joint-bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("alpha")
|
||||
|
||||
self.assertEqual(result.debug.decision, "execute")
|
||||
self.assertEqual(result.intent.intent_id if result.intent else None, "alpha")
|
||||
|
||||
def test_clarify_when_bert_top_candidates_are_too_close(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=True,
|
||||
selected_intent="alpha",
|
||||
score=0.22,
|
||||
reason="classifier selected best candidate",
|
||||
backend="bert-local",
|
||||
metadata={"threshold": 0.2},
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.31, reason="classifier", model_name="bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.28, reason="classifier", model_name="bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
route_to_cloud_threshold=0.2,
|
||||
)
|
||||
|
||||
result = matcher.match("ambiguous request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "clarify")
|
||||
self.assertIsNone(result.intent)
|
||||
self.assertEqual(result.debug.confidence_grade, "medium")
|
||||
|
||||
def test_route_to_cloud_when_bert_signal_is_weak_but_known(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=False,
|
||||
selected_intent="alpha",
|
||||
score=0.29,
|
||||
reason="classifier below execute threshold",
|
||||
backend="joint-bert-local",
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.29, reason="classifier", model_name="joint-bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.14, reason="classifier", model_name="joint-bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("weak symbolic request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "route_to_cloud")
|
||||
self.assertIsNone(result.intent)
|
||||
|
||||
def test_reject_when_no_branch_has_usable_signal(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=False,
|
||||
score=0.12,
|
||||
reason="classifier below threshold",
|
||||
backend="bert-local",
|
||||
metadata={"threshold": 0.2},
|
||||
candidates=[],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("unknown request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "reject")
|
||||
self.assertTrue(result.debug.unknown_detected)
|
||||
self.assertIsNone(result.intent)
|
||||
|
||||
def test_route_to_cloud_for_low_confidence_classifier_only_bert_signal(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=True,
|
||||
selected_intent="alpha",
|
||||
score=0.31,
|
||||
reason="classifier selected best candidate",
|
||||
backend="bert-local",
|
||||
metadata={"threshold": 0.0, "top_margin": 0.04},
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.31, reason="classifier", model_name="bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.27, reason="classifier", model_name="bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("bert only weak request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "route_to_cloud")
|
||||
self.assertIsNone(result.intent)
|
||||
|
||||
def test_execute_for_high_confidence_classifier_only_bert_signal(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=True,
|
||||
selected_intent="alpha",
|
||||
score=0.92,
|
||||
reason="classifier selected best candidate",
|
||||
backend="bert-local",
|
||||
metadata={"threshold": 0.0, "top_margin": 0.63},
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.92, reason="classifier", model_name="bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.29, reason="classifier", model_name="bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("bert only strong request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "execute")
|
||||
self.assertEqual(result.intent.intent_id if result.intent else None, "alpha")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
180
intelligent_cabin/archive/tests/test_social_chat.py
Normal file
180
intelligent_cabin/archive/tests/test_social_chat.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.plugins.base import PluginRegistry
|
||||
from app.schemas.chat import ChatRequest, FillSlotsRequest
|
||||
from app.schemas.debug import RoutingDebug
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.schemas.workflow import Workflow, WorkflowStep
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.session_store import InMemorySessionStore
|
||||
from app.services.social import SocialReplyResult, SocialRouter
|
||||
|
||||
|
||||
class _FailingRouter:
|
||||
def route(self, text: str): # pragma: no cover - should not be called in these tests
|
||||
raise AssertionError(f"router should not be called for social input: {text}")
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, object]:
|
||||
_ = (text, intent)
|
||||
return {}
|
||||
|
||||
|
||||
class _FakeSocialResponder:
|
||||
def reply(self, text: str, session) -> SocialReplyResult:
|
||||
_ = (text, session)
|
||||
normalized = text.strip()
|
||||
if "你好" in normalized:
|
||||
text = "你好呀,我在,想聊什么都可以。"
|
||||
elif "名字" in normalized or "你是谁" in normalized:
|
||||
text = "我是一名智能座舱助手,你可以直接叫我座舱助手。"
|
||||
elif "天气" in normalized:
|
||||
text = "是啊,今天确实挺舒服的。"
|
||||
else:
|
||||
text = "我在,咱们继续聊。"
|
||||
return SocialReplyResult(
|
||||
text=text,
|
||||
backend="fake-cloud",
|
||||
model_name="fake-social",
|
||||
)
|
||||
|
||||
|
||||
def _intent(intent_id: str, plugin_id: str) -> IntentDefinition:
|
||||
return IntentDefinition(
|
||||
intent_id=intent_id,
|
||||
plugin_id=plugin_id,
|
||||
domain="service" if intent_id.startswith("cs_") else "cabin",
|
||||
risk_level="high" if intent_id == "cs_cancel_order" else "low",
|
||||
required_slots=["order_id"] if intent_id == "cs_cancel_order" else [],
|
||||
ask_templates={"order_id": "请告诉我订单号。"} if intent_id == "cs_cancel_order" else {},
|
||||
keywords=[],
|
||||
examples=[],
|
||||
)
|
||||
|
||||
|
||||
class SocialChatTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.session_store = InMemorySessionStore()
|
||||
self.plugins = PluginRegistry()
|
||||
self.plugins.register(
|
||||
"mock.cancel_order",
|
||||
lambda slots: {"success": True, "message": f"已取消订单 {slots.get('order_id', '')}。"},
|
||||
)
|
||||
self.service = AgentService(
|
||||
intent_registry=IntentRegistry([_intent("cs_cancel_order", "mock.cancel_order")]),
|
||||
router=_FailingRouter(),
|
||||
plugins=self.plugins,
|
||||
session_store=self.session_store,
|
||||
social_router=SocialRouter(),
|
||||
social_responder=_FakeSocialResponder(),
|
||||
)
|
||||
|
||||
def test_greeting_social_reply_uses_social_responder(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_social_hi",
|
||||
user_id="user_1",
|
||||
input_text="你好",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.decision, "open_social")
|
||||
self.assertEqual(response.status, "social")
|
||||
self.assertIn("你好呀", response.reply_text)
|
||||
|
||||
def test_capability_social_question_does_not_fall_into_business_intent(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_social_name",
|
||||
user_id="user_1",
|
||||
input_text="你叫什么名字",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.decision, "open_social")
|
||||
self.assertEqual(response.status, "social")
|
||||
self.assertNotEqual(response.reply_type, "ask_slot")
|
||||
self.assertIn("智能座舱助手", response.reply_text)
|
||||
|
||||
def test_open_social_reply_uses_social_responder(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_social_open",
|
||||
user_id="user_1",
|
||||
input_text="今天天气真不错啊",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.decision, "open_social")
|
||||
self.assertEqual(response.status, "social")
|
||||
self.assertIn("挺舒服", response.reply_text)
|
||||
|
||||
def test_social_turn_does_not_break_waiting_confirmation(self) -> None:
|
||||
session = self.session_store.get_or_create("sess_confirm", "user_1")
|
||||
session.current_intent = "cs_cancel_order"
|
||||
session.status = "waiting_confirmation"
|
||||
session.pending_slots = ["confirmation"]
|
||||
session.slots = {"order_id": "A123456"}
|
||||
session.routing_debug = RoutingDebug(selected_intent="cs_cancel_order", decision="execute").model_dump()
|
||||
session.workflow = Workflow(
|
||||
workflow_id="wf_sess_confirm",
|
||||
workflow_type="conditional",
|
||||
domain="service",
|
||||
intent_id="cs_cancel_order",
|
||||
status="waiting_confirmation",
|
||||
risk_level="high",
|
||||
slots={"order_id": "A123456"},
|
||||
steps=[
|
||||
WorkflowStep(
|
||||
step=1,
|
||||
step_id="step_confirm",
|
||||
intent_id="cs_cancel_order",
|
||||
plugin_id="mock.cancel_order",
|
||||
action="cancel_order",
|
||||
status="waiting_confirmation",
|
||||
slots={"order_id": "A123456"},
|
||||
requires_confirmation=True,
|
||||
)
|
||||
],
|
||||
meta={
|
||||
"pending_confirmation": {
|
||||
"step_id": "step_confirm",
|
||||
"intent_id": "cs_cancel_order",
|
||||
"detail": "确认取消订单 A123456",
|
||||
},
|
||||
"confirmed_steps": [],
|
||||
"step_results": {},
|
||||
},
|
||||
).model_dump()
|
||||
self.session_store.save(session)
|
||||
|
||||
social_response = self.service.handle_fill_slots(
|
||||
FillSlotsRequest(
|
||||
session_id="sess_confirm",
|
||||
user_id="user_1",
|
||||
input_text="今天天气真不错啊",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(social_response.decision, "open_social")
|
||||
self.assertEqual(social_response.status, "waiting_confirmation")
|
||||
self.assertEqual(social_response.pending_slots, ["confirmation"])
|
||||
self.assertIn("回复“确认”或“取消”即可", social_response.reply_text)
|
||||
|
||||
confirm_response = self.service.handle_fill_slots(
|
||||
FillSlotsRequest(
|
||||
session_id="sess_confirm",
|
||||
user_id="user_1",
|
||||
input_text="确认",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(confirm_response.reply_type, "workflow_result")
|
||||
self.assertEqual(confirm_response.status, "completed")
|
||||
self.assertIn("已取消订单", confirm_response.reply_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
184
intelligent_cabin/archive/tests/test_workflow_templates.py
Normal file
184
intelligent_cabin/archive/tests/test_workflow_templates.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from app.schemas.configuration import WorkflowTemplatesConfig
|
||||
from app.services.classifier import ClassificationResult
|
||||
from app.services.multi_intent_detector import MultiIntentCandidate, MultiIntentDetectionResult
|
||||
from app.services.planner import CompositeWorkflowPlanner, HeuristicWorkflowPlanner, TemplateWorkflowPlanner
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
|
||||
|
||||
class FakeClauseClassifier:
|
||||
def __init__(self, predictions: dict[str, list[dict[str, float | str]]]) -> None:
|
||||
self._predictions = predictions
|
||||
|
||||
def predict(self, text, intents):
|
||||
_ = intents
|
||||
return ClassificationResult(
|
||||
intent=None,
|
||||
score=0.0,
|
||||
model_name="fake-bert-clause",
|
||||
backend_name="fake-bert-clause",
|
||||
raw_candidates=self._predictions.get(text, []),
|
||||
)
|
||||
|
||||
|
||||
class FakeMultiIntentDetector:
|
||||
def __init__(self, predictions: dict[str, list[tuple[str, float]]]) -> None:
|
||||
self._predictions = predictions
|
||||
|
||||
def detect(self, text, intents):
|
||||
_ = intents
|
||||
candidates = [
|
||||
MultiIntentCandidate(intent_id=intent_id, score=score, label=intent_id)
|
||||
for intent_id, score in self._predictions.get(text, [])
|
||||
]
|
||||
return MultiIntentDetectionResult(
|
||||
detected=len(candidates) >= 2,
|
||||
candidates=candidates,
|
||||
reason="fake detector",
|
||||
backend_name="fake-multi",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowTemplateTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.registry = IntentRegistry.from_json("app/data/intents.json")
|
||||
self.templates = WorkflowTemplatesConfig.model_validate_json(
|
||||
Path("config/workflows.yml").read_text(encoding="utf-8")
|
||||
)
|
||||
|
||||
def test_template_planner_matches_sequence_template(self) -> None:
|
||||
planner = TemplateWorkflowPlanner(self.templates)
|
||||
|
||||
result = planner.plan("打开车窗,然后把空调调到20度", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.backend, "local-template")
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_set_ac"])
|
||||
self.assertEqual(result.steps[1].slots.get("temperature"), 20)
|
||||
|
||||
def test_template_planner_matches_conditional_template(self) -> None:
|
||||
planner = TemplateWorkflowPlanner(self.templates)
|
||||
|
||||
result = planner.plan("查一下订单A123456,如果还没发货就取消", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "conditional")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cs_query_order", "cs_cancel_order"])
|
||||
self.assertEqual(result.steps[1].depends_on, [1])
|
||||
self.assertTrue(result.steps[1].requires_confirmation)
|
||||
|
||||
def test_composite_planner_falls_back_to_heuristic_when_template_misses(self) -> None:
|
||||
planner = CompositeWorkflowPlanner([TemplateWorkflowPlanner(self.templates), HeuristicWorkflowPlanner()])
|
||||
|
||||
result = planner.plan("打开车窗,并且播放轻音乐", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertIn(result.backend, {"local-template", "local-heuristic"})
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
|
||||
def test_heuristic_planner_parses_ac_then_window_close_sequence(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("打开空调,再把窗户降下来", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.backend, "local-heuristic")
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_close"])
|
||||
|
||||
def test_planner_metadata_contains_clause_analysis(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("打开空调,然后打开车窗", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertTrue(result.metadata.get("multi_intent_detected"))
|
||||
clause_analysis = result.metadata.get("clause_analysis", [])
|
||||
self.assertEqual(len(clause_analysis), 2)
|
||||
self.assertEqual(clause_analysis[0].get("selected_intent_id"), "cabin_ac_on")
|
||||
self.assertEqual(clause_analysis[1].get("selected_intent_id"), "cabin_window_open")
|
||||
|
||||
def test_heuristic_planner_supports_shared_action_parallel_objects(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("打开空调和车窗", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_open"])
|
||||
|
||||
def test_heuristic_planner_supports_parallel_objects_with_suffix_action(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("把车窗和天窗打开", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_sunroof_open"])
|
||||
|
||||
def test_heuristic_planner_supports_parallel_clause_with_bing_connector(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("打开车窗并播放轻音乐", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
|
||||
|
||||
def test_heuristic_planner_can_use_clause_classifier_to_rescue_semantic_clause(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner(
|
||||
clause_classifier=FakeClauseClassifier(
|
||||
{
|
||||
"车里太闷了": [
|
||||
{"label": "cabin_window_open", "intent_id": "cabin_window_open", "score": 0.83},
|
||||
],
|
||||
"来点轻音乐": [
|
||||
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = planner.plan("车里太闷了,然后来点轻音乐", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
|
||||
clause_analysis = result.metadata.get("clause_analysis", [])
|
||||
self.assertGreater(clause_analysis[0].get("candidates", [])[0].get("model_score", 0.0), 0.8)
|
||||
|
||||
def test_heuristic_planner_can_use_multi_intent_detector_prior(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner(
|
||||
clause_classifier=FakeClauseClassifier(
|
||||
{
|
||||
"来点轻音乐": [
|
||||
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
|
||||
],
|
||||
}
|
||||
),
|
||||
multi_intent_detector=FakeMultiIntentDetector(
|
||||
{
|
||||
"顺便开下车窗,再来点轻音乐": [
|
||||
("cabin_window_open", 0.87),
|
||||
("cabin_play_music", 0.82),
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
result = planner.plan("顺便开下车窗,再来点轻音乐", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
|
||||
detector_meta = result.metadata.get("multi_intent_detector") or {}
|
||||
self.assertTrue(detector_meta.get("detected"))
|
||||
self.assertEqual(len(detector_meta.get("candidates", [])), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user