Files
ai-device/intelligent_cabin/app/services/rewrite_engine.py
2026-06-11 16:28:00 +08:00

109 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any
from app.schemas.configuration import ContextRewriteConfig, ParamContextDefinition
from app.services.session_store import SessionState
@dataclass
class RewriteResult:
original_text: str
rewritten_text: str
applied: bool = False
reason: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class ContextRewriteEngine:
"""
将短句 follow-up"再快一点""电压高一点")改写为完整命令(如"速度设为 85 mm/min"
使其能复用本地快路径而不必每轮重做完整规划。
改写规则完全由外部配置文件context_rewrite.yml驱动不硬编码业务参数
适用于不同设备(线切割 / 激光切割 / 注塑机等)的部署切换。
"""
def __init__(self, config: ContextRewriteConfig | None = None) -> None:
self._config = config or ContextRewriteConfig()
# 构建反向索引intent_id → ParamContextDefinition
self._intent_index: dict[str, ParamContextDefinition] = {}
for ctx in self._config.param_contexts:
for intent_id in ctx.intent_ids:
self._intent_index[intent_id] = ctx
# ------------------------------------------------------------------ public
def rewrite(self, text: str, session: SessionState) -> RewriteResult:
normalized = text.strip()
if not normalized:
return RewriteResult(original_text=text, rewritten_text=text)
current_intent = session.current_intent
if current_intent and current_intent in self._intent_index:
ctx = self._intent_index[current_intent]
result = self._rewrite_param_adjustment(normalized, session, ctx)
if result.applied:
return result
return RewriteResult(original_text=text, rewritten_text=text)
# ----------------------------------------------------------------- private
def _rewrite_param_adjustment(
self,
text: str,
session: SessionState,
ctx: ParamContextDefinition,
) -> RewriteResult:
direction: str | None = None
if any(phrase and phrase in text for phrase in ctx.up_phrases):
direction = "up"
elif any(phrase and phrase in text for phrase in ctx.down_phrases):
direction = "down"
if direction is None:
return RewriteResult(original_text=text, rewritten_text=text)
previous_value = self._last_slot_value(session, ctx.slot_name)
base_value = previous_value if previous_value is not None else ctx.default_value
delta = ctx.step if direction == "up" else -ctx.step
if isinstance(ctx.min_value, float) or isinstance(ctx.max_value, float) or isinstance(ctx.step, float):
next_value: int | float = max(float(ctx.min_value), min(float(ctx.max_value), float(base_value) + float(delta)))
else:
next_value = max(int(ctx.min_value), min(int(ctx.max_value), int(base_value) + int(delta)))
rewritten = ctx.rewrite_template.format(value=next_value)
return RewriteResult(
original_text=text,
rewritten_text=rewritten,
applied=True,
reason=f"normalize relative {ctx.slot_name} adjustment into an explicit target",
metadata={
"cache_hit": True,
"rewrite_type": "param_adjustment",
"slot_name": ctx.slot_name,
"direction": direction,
"previous_value": previous_value,
"base_value": base_value,
"next_value": next_value,
},
)
@staticmethod
def _last_slot_value(session: SessionState, slot_name: str) -> int | float | None:
raw = session.context_memory.get(f"last_{slot_name}", session.slots.get(slot_name))
if raw is None:
return None
if isinstance(raw, (int, float)):
return raw
if isinstance(raw, str):
try:
return int(raw) if raw.isdigit() else float(raw)
except ValueError:
pass
return None