416 lines
15 KiB
Python
416 lines
15 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import random
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader, Dataset
|
||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||
import yaml
|
||
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||
SINGLE_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||
MULTI_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
|
||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_multi_intent"
|
||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||
SOCIAL_LABEL = "__social__"
|
||
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
|
||
BLOCKED_LABELS = {SOCIAL_LABEL, OUT_OF_SCOPE_LABEL}
|
||
MAX_LENGTH = 48
|
||
BATCH_SIZE = 8
|
||
EPOCHS = 12
|
||
LEARNING_RATE = 2e-5
|
||
THRESHOLD = 0.5
|
||
TOP_K = 4
|
||
SEED = 42
|
||
|
||
CONNECTOR_VARIANTS: tuple[tuple[str, str], ...] = (
|
||
("并", "然后"),
|
||
("然后", "并"),
|
||
("顺便", "再"),
|
||
("再", "顺便"),
|
||
)
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class MultiLabelSample:
|
||
text: str
|
||
intent_ids: tuple[str, ...]
|
||
|
||
|
||
class MultiLabelIntentDataset(Dataset):
|
||
def __init__(
|
||
self,
|
||
samples: list[MultiLabelSample],
|
||
tokenizer,
|
||
label_to_id: dict[str, int],
|
||
) -> None:
|
||
self._samples = samples
|
||
self._tokenizer = tokenizer
|
||
self._label_to_id = label_to_id
|
||
self._label_size = len(label_to_id)
|
||
|
||
def __len__(self) -> int:
|
||
return len(self._samples)
|
||
|
||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||
sample = self._samples[index]
|
||
encoded = self._tokenizer(
|
||
sample.text,
|
||
truncation=True,
|
||
padding="max_length",
|
||
max_length=MAX_LENGTH,
|
||
return_tensors="pt",
|
||
)
|
||
labels = torch.zeros(self._label_size, dtype=torch.float32)
|
||
for intent_id in sample.intent_ids:
|
||
labels[self._label_to_id[intent_id]] = 1.0
|
||
return {
|
||
"input_ids": encoded["input_ids"].squeeze(0),
|
||
"attention_mask": encoded["attention_mask"].squeeze(0),
|
||
"labels": labels,
|
||
}
|
||
|
||
|
||
def set_seed(seed: int) -> None:
|
||
random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
if torch.backends.mps.is_available():
|
||
torch.mps.manual_seed(seed)
|
||
|
||
|
||
def resolve_base_model() -> str:
|
||
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
|
||
if configured:
|
||
return configured
|
||
return DEFAULT_BASE_MODEL
|
||
|
||
|
||
def normalize_text(text: str) -> str:
|
||
return " ".join(str(text).strip().split())
|
||
|
||
|
||
def normalize_intent_ids(intent_ids: list[str] | tuple[str, ...]) -> tuple[str, ...]:
|
||
cleaned = sorted(
|
||
{
|
||
str(intent_id).strip()
|
||
for intent_id in intent_ids
|
||
if str(intent_id).strip() and str(intent_id).strip() not in BLOCKED_LABELS
|
||
}
|
||
)
|
||
return tuple(cleaned)
|
||
|
||
|
||
def expand_single_label_variants(text: str) -> list[str]:
|
||
normalized = text.strip().strip(",。!?;; ")
|
||
if not normalized:
|
||
return []
|
||
variants = {
|
||
normalized,
|
||
normalized.replace("一下", "").strip(),
|
||
normalized.replace("帮我", "").strip(),
|
||
normalized.replace("请", "").strip(),
|
||
f"帮我{normalized}",
|
||
f"请{normalized}",
|
||
f"{normalized}一下",
|
||
}
|
||
cleaned: list[str] = []
|
||
for item in variants:
|
||
compact = " ".join(item.split()).strip(",。!?;; ")
|
||
if compact:
|
||
cleaned.append(compact)
|
||
return cleaned
|
||
|
||
|
||
def load_single_label_samples(file_path: Path) -> list[MultiLabelSample]:
|
||
samples: list[MultiLabelSample] = []
|
||
if not file_path.exists():
|
||
return samples
|
||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
payload = json.loads(line)
|
||
intent_ids = normalize_intent_ids([str(payload.get("intent_id") or "")])
|
||
if not intent_ids:
|
||
continue
|
||
text = normalize_text(str(payload.get("text") or ""))
|
||
if not text:
|
||
continue
|
||
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
|
||
return samples
|
||
|
||
|
||
def load_domain_samples(file_path: Path) -> list[MultiLabelSample]:
|
||
if not file_path.exists():
|
||
return []
|
||
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
|
||
intents = payload.get("intents", [])
|
||
samples: list[MultiLabelSample] = []
|
||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||
for item in intents:
|
||
intent_ids = normalize_intent_ids([str(item.get("intent_id") or "")])
|
||
if not intent_ids:
|
||
continue
|
||
seed_texts = list(item.get("examples") or [])
|
||
seed_texts.extend(item.get("keywords") or [])
|
||
label = str(item.get("label") or "").strip()
|
||
if label:
|
||
seed_texts.append(label)
|
||
for text in seed_texts:
|
||
normalized = normalize_text(text)
|
||
if not normalized:
|
||
continue
|
||
for variant in expand_single_label_variants(normalized):
|
||
key = (variant, intent_ids)
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
samples.append(MultiLabelSample(text=variant, intent_ids=intent_ids))
|
||
return samples
|
||
|
||
|
||
def load_multilabel_samples(file_path: Path) -> list[MultiLabelSample]:
|
||
samples: list[MultiLabelSample] = []
|
||
if not file_path.exists():
|
||
return samples
|
||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
payload = json.loads(line)
|
||
intent_ids = normalize_intent_ids(list(payload.get("intent_ids") or []))
|
||
if len(intent_ids) < 2:
|
||
continue
|
||
text = normalize_text(str(payload.get("text") or ""))
|
||
if not text:
|
||
continue
|
||
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
|
||
return samples
|
||
|
||
|
||
def augment_multilabel_samples(samples: list[MultiLabelSample]) -> list[MultiLabelSample]:
|
||
augmented = list(samples)
|
||
seen = {(sample.text, sample.intent_ids) for sample in augmented}
|
||
for sample in list(samples):
|
||
variants = {
|
||
sample.text,
|
||
f"帮我{sample.text}",
|
||
f"请{sample.text}",
|
||
sample.text.replace(",", ", "),
|
||
sample.text.replace(",", ""),
|
||
}
|
||
for source, target in CONNECTOR_VARIANTS:
|
||
if source in sample.text:
|
||
variants.add(sample.text.replace(source, target, 1))
|
||
for variant in variants:
|
||
normalized = normalize_text(variant).strip(",。!?;; ")
|
||
key = (normalized, sample.intent_ids)
|
||
if normalized and key not in seen:
|
||
augmented.append(MultiLabelSample(text=normalized, intent_ids=sample.intent_ids))
|
||
seen.add(key)
|
||
return augmented
|
||
|
||
|
||
def load_all_samples() -> list[MultiLabelSample]:
|
||
samples = load_single_label_samples(SINGLE_LABEL_PATH)
|
||
samples.extend(load_domain_samples(DOMAIN_PATH))
|
||
samples.extend(augment_multilabel_samples(load_multilabel_samples(MULTI_LABEL_PATH)))
|
||
deduped: list[MultiLabelSample] = []
|
||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||
for sample in samples:
|
||
key = (sample.text, sample.intent_ids)
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
deduped.append(sample)
|
||
random.shuffle(deduped)
|
||
return deduped
|
||
|
||
|
||
def split_samples(samples: list[MultiLabelSample]) -> tuple[list[MultiLabelSample], list[MultiLabelSample]]:
|
||
grouped: dict[tuple[str, ...], list[MultiLabelSample]] = {}
|
||
for sample in samples:
|
||
grouped.setdefault(sample.intent_ids, []).append(sample)
|
||
train_samples: list[MultiLabelSample] = []
|
||
dev_samples: list[MultiLabelSample] = []
|
||
for items in grouped.values():
|
||
random.shuffle(items)
|
||
if len(items) == 1:
|
||
train_samples.extend(items)
|
||
continue
|
||
cut = max(1, int(len(items) * 0.8))
|
||
if cut >= len(items):
|
||
cut = len(items) - 1
|
||
train_samples.extend(items[:cut])
|
||
dev_samples.extend(items[cut:])
|
||
if not dev_samples:
|
||
dev_samples = train_samples[-max(1, min(32, len(train_samples) // 5 or 1)) :]
|
||
train_samples = train_samples[: len(train_samples) - len(dev_samples)]
|
||
random.shuffle(train_samples)
|
||
random.shuffle(dev_samples)
|
||
return train_samples, dev_samples
|
||
|
||
|
||
def logits_to_probabilities(logits: torch.Tensor) -> list[list[float]]:
|
||
return torch.sigmoid(logits).detach().cpu().tolist()
|
||
|
||
|
||
def compute_metrics(
|
||
probabilities: list[list[float]],
|
||
targets: list[list[float]],
|
||
threshold: float,
|
||
top_k: int,
|
||
) -> dict[str, float]:
|
||
true_positive = 0
|
||
false_positive = 0
|
||
false_negative = 0
|
||
exact_match = 0
|
||
recall_at_k_total = 0.0
|
||
total = len(probabilities)
|
||
for scores, target in zip(probabilities, targets):
|
||
predicted = {index for index, score in enumerate(scores) if score >= threshold}
|
||
expected = {index for index, value in enumerate(target) if value >= 0.5}
|
||
if predicted == expected:
|
||
exact_match += 1
|
||
true_positive += len(predicted & expected)
|
||
false_positive += len(predicted - expected)
|
||
false_negative += len(expected - predicted)
|
||
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
|
||
if expected:
|
||
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
|
||
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
||
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
||
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||
return {
|
||
"micro_precision": round(precision, 4),
|
||
"micro_recall": round(recall, 4),
|
||
"micro_f1": round(micro_f1, 4),
|
||
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
||
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
|
||
}
|
||
|
||
|
||
def evaluate(model, loader: DataLoader, device: torch.device, threshold: float, top_k: int) -> tuple[float, dict[str, float]]:
|
||
model.eval()
|
||
total_loss = 0.0
|
||
probabilities: list[list[float]] = []
|
||
targets: list[list[float]] = []
|
||
with torch.no_grad():
|
||
for batch in loader:
|
||
input_ids = batch["input_ids"].to(device)
|
||
attention_mask = batch["attention_mask"].to(device)
|
||
labels = batch["labels"].to(device)
|
||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||
total_loss += float(outputs.loss.item())
|
||
probabilities.extend(logits_to_probabilities(outputs.logits))
|
||
targets.extend(labels.detach().cpu().tolist())
|
||
avg_loss = total_loss / max(len(loader), 1)
|
||
return avg_loss, compute_metrics(probabilities, targets, threshold=threshold, top_k=top_k)
|
||
|
||
|
||
def main() -> None:
|
||
set_seed(SEED)
|
||
samples = load_all_samples()
|
||
intents = sorted({intent_id for sample in samples for intent_id in sample.intent_ids})
|
||
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
|
||
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
|
||
|
||
train_samples, dev_samples = split_samples(samples)
|
||
base_model = resolve_base_model()
|
||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||
train_dataset = MultiLabelIntentDataset(train_samples, tokenizer, label_to_id)
|
||
dev_dataset = MultiLabelIntentDataset(dev_samples, tokenizer, label_to_id)
|
||
|
||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
|
||
|
||
model = AutoModelForSequenceClassification.from_pretrained(
|
||
base_model,
|
||
num_labels=len(intents),
|
||
id2label=id_to_label,
|
||
label2id=label_to_id,
|
||
problem_type="multi_label_classification",
|
||
)
|
||
|
||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||
model.to(device)
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||
|
||
best_dev_f1 = 0.0
|
||
best_state = None
|
||
best_metrics: dict[str, float] = {}
|
||
|
||
for epoch in range(1, EPOCHS + 1):
|
||
model.train()
|
||
total_loss = 0.0
|
||
for batch in train_loader:
|
||
optimizer.zero_grad()
|
||
input_ids = batch["input_ids"].to(device)
|
||
attention_mask = batch["attention_mask"].to(device)
|
||
labels = batch["labels"].to(device)
|
||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||
loss = outputs.loss
|
||
loss.backward()
|
||
optimizer.step()
|
||
total_loss += float(loss.item())
|
||
|
||
dev_loss, dev_metrics = evaluate(model, dev_loader, device, threshold=THRESHOLD, top_k=TOP_K)
|
||
avg_loss = total_loss / max(len(train_loader), 1)
|
||
print(
|
||
" ".join(
|
||
[
|
||
f"epoch={epoch}",
|
||
f"train_loss={avg_loss:.4f}",
|
||
f"dev_loss={dev_loss:.4f}",
|
||
f"dev_micro_f1={dev_metrics['micro_f1']:.4f}",
|
||
f"dev_exact_match={dev_metrics['exact_match']:.4f}",
|
||
]
|
||
)
|
||
)
|
||
if dev_metrics["micro_f1"] >= best_dev_f1:
|
||
best_dev_f1 = dev_metrics["micro_f1"]
|
||
best_metrics = dict(dev_metrics)
|
||
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
|
||
|
||
if best_state is not None:
|
||
model.load_state_dict(best_state)
|
||
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
model.save_pretrained(OUTPUT_DIR)
|
||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
|
||
(OUTPUT_DIR / "label_map.json").write_text(
|
||
json.dumps(label_map, ensure_ascii=False, indent=2),
|
||
encoding="utf-8",
|
||
)
|
||
train_summary = {
|
||
"task_type": "multi_label_intent_detection",
|
||
"base_model": base_model,
|
||
"epochs": EPOCHS,
|
||
"batch_size": BATCH_SIZE,
|
||
"learning_rate": LEARNING_RATE,
|
||
"threshold": THRESHOLD,
|
||
"top_k": TOP_K,
|
||
"train_size": len(train_samples),
|
||
"dev_size": len(dev_samples),
|
||
"label_count": len(intents),
|
||
"labels": intents,
|
||
"best_dev_metrics": best_metrics,
|
||
"device": str(device),
|
||
}
|
||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||
json.dumps(train_summary, ensure_ascii=False, indent=2),
|
||
encoding="utf-8",
|
||
)
|
||
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|