258 lines
7.0 KiB
Go
258 lines
7.0 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
autoReplyContextLimit = 20
|
||
autoReplyContextPromptLimit = 4000
|
||
)
|
||
|
||
type autoReplyContextEntry struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
NormalizedContent string `json:"normalizedContent"`
|
||
MessageType string `json:"messageType"`
|
||
ServerID string `json:"serverId"`
|
||
LocalID string `json:"localId"`
|
||
CreatedAt int64 `json:"createdAt"`
|
||
SenderName string `json:"senderName"`
|
||
}
|
||
|
||
type autoReplyContextStore struct {
|
||
Conversations map[string][]autoReplyContextEntry `json:"conversations"`
|
||
LastSavedAt int64 `json:"lastSavedAt"`
|
||
}
|
||
|
||
var contextCachePathOverride string
|
||
|
||
func autoReplyContextCachePath() string {
|
||
if strings.TrimSpace(contextCachePathOverride) != "" {
|
||
return contextCachePathOverride
|
||
}
|
||
return resolveAutoReplyPath("config/auto_reply_context_cache.json")
|
||
}
|
||
|
||
func (e *AutoReplyEngine) loadContextCache() error {
|
||
path := autoReplyContextCachePath()
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
if os.IsNotExist(err) {
|
||
e.mu.Lock()
|
||
if e.contextEntries == nil {
|
||
e.contextEntries = make(map[string][]autoReplyContextEntry)
|
||
}
|
||
e.mu.Unlock()
|
||
return nil
|
||
}
|
||
return err
|
||
}
|
||
var store autoReplyContextStore
|
||
if err := json.Unmarshal(data, &store); err != nil {
|
||
return err
|
||
}
|
||
e.mu.Lock()
|
||
e.contextEntries = make(map[string][]autoReplyContextEntry, len(store.Conversations))
|
||
for key, entries := range store.Conversations {
|
||
key = strings.TrimSpace(key)
|
||
if key == "" {
|
||
continue
|
||
}
|
||
e.contextEntries[key] = trimAutoReplyContextEntries(entries)
|
||
}
|
||
e.mu.Unlock()
|
||
return nil
|
||
}
|
||
|
||
func (e *AutoReplyEngine) saveContextCache() {
|
||
if err := e.saveContextCacheToDisk(); err != nil {
|
||
e.setLastErrorWithScope(autoReplyErrorScopeRecords, "conversation context save failed: "+err.Error())
|
||
}
|
||
}
|
||
|
||
func (e *AutoReplyEngine) saveContextCacheToDisk() error {
|
||
e.mu.Lock()
|
||
store := autoReplyContextStore{
|
||
Conversations: make(map[string][]autoReplyContextEntry, len(e.contextEntries)),
|
||
LastSavedAt: time.Now().Unix(),
|
||
}
|
||
for key, entries := range e.contextEntries {
|
||
store.Conversations[key] = append([]autoReplyContextEntry(nil), trimAutoReplyContextEntries(entries)...)
|
||
}
|
||
e.mu.Unlock()
|
||
path := autoReplyContextCachePath()
|
||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||
return err
|
||
}
|
||
return atomicWriteJSON(path, store)
|
||
}
|
||
|
||
func (e *AutoReplyEngine) rememberUserMessage(msg autoReplyMessage) {
|
||
e.rememberContextEntry(msg, autoReplyContextEntry{
|
||
Role: "user",
|
||
Content: strings.TrimSpace(msg.Content),
|
||
MessageType: msg.MessageType,
|
||
ServerID: msg.ServerID,
|
||
LocalID: msg.LocalID,
|
||
CreatedAt: time.Now().Unix(),
|
||
SenderName: msg.FromNickName,
|
||
})
|
||
}
|
||
|
||
func (e *AutoReplyEngine) rememberAssistantMessage(msg autoReplyMessage, answer string) {
|
||
e.rememberContextEntry(msg, autoReplyContextEntry{
|
||
Role: "assistant",
|
||
Content: strings.TrimSpace(answer),
|
||
MessageType: "text",
|
||
CreatedAt: time.Now().Unix(),
|
||
SenderName: "assistant",
|
||
})
|
||
}
|
||
|
||
func (e *AutoReplyEngine) rememberContextEntry(msg autoReplyMessage, entry autoReplyContextEntry) {
|
||
entry.Content = strings.TrimSpace(entry.Content)
|
||
if entry.Content == "" || strings.TrimSpace(msg.ConversationID) == "" {
|
||
return
|
||
}
|
||
entry.Role = strings.TrimSpace(entry.Role)
|
||
if entry.Role == "" {
|
||
entry.Role = "user"
|
||
}
|
||
if entry.CreatedAt <= 0 {
|
||
entry.CreatedAt = time.Now().Unix()
|
||
}
|
||
entry.NormalizedContent = normalizeContextContent(entry.Content)
|
||
key := e.contextKeyForMessage(msg)
|
||
e.mu.Lock()
|
||
if e.contextEntries == nil {
|
||
e.contextEntries = make(map[string][]autoReplyContextEntry)
|
||
}
|
||
entries := append(e.contextEntries[key], entry)
|
||
e.contextEntries[key] = trimAutoReplyContextEntries(entries)
|
||
e.mu.Unlock()
|
||
e.saveContextCache()
|
||
}
|
||
|
||
func (e *AutoReplyEngine) previousUserQuestion(msg autoReplyMessage) string {
|
||
entries := e.contextEntriesForMessage(msg)
|
||
for i := len(entries) - 1; i >= 0; i-- {
|
||
entry := entries[i]
|
||
if entry.Role == "user" && strings.TrimSpace(entry.Content) != "" {
|
||
return strings.TrimSpace(entry.Content)
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func (e *AutoReplyEngine) recentContextPrompt(msg autoReplyMessage, maxEntries int) string {
|
||
entries := e.contextEntriesForMessage(msg)
|
||
if len(entries) == 0 {
|
||
return ""
|
||
}
|
||
if maxEntries <= 0 {
|
||
maxEntries = 6
|
||
}
|
||
start := len(entries) - maxEntries
|
||
if start < 0 {
|
||
start = 0
|
||
}
|
||
var b strings.Builder
|
||
for _, entry := range entries[start:] {
|
||
content := strings.TrimSpace(entry.Content)
|
||
if content == "" {
|
||
continue
|
||
}
|
||
role := "客户"
|
||
if entry.Role == "assistant" {
|
||
role = "客服"
|
||
}
|
||
line := role + ":" + content
|
||
if b.Len()+len([]rune(line))+1 > autoReplyContextPromptLimit {
|
||
break
|
||
}
|
||
if b.Len() > 0 {
|
||
b.WriteString("\n")
|
||
}
|
||
b.WriteString(line)
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
func (e *AutoReplyEngine) contextualSearchText(question string, msg autoReplyMessage) string {
|
||
contextText := e.recentContextPrompt(msg, 6)
|
||
question = strings.TrimSpace(question)
|
||
if contextText == "" {
|
||
return question
|
||
}
|
||
return contextText + "\n当前问题:" + question
|
||
}
|
||
|
||
func (e *AutoReplyEngine) contextEntriesForMessage(msg autoReplyMessage) []autoReplyContextEntry {
|
||
key := e.contextKeyForMessage(msg)
|
||
e.mu.Lock()
|
||
defer e.mu.Unlock()
|
||
return append([]autoReplyContextEntry(nil), e.contextEntries[key]...)
|
||
}
|
||
|
||
func (e *AutoReplyEngine) contextKeyForMessage(msg autoReplyMessage) string {
|
||
scope := strings.TrimSpace(e.identityScopeForClient(msg.ClientID))
|
||
if scope == "" {
|
||
scope = "client:" + stringFromAny(msg.ClientID)
|
||
}
|
||
robotID := strings.TrimSpace(msg.stableRobotID())
|
||
conversationID := strings.TrimSpace(msg.ConversationID)
|
||
return scope + "|" + robotID + "|" + conversationID
|
||
}
|
||
|
||
func trimAutoReplyContextEntries(entries []autoReplyContextEntry) []autoReplyContextEntry {
|
||
if len(entries) > autoReplyContextLimit {
|
||
entries = entries[len(entries)-autoReplyContextLimit:]
|
||
}
|
||
total := 0
|
||
for i := len(entries) - 1; i >= 0; i-- {
|
||
total += len([]rune(entries[i].Content))
|
||
if total > autoReplyContextPromptLimit {
|
||
return append([]autoReplyContextEntry(nil), entries[i+1:]...)
|
||
}
|
||
}
|
||
return append([]autoReplyContextEntry(nil), entries...)
|
||
}
|
||
|
||
func normalizeContextContent(content string) string {
|
||
return normalizeGreetingText(strings.TrimSpace(content))
|
||
}
|
||
|
||
func isPreviousQuestionQuery(content string) bool {
|
||
normalized := normalizeGreetingText(content)
|
||
if normalized == "" {
|
||
return false
|
||
}
|
||
for _, token := range []string{
|
||
"我上一个问题问了什么",
|
||
"我上个问题问了什么",
|
||
"我刚才问了什么",
|
||
"刚才我问了什么",
|
||
"上一句是什么",
|
||
"上一个问题是什么",
|
||
"上个问题是什么",
|
||
} {
|
||
if strings.Contains(normalized, normalizeGreetingText(token)) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
func previousQuestionAnswer(previous string) string {
|
||
previous = strings.TrimSpace(previous)
|
||
if previous == "" {
|
||
return "我这边暂时没有查到您上一条具体问题,您可以再发一遍,我继续帮您处理。"
|
||
}
|
||
return "您上一个问题是:“" + previous + "”。"
|
||
}
|