Files
ai-device/intelligent_cabin/archive/scripts/train_local_bert_multi_intent.py
2026-06-11 16:28:00 +08:00

416 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import os
import random
from dataclasses import dataclass
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import yaml
PROJECT_ROOT = Path(__file__).resolve().parents[1]
SINGLE_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
MULTI_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_multi_intent"
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
SOCIAL_LABEL = "__social__"
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
BLOCKED_LABELS = {SOCIAL_LABEL, OUT_OF_SCOPE_LABEL}
MAX_LENGTH = 48
BATCH_SIZE = 8
EPOCHS = 12
LEARNING_RATE = 2e-5
THRESHOLD = 0.5
TOP_K = 4
SEED = 42
CONNECTOR_VARIANTS: tuple[tuple[str, str], ...] = (
("", "然后"),
("然后", ""),
("顺便", ""),
("", "顺便"),
)
@dataclass(frozen=True)
class MultiLabelSample:
text: str
intent_ids: tuple[str, ...]
class MultiLabelIntentDataset(Dataset):
def __init__(
self,
samples: list[MultiLabelSample],
tokenizer,
label_to_id: dict[str, int],
) -> None:
self._samples = samples
self._tokenizer = tokenizer
self._label_to_id = label_to_id
self._label_size = len(label_to_id)
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
sample = self._samples[index]
encoded = self._tokenizer(
sample.text,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt",
)
labels = torch.zeros(self._label_size, dtype=torch.float32)
for intent_id in sample.intent_ids:
labels[self._label_to_id[intent_id]] = 1.0
return {
"input_ids": encoded["input_ids"].squeeze(0),
"attention_mask": encoded["attention_mask"].squeeze(0),
"labels": labels,
}
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def resolve_base_model() -> str:
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
if configured:
return configured
return DEFAULT_BASE_MODEL
def normalize_text(text: str) -> str:
return " ".join(str(text).strip().split())
def normalize_intent_ids(intent_ids: list[str] | tuple[str, ...]) -> tuple[str, ...]:
cleaned = sorted(
{
str(intent_id).strip()
for intent_id in intent_ids
if str(intent_id).strip() and str(intent_id).strip() not in BLOCKED_LABELS
}
)
return tuple(cleaned)
def expand_single_label_variants(text: str) -> list[str]:
normalized = text.strip().strip(",。!?;; ")
if not normalized:
return []
variants = {
normalized,
normalized.replace("一下", "").strip(),
normalized.replace("帮我", "").strip(),
normalized.replace("", "").strip(),
f"帮我{normalized}",
f"{normalized}",
f"{normalized}一下",
}
cleaned: list[str] = []
for item in variants:
compact = " ".join(item.split()).strip(",。!?;; ")
if compact:
cleaned.append(compact)
return cleaned
def load_single_label_samples(file_path: Path) -> list[MultiLabelSample]:
samples: list[MultiLabelSample] = []
if not file_path.exists():
return samples
for line in file_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
intent_ids = normalize_intent_ids([str(payload.get("intent_id") or "")])
if not intent_ids:
continue
text = normalize_text(str(payload.get("text") or ""))
if not text:
continue
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
return samples
def load_domain_samples(file_path: Path) -> list[MultiLabelSample]:
if not file_path.exists():
return []
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
intents = payload.get("intents", [])
samples: list[MultiLabelSample] = []
seen: set[tuple[str, tuple[str, ...]]] = set()
for item in intents:
intent_ids = normalize_intent_ids([str(item.get("intent_id") or "")])
if not intent_ids:
continue
seed_texts = list(item.get("examples") or [])
seed_texts.extend(item.get("keywords") or [])
label = str(item.get("label") or "").strip()
if label:
seed_texts.append(label)
for text in seed_texts:
normalized = normalize_text(text)
if not normalized:
continue
for variant in expand_single_label_variants(normalized):
key = (variant, intent_ids)
if key in seen:
continue
seen.add(key)
samples.append(MultiLabelSample(text=variant, intent_ids=intent_ids))
return samples
def load_multilabel_samples(file_path: Path) -> list[MultiLabelSample]:
samples: list[MultiLabelSample] = []
if not file_path.exists():
return samples
for line in file_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
intent_ids = normalize_intent_ids(list(payload.get("intent_ids") or []))
if len(intent_ids) < 2:
continue
text = normalize_text(str(payload.get("text") or ""))
if not text:
continue
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
return samples
def augment_multilabel_samples(samples: list[MultiLabelSample]) -> list[MultiLabelSample]:
augmented = list(samples)
seen = {(sample.text, sample.intent_ids) for sample in augmented}
for sample in list(samples):
variants = {
sample.text,
f"帮我{sample.text}",
f"{sample.text}",
sample.text.replace("", ", "),
sample.text.replace("", ""),
}
for source, target in CONNECTOR_VARIANTS:
if source in sample.text:
variants.add(sample.text.replace(source, target, 1))
for variant in variants:
normalized = normalize_text(variant).strip(",。!?;; ")
key = (normalized, sample.intent_ids)
if normalized and key not in seen:
augmented.append(MultiLabelSample(text=normalized, intent_ids=sample.intent_ids))
seen.add(key)
return augmented
def load_all_samples() -> list[MultiLabelSample]:
samples = load_single_label_samples(SINGLE_LABEL_PATH)
samples.extend(load_domain_samples(DOMAIN_PATH))
samples.extend(augment_multilabel_samples(load_multilabel_samples(MULTI_LABEL_PATH)))
deduped: list[MultiLabelSample] = []
seen: set[tuple[str, tuple[str, ...]]] = set()
for sample in samples:
key = (sample.text, sample.intent_ids)
if key in seen:
continue
seen.add(key)
deduped.append(sample)
random.shuffle(deduped)
return deduped
def split_samples(samples: list[MultiLabelSample]) -> tuple[list[MultiLabelSample], list[MultiLabelSample]]:
grouped: dict[tuple[str, ...], list[MultiLabelSample]] = {}
for sample in samples:
grouped.setdefault(sample.intent_ids, []).append(sample)
train_samples: list[MultiLabelSample] = []
dev_samples: list[MultiLabelSample] = []
for items in grouped.values():
random.shuffle(items)
if len(items) == 1:
train_samples.extend(items)
continue
cut = max(1, int(len(items) * 0.8))
if cut >= len(items):
cut = len(items) - 1
train_samples.extend(items[:cut])
dev_samples.extend(items[cut:])
if not dev_samples:
dev_samples = train_samples[-max(1, min(32, len(train_samples) // 5 or 1)) :]
train_samples = train_samples[: len(train_samples) - len(dev_samples)]
random.shuffle(train_samples)
random.shuffle(dev_samples)
return train_samples, dev_samples
def logits_to_probabilities(logits: torch.Tensor) -> list[list[float]]:
return torch.sigmoid(logits).detach().cpu().tolist()
def compute_metrics(
probabilities: list[list[float]],
targets: list[list[float]],
threshold: float,
top_k: int,
) -> dict[str, float]:
true_positive = 0
false_positive = 0
false_negative = 0
exact_match = 0
recall_at_k_total = 0.0
total = len(probabilities)
for scores, target in zip(probabilities, targets):
predicted = {index for index, score in enumerate(scores) if score >= threshold}
expected = {index for index, value in enumerate(target) if value >= 0.5}
if predicted == expected:
exact_match += 1
true_positive += len(predicted & expected)
false_positive += len(predicted - expected)
false_negative += len(expected - predicted)
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
if expected:
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
return {
"micro_precision": round(precision, 4),
"micro_recall": round(recall, 4),
"micro_f1": round(micro_f1, 4),
"exact_match": round(exact_match / total, 4) if total else 0.0,
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
}
def evaluate(model, loader: DataLoader, device: torch.device, threshold: float, top_k: int) -> tuple[float, dict[str, float]]:
model.eval()
total_loss = 0.0
probabilities: list[list[float]] = []
targets: list[list[float]] = []
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
total_loss += float(outputs.loss.item())
probabilities.extend(logits_to_probabilities(outputs.logits))
targets.extend(labels.detach().cpu().tolist())
avg_loss = total_loss / max(len(loader), 1)
return avg_loss, compute_metrics(probabilities, targets, threshold=threshold, top_k=top_k)
def main() -> None:
set_seed(SEED)
samples = load_all_samples()
intents = sorted({intent_id for sample in samples for intent_id in sample.intent_ids})
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
train_samples, dev_samples = split_samples(samples)
base_model = resolve_base_model()
tokenizer = AutoTokenizer.from_pretrained(base_model)
train_dataset = MultiLabelIntentDataset(train_samples, tokenizer, label_to_id)
dev_dataset = MultiLabelIntentDataset(dev_samples, tokenizer, label_to_id)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
model = AutoModelForSequenceClassification.from_pretrained(
base_model,
num_labels=len(intents),
id2label=id_to_label,
label2id=label_to_id,
problem_type="multi_label_classification",
)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
best_dev_f1 = 0.0
best_state = None
best_metrics: dict[str, float] = {}
for epoch in range(1, EPOCHS + 1):
model.train()
total_loss = 0.0
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
total_loss += float(loss.item())
dev_loss, dev_metrics = evaluate(model, dev_loader, device, threshold=THRESHOLD, top_k=TOP_K)
avg_loss = total_loss / max(len(train_loader), 1)
print(
" ".join(
[
f"epoch={epoch}",
f"train_loss={avg_loss:.4f}",
f"dev_loss={dev_loss:.4f}",
f"dev_micro_f1={dev_metrics['micro_f1']:.4f}",
f"dev_exact_match={dev_metrics['exact_match']:.4f}",
]
)
)
if dev_metrics["micro_f1"] >= best_dev_f1:
best_dev_f1 = dev_metrics["micro_f1"]
best_metrics = dict(dev_metrics)
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
if best_state is not None:
model.load_state_dict(best_state)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
(OUTPUT_DIR / "label_map.json").write_text(
json.dumps(label_map, ensure_ascii=False, indent=2),
encoding="utf-8",
)
train_summary = {
"task_type": "multi_label_intent_detection",
"base_model": base_model,
"epochs": EPOCHS,
"batch_size": BATCH_SIZE,
"learning_rate": LEARNING_RATE,
"threshold": THRESHOLD,
"top_k": TOP_K,
"train_size": len(train_samples),
"dev_size": len(dev_samples),
"label_count": len(intents),
"labels": intents,
"best_dev_metrics": best_metrics,
"device": str(device),
}
(OUTPUT_DIR / "train_summary.json").write_text(
json.dumps(train_summary, ensure_ascii=False, indent=2),
encoding="utf-8",
)
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()