501 lines
19 KiB
Python
501 lines
19 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import random
|
||
import re
|
||
import sys
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader, Dataset
|
||
from transformers import AutoTokenizer
|
||
import yaml
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||
if str(PROJECT_ROOT) not in sys.path:
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from app.services.joint_nlu import JointBertForNLU
|
||
|
||
|
||
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||
MULTI_TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
|
||
SEED_PATH = PROJECT_ROOT / "app/data/joint_nlu_seed.jsonl"
|
||
EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval.jsonl"
|
||
MULTI_EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_multilabel_eval.jsonl"
|
||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||
OUTPUT_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
|
||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||
MAX_LENGTH = 64
|
||
BATCH_SIZE = 8
|
||
EPOCHS = 8
|
||
LEARNING_RATE = 2e-5
|
||
SEED = 42
|
||
IGNORE_INDEX = -100
|
||
GENRE_KEYWORDS = ("轻音乐", "摇滚", "古典", "民谣", "爵士", "流行", "儿歌")
|
||
DEFAULT_INTENT_THRESHOLD = 0.3
|
||
MULTI_INTENT_REPEAT = 6
|
||
THRESHOLD_CANDIDATES = [0.1, 0.12, 0.15, 0.18, 0.2, 0.22, 0.25, 0.28, 0.3, 0.33, 0.35, 0.38, 0.4, 0.45]
|
||
|
||
|
||
@dataclass
|
||
class JointSample:
|
||
text: str
|
||
intent_ids: list[str]
|
||
slots: list[dict[str, object]]
|
||
|
||
|
||
class JointDataset(Dataset):
|
||
def __init__(
|
||
self,
|
||
samples: list[JointSample],
|
||
tokenizer,
|
||
intent_to_index: dict[str, int],
|
||
slot_to_index: dict[str, int],
|
||
) -> None:
|
||
self._samples = samples
|
||
self._tokenizer = tokenizer
|
||
self._intent_to_index = intent_to_index
|
||
self._slot_to_index = slot_to_index
|
||
|
||
def __len__(self) -> int:
|
||
return len(self._samples)
|
||
|
||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||
sample = self._samples[index]
|
||
encoded = self._tokenizer(
|
||
sample.text,
|
||
truncation=True,
|
||
max_length=MAX_LENGTH,
|
||
padding="max_length",
|
||
return_offsets_mapping=True,
|
||
return_tensors="pt",
|
||
)
|
||
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
|
||
slot_labels = [IGNORE_INDEX] * len(offset_mapping)
|
||
char_labels = ["O"] * len(sample.text)
|
||
for slot in sample.slots:
|
||
start = int(slot["start"])
|
||
end = int(slot["end"])
|
||
slot_name = str(slot["slot_name"])
|
||
if start < 0 or end > len(sample.text) or start >= end:
|
||
continue
|
||
char_labels[start] = f"B-{slot_name}"
|
||
for pos in range(start + 1, end):
|
||
char_labels[pos] = f"I-{slot_name}"
|
||
|
||
for token_index, (start, end) in enumerate(offset_mapping):
|
||
if end <= start:
|
||
continue
|
||
label = char_labels[start]
|
||
slot_labels[token_index] = self._slot_to_index.get(label, self._slot_to_index["O"])
|
||
intent_vector = torch.zeros(len(self._intent_to_index), dtype=torch.float32)
|
||
for intent_id in sample.intent_ids:
|
||
if intent_id in self._intent_to_index:
|
||
intent_vector[self._intent_to_index[intent_id]] = 1.0
|
||
|
||
return {
|
||
"input_ids": encoded["input_ids"][0],
|
||
"attention_mask": encoded["attention_mask"][0],
|
||
"intent_labels": intent_vector,
|
||
"slot_labels": torch.tensor(slot_labels, dtype=torch.long),
|
||
}
|
||
|
||
|
||
def set_seed() -> None:
|
||
random.seed(SEED)
|
||
torch.manual_seed(SEED)
|
||
|
||
|
||
def load_jsonl(path: Path) -> list[dict[str, object]]:
|
||
rows: list[dict[str, object]] = []
|
||
with path.open("r", encoding="utf-8") as handle:
|
||
for line in handle:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
rows.append(json.loads(line))
|
||
return rows
|
||
|
||
|
||
def find_order_id_span(text: str) -> tuple[str, int, int] | None:
|
||
match = re.search(r"[A-Za-z]\d{5,}", text)
|
||
if not match:
|
||
return None
|
||
return match.group(0), match.start(), match.end()
|
||
|
||
|
||
def find_temperature_span(text: str) -> tuple[str, int, int] | None:
|
||
match = re.search(r"(\d{2}\s*度)", text)
|
||
if not match:
|
||
return None
|
||
return match.group(1), match.start(), match.end()
|
||
|
||
|
||
def find_destination_span(text: str) -> tuple[str, int, int] | None:
|
||
for pattern in (
|
||
r"导航去(?P<destination>.+)",
|
||
r"导航到(?P<destination>.+)",
|
||
r"带我去(?P<destination>.+)",
|
||
r"送我去(?P<destination>.+)",
|
||
r"去(?P<destination>.+)",
|
||
):
|
||
match = re.search(pattern, text)
|
||
if not match:
|
||
continue
|
||
destination = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", match.group("destination"), maxsplit=1)[0].strip(" ,。")
|
||
if not destination:
|
||
continue
|
||
start = text.find(destination)
|
||
if start >= 0:
|
||
return destination, start, start + len(destination)
|
||
return None
|
||
|
||
|
||
def find_music_span(text: str) -> tuple[str, str, int, int] | None:
|
||
for genre in GENRE_KEYWORDS:
|
||
start = text.find(genre)
|
||
if start >= 0:
|
||
return "genre", genre, start, start + len(genre)
|
||
for trigger in ("播放", "来点", "放点", "听", "来首", "来一首", "放一首"):
|
||
if trigger not in text:
|
||
continue
|
||
target = text.split(trigger, maxsplit=1)[-1]
|
||
target = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", target, maxsplit=1)[0].strip(" 的一首首个歌曲音乐吧呀啊,。")
|
||
if not target or target in {"歌", "音乐"}:
|
||
continue
|
||
for genre in GENRE_KEYWORDS:
|
||
if genre in target:
|
||
start = text.find(genre)
|
||
return "genre", genre, start, start + len(genre)
|
||
start = text.find(target)
|
||
if start >= 0:
|
||
return "song", target, start, start + len(target)
|
||
return None
|
||
|
||
|
||
def annotate_slots(text: str, intent_id: str) -> list[dict[str, object]]:
|
||
slots: list[dict[str, object]] = []
|
||
if intent_id in {"cs_query_order", "cs_query_logistics", "cs_cancel_order"}:
|
||
matched = find_order_id_span(text)
|
||
if matched is not None:
|
||
value, start, end = matched
|
||
slots.append({"slot_name": "order_id", "value": value, "start": start, "end": end})
|
||
elif intent_id == "cabin_set_ac":
|
||
matched = find_temperature_span(text)
|
||
if matched is not None:
|
||
value, start, end = matched
|
||
slots.append({"slot_name": "temperature", "value": value, "start": start, "end": end})
|
||
elif intent_id == "cabin_nav_to":
|
||
matched = find_destination_span(text)
|
||
if matched is not None:
|
||
value, start, end = matched
|
||
slots.append({"slot_name": "destination", "value": value, "start": start, "end": end})
|
||
elif intent_id == "cabin_play_music":
|
||
matched = find_music_span(text)
|
||
if matched is not None:
|
||
slot_name, value, start, end = matched
|
||
slots.append({"slot_name": slot_name, "value": value, "start": start, "end": end})
|
||
return slots
|
||
|
||
|
||
def annotate_slots_for_intents(text: str, intent_ids: list[str]) -> list[dict[str, object]]:
|
||
merged: list[dict[str, object]] = []
|
||
seen: set[tuple[str, int, int]] = set()
|
||
for intent_id in intent_ids:
|
||
for slot in annotate_slots(text, intent_id):
|
||
key = (str(slot["slot_name"]), int(slot["start"]), int(slot["end"]))
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
merged.append(slot)
|
||
merged.sort(key=lambda item: (int(item["start"]), int(item["end"])))
|
||
return merged
|
||
|
||
|
||
def build_train_samples() -> list[JointSample]:
|
||
samples: list[JointSample] = []
|
||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||
domain_data = yaml.safe_load(DOMAIN_PATH.read_text(encoding="utf-8")) or {}
|
||
for intent in domain_data.get("intents", []):
|
||
intent_id = str(intent.get("intent_id", "")).strip()
|
||
if not intent_id:
|
||
continue
|
||
for text in list(intent.get("examples", [])) + list(intent.get("keywords", [])):
|
||
text = str(text).strip()
|
||
if not text:
|
||
continue
|
||
key = (text, (intent_id,))
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
|
||
for row in load_jsonl(TRAIN_PATH):
|
||
text = str(row["text"])
|
||
intent_id = str(row["intent_id"])
|
||
key = (text, (intent_id,))
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
|
||
for row in load_jsonl(SEED_PATH):
|
||
text = str(row["text"])
|
||
intent_id = str(row["intent_id"])
|
||
key = (text, (intent_id,))
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=list(row.get("slots", []))))
|
||
for row in load_jsonl(MULTI_TRAIN_PATH):
|
||
text = str(row["text"]).strip()
|
||
intent_ids = sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()})
|
||
if not text or not intent_ids:
|
||
continue
|
||
key = (text, tuple(intent_ids))
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
slots = list(row.get("slots") or annotate_slots_for_intents(text, intent_ids))
|
||
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=slots))
|
||
if len(intent_ids) >= 2:
|
||
for _ in range(MULTI_INTENT_REPEAT - 1):
|
||
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=list(slots)))
|
||
random.shuffle(samples)
|
||
return samples
|
||
|
||
|
||
def build_eval_samples() -> list[JointSample]:
|
||
rows = load_jsonl(EVAL_PATH)
|
||
samples = [
|
||
JointSample(
|
||
text=str(row["text"]),
|
||
intent_ids=[str(row["intent_id"])],
|
||
slots=list(row.get("slots", [])),
|
||
)
|
||
for row in rows
|
||
]
|
||
if MULTI_EVAL_PATH.exists():
|
||
for row in load_jsonl(MULTI_EVAL_PATH):
|
||
samples.append(
|
||
JointSample(
|
||
text=str(row["text"]),
|
||
intent_ids=sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()}),
|
||
slots=list(row.get("slots") or annotate_slots_for_intents(str(row["text"]), list(row.get("intent_ids", [])))),
|
||
)
|
||
)
|
||
return samples
|
||
|
||
|
||
def build_slot_labels(samples: list[JointSample]) -> list[str]:
|
||
slot_names = sorted({str(slot["slot_name"]) for sample in samples for slot in sample.slots})
|
||
labels = ["O"]
|
||
for name in slot_names:
|
||
labels.append(f"B-{name}")
|
||
labels.append(f"I-{name}")
|
||
return labels
|
||
|
||
|
||
def compute_metrics(
|
||
model: JointBertForNLU,
|
||
dataloader: DataLoader,
|
||
device: torch.device,
|
||
intent_labels: list[str],
|
||
slot_labels: list[str],
|
||
threshold: float,
|
||
) -> dict[str, float]:
|
||
model.eval()
|
||
intent_tp = 0
|
||
intent_fp = 0
|
||
intent_fn = 0
|
||
single_intent_correct = 0
|
||
single_intent_total = 0
|
||
intent_exact_match = 0
|
||
correct_slot_tokens = 0
|
||
total_slot_tokens = 0
|
||
exact_slot_samples = 0
|
||
total_samples = 0
|
||
with torch.no_grad():
|
||
for batch in dataloader:
|
||
batch = {key: value.to(device) for key, value in batch.items()}
|
||
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
|
||
predicted_probs = torch.sigmoid(intent_logits)
|
||
predicted_multi = predicted_probs >= threshold
|
||
gold_multi = batch["intent_labels"] > 0.5
|
||
intent_tp += int((predicted_multi & gold_multi).sum().item())
|
||
intent_fp += int((predicted_multi & ~gold_multi).sum().item())
|
||
intent_fn += int((~predicted_multi & gold_multi).sum().item())
|
||
intent_exact_match += int((predicted_multi == gold_multi).all(dim=1).sum().item())
|
||
top_predicted = torch.argmax(predicted_probs, dim=-1)
|
||
gold_counts = gold_multi.sum(dim=-1)
|
||
single_mask = gold_counts == 1
|
||
if int(single_mask.sum().item()) > 0:
|
||
gold_top = torch.argmax(gold_multi.float(), dim=-1)
|
||
single_intent_correct += int((top_predicted[single_mask] == gold_top[single_mask]).sum().item())
|
||
single_intent_total += int(single_mask.sum().item())
|
||
|
||
predicted_slots = torch.argmax(slot_logits, dim=-1)
|
||
mask = batch["slot_labels"] != IGNORE_INDEX
|
||
correct_slot_tokens += int(((predicted_slots == batch["slot_labels"]) & mask).sum().item())
|
||
total_slot_tokens += int(mask.sum().item())
|
||
|
||
for index in range(batch["slot_labels"].size(0)):
|
||
gold = batch["slot_labels"][index][mask[index]]
|
||
pred = predicted_slots[index][mask[index]]
|
||
exact_slot_samples += int(torch.equal(gold, pred))
|
||
total_samples += 1
|
||
precision = intent_tp / max(intent_tp + intent_fp, 1)
|
||
recall = intent_tp / max(intent_tp + intent_fn, 1)
|
||
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||
return {
|
||
"intent_threshold": round(threshold, 4),
|
||
"intent_micro_precision": round(precision, 4),
|
||
"intent_micro_recall": round(recall, 4),
|
||
"intent_micro_f1": round(f1, 4),
|
||
"intent_exact_match": round(intent_exact_match / max(total_samples, 1), 4),
|
||
"single_intent_top1_accuracy": round(single_intent_correct / max(single_intent_total, 1), 4),
|
||
"slot_token_accuracy": round(correct_slot_tokens / max(total_slot_tokens, 1), 4),
|
||
"slot_exact_match": round(exact_slot_samples / max(total_samples, 1), 4),
|
||
"intent_label_count": float(len(intent_labels)),
|
||
"slot_label_count": float(len(slot_labels)),
|
||
}
|
||
|
||
|
||
def search_best_threshold(
|
||
model: JointBertForNLU,
|
||
dataloader: DataLoader,
|
||
device: torch.device,
|
||
intent_labels: list[str],
|
||
slot_labels: list[str],
|
||
) -> dict[str, float]:
|
||
best_metrics: dict[str, float] | None = None
|
||
for threshold in THRESHOLD_CANDIDATES:
|
||
metrics = compute_metrics(
|
||
model,
|
||
dataloader,
|
||
device,
|
||
intent_labels,
|
||
slot_labels,
|
||
threshold=threshold,
|
||
)
|
||
if best_metrics is None:
|
||
best_metrics = metrics
|
||
continue
|
||
current_score = (metrics["intent_micro_f1"], metrics["intent_exact_match"], metrics["slot_exact_match"])
|
||
best_score = (
|
||
best_metrics["intent_micro_f1"],
|
||
best_metrics["intent_exact_match"],
|
||
best_metrics["slot_exact_match"],
|
||
)
|
||
if current_score > best_score:
|
||
best_metrics = metrics
|
||
assert best_metrics is not None
|
||
return best_metrics
|
||
|
||
|
||
def build_pos_weight(samples: list[JointSample], intent_labels: list[str]) -> torch.Tensor:
|
||
positive_counts = {label: 0 for label in intent_labels}
|
||
for sample in samples:
|
||
sample_intents = set(sample.intent_ids)
|
||
for label in intent_labels:
|
||
if label in sample_intents:
|
||
positive_counts[label] += 1
|
||
total = max(len(samples), 1)
|
||
weights: list[float] = []
|
||
for label in intent_labels:
|
||
positives = max(positive_counts[label], 1)
|
||
negatives = max(total - positives, 1)
|
||
weight = negatives / positives
|
||
weights.append(min(max(weight, 1.0), 12.0))
|
||
return torch.tensor(weights, dtype=torch.float32)
|
||
|
||
|
||
def main() -> None:
|
||
set_seed()
|
||
train_samples = build_train_samples()
|
||
eval_samples = build_eval_samples()
|
||
intent_labels = sorted({intent_id for sample in train_samples + eval_samples for intent_id in sample.intent_ids})
|
||
slot_labels = build_slot_labels(train_samples + eval_samples)
|
||
intent_to_index = {label: index for index, label in enumerate(intent_labels)}
|
||
slot_to_index = {label: index for index, label in enumerate(slot_labels)}
|
||
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_BASE_MODEL)
|
||
train_dataset = JointDataset(train_samples, tokenizer, intent_to_index, slot_to_index)
|
||
eval_dataset = JointDataset(eval_samples, tokenizer, intent_to_index, slot_to_index)
|
||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||
|
||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||
model = JointBertForNLU(
|
||
base_model_name=DEFAULT_BASE_MODEL,
|
||
num_intents=len(intent_labels),
|
||
num_slot_labels=len(slot_labels),
|
||
)
|
||
model.to(device)
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||
pos_weight = build_pos_weight(train_samples, intent_labels).to(device)
|
||
best_metrics: dict[str, float] | None = None
|
||
best_state: dict[str, torch.Tensor] | None = None
|
||
|
||
for epoch in range(EPOCHS):
|
||
model.train()
|
||
epoch_loss = 0.0
|
||
for batch in train_loader:
|
||
batch = {key: value.to(device) for key, value in batch.items()}
|
||
optimizer.zero_grad()
|
||
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
|
||
intent_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||
intent_logits,
|
||
batch["intent_labels"],
|
||
pos_weight=pos_weight,
|
||
)
|
||
slot_loss = torch.nn.functional.cross_entropy(
|
||
slot_logits.view(-1, slot_logits.size(-1)),
|
||
batch["slot_labels"].view(-1),
|
||
ignore_index=IGNORE_INDEX,
|
||
)
|
||
loss = intent_loss + slot_loss
|
||
loss.backward()
|
||
optimizer.step()
|
||
epoch_loss += float(loss.item())
|
||
|
||
metrics = search_best_threshold(model, eval_loader, device, intent_labels, slot_labels)
|
||
metrics["train_loss"] = round(epoch_loss / max(len(train_loader), 1), 4)
|
||
print(json.dumps({"epoch": epoch + 1, **metrics}, ensure_ascii=False))
|
||
if best_metrics is None or metrics["intent_micro_f1"] > best_metrics["intent_micro_f1"]:
|
||
best_metrics = metrics
|
||
best_state = {key: value.detach().cpu() for key, value in model.state_dict().items()}
|
||
|
||
if best_state is None or best_metrics is None:
|
||
raise RuntimeError("joint nlu training did not produce a best checkpoint")
|
||
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||
torch.save(best_state, OUTPUT_DIR / "model_state.pt")
|
||
config = {
|
||
"base_model_name": DEFAULT_BASE_MODEL,
|
||
"intent_task": "multi_label",
|
||
"intent_labels": intent_labels,
|
||
"slot_labels": slot_labels,
|
||
"max_length": MAX_LENGTH,
|
||
"intent_threshold": float(best_metrics["intent_threshold"]),
|
||
"multi_intent_threshold": float(best_metrics["intent_threshold"]),
|
||
"max_multi_intents": 4,
|
||
}
|
||
(OUTPUT_DIR / "joint_nlu_config.json").write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||
json.dumps(
|
||
{
|
||
"train_size": len(train_samples),
|
||
"eval_size": len(eval_samples),
|
||
"metrics": best_metrics,
|
||
},
|
||
ensure_ascii=False,
|
||
indent=2,
|
||
),
|
||
encoding="utf-8",
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|