Update project and configurations
This commit is contained in:
@@ -0,0 +1,415 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
import yaml
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
SINGLE_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||||
MULTI_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
|
||||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_multi_intent"
|
||||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||||
SOCIAL_LABEL = "__social__"
|
||||
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
|
||||
BLOCKED_LABELS = {SOCIAL_LABEL, OUT_OF_SCOPE_LABEL}
|
||||
MAX_LENGTH = 48
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 12
|
||||
LEARNING_RATE = 2e-5
|
||||
THRESHOLD = 0.5
|
||||
TOP_K = 4
|
||||
SEED = 42
|
||||
|
||||
CONNECTOR_VARIANTS: tuple[tuple[str, str], ...] = (
|
||||
("并", "然后"),
|
||||
("然后", "并"),
|
||||
("顺便", "再"),
|
||||
("再", "顺便"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiLabelSample:
|
||||
text: str
|
||||
intent_ids: tuple[str, ...]
|
||||
|
||||
|
||||
class MultiLabelIntentDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
samples: list[MultiLabelSample],
|
||||
tokenizer,
|
||||
label_to_id: dict[str, int],
|
||||
) -> None:
|
||||
self._samples = samples
|
||||
self._tokenizer = tokenizer
|
||||
self._label_to_id = label_to_id
|
||||
self._label_size = len(label_to_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||||
sample = self._samples[index]
|
||||
encoded = self._tokenizer(
|
||||
sample.text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=MAX_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = torch.zeros(self._label_size, dtype=torch.float32)
|
||||
for intent_id in sample.intent_ids:
|
||||
labels[self._label_to_id[intent_id]] = 1.0
|
||||
return {
|
||||
"input_ids": encoded["input_ids"].squeeze(0),
|
||||
"attention_mask": encoded["attention_mask"].squeeze(0),
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
|
||||
def resolve_base_model() -> str:
|
||||
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
return DEFAULT_BASE_MODEL
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
return " ".join(str(text).strip().split())
|
||||
|
||||
|
||||
def normalize_intent_ids(intent_ids: list[str] | tuple[str, ...]) -> tuple[str, ...]:
|
||||
cleaned = sorted(
|
||||
{
|
||||
str(intent_id).strip()
|
||||
for intent_id in intent_ids
|
||||
if str(intent_id).strip() and str(intent_id).strip() not in BLOCKED_LABELS
|
||||
}
|
||||
)
|
||||
return tuple(cleaned)
|
||||
|
||||
|
||||
def expand_single_label_variants(text: str) -> list[str]:
|
||||
normalized = text.strip().strip(",。!?;; ")
|
||||
if not normalized:
|
||||
return []
|
||||
variants = {
|
||||
normalized,
|
||||
normalized.replace("一下", "").strip(),
|
||||
normalized.replace("帮我", "").strip(),
|
||||
normalized.replace("请", "").strip(),
|
||||
f"帮我{normalized}",
|
||||
f"请{normalized}",
|
||||
f"{normalized}一下",
|
||||
}
|
||||
cleaned: list[str] = []
|
||||
for item in variants:
|
||||
compact = " ".join(item.split()).strip(",。!?;; ")
|
||||
if compact:
|
||||
cleaned.append(compact)
|
||||
return cleaned
|
||||
|
||||
|
||||
def load_single_label_samples(file_path: Path) -> list[MultiLabelSample]:
|
||||
samples: list[MultiLabelSample] = []
|
||||
if not file_path.exists():
|
||||
return samples
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
intent_ids = normalize_intent_ids([str(payload.get("intent_id") or "")])
|
||||
if not intent_ids:
|
||||
continue
|
||||
text = normalize_text(str(payload.get("text") or ""))
|
||||
if not text:
|
||||
continue
|
||||
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
|
||||
return samples
|
||||
|
||||
|
||||
def load_domain_samples(file_path: Path) -> list[MultiLabelSample]:
|
||||
if not file_path.exists():
|
||||
return []
|
||||
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
|
||||
intents = payload.get("intents", [])
|
||||
samples: list[MultiLabelSample] = []
|
||||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||||
for item in intents:
|
||||
intent_ids = normalize_intent_ids([str(item.get("intent_id") or "")])
|
||||
if not intent_ids:
|
||||
continue
|
||||
seed_texts = list(item.get("examples") or [])
|
||||
seed_texts.extend(item.get("keywords") or [])
|
||||
label = str(item.get("label") or "").strip()
|
||||
if label:
|
||||
seed_texts.append(label)
|
||||
for text in seed_texts:
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
continue
|
||||
for variant in expand_single_label_variants(normalized):
|
||||
key = (variant, intent_ids)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(MultiLabelSample(text=variant, intent_ids=intent_ids))
|
||||
return samples
|
||||
|
||||
|
||||
def load_multilabel_samples(file_path: Path) -> list[MultiLabelSample]:
|
||||
samples: list[MultiLabelSample] = []
|
||||
if not file_path.exists():
|
||||
return samples
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
intent_ids = normalize_intent_ids(list(payload.get("intent_ids") or []))
|
||||
if len(intent_ids) < 2:
|
||||
continue
|
||||
text = normalize_text(str(payload.get("text") or ""))
|
||||
if not text:
|
||||
continue
|
||||
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
|
||||
return samples
|
||||
|
||||
|
||||
def augment_multilabel_samples(samples: list[MultiLabelSample]) -> list[MultiLabelSample]:
|
||||
augmented = list(samples)
|
||||
seen = {(sample.text, sample.intent_ids) for sample in augmented}
|
||||
for sample in list(samples):
|
||||
variants = {
|
||||
sample.text,
|
||||
f"帮我{sample.text}",
|
||||
f"请{sample.text}",
|
||||
sample.text.replace(",", ", "),
|
||||
sample.text.replace(",", ""),
|
||||
}
|
||||
for source, target in CONNECTOR_VARIANTS:
|
||||
if source in sample.text:
|
||||
variants.add(sample.text.replace(source, target, 1))
|
||||
for variant in variants:
|
||||
normalized = normalize_text(variant).strip(",。!?;; ")
|
||||
key = (normalized, sample.intent_ids)
|
||||
if normalized and key not in seen:
|
||||
augmented.append(MultiLabelSample(text=normalized, intent_ids=sample.intent_ids))
|
||||
seen.add(key)
|
||||
return augmented
|
||||
|
||||
|
||||
def load_all_samples() -> list[MultiLabelSample]:
|
||||
samples = load_single_label_samples(SINGLE_LABEL_PATH)
|
||||
samples.extend(load_domain_samples(DOMAIN_PATH))
|
||||
samples.extend(augment_multilabel_samples(load_multilabel_samples(MULTI_LABEL_PATH)))
|
||||
deduped: list[MultiLabelSample] = []
|
||||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||||
for sample in samples:
|
||||
key = (sample.text, sample.intent_ids)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(sample)
|
||||
random.shuffle(deduped)
|
||||
return deduped
|
||||
|
||||
|
||||
def split_samples(samples: list[MultiLabelSample]) -> tuple[list[MultiLabelSample], list[MultiLabelSample]]:
|
||||
grouped: dict[tuple[str, ...], list[MultiLabelSample]] = {}
|
||||
for sample in samples:
|
||||
grouped.setdefault(sample.intent_ids, []).append(sample)
|
||||
train_samples: list[MultiLabelSample] = []
|
||||
dev_samples: list[MultiLabelSample] = []
|
||||
for items in grouped.values():
|
||||
random.shuffle(items)
|
||||
if len(items) == 1:
|
||||
train_samples.extend(items)
|
||||
continue
|
||||
cut = max(1, int(len(items) * 0.8))
|
||||
if cut >= len(items):
|
||||
cut = len(items) - 1
|
||||
train_samples.extend(items[:cut])
|
||||
dev_samples.extend(items[cut:])
|
||||
if not dev_samples:
|
||||
dev_samples = train_samples[-max(1, min(32, len(train_samples) // 5 or 1)) :]
|
||||
train_samples = train_samples[: len(train_samples) - len(dev_samples)]
|
||||
random.shuffle(train_samples)
|
||||
random.shuffle(dev_samples)
|
||||
return train_samples, dev_samples
|
||||
|
||||
|
||||
def logits_to_probabilities(logits: torch.Tensor) -> list[list[float]]:
|
||||
return torch.sigmoid(logits).detach().cpu().tolist()
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
probabilities: list[list[float]],
|
||||
targets: list[list[float]],
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
) -> dict[str, float]:
|
||||
true_positive = 0
|
||||
false_positive = 0
|
||||
false_negative = 0
|
||||
exact_match = 0
|
||||
recall_at_k_total = 0.0
|
||||
total = len(probabilities)
|
||||
for scores, target in zip(probabilities, targets):
|
||||
predicted = {index for index, score in enumerate(scores) if score >= threshold}
|
||||
expected = {index for index, value in enumerate(target) if value >= 0.5}
|
||||
if predicted == expected:
|
||||
exact_match += 1
|
||||
true_positive += len(predicted & expected)
|
||||
false_positive += len(predicted - expected)
|
||||
false_negative += len(expected - predicted)
|
||||
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
|
||||
if expected:
|
||||
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
|
||||
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
||||
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
||||
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(micro_f1, 4),
|
||||
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
||||
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
|
||||
def evaluate(model, loader: DataLoader, device: torch.device, threshold: float, top_k: int) -> tuple[float, dict[str, float]]:
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
probabilities: list[list[float]] = []
|
||||
targets: list[list[float]] = []
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
total_loss += float(outputs.loss.item())
|
||||
probabilities.extend(logits_to_probabilities(outputs.logits))
|
||||
targets.extend(labels.detach().cpu().tolist())
|
||||
avg_loss = total_loss / max(len(loader), 1)
|
||||
return avg_loss, compute_metrics(probabilities, targets, threshold=threshold, top_k=top_k)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
set_seed(SEED)
|
||||
samples = load_all_samples()
|
||||
intents = sorted({intent_id for sample in samples for intent_id in sample.intent_ids})
|
||||
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
|
||||
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
|
||||
|
||||
train_samples, dev_samples = split_samples(samples)
|
||||
base_model = resolve_base_model()
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
train_dataset = MultiLabelIntentDataset(train_samples, tokenizer, label_to_id)
|
||||
dev_dataset = MultiLabelIntentDataset(dev_samples, tokenizer, label_to_id)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model,
|
||||
num_labels=len(intents),
|
||||
id2label=id_to_label,
|
||||
label2id=label_to_id,
|
||||
problem_type="multi_label_classification",
|
||||
)
|
||||
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
best_dev_f1 = 0.0
|
||||
best_state = None
|
||||
best_metrics: dict[str, float] = {}
|
||||
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += float(loss.item())
|
||||
|
||||
dev_loss, dev_metrics = evaluate(model, dev_loader, device, threshold=THRESHOLD, top_k=TOP_K)
|
||||
avg_loss = total_loss / max(len(train_loader), 1)
|
||||
print(
|
||||
" ".join(
|
||||
[
|
||||
f"epoch={epoch}",
|
||||
f"train_loss={avg_loss:.4f}",
|
||||
f"dev_loss={dev_loss:.4f}",
|
||||
f"dev_micro_f1={dev_metrics['micro_f1']:.4f}",
|
||||
f"dev_exact_match={dev_metrics['exact_match']:.4f}",
|
||||
]
|
||||
)
|
||||
)
|
||||
if dev_metrics["micro_f1"] >= best_dev_f1:
|
||||
best_dev_f1 = dev_metrics["micro_f1"]
|
||||
best_metrics = dict(dev_metrics)
|
||||
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
|
||||
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(OUTPUT_DIR)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
|
||||
(OUTPUT_DIR / "label_map.json").write_text(
|
||||
json.dumps(label_map, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
train_summary = {
|
||||
"task_type": "multi_label_intent_detection",
|
||||
"base_model": base_model,
|
||||
"epochs": EPOCHS,
|
||||
"batch_size": BATCH_SIZE,
|
||||
"learning_rate": LEARNING_RATE,
|
||||
"threshold": THRESHOLD,
|
||||
"top_k": TOP_K,
|
||||
"train_size": len(train_samples),
|
||||
"dev_size": len(dev_samples),
|
||||
"label_count": len(intents),
|
||||
"labels": intents,
|
||||
"best_dev_metrics": best_metrics,
|
||||
"device": str(device),
|
||||
}
|
||||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||||
json.dumps(train_summary, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user