324 lines
13 KiB
Python
324 lines
13 KiB
Python
from app.core.config import settings
|
|
from app.plugins.base import PluginRegistry
|
|
from app.plugins.mock import MockPluginExecutor
|
|
from app.services.agent_service import AgentService
|
|
from app.services.classifier import (
|
|
BertIntentClassifier,
|
|
IntentClassifier,
|
|
JointBertIntentClassifier,
|
|
MockIntentClassifier,
|
|
RemoteIntentClassifier,
|
|
)
|
|
from app.services.config_loader import ConfigLoader
|
|
from app.services.intent_registry import IntentRegistry
|
|
from app.services.joint_nlu import JointBertNLU
|
|
from app.services.knowledge_llm import DashScopeKnowledgeLLM
|
|
from app.services.knowledge_store import KnowledgeStore
|
|
from app.services.multi_intent_detector import (
|
|
BertMultiIntentDetector,
|
|
JointBertMultiIntentDetector,
|
|
MultiIntentDetector,
|
|
)
|
|
from app.services.planner import (
|
|
CompositeWorkflowPlanner,
|
|
DashScopeWorkflowPlanner,
|
|
HeuristicWorkflowPlanner,
|
|
TemplateWorkflowPlanner,
|
|
WorkflowPlanner,
|
|
)
|
|
from app.services.response_policy import ResponsePolicy
|
|
from app.services.rewrite_engine import ContextRewriteEngine
|
|
from app.services.router import (
|
|
HeuristicSlotExtractor,
|
|
JointBertSlotExtractor,
|
|
IntentRouter,
|
|
Router,
|
|
build_matcher_pipeline,
|
|
)
|
|
from app.services.session_store import InMemorySessionStore, RedisSessionStore, SessionStore
|
|
from app.services.social import DashScopeSocialResponder, SocialResponder, SocialRouter
|
|
|
|
|
|
def build_session_store(session_backend: str | None = None) -> SessionStore:
|
|
backend = session_backend or settings.session_backend
|
|
if backend == "memory":
|
|
return InMemorySessionStore()
|
|
if backend == "redis":
|
|
return RedisSessionStore(
|
|
redis_url=settings.redis_url,
|
|
key_prefix=settings.redis_key_prefix,
|
|
ttl_seconds=settings.session_ttl_seconds,
|
|
)
|
|
raise ValueError(f"Unsupported session backend: {backend}")
|
|
|
|
|
|
def build_router(
|
|
intent_registry: IntentRegistry,
|
|
matcher_pipeline: str | None = None,
|
|
classifier_backend: str | None = None,
|
|
classifier: IntentClassifier | None = None,
|
|
joint_nlu: JointBertNLU | None = None,
|
|
) -> Router:
|
|
active_pipeline = matcher_pipeline or settings.matcher_pipeline
|
|
matcher_stages = [stage.strip() for stage in active_pipeline.split(",") if stage.strip()]
|
|
if not matcher_stages:
|
|
matcher_stages = ["classifier"]
|
|
if matcher_stages != ["classifier"]:
|
|
raise ValueError("Only classifier matcher pipeline is supported in bert-first mode")
|
|
if settings.slot_extractor_backend not in {"heuristic", "joint_bert"}:
|
|
raise ValueError(f"Unsupported slot extractor backend: {settings.slot_extractor_backend}")
|
|
classifier = classifier or build_classifier(
|
|
matcher_pipeline=active_pipeline,
|
|
classifier_backend=classifier_backend,
|
|
joint_nlu=joint_nlu,
|
|
)
|
|
if settings.slot_extractor_backend == "heuristic":
|
|
slot_extractor = HeuristicSlotExtractor()
|
|
else:
|
|
if joint_nlu is None:
|
|
raise ValueError("slot_extractor_backend=joint_bert requires a Joint NLU runtime")
|
|
slot_extractor = JointBertSlotExtractor(joint_nlu)
|
|
return IntentRouter(
|
|
matcher=build_matcher_pipeline(
|
|
intent_registry,
|
|
matcher_stages,
|
|
classifier=classifier,
|
|
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
|
|
clarify_margin_threshold=settings.local_clarify_margin_threshold,
|
|
classifier_execute_score_threshold=settings.local_classifier_execute_score_threshold,
|
|
classifier_execute_margin_threshold=settings.local_classifier_execute_margin_threshold,
|
|
),
|
|
slot_extractor=slot_extractor,
|
|
)
|
|
|
|
|
|
def build_classifier(
|
|
matcher_pipeline: str | None = None,
|
|
classifier_backend: str | None = None,
|
|
joint_nlu: JointBertNLU | None = None,
|
|
) -> IntentClassifier | None:
|
|
active_pipeline = matcher_pipeline or settings.matcher_pipeline
|
|
active_backend = classifier_backend or settings.classifier_backend
|
|
if "classifier" not in active_pipeline:
|
|
return None
|
|
fallback = MockIntentClassifier(
|
|
threshold=settings.classifier_threshold,
|
|
top_k=settings.classifier_top_k,
|
|
)
|
|
if active_backend == "mock":
|
|
return fallback
|
|
if active_backend == "bert":
|
|
classifier = BertIntentClassifier(
|
|
model_path=settings.classifier_model_path,
|
|
threshold=settings.classifier_bert_threshold,
|
|
label_map_path=settings.classifier_label_map_path or None,
|
|
fallback=fallback,
|
|
top_k=settings.classifier_top_k,
|
|
)
|
|
if settings.classifier_warmup_enabled:
|
|
classifier.warmup(settings.classifier_warmup_text)
|
|
return classifier
|
|
if active_backend == "joint_bert":
|
|
runtime = joint_nlu or build_joint_nlu()
|
|
classifier = JointBertIntentClassifier(
|
|
nlu=runtime,
|
|
threshold=settings.joint_nlu_intent_threshold if settings.joint_nlu_intent_threshold > 0 else 0.0,
|
|
top_k=settings.joint_nlu_top_k,
|
|
)
|
|
if settings.classifier_warmup_enabled:
|
|
classifier.warmup(settings.classifier_warmup_text)
|
|
return classifier
|
|
if active_backend == "remote":
|
|
return RemoteIntentClassifier(
|
|
endpoint=settings.classifier_remote_url,
|
|
timeout_seconds=settings.classifier_remote_timeout_seconds,
|
|
threshold=settings.classifier_threshold,
|
|
fallback=fallback,
|
|
label_map_path=settings.classifier_label_map_path or None,
|
|
top_k=settings.classifier_top_k,
|
|
)
|
|
raise ValueError(f"Unsupported classifier backend: {active_backend}")
|
|
|
|
|
|
def build_agent_service() -> AgentService:
|
|
return build_agent_service_with_runtime()
|
|
|
|
|
|
def build_agent_service_with_runtime(
|
|
matcher_pipeline: str | None = None,
|
|
classifier_backend: str | None = None,
|
|
session_backend: str | None = None,
|
|
) -> AgentService:
|
|
runtime_bundle = load_runtime_bundle()
|
|
intent_registry = runtime_bundle.intent_registry
|
|
active_classifier_backend = classifier_backend or settings.classifier_backend
|
|
needs_joint_nlu = active_classifier_backend == "joint_bert" or settings.slot_extractor_backend == "joint_bert"
|
|
joint_nlu = build_joint_nlu() if needs_joint_nlu else None
|
|
classifier = build_classifier(
|
|
matcher_pipeline=matcher_pipeline or settings.matcher_pipeline,
|
|
classifier_backend=active_classifier_backend,
|
|
joint_nlu=joint_nlu,
|
|
)
|
|
planner_clause_classifier = (
|
|
classifier
|
|
if settings.planner_clause_classifier_enabled and active_classifier_backend in {"bert", "remote", "joint_bert"}
|
|
else None
|
|
)
|
|
multi_intent_detector = build_multi_intent_detector(
|
|
classifier_backend=classifier_backend,
|
|
joint_nlu=joint_nlu,
|
|
)
|
|
plugin_registry = PluginRegistry()
|
|
MockPluginExecutor().register(plugin_registry)
|
|
return AgentService(
|
|
intent_registry=intent_registry,
|
|
router=build_router(
|
|
intent_registry,
|
|
matcher_pipeline=matcher_pipeline,
|
|
classifier_backend=active_classifier_backend,
|
|
classifier=classifier,
|
|
joint_nlu=joint_nlu,
|
|
),
|
|
plugins=plugin_registry,
|
|
session_store=build_session_store(session_backend=session_backend),
|
|
rewrite_engine=runtime_bundle.rewrite_engine,
|
|
response_policy=ResponsePolicy(
|
|
templates=runtime_bundle.response_templates,
|
|
intent_hints=runtime_bundle.intent_hints,
|
|
),
|
|
dialog_rules=runtime_bundle.dialog_rules,
|
|
dialog_act_engine=runtime_bundle.dialog_act_engine,
|
|
planner=build_planner(
|
|
runtime_bundle.workflow_templates,
|
|
clause_classifier=planner_clause_classifier,
|
|
multi_intent_detector=multi_intent_detector,
|
|
joint_nlu=joint_nlu,
|
|
),
|
|
social_router=SocialRouter(),
|
|
social_responder=build_social_responder(),
|
|
knowledge_llm=build_knowledge_llm(),
|
|
)
|
|
|
|
|
|
def build_intent_registry() -> IntentRegistry:
|
|
return load_runtime_bundle().intent_registry
|
|
|
|
|
|
def load_runtime_bundle():
|
|
return ConfigLoader(
|
|
domain_path=settings.domain_config_path,
|
|
action_path=settings.action_config_path,
|
|
response_path=settings.response_config_path,
|
|
form_path=settings.form_config_path,
|
|
rule_path=settings.rule_config_path,
|
|
dialog_act_path=settings.dialog_act_config_path,
|
|
workflow_path=settings.workflow_config_path,
|
|
legacy_intent_path=settings.intent_config_path,
|
|
context_rewrite_path=settings.context_rewrite_config_path,
|
|
).load()
|
|
|
|
|
|
def build_joint_nlu() -> JointBertNLU:
|
|
runtime = JointBertNLU(
|
|
model_path=settings.joint_nlu_model_path,
|
|
intent_threshold=settings.joint_nlu_intent_threshold if settings.joint_nlu_intent_threshold > 0 else None,
|
|
top_k=settings.joint_nlu_top_k,
|
|
)
|
|
if settings.classifier_warmup_enabled:
|
|
runtime.warmup(settings.classifier_warmup_text)
|
|
return runtime
|
|
|
|
|
|
def build_multi_intent_detector(
|
|
classifier_backend: str | None = None,
|
|
joint_nlu: JointBertNLU | None = None,
|
|
) -> MultiIntentDetector | None:
|
|
active_backend = classifier_backend or settings.classifier_backend
|
|
if not settings.planner_multi_intent_detector_enabled:
|
|
return None
|
|
if active_backend not in {"bert", "joint_bert"}:
|
|
return None
|
|
if active_backend == "joint_bert":
|
|
runtime = joint_nlu or build_joint_nlu()
|
|
detector = JointBertMultiIntentDetector(
|
|
nlu=runtime,
|
|
threshold=settings.planner_multi_intent_detector_threshold if settings.planner_multi_intent_detector_threshold > 0 else None,
|
|
top_k=settings.planner_multi_intent_detector_top_k,
|
|
max_labels=settings.planner_multi_intent_detector_max_labels,
|
|
)
|
|
if settings.classifier_warmup_enabled:
|
|
detector.warmup(settings.classifier_warmup_text)
|
|
return detector
|
|
detector_model_path = settings.planner_multi_intent_detector_model_path or settings.classifier_model_path
|
|
detector = BertMultiIntentDetector(
|
|
model_path=detector_model_path,
|
|
threshold=settings.planner_multi_intent_detector_threshold,
|
|
top_k=settings.planner_multi_intent_detector_top_k,
|
|
max_labels=settings.planner_multi_intent_detector_max_labels,
|
|
)
|
|
if settings.classifier_warmup_enabled:
|
|
detector.warmup(settings.classifier_warmup_text)
|
|
return detector
|
|
|
|
|
|
def build_planner(
|
|
workflow_templates=None,
|
|
clause_classifier: IntentClassifier | None = None,
|
|
multi_intent_detector: MultiIntentDetector | None = None,
|
|
joint_nlu: JointBertNLU | None = None,
|
|
) -> WorkflowPlanner:
|
|
template_planner = TemplateWorkflowPlanner(
|
|
workflow_templates,
|
|
clause_classifier=clause_classifier,
|
|
multi_intent_detector=multi_intent_detector,
|
|
joint_nlu=joint_nlu,
|
|
classifier_weight=settings.planner_clause_classifier_weight,
|
|
model_only_threshold=settings.planner_clause_model_only_threshold,
|
|
)
|
|
local_first = CompositeWorkflowPlanner(
|
|
[
|
|
template_planner,
|
|
HeuristicWorkflowPlanner(
|
|
clause_classifier=clause_classifier,
|
|
multi_intent_detector=multi_intent_detector,
|
|
joint_nlu=joint_nlu,
|
|
classifier_weight=settings.planner_clause_classifier_weight,
|
|
model_only_threshold=settings.planner_clause_model_only_threshold,
|
|
),
|
|
]
|
|
)
|
|
if settings.planner_backend == "heuristic":
|
|
return local_first
|
|
if settings.planner_backend == "dashscope":
|
|
cloud_planner = DashScopeWorkflowPlanner(
|
|
base_url=settings.planner_base_url,
|
|
api_key=settings.planner_api_key,
|
|
model_name=settings.planner_model_name,
|
|
timeout_seconds=settings.planner_timeout_seconds,
|
|
fallback=local_first,
|
|
joint_nlu=joint_nlu,
|
|
)
|
|
return CompositeWorkflowPlanner([local_first, cloud_planner])
|
|
raise ValueError(f"Unsupported planner backend: {settings.planner_backend}")
|
|
|
|
|
|
def build_social_responder() -> SocialResponder:
|
|
return DashScopeSocialResponder(
|
|
base_url=settings.planner_base_url,
|
|
api_key=settings.planner_api_key,
|
|
model_name=settings.planner_model_name,
|
|
timeout_seconds=settings.planner_timeout_seconds,
|
|
)
|
|
|
|
|
|
def build_knowledge_llm() -> DashScopeKnowledgeLLM:
|
|
"""构建知识库 LLM 问答器(与 planner 共用 DashScope 配置)。"""
|
|
store = KnowledgeStore(settings.knowledge_dir)
|
|
return DashScopeKnowledgeLLM(
|
|
base_url=settings.planner_base_url,
|
|
api_key=settings.planner_api_key,
|
|
model_name=settings.planner_model_name,
|
|
knowledge_store=store,
|
|
timeout_seconds=12.0,
|
|
)
|