from __future__ import annotations import re from dataclasses import dataclass, field from typing import Any from app.schemas.configuration import ContextRewriteConfig, ParamContextDefinition from app.services.session_store import SessionState @dataclass class RewriteResult: original_text: str rewritten_text: str applied: bool = False reason: str | None = None metadata: dict[str, Any] = field(default_factory=dict) class ContextRewriteEngine: """ 将短句 follow-up(如"再快一点"、"电压高一点")改写为完整命令(如"速度设为 85 mm/min"), 使其能复用本地快路径而不必每轮重做完整规划。 改写规则完全由外部配置文件(context_rewrite.yml)驱动,不硬编码业务参数, 适用于不同设备(线切割 / 激光切割 / 注塑机等)的部署切换。 """ def __init__(self, config: ContextRewriteConfig | None = None) -> None: self._config = config or ContextRewriteConfig() # 构建反向索引:intent_id → ParamContextDefinition self._intent_index: dict[str, ParamContextDefinition] = {} for ctx in self._config.param_contexts: for intent_id in ctx.intent_ids: self._intent_index[intent_id] = ctx # ------------------------------------------------------------------ public def rewrite(self, text: str, session: SessionState) -> RewriteResult: normalized = text.strip() if not normalized: return RewriteResult(original_text=text, rewritten_text=text) current_intent = session.current_intent if current_intent and current_intent in self._intent_index: ctx = self._intent_index[current_intent] result = self._rewrite_param_adjustment(normalized, session, ctx) if result.applied: return result return RewriteResult(original_text=text, rewritten_text=text) # ----------------------------------------------------------------- private def _rewrite_param_adjustment( self, text: str, session: SessionState, ctx: ParamContextDefinition, ) -> RewriteResult: direction: str | None = None if any(phrase and phrase in text for phrase in ctx.up_phrases): direction = "up" elif any(phrase and phrase in text for phrase in ctx.down_phrases): direction = "down" if direction is None: return RewriteResult(original_text=text, rewritten_text=text) previous_value = self._last_slot_value(session, ctx.slot_name) base_value = previous_value if previous_value is not None else ctx.default_value delta = ctx.step if direction == "up" else -ctx.step if isinstance(ctx.min_value, float) or isinstance(ctx.max_value, float) or isinstance(ctx.step, float): next_value: int | float = max(float(ctx.min_value), min(float(ctx.max_value), float(base_value) + float(delta))) else: next_value = max(int(ctx.min_value), min(int(ctx.max_value), int(base_value) + int(delta))) rewritten = ctx.rewrite_template.format(value=next_value) return RewriteResult( original_text=text, rewritten_text=rewritten, applied=True, reason=f"normalize relative {ctx.slot_name} adjustment into an explicit target", metadata={ "cache_hit": True, "rewrite_type": "param_adjustment", "slot_name": ctx.slot_name, "direction": direction, "previous_value": previous_value, "base_value": base_value, "next_value": next_value, }, ) @staticmethod def _last_slot_value(session: SessionState, slot_name: str) -> int | float | None: raw = session.context_memory.get(f"last_{slot_name}", session.slots.get(slot_name)) if raw is None: return None if isinstance(raw, (int, float)): return raw if isinstance(raw, str): try: return int(raw) if raw.isdigit() else float(raw) except ValueError: pass return None