174 lines
7.3 KiB
Python
174 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
from app.schemas.configuration import (
|
|
ActionsConfig,
|
|
ContextRewriteConfig,
|
|
DialogActsConfig,
|
|
DialogRulesConfig,
|
|
DomainConfig,
|
|
FormsConfig,
|
|
ResponsesConfig,
|
|
WorkflowTemplatesConfig,
|
|
)
|
|
from app.services.dialog_act import DialogActEngine
|
|
from app.services.dialog_rules import DialogRuleEngine
|
|
from app.services.intent_registry import IntentRegistry
|
|
from app.services.rewrite_engine import ContextRewriteEngine
|
|
|
|
|
|
@dataclass
|
|
class RuntimeConfigBundle:
|
|
intent_registry: IntentRegistry
|
|
response_templates: dict[str, str]
|
|
intent_hints: dict[str, str]
|
|
dialog_rules: DialogRuleEngine
|
|
dialog_act_engine: DialogActEngine
|
|
workflow_templates: WorkflowTemplatesConfig
|
|
rewrite_engine: ContextRewriteEngine = field(default_factory=ContextRewriteEngine)
|
|
|
|
|
|
class ConfigLoader:
|
|
def __init__(
|
|
self,
|
|
domain_path: str,
|
|
action_path: str,
|
|
response_path: str,
|
|
form_path: str | None = None,
|
|
rule_path: str | None = None,
|
|
dialog_act_path: str | None = None,
|
|
workflow_path: str | None = None,
|
|
legacy_intent_path: str | None = None,
|
|
context_rewrite_path: str | None = None,
|
|
) -> None:
|
|
self._domain_path = Path(domain_path)
|
|
self._action_path = Path(action_path)
|
|
self._response_path = Path(response_path)
|
|
self._form_path = Path(form_path) if form_path else None
|
|
self._rule_path = Path(rule_path) if rule_path else None
|
|
self._dialog_act_path = Path(dialog_act_path) if dialog_act_path else None
|
|
self._workflow_path = Path(workflow_path) if workflow_path else None
|
|
self._legacy_intent_path = Path(legacy_intent_path) if legacy_intent_path else None
|
|
self._context_rewrite_path = Path(context_rewrite_path) if context_rewrite_path else None
|
|
|
|
def load(self) -> RuntimeConfigBundle:
|
|
if self._domain_path.exists() and self._action_path.exists():
|
|
return self._load_from_config_files()
|
|
if self._legacy_intent_path is not None and self._legacy_intent_path.exists():
|
|
return RuntimeConfigBundle(
|
|
intent_registry=IntentRegistry.from_json(str(self._legacy_intent_path)),
|
|
response_templates=self._load_response_templates(),
|
|
intent_hints={},
|
|
dialog_rules=self._load_dialog_rules(),
|
|
dialog_act_engine=self._load_dialog_act_engine(),
|
|
workflow_templates=self._load_workflow_templates(),
|
|
rewrite_engine=self._load_rewrite_engine(),
|
|
)
|
|
raise FileNotFoundError(
|
|
"no runtime config found, expected config/*.yml or legacy intent json"
|
|
)
|
|
|
|
def _load_from_config_files(self) -> RuntimeConfigBundle:
|
|
domain = DomainConfig.model_validate(self._read_structured_file(self._domain_path))
|
|
actions = ActionsConfig.model_validate(self._read_structured_file(self._action_path))
|
|
forms = self._load_forms()
|
|
action_map = {item.action_id: item for item in actions.actions}
|
|
form_map = {item.intent_id: item for item in forms.forms}
|
|
intents = []
|
|
for item in domain.intents:
|
|
form = form_map.get(item.intent_id)
|
|
if form is not None:
|
|
item = item.model_copy(
|
|
update={
|
|
"required_slots": form.required_slots,
|
|
"ask_templates": form.ask_templates,
|
|
}
|
|
)
|
|
intents.append(item.to_intent_definition(action_map))
|
|
intent_hints = {
|
|
item.intent_id: item.label.strip()
|
|
for item in domain.intents
|
|
if item.label and item.label.strip()
|
|
}
|
|
return RuntimeConfigBundle(
|
|
intent_registry=IntentRegistry(intents),
|
|
response_templates=self._load_response_templates(),
|
|
intent_hints=intent_hints,
|
|
dialog_rules=self._load_dialog_rules(),
|
|
dialog_act_engine=self._load_dialog_act_engine(),
|
|
workflow_templates=self._load_workflow_templates(),
|
|
rewrite_engine=self._load_rewrite_engine(),
|
|
)
|
|
|
|
def _load_response_templates(self) -> dict[str, str]:
|
|
if not self._response_path.exists():
|
|
return {}
|
|
raw = self._read_structured_file(self._response_path)
|
|
parsed = ResponsesConfig.model_validate(raw)
|
|
return parsed.templates
|
|
|
|
def _load_forms(self) -> FormsConfig:
|
|
if self._form_path is None or not self._form_path.exists():
|
|
return FormsConfig()
|
|
raw = self._read_structured_file(self._form_path)
|
|
return FormsConfig.model_validate(raw)
|
|
|
|
def _load_dialog_rules(self) -> DialogRuleEngine:
|
|
if self._rule_path is None or not self._rule_path.exists():
|
|
return DialogRuleEngine()
|
|
raw = self._read_structured_file(self._rule_path)
|
|
parsed = DialogRulesConfig.model_validate(raw)
|
|
return DialogRuleEngine(
|
|
stop_phrases=tuple(parsed.stop.phrases) or DialogRuleEngine.stop_phrases,
|
|
positive_confirmation_tokens=tuple(parsed.confirmation.positive_tokens)
|
|
or DialogRuleEngine.positive_confirmation_tokens,
|
|
negative_confirmation_tokens=tuple(parsed.confirmation.negative_tokens)
|
|
or DialogRuleEngine.negative_confirmation_tokens,
|
|
confirmation_required_intents=tuple(parsed.confirmation.required_intents)
|
|
or DialogRuleEngine.confirmation_required_intents,
|
|
confirmation_required_risk_levels=tuple(parsed.confirmation.required_risk_levels)
|
|
or DialogRuleEngine.confirmation_required_risk_levels,
|
|
metadata={"source": str(self._rule_path)},
|
|
)
|
|
|
|
def _load_dialog_act_engine(self) -> DialogActEngine:
|
|
if self._dialog_act_path is None or not self._dialog_act_path.exists():
|
|
return DialogActEngine()
|
|
raw = self._read_structured_file(self._dialog_act_path)
|
|
parsed = DialogActsConfig.model_validate(raw)
|
|
return DialogActEngine(
|
|
patterns={
|
|
item.act_id: tuple(item.phrases)
|
|
for item in parsed.acts
|
|
},
|
|
numeric_patterns={
|
|
item.act_id: tuple(item.numeric_patterns)
|
|
for item in parsed.acts
|
|
if item.numeric_patterns
|
|
},
|
|
)
|
|
|
|
def _load_rewrite_engine(self) -> ContextRewriteEngine:
|
|
if self._context_rewrite_path is None or not self._context_rewrite_path.exists():
|
|
return ContextRewriteEngine()
|
|
raw = self._read_structured_file(self._context_rewrite_path)
|
|
config = ContextRewriteConfig.model_validate(raw)
|
|
return ContextRewriteEngine(config=config)
|
|
|
|
def _load_workflow_templates(self) -> WorkflowTemplatesConfig:
|
|
if self._workflow_path is None or not self._workflow_path.exists():
|
|
return WorkflowTemplatesConfig()
|
|
raw = self._read_structured_file(self._workflow_path)
|
|
return WorkflowTemplatesConfig.model_validate(raw)
|
|
|
|
def _read_structured_file(self, path: Path) -> dict[str, Any]:
|
|
if path.suffix.lower() == ".json":
|
|
return json.loads(path.read_text(encoding="utf-8"))
|
|
return yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|