Files
ai-device/intelligent_cabin/archive/scripts/train_joint_bert_nlu.py
2026-06-11 16:28:00 +08:00

501 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import random
import re
import sys
from dataclasses import dataclass
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import yaml
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.services.joint_nlu import JointBertForNLU
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
MULTI_TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
SEED_PATH = PROJECT_ROOT / "app/data/joint_nlu_seed.jsonl"
EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval.jsonl"
MULTI_EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_multilabel_eval.jsonl"
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
OUTPUT_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
MAX_LENGTH = 64
BATCH_SIZE = 8
EPOCHS = 8
LEARNING_RATE = 2e-5
SEED = 42
IGNORE_INDEX = -100
GENRE_KEYWORDS = ("轻音乐", "摇滚", "古典", "民谣", "爵士", "流行", "儿歌")
DEFAULT_INTENT_THRESHOLD = 0.3
MULTI_INTENT_REPEAT = 6
THRESHOLD_CANDIDATES = [0.1, 0.12, 0.15, 0.18, 0.2, 0.22, 0.25, 0.28, 0.3, 0.33, 0.35, 0.38, 0.4, 0.45]
@dataclass
class JointSample:
text: str
intent_ids: list[str]
slots: list[dict[str, object]]
class JointDataset(Dataset):
def __init__(
self,
samples: list[JointSample],
tokenizer,
intent_to_index: dict[str, int],
slot_to_index: dict[str, int],
) -> None:
self._samples = samples
self._tokenizer = tokenizer
self._intent_to_index = intent_to_index
self._slot_to_index = slot_to_index
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
sample = self._samples[index]
encoded = self._tokenizer(
sample.text,
truncation=True,
max_length=MAX_LENGTH,
padding="max_length",
return_offsets_mapping=True,
return_tensors="pt",
)
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
slot_labels = [IGNORE_INDEX] * len(offset_mapping)
char_labels = ["O"] * len(sample.text)
for slot in sample.slots:
start = int(slot["start"])
end = int(slot["end"])
slot_name = str(slot["slot_name"])
if start < 0 or end > len(sample.text) or start >= end:
continue
char_labels[start] = f"B-{slot_name}"
for pos in range(start + 1, end):
char_labels[pos] = f"I-{slot_name}"
for token_index, (start, end) in enumerate(offset_mapping):
if end <= start:
continue
label = char_labels[start]
slot_labels[token_index] = self._slot_to_index.get(label, self._slot_to_index["O"])
intent_vector = torch.zeros(len(self._intent_to_index), dtype=torch.float32)
for intent_id in sample.intent_ids:
if intent_id in self._intent_to_index:
intent_vector[self._intent_to_index[intent_id]] = 1.0
return {
"input_ids": encoded["input_ids"][0],
"attention_mask": encoded["attention_mask"][0],
"intent_labels": intent_vector,
"slot_labels": torch.tensor(slot_labels, dtype=torch.long),
}
def set_seed() -> None:
random.seed(SEED)
torch.manual_seed(SEED)
def load_jsonl(path: Path) -> list[dict[str, object]]:
rows: list[dict[str, object]] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def find_order_id_span(text: str) -> tuple[str, int, int] | None:
match = re.search(r"[A-Za-z]\d{5,}", text)
if not match:
return None
return match.group(0), match.start(), match.end()
def find_temperature_span(text: str) -> tuple[str, int, int] | None:
match = re.search(r"(\d{2}\s*度)", text)
if not match:
return None
return match.group(1), match.start(), match.end()
def find_destination_span(text: str) -> tuple[str, int, int] | None:
for pattern in (
r"导航去(?P<destination>.+)",
r"导航到(?P<destination>.+)",
r"带我去(?P<destination>.+)",
r"送我去(?P<destination>.+)",
r"去(?P<destination>.+)",
):
match = re.search(pattern, text)
if not match:
continue
destination = re.split(r"(?:然后|并且|同时|再|,|||;)", match.group("destination"), maxsplit=1)[0].strip(" ,。")
if not destination:
continue
start = text.find(destination)
if start >= 0:
return destination, start, start + len(destination)
return None
def find_music_span(text: str) -> tuple[str, str, int, int] | None:
for genre in GENRE_KEYWORDS:
start = text.find(genre)
if start >= 0:
return "genre", genre, start, start + len(genre)
for trigger in ("播放", "来点", "放点", "", "来首", "来一首", "放一首"):
if trigger not in text:
continue
target = text.split(trigger, maxsplit=1)[-1]
target = re.split(r"(?:然后|并且|同时|再|,|||;)", target, maxsplit=1)[0].strip(" 的一首首个歌曲音乐吧呀啊,。")
if not target or target in {"", "音乐"}:
continue
for genre in GENRE_KEYWORDS:
if genre in target:
start = text.find(genre)
return "genre", genre, start, start + len(genre)
start = text.find(target)
if start >= 0:
return "song", target, start, start + len(target)
return None
def annotate_slots(text: str, intent_id: str) -> list[dict[str, object]]:
slots: list[dict[str, object]] = []
if intent_id in {"cs_query_order", "cs_query_logistics", "cs_cancel_order"}:
matched = find_order_id_span(text)
if matched is not None:
value, start, end = matched
slots.append({"slot_name": "order_id", "value": value, "start": start, "end": end})
elif intent_id == "cabin_set_ac":
matched = find_temperature_span(text)
if matched is not None:
value, start, end = matched
slots.append({"slot_name": "temperature", "value": value, "start": start, "end": end})
elif intent_id == "cabin_nav_to":
matched = find_destination_span(text)
if matched is not None:
value, start, end = matched
slots.append({"slot_name": "destination", "value": value, "start": start, "end": end})
elif intent_id == "cabin_play_music":
matched = find_music_span(text)
if matched is not None:
slot_name, value, start, end = matched
slots.append({"slot_name": slot_name, "value": value, "start": start, "end": end})
return slots
def annotate_slots_for_intents(text: str, intent_ids: list[str]) -> list[dict[str, object]]:
merged: list[dict[str, object]] = []
seen: set[tuple[str, int, int]] = set()
for intent_id in intent_ids:
for slot in annotate_slots(text, intent_id):
key = (str(slot["slot_name"]), int(slot["start"]), int(slot["end"]))
if key in seen:
continue
seen.add(key)
merged.append(slot)
merged.sort(key=lambda item: (int(item["start"]), int(item["end"])))
return merged
def build_train_samples() -> list[JointSample]:
samples: list[JointSample] = []
seen: set[tuple[str, tuple[str, ...]]] = set()
domain_data = yaml.safe_load(DOMAIN_PATH.read_text(encoding="utf-8")) or {}
for intent in domain_data.get("intents", []):
intent_id = str(intent.get("intent_id", "")).strip()
if not intent_id:
continue
for text in list(intent.get("examples", [])) + list(intent.get("keywords", [])):
text = str(text).strip()
if not text:
continue
key = (text, (intent_id,))
if key in seen:
continue
seen.add(key)
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
for row in load_jsonl(TRAIN_PATH):
text = str(row["text"])
intent_id = str(row["intent_id"])
key = (text, (intent_id,))
if key in seen:
continue
seen.add(key)
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
for row in load_jsonl(SEED_PATH):
text = str(row["text"])
intent_id = str(row["intent_id"])
key = (text, (intent_id,))
if key in seen:
continue
seen.add(key)
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=list(row.get("slots", []))))
for row in load_jsonl(MULTI_TRAIN_PATH):
text = str(row["text"]).strip()
intent_ids = sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()})
if not text or not intent_ids:
continue
key = (text, tuple(intent_ids))
if key in seen:
continue
seen.add(key)
slots = list(row.get("slots") or annotate_slots_for_intents(text, intent_ids))
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=slots))
if len(intent_ids) >= 2:
for _ in range(MULTI_INTENT_REPEAT - 1):
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=list(slots)))
random.shuffle(samples)
return samples
def build_eval_samples() -> list[JointSample]:
rows = load_jsonl(EVAL_PATH)
samples = [
JointSample(
text=str(row["text"]),
intent_ids=[str(row["intent_id"])],
slots=list(row.get("slots", [])),
)
for row in rows
]
if MULTI_EVAL_PATH.exists():
for row in load_jsonl(MULTI_EVAL_PATH):
samples.append(
JointSample(
text=str(row["text"]),
intent_ids=sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()}),
slots=list(row.get("slots") or annotate_slots_for_intents(str(row["text"]), list(row.get("intent_ids", [])))),
)
)
return samples
def build_slot_labels(samples: list[JointSample]) -> list[str]:
slot_names = sorted({str(slot["slot_name"]) for sample in samples for slot in sample.slots})
labels = ["O"]
for name in slot_names:
labels.append(f"B-{name}")
labels.append(f"I-{name}")
return labels
def compute_metrics(
model: JointBertForNLU,
dataloader: DataLoader,
device: torch.device,
intent_labels: list[str],
slot_labels: list[str],
threshold: float,
) -> dict[str, float]:
model.eval()
intent_tp = 0
intent_fp = 0
intent_fn = 0
single_intent_correct = 0
single_intent_total = 0
intent_exact_match = 0
correct_slot_tokens = 0
total_slot_tokens = 0
exact_slot_samples = 0
total_samples = 0
with torch.no_grad():
for batch in dataloader:
batch = {key: value.to(device) for key, value in batch.items()}
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
predicted_probs = torch.sigmoid(intent_logits)
predicted_multi = predicted_probs >= threshold
gold_multi = batch["intent_labels"] > 0.5
intent_tp += int((predicted_multi & gold_multi).sum().item())
intent_fp += int((predicted_multi & ~gold_multi).sum().item())
intent_fn += int((~predicted_multi & gold_multi).sum().item())
intent_exact_match += int((predicted_multi == gold_multi).all(dim=1).sum().item())
top_predicted = torch.argmax(predicted_probs, dim=-1)
gold_counts = gold_multi.sum(dim=-1)
single_mask = gold_counts == 1
if int(single_mask.sum().item()) > 0:
gold_top = torch.argmax(gold_multi.float(), dim=-1)
single_intent_correct += int((top_predicted[single_mask] == gold_top[single_mask]).sum().item())
single_intent_total += int(single_mask.sum().item())
predicted_slots = torch.argmax(slot_logits, dim=-1)
mask = batch["slot_labels"] != IGNORE_INDEX
correct_slot_tokens += int(((predicted_slots == batch["slot_labels"]) & mask).sum().item())
total_slot_tokens += int(mask.sum().item())
for index in range(batch["slot_labels"].size(0)):
gold = batch["slot_labels"][index][mask[index]]
pred = predicted_slots[index][mask[index]]
exact_slot_samples += int(torch.equal(gold, pred))
total_samples += 1
precision = intent_tp / max(intent_tp + intent_fp, 1)
recall = intent_tp / max(intent_tp + intent_fn, 1)
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
return {
"intent_threshold": round(threshold, 4),
"intent_micro_precision": round(precision, 4),
"intent_micro_recall": round(recall, 4),
"intent_micro_f1": round(f1, 4),
"intent_exact_match": round(intent_exact_match / max(total_samples, 1), 4),
"single_intent_top1_accuracy": round(single_intent_correct / max(single_intent_total, 1), 4),
"slot_token_accuracy": round(correct_slot_tokens / max(total_slot_tokens, 1), 4),
"slot_exact_match": round(exact_slot_samples / max(total_samples, 1), 4),
"intent_label_count": float(len(intent_labels)),
"slot_label_count": float(len(slot_labels)),
}
def search_best_threshold(
model: JointBertForNLU,
dataloader: DataLoader,
device: torch.device,
intent_labels: list[str],
slot_labels: list[str],
) -> dict[str, float]:
best_metrics: dict[str, float] | None = None
for threshold in THRESHOLD_CANDIDATES:
metrics = compute_metrics(
model,
dataloader,
device,
intent_labels,
slot_labels,
threshold=threshold,
)
if best_metrics is None:
best_metrics = metrics
continue
current_score = (metrics["intent_micro_f1"], metrics["intent_exact_match"], metrics["slot_exact_match"])
best_score = (
best_metrics["intent_micro_f1"],
best_metrics["intent_exact_match"],
best_metrics["slot_exact_match"],
)
if current_score > best_score:
best_metrics = metrics
assert best_metrics is not None
return best_metrics
def build_pos_weight(samples: list[JointSample], intent_labels: list[str]) -> torch.Tensor:
positive_counts = {label: 0 for label in intent_labels}
for sample in samples:
sample_intents = set(sample.intent_ids)
for label in intent_labels:
if label in sample_intents:
positive_counts[label] += 1
total = max(len(samples), 1)
weights: list[float] = []
for label in intent_labels:
positives = max(positive_counts[label], 1)
negatives = max(total - positives, 1)
weight = negatives / positives
weights.append(min(max(weight, 1.0), 12.0))
return torch.tensor(weights, dtype=torch.float32)
def main() -> None:
set_seed()
train_samples = build_train_samples()
eval_samples = build_eval_samples()
intent_labels = sorted({intent_id for sample in train_samples + eval_samples for intent_id in sample.intent_ids})
slot_labels = build_slot_labels(train_samples + eval_samples)
intent_to_index = {label: index for index, label in enumerate(intent_labels)}
slot_to_index = {label: index for index, label in enumerate(slot_labels)}
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_BASE_MODEL)
train_dataset = JointDataset(train_samples, tokenizer, intent_to_index, slot_to_index)
eval_dataset = JointDataset(eval_samples, tokenizer, intent_to_index, slot_to_index)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = JointBertForNLU(
base_model_name=DEFAULT_BASE_MODEL,
num_intents=len(intent_labels),
num_slot_labels=len(slot_labels),
)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
pos_weight = build_pos_weight(train_samples, intent_labels).to(device)
best_metrics: dict[str, float] | None = None
best_state: dict[str, torch.Tensor] | None = None
for epoch in range(EPOCHS):
model.train()
epoch_loss = 0.0
for batch in train_loader:
batch = {key: value.to(device) for key, value in batch.items()}
optimizer.zero_grad()
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
intent_loss = torch.nn.functional.binary_cross_entropy_with_logits(
intent_logits,
batch["intent_labels"],
pos_weight=pos_weight,
)
slot_loss = torch.nn.functional.cross_entropy(
slot_logits.view(-1, slot_logits.size(-1)),
batch["slot_labels"].view(-1),
ignore_index=IGNORE_INDEX,
)
loss = intent_loss + slot_loss
loss.backward()
optimizer.step()
epoch_loss += float(loss.item())
metrics = search_best_threshold(model, eval_loader, device, intent_labels, slot_labels)
metrics["train_loss"] = round(epoch_loss / max(len(train_loader), 1), 4)
print(json.dumps({"epoch": epoch + 1, **metrics}, ensure_ascii=False))
if best_metrics is None or metrics["intent_micro_f1"] > best_metrics["intent_micro_f1"]:
best_metrics = metrics
best_state = {key: value.detach().cpu() for key, value in model.state_dict().items()}
if best_state is None or best_metrics is None:
raise RuntimeError("joint nlu training did not produce a best checkpoint")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
tokenizer.save_pretrained(OUTPUT_DIR)
torch.save(best_state, OUTPUT_DIR / "model_state.pt")
config = {
"base_model_name": DEFAULT_BASE_MODEL,
"intent_task": "multi_label",
"intent_labels": intent_labels,
"slot_labels": slot_labels,
"max_length": MAX_LENGTH,
"intent_threshold": float(best_metrics["intent_threshold"]),
"multi_intent_threshold": float(best_metrics["intent_threshold"]),
"max_multi_intents": 4,
}
(OUTPUT_DIR / "joint_nlu_config.json").write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")
(OUTPUT_DIR / "train_summary.json").write_text(
json.dumps(
{
"train_size": len(train_samples),
"eval_size": len(eval_samples),
"metrics": best_metrics,
},
ensure_ascii=False,
indent=2,
),
encoding="utf-8",
)
if __name__ == "__main__":
main()