Update project and configurations
This commit is contained in:
500
intelligent_cabin/archive/scripts/train_joint_bert_nlu.py
Normal file
500
intelligent_cabin/archive/scripts/train_joint_bert_nlu.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoTokenizer
|
||||
import yaml
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.services.joint_nlu import JointBertForNLU
|
||||
|
||||
|
||||
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||||
MULTI_TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
|
||||
SEED_PATH = PROJECT_ROOT / "app/data/joint_nlu_seed.jsonl"
|
||||
EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval.jsonl"
|
||||
MULTI_EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_multilabel_eval.jsonl"
|
||||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
|
||||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||||
MAX_LENGTH = 64
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 8
|
||||
LEARNING_RATE = 2e-5
|
||||
SEED = 42
|
||||
IGNORE_INDEX = -100
|
||||
GENRE_KEYWORDS = ("轻音乐", "摇滚", "古典", "民谣", "爵士", "流行", "儿歌")
|
||||
DEFAULT_INTENT_THRESHOLD = 0.3
|
||||
MULTI_INTENT_REPEAT = 6
|
||||
THRESHOLD_CANDIDATES = [0.1, 0.12, 0.15, 0.18, 0.2, 0.22, 0.25, 0.28, 0.3, 0.33, 0.35, 0.38, 0.4, 0.45]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointSample:
|
||||
text: str
|
||||
intent_ids: list[str]
|
||||
slots: list[dict[str, object]]
|
||||
|
||||
|
||||
class JointDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
samples: list[JointSample],
|
||||
tokenizer,
|
||||
intent_to_index: dict[str, int],
|
||||
slot_to_index: dict[str, int],
|
||||
) -> None:
|
||||
self._samples = samples
|
||||
self._tokenizer = tokenizer
|
||||
self._intent_to_index = intent_to_index
|
||||
self._slot_to_index = slot_to_index
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||||
sample = self._samples[index]
|
||||
encoded = self._tokenizer(
|
||||
sample.text,
|
||||
truncation=True,
|
||||
max_length=MAX_LENGTH,
|
||||
padding="max_length",
|
||||
return_offsets_mapping=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
|
||||
slot_labels = [IGNORE_INDEX] * len(offset_mapping)
|
||||
char_labels = ["O"] * len(sample.text)
|
||||
for slot in sample.slots:
|
||||
start = int(slot["start"])
|
||||
end = int(slot["end"])
|
||||
slot_name = str(slot["slot_name"])
|
||||
if start < 0 or end > len(sample.text) or start >= end:
|
||||
continue
|
||||
char_labels[start] = f"B-{slot_name}"
|
||||
for pos in range(start + 1, end):
|
||||
char_labels[pos] = f"I-{slot_name}"
|
||||
|
||||
for token_index, (start, end) in enumerate(offset_mapping):
|
||||
if end <= start:
|
||||
continue
|
||||
label = char_labels[start]
|
||||
slot_labels[token_index] = self._slot_to_index.get(label, self._slot_to_index["O"])
|
||||
intent_vector = torch.zeros(len(self._intent_to_index), dtype=torch.float32)
|
||||
for intent_id in sample.intent_ids:
|
||||
if intent_id in self._intent_to_index:
|
||||
intent_vector[self._intent_to_index[intent_id]] = 1.0
|
||||
|
||||
return {
|
||||
"input_ids": encoded["input_ids"][0],
|
||||
"attention_mask": encoded["attention_mask"][0],
|
||||
"intent_labels": intent_vector,
|
||||
"slot_labels": torch.tensor(slot_labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def set_seed() -> None:
|
||||
random.seed(SEED)
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
|
||||
def load_jsonl(path: Path) -> list[dict[str, object]]:
|
||||
rows: list[dict[str, object]] = []
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def find_order_id_span(text: str) -> tuple[str, int, int] | None:
|
||||
match = re.search(r"[A-Za-z]\d{5,}", text)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(0), match.start(), match.end()
|
||||
|
||||
|
||||
def find_temperature_span(text: str) -> tuple[str, int, int] | None:
|
||||
match = re.search(r"(\d{2}\s*度)", text)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(1), match.start(), match.end()
|
||||
|
||||
|
||||
def find_destination_span(text: str) -> tuple[str, int, int] | None:
|
||||
for pattern in (
|
||||
r"导航去(?P<destination>.+)",
|
||||
r"导航到(?P<destination>.+)",
|
||||
r"带我去(?P<destination>.+)",
|
||||
r"送我去(?P<destination>.+)",
|
||||
r"去(?P<destination>.+)",
|
||||
):
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
continue
|
||||
destination = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", match.group("destination"), maxsplit=1)[0].strip(" ,。")
|
||||
if not destination:
|
||||
continue
|
||||
start = text.find(destination)
|
||||
if start >= 0:
|
||||
return destination, start, start + len(destination)
|
||||
return None
|
||||
|
||||
|
||||
def find_music_span(text: str) -> tuple[str, str, int, int] | None:
|
||||
for genre in GENRE_KEYWORDS:
|
||||
start = text.find(genre)
|
||||
if start >= 0:
|
||||
return "genre", genre, start, start + len(genre)
|
||||
for trigger in ("播放", "来点", "放点", "听", "来首", "来一首", "放一首"):
|
||||
if trigger not in text:
|
||||
continue
|
||||
target = text.split(trigger, maxsplit=1)[-1]
|
||||
target = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", target, maxsplit=1)[0].strip(" 的一首首个歌曲音乐吧呀啊,。")
|
||||
if not target or target in {"歌", "音乐"}:
|
||||
continue
|
||||
for genre in GENRE_KEYWORDS:
|
||||
if genre in target:
|
||||
start = text.find(genre)
|
||||
return "genre", genre, start, start + len(genre)
|
||||
start = text.find(target)
|
||||
if start >= 0:
|
||||
return "song", target, start, start + len(target)
|
||||
return None
|
||||
|
||||
|
||||
def annotate_slots(text: str, intent_id: str) -> list[dict[str, object]]:
|
||||
slots: list[dict[str, object]] = []
|
||||
if intent_id in {"cs_query_order", "cs_query_logistics", "cs_cancel_order"}:
|
||||
matched = find_order_id_span(text)
|
||||
if matched is not None:
|
||||
value, start, end = matched
|
||||
slots.append({"slot_name": "order_id", "value": value, "start": start, "end": end})
|
||||
elif intent_id == "cabin_set_ac":
|
||||
matched = find_temperature_span(text)
|
||||
if matched is not None:
|
||||
value, start, end = matched
|
||||
slots.append({"slot_name": "temperature", "value": value, "start": start, "end": end})
|
||||
elif intent_id == "cabin_nav_to":
|
||||
matched = find_destination_span(text)
|
||||
if matched is not None:
|
||||
value, start, end = matched
|
||||
slots.append({"slot_name": "destination", "value": value, "start": start, "end": end})
|
||||
elif intent_id == "cabin_play_music":
|
||||
matched = find_music_span(text)
|
||||
if matched is not None:
|
||||
slot_name, value, start, end = matched
|
||||
slots.append({"slot_name": slot_name, "value": value, "start": start, "end": end})
|
||||
return slots
|
||||
|
||||
|
||||
def annotate_slots_for_intents(text: str, intent_ids: list[str]) -> list[dict[str, object]]:
|
||||
merged: list[dict[str, object]] = []
|
||||
seen: set[tuple[str, int, int]] = set()
|
||||
for intent_id in intent_ids:
|
||||
for slot in annotate_slots(text, intent_id):
|
||||
key = (str(slot["slot_name"]), int(slot["start"]), int(slot["end"]))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
merged.append(slot)
|
||||
merged.sort(key=lambda item: (int(item["start"]), int(item["end"])))
|
||||
return merged
|
||||
|
||||
|
||||
def build_train_samples() -> list[JointSample]:
|
||||
samples: list[JointSample] = []
|
||||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||||
domain_data = yaml.safe_load(DOMAIN_PATH.read_text(encoding="utf-8")) or {}
|
||||
for intent in domain_data.get("intents", []):
|
||||
intent_id = str(intent.get("intent_id", "")).strip()
|
||||
if not intent_id:
|
||||
continue
|
||||
for text in list(intent.get("examples", [])) + list(intent.get("keywords", [])):
|
||||
text = str(text).strip()
|
||||
if not text:
|
||||
continue
|
||||
key = (text, (intent_id,))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
|
||||
for row in load_jsonl(TRAIN_PATH):
|
||||
text = str(row["text"])
|
||||
intent_id = str(row["intent_id"])
|
||||
key = (text, (intent_id,))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
|
||||
for row in load_jsonl(SEED_PATH):
|
||||
text = str(row["text"])
|
||||
intent_id = str(row["intent_id"])
|
||||
key = (text, (intent_id,))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=list(row.get("slots", []))))
|
||||
for row in load_jsonl(MULTI_TRAIN_PATH):
|
||||
text = str(row["text"]).strip()
|
||||
intent_ids = sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()})
|
||||
if not text or not intent_ids:
|
||||
continue
|
||||
key = (text, tuple(intent_ids))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
slots = list(row.get("slots") or annotate_slots_for_intents(text, intent_ids))
|
||||
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=slots))
|
||||
if len(intent_ids) >= 2:
|
||||
for _ in range(MULTI_INTENT_REPEAT - 1):
|
||||
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=list(slots)))
|
||||
random.shuffle(samples)
|
||||
return samples
|
||||
|
||||
|
||||
def build_eval_samples() -> list[JointSample]:
|
||||
rows = load_jsonl(EVAL_PATH)
|
||||
samples = [
|
||||
JointSample(
|
||||
text=str(row["text"]),
|
||||
intent_ids=[str(row["intent_id"])],
|
||||
slots=list(row.get("slots", [])),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
if MULTI_EVAL_PATH.exists():
|
||||
for row in load_jsonl(MULTI_EVAL_PATH):
|
||||
samples.append(
|
||||
JointSample(
|
||||
text=str(row["text"]),
|
||||
intent_ids=sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()}),
|
||||
slots=list(row.get("slots") or annotate_slots_for_intents(str(row["text"]), list(row.get("intent_ids", [])))),
|
||||
)
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def build_slot_labels(samples: list[JointSample]) -> list[str]:
|
||||
slot_names = sorted({str(slot["slot_name"]) for sample in samples for slot in sample.slots})
|
||||
labels = ["O"]
|
||||
for name in slot_names:
|
||||
labels.append(f"B-{name}")
|
||||
labels.append(f"I-{name}")
|
||||
return labels
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
model: JointBertForNLU,
|
||||
dataloader: DataLoader,
|
||||
device: torch.device,
|
||||
intent_labels: list[str],
|
||||
slot_labels: list[str],
|
||||
threshold: float,
|
||||
) -> dict[str, float]:
|
||||
model.eval()
|
||||
intent_tp = 0
|
||||
intent_fp = 0
|
||||
intent_fn = 0
|
||||
single_intent_correct = 0
|
||||
single_intent_total = 0
|
||||
intent_exact_match = 0
|
||||
correct_slot_tokens = 0
|
||||
total_slot_tokens = 0
|
||||
exact_slot_samples = 0
|
||||
total_samples = 0
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
batch = {key: value.to(device) for key, value in batch.items()}
|
||||
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
|
||||
predicted_probs = torch.sigmoid(intent_logits)
|
||||
predicted_multi = predicted_probs >= threshold
|
||||
gold_multi = batch["intent_labels"] > 0.5
|
||||
intent_tp += int((predicted_multi & gold_multi).sum().item())
|
||||
intent_fp += int((predicted_multi & ~gold_multi).sum().item())
|
||||
intent_fn += int((~predicted_multi & gold_multi).sum().item())
|
||||
intent_exact_match += int((predicted_multi == gold_multi).all(dim=1).sum().item())
|
||||
top_predicted = torch.argmax(predicted_probs, dim=-1)
|
||||
gold_counts = gold_multi.sum(dim=-1)
|
||||
single_mask = gold_counts == 1
|
||||
if int(single_mask.sum().item()) > 0:
|
||||
gold_top = torch.argmax(gold_multi.float(), dim=-1)
|
||||
single_intent_correct += int((top_predicted[single_mask] == gold_top[single_mask]).sum().item())
|
||||
single_intent_total += int(single_mask.sum().item())
|
||||
|
||||
predicted_slots = torch.argmax(slot_logits, dim=-1)
|
||||
mask = batch["slot_labels"] != IGNORE_INDEX
|
||||
correct_slot_tokens += int(((predicted_slots == batch["slot_labels"]) & mask).sum().item())
|
||||
total_slot_tokens += int(mask.sum().item())
|
||||
|
||||
for index in range(batch["slot_labels"].size(0)):
|
||||
gold = batch["slot_labels"][index][mask[index]]
|
||||
pred = predicted_slots[index][mask[index]]
|
||||
exact_slot_samples += int(torch.equal(gold, pred))
|
||||
total_samples += 1
|
||||
precision = intent_tp / max(intent_tp + intent_fp, 1)
|
||||
recall = intent_tp / max(intent_tp + intent_fn, 1)
|
||||
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"intent_threshold": round(threshold, 4),
|
||||
"intent_micro_precision": round(precision, 4),
|
||||
"intent_micro_recall": round(recall, 4),
|
||||
"intent_micro_f1": round(f1, 4),
|
||||
"intent_exact_match": round(intent_exact_match / max(total_samples, 1), 4),
|
||||
"single_intent_top1_accuracy": round(single_intent_correct / max(single_intent_total, 1), 4),
|
||||
"slot_token_accuracy": round(correct_slot_tokens / max(total_slot_tokens, 1), 4),
|
||||
"slot_exact_match": round(exact_slot_samples / max(total_samples, 1), 4),
|
||||
"intent_label_count": float(len(intent_labels)),
|
||||
"slot_label_count": float(len(slot_labels)),
|
||||
}
|
||||
|
||||
|
||||
def search_best_threshold(
|
||||
model: JointBertForNLU,
|
||||
dataloader: DataLoader,
|
||||
device: torch.device,
|
||||
intent_labels: list[str],
|
||||
slot_labels: list[str],
|
||||
) -> dict[str, float]:
|
||||
best_metrics: dict[str, float] | None = None
|
||||
for threshold in THRESHOLD_CANDIDATES:
|
||||
metrics = compute_metrics(
|
||||
model,
|
||||
dataloader,
|
||||
device,
|
||||
intent_labels,
|
||||
slot_labels,
|
||||
threshold=threshold,
|
||||
)
|
||||
if best_metrics is None:
|
||||
best_metrics = metrics
|
||||
continue
|
||||
current_score = (metrics["intent_micro_f1"], metrics["intent_exact_match"], metrics["slot_exact_match"])
|
||||
best_score = (
|
||||
best_metrics["intent_micro_f1"],
|
||||
best_metrics["intent_exact_match"],
|
||||
best_metrics["slot_exact_match"],
|
||||
)
|
||||
if current_score > best_score:
|
||||
best_metrics = metrics
|
||||
assert best_metrics is not None
|
||||
return best_metrics
|
||||
|
||||
|
||||
def build_pos_weight(samples: list[JointSample], intent_labels: list[str]) -> torch.Tensor:
|
||||
positive_counts = {label: 0 for label in intent_labels}
|
||||
for sample in samples:
|
||||
sample_intents = set(sample.intent_ids)
|
||||
for label in intent_labels:
|
||||
if label in sample_intents:
|
||||
positive_counts[label] += 1
|
||||
total = max(len(samples), 1)
|
||||
weights: list[float] = []
|
||||
for label in intent_labels:
|
||||
positives = max(positive_counts[label], 1)
|
||||
negatives = max(total - positives, 1)
|
||||
weight = negatives / positives
|
||||
weights.append(min(max(weight, 1.0), 12.0))
|
||||
return torch.tensor(weights, dtype=torch.float32)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
set_seed()
|
||||
train_samples = build_train_samples()
|
||||
eval_samples = build_eval_samples()
|
||||
intent_labels = sorted({intent_id for sample in train_samples + eval_samples for intent_id in sample.intent_ids})
|
||||
slot_labels = build_slot_labels(train_samples + eval_samples)
|
||||
intent_to_index = {label: index for index, label in enumerate(intent_labels)}
|
||||
slot_to_index = {label: index for index, label in enumerate(slot_labels)}
|
||||
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_BASE_MODEL)
|
||||
train_dataset = JointDataset(train_samples, tokenizer, intent_to_index, slot_to_index)
|
||||
eval_dataset = JointDataset(eval_samples, tokenizer, intent_to_index, slot_to_index)
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model = JointBertForNLU(
|
||||
base_model_name=DEFAULT_BASE_MODEL,
|
||||
num_intents=len(intent_labels),
|
||||
num_slot_labels=len(slot_labels),
|
||||
)
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
pos_weight = build_pos_weight(train_samples, intent_labels).to(device)
|
||||
best_metrics: dict[str, float] | None = None
|
||||
best_state: dict[str, torch.Tensor] | None = None
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
model.train()
|
||||
epoch_loss = 0.0
|
||||
for batch in train_loader:
|
||||
batch = {key: value.to(device) for key, value in batch.items()}
|
||||
optimizer.zero_grad()
|
||||
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
|
||||
intent_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
intent_logits,
|
||||
batch["intent_labels"],
|
||||
pos_weight=pos_weight,
|
||||
)
|
||||
slot_loss = torch.nn.functional.cross_entropy(
|
||||
slot_logits.view(-1, slot_logits.size(-1)),
|
||||
batch["slot_labels"].view(-1),
|
||||
ignore_index=IGNORE_INDEX,
|
||||
)
|
||||
loss = intent_loss + slot_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
epoch_loss += float(loss.item())
|
||||
|
||||
metrics = search_best_threshold(model, eval_loader, device, intent_labels, slot_labels)
|
||||
metrics["train_loss"] = round(epoch_loss / max(len(train_loader), 1), 4)
|
||||
print(json.dumps({"epoch": epoch + 1, **metrics}, ensure_ascii=False))
|
||||
if best_metrics is None or metrics["intent_micro_f1"] > best_metrics["intent_micro_f1"]:
|
||||
best_metrics = metrics
|
||||
best_state = {key: value.detach().cpu() for key, value in model.state_dict().items()}
|
||||
|
||||
if best_state is None or best_metrics is None:
|
||||
raise RuntimeError("joint nlu training did not produce a best checkpoint")
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
torch.save(best_state, OUTPUT_DIR / "model_state.pt")
|
||||
config = {
|
||||
"base_model_name": DEFAULT_BASE_MODEL,
|
||||
"intent_task": "multi_label",
|
||||
"intent_labels": intent_labels,
|
||||
"slot_labels": slot_labels,
|
||||
"max_length": MAX_LENGTH,
|
||||
"intent_threshold": float(best_metrics["intent_threshold"]),
|
||||
"multi_intent_threshold": float(best_metrics["intent_threshold"]),
|
||||
"max_multi_intents": 4,
|
||||
}
|
||||
(OUTPUT_DIR / "joint_nlu_config.json").write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"train_size": len(train_samples),
|
||||
"eval_size": len(eval_samples),
|
||||
"metrics": best_metrics,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user