Update project and configurations
This commit is contained in:
108
intelligent_cabin/app/services/rewrite_engine.py
Normal file
108
intelligent_cabin/app/services/rewrite_engine.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.configuration import ContextRewriteConfig, ParamContextDefinition
|
||||
from app.services.session_store import SessionState
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewriteResult:
|
||||
original_text: str
|
||||
rewritten_text: str
|
||||
applied: bool = False
|
||||
reason: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ContextRewriteEngine:
|
||||
"""
|
||||
将短句 follow-up(如"再快一点"、"电压高一点")改写为完整命令(如"速度设为 85 mm/min"),
|
||||
使其能复用本地快路径而不必每轮重做完整规划。
|
||||
|
||||
改写规则完全由外部配置文件(context_rewrite.yml)驱动,不硬编码业务参数,
|
||||
适用于不同设备(线切割 / 激光切割 / 注塑机等)的部署切换。
|
||||
"""
|
||||
|
||||
def __init__(self, config: ContextRewriteConfig | None = None) -> None:
|
||||
self._config = config or ContextRewriteConfig()
|
||||
# 构建反向索引:intent_id → ParamContextDefinition
|
||||
self._intent_index: dict[str, ParamContextDefinition] = {}
|
||||
for ctx in self._config.param_contexts:
|
||||
for intent_id in ctx.intent_ids:
|
||||
self._intent_index[intent_id] = ctx
|
||||
|
||||
# ------------------------------------------------------------------ public
|
||||
|
||||
def rewrite(self, text: str, session: SessionState) -> RewriteResult:
|
||||
normalized = text.strip()
|
||||
if not normalized:
|
||||
return RewriteResult(original_text=text, rewritten_text=text)
|
||||
|
||||
current_intent = session.current_intent
|
||||
if current_intent and current_intent in self._intent_index:
|
||||
ctx = self._intent_index[current_intent]
|
||||
result = self._rewrite_param_adjustment(normalized, session, ctx)
|
||||
if result.applied:
|
||||
return result
|
||||
|
||||
return RewriteResult(original_text=text, rewritten_text=text)
|
||||
|
||||
# ----------------------------------------------------------------- private
|
||||
|
||||
def _rewrite_param_adjustment(
|
||||
self,
|
||||
text: str,
|
||||
session: SessionState,
|
||||
ctx: ParamContextDefinition,
|
||||
) -> RewriteResult:
|
||||
direction: str | None = None
|
||||
if any(phrase and phrase in text for phrase in ctx.up_phrases):
|
||||
direction = "up"
|
||||
elif any(phrase and phrase in text for phrase in ctx.down_phrases):
|
||||
direction = "down"
|
||||
|
||||
if direction is None:
|
||||
return RewriteResult(original_text=text, rewritten_text=text)
|
||||
|
||||
previous_value = self._last_slot_value(session, ctx.slot_name)
|
||||
base_value = previous_value if previous_value is not None else ctx.default_value
|
||||
delta = ctx.step if direction == "up" else -ctx.step
|
||||
|
||||
if isinstance(ctx.min_value, float) or isinstance(ctx.max_value, float) or isinstance(ctx.step, float):
|
||||
next_value: int | float = max(float(ctx.min_value), min(float(ctx.max_value), float(base_value) + float(delta)))
|
||||
else:
|
||||
next_value = max(int(ctx.min_value), min(int(ctx.max_value), int(base_value) + int(delta)))
|
||||
|
||||
rewritten = ctx.rewrite_template.format(value=next_value)
|
||||
return RewriteResult(
|
||||
original_text=text,
|
||||
rewritten_text=rewritten,
|
||||
applied=True,
|
||||
reason=f"normalize relative {ctx.slot_name} adjustment into an explicit target",
|
||||
metadata={
|
||||
"cache_hit": True,
|
||||
"rewrite_type": "param_adjustment",
|
||||
"slot_name": ctx.slot_name,
|
||||
"direction": direction,
|
||||
"previous_value": previous_value,
|
||||
"base_value": base_value,
|
||||
"next_value": next_value,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _last_slot_value(session: SessionState, slot_name: str) -> int | float | None:
|
||||
raw = session.context_memory.get(f"last_{slot_name}", session.slots.get(slot_name))
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, (int, float)):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
return int(raw) if raw.isdigit() else float(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
Reference in New Issue
Block a user