Files
qiweimanager-master/helper/auto_reply_retrieval.go

973 lines
28 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
"qiweimanager/config"
)
const (
retrievalModeKeywordOnly = "keyword"
retrievalModeHybridRerank = "hybrid_rerank"
defaultRRFK = 60.0
)
type EmbeddingEntry struct {
ChunkID string `json:"chunkId"`
Hash string `json:"hash"`
Source string `json:"source"`
Title string `json:"title"`
Embedding []float64 `json:"embedding"`
UpdatedAt int64 `json:"updatedAt"`
}
type EmbeddingIndex struct {
Model string `json:"model"`
Dimensions int `json:"dimensions"`
Entries map[string]EmbeddingEntry `json:"entries"`
LastIndexedAt int64 `json:"lastIndexedAt"`
}
type KnowledgeSearchResult struct {
Hits []KnowledgeChunk
KeywordScore float64
VectorScore float64
RerankScore float64
RetrievalMode string
UsedKnowledgeSources []string
Timings autoReplyTimings
}
var wikiLinkPattern = regexp.MustCompile(`\[\[([^\]|#]+)(?:[|#][^\]]*)?\]\]`)
type retrievalCandidate struct {
Chunk KnowledgeChunk
KeywordScore float64
VectorScore float64
FusionScore float64
RerankScore float64
KeywordRank int
VectorRank int
}
func NewEmbeddingIndex(model string, dimensions int) *EmbeddingIndex {
return &EmbeddingIndex{
Model: model,
Dimensions: dimensions,
Entries: make(map[string]EmbeddingEntry),
}
}
func (e *AutoReplyEngine) loadEmbeddingIndex() error {
cfg := e.getConfig()
path := resolveAutoReplyPath(cfg.Retrieval.EmbeddingIndexPath)
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
e.updateEmbeddingStatus(NewEmbeddingIndex(cfg.Retrieval.EmbeddingModel, cfg.Retrieval.EmbeddingDimensions))
return nil
}
return err
}
var idx EmbeddingIndex
if err := json.Unmarshal(data, &idx); err != nil {
return err
}
if idx.Entries == nil {
idx.Entries = make(map[string]EmbeddingEntry)
}
e.updateEmbeddingStatus(&idx)
return nil
}
func (e *AutoReplyEngine) updateEmbeddingStatus(idx *EmbeddingIndex) {
if idx == nil {
idx = NewEmbeddingIndex("", 0)
}
e.mu.Lock()
e.embeddingIndex = idx
e.status.EmbeddingChunkCount = len(idx.Entries)
e.status.EmbeddingModel = idx.Model
e.status.EmbeddingDimensions = idx.Dimensions
e.status.EmbeddingLastIndexedAt = idx.LastIndexedAt
e.mu.Unlock()
}
func (e *AutoReplyEngine) rebuildEmbeddingIndex(idx *KnowledgeIndex) error {
cfg := e.getConfig()
if strings.TrimSpace(cfg.AI.APIKey) == "" || strings.TrimSpace(cfg.AI.BaseURL) == "" {
e.updateEmbeddingStatus(NewEmbeddingIndex(cfg.Retrieval.EmbeddingModel, cfg.Retrieval.EmbeddingDimensions))
return fmt.Errorf("Embedding索引跳过AI Base URL 或 API Key 未配置")
}
if idx == nil {
return nil
}
previous := e.embeddingIndex
if previous == nil {
previous = NewEmbeddingIndex(cfg.Retrieval.EmbeddingModel, cfg.Retrieval.EmbeddingDimensions)
}
next := NewEmbeddingIndex(cfg.Retrieval.EmbeddingModel, cfg.Retrieval.EmbeddingDimensions)
next.LastIndexedAt = time.Now().Unix()
var batchChunks []KnowledgeChunk
var batchTexts []string
flush := func() error {
if len(batchChunks) == 0 {
return nil
}
vectors, err := callDashScopeEmbeddings(cfg.AI, cfg.Retrieval, batchTexts)
if err != nil {
return err
}
for i, vector := range vectors {
if i >= len(batchChunks) {
break
}
chunk := batchChunks[i]
next.Entries[chunk.ID] = EmbeddingEntry{
ChunkID: chunk.ID,
Hash: chunk.Hash,
Source: chunk.Source,
Title: chunk.Title,
Embedding: vector,
UpdatedAt: chunk.UpdatedAt,
}
}
batchChunks = nil
batchTexts = nil
return nil
}
for _, chunk := range idx.Chunks {
if entry, ok := previous.Entries[chunk.ID]; ok &&
entry.Hash == chunk.Hash &&
len(entry.Embedding) > 0 &&
previous.Model == cfg.Retrieval.EmbeddingModel &&
previous.Dimensions == cfg.Retrieval.EmbeddingDimensions {
next.Entries[chunk.ID] = entry
continue
}
batchChunks = append(batchChunks, chunk)
batchTexts = append(batchTexts, buildRetrievalDocumentText(chunk))
if len(batchChunks) >= 10 {
if err := flush(); err != nil {
return err
}
}
}
if err := flush(); err != nil {
return err
}
path := resolveAutoReplyPath(cfg.Retrieval.EmbeddingIndexPath)
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return err
}
data, err := json.MarshalIndent(next, "", " ")
if err != nil {
return err
}
if err := os.WriteFile(path, data, 0644); err != nil {
return err
}
e.updateEmbeddingStatus(next)
return nil
}
func (e *AutoReplyEngine) searchKnowledge(query string) []KnowledgeChunk {
return e.searchKnowledgeDetailed(query).Hits
}
func (e *AutoReplyEngine) searchKnowledgeDetailed(query string) KnowledgeSearchResult {
cfg := e.getConfig()
mode := strings.TrimSpace(cfg.Retrieval.RetrievalMode)
if mode == "" {
mode = retrievalModeHybridRerank
}
result := KnowledgeSearchResult{RetrievalMode: mode}
keywordStart := time.Now()
keywordHits := e.searchKeywordKnowledge(query, maxInt(cfg.Retrieval.RecallTopK, cfg.Knowledge.TopK))
if isGenericProductQuery(query) {
keywordHits = e.expandProductKnowledgeHits(query, keywordHits)
}
result.Timings.KeywordDurationMS = time.Since(keywordStart).Milliseconds()
result.KeywordScore = topChunkScore(keywordHits)
if mode == retrievalModeKeywordOnly {
result.Hits = e.expandKnowledgeNeighborHits(query, limitKnowledgeChunks(keywordHits, cfg.Retrieval.FinalTopK))
result.UsedKnowledgeSources = knowledgeSources(result.Hits)
result.Timings.KnowledgeDurationMS = result.Timings.KeywordDurationMS
return result
}
vectorStart := time.Now()
vectorHits, vectorErr := e.searchVectorKnowledge(query, cfg.Retrieval.RecallTopK)
result.Timings.VectorDurationMS = time.Since(vectorStart).Milliseconds()
result.VectorScore = topChunkScore(vectorHits)
if vectorErr != nil {
e.setLastErrorWithScope(autoReplyErrorScopeKnowledge, "向量召回失败,已降级关键词检索: "+vectorErr.Error())
result.Hits = e.expandKnowledgeNeighborHits(query, limitKnowledgeChunks(keywordHits, cfg.Retrieval.FinalTopK))
result.UsedKnowledgeSources = knowledgeSources(result.Hits)
result.Timings.KnowledgeDurationMS = result.Timings.KeywordDurationMS + result.Timings.VectorDurationMS
return result
}
candidates := fuseRetrievalCandidates(keywordHits, vectorHits, query)
if len(candidates) == 0 {
result.Hits = nil
result.Timings.KnowledgeDurationMS = result.Timings.KeywordDurationMS + result.Timings.VectorDurationMS
return result
}
candidates = limitCandidates(candidates, cfg.Retrieval.RerankTopK)
rerankStart := time.Now()
reranked, rerankErr := callDashScopeRerank(cfg.AI, cfg.Retrieval, query, candidates)
result.Timings.RerankDurationMS = time.Since(rerankStart).Milliseconds()
if rerankErr == nil && len(reranked) > 0 {
candidates = reranked
result.RetrievalMode = retrievalModeHybridRerank
} else if rerankErr != nil {
e.setLastErrorWithScope(autoReplyErrorScopeKnowledge, "重排序失败,已使用混合召回结果: "+rerankErr.Error())
}
sort.Slice(candidates, func(i, j int) bool {
return candidateScore(candidates[i]) > candidateScore(candidates[j])
})
candidates = limitCandidates(candidates, cfg.Retrieval.FinalTopK)
result.Hits = e.expandKnowledgeNeighborHits(query, candidatesToKnowledgeChunks(candidates))
if isGenericProductQuery(query) {
result.Hits = e.expandProductKnowledgeHits(query, result.Hits)
}
result.RerankScore = topCandidateRerankScore(candidates)
result.UsedKnowledgeSources = knowledgeSources(result.Hits)
result.Timings.KnowledgeDurationMS = result.Timings.KeywordDurationMS + result.Timings.VectorDurationMS + result.Timings.RerankDurationMS
return result
}
func isGenericProductQuery(query string) bool {
query = strings.ToLower(strings.TrimSpace(query))
if query == "" {
return false
}
keywords := []string{
"有什么产品", "有哪些产品", "具体有什么产品", "产品介绍", "产品线", "产品矩阵",
"产品清单", "产品列表", "产品型号", "型号", "设备型号", "哪些型号",
"全部产品", "所有产品", "全部产品介绍", "所有产品介绍", "产品大全", "完整产品线",
"你们公司的全部产品", "你们公司全部产品", "你们所有产品", "公司的全部产品",
}
for _, keyword := range keywords {
if strings.Contains(query, strings.ToLower(keyword)) {
return true
}
}
if strings.Contains(query, "产品") && (strings.Contains(query, "什么") || strings.Contains(query, "哪些") || strings.Contains(query, "介绍") || strings.Contains(query, "全部") || strings.Contains(query, "所有") || strings.Contains(query, "完整")) {
return true
}
return false
}
func (e *AutoReplyEngine) expandProductKnowledgeHits(query string, hits []KnowledgeChunk) []KnowledgeChunk {
e.mu.Lock()
idx := e.index
e.mu.Unlock()
if idx == nil || len(idx.Chunks) == 0 {
return hits
}
bySource := make(map[string][]KnowledgeChunk)
for _, chunk := range idx.Chunks {
if isLowValueKnowledgeBlock(chunk.Title, chunk.Content) {
continue
}
sourceKey := normalizeKnowledgeSourceKey(chunk.Source)
bySource[sourceKey] = append(bySource[sourceKey], chunk)
}
result := append([]KnowledgeChunk(nil), hits...)
seen := make(map[string]bool)
for _, hit := range result {
seen[hit.ID] = true
}
linkedNames := make([]string, 0)
for _, hit := range hits {
if isProductHubChunk(hit) {
linkedNames = append(linkedNames, extractWikiLinkNames(hit.Content)...)
}
}
linkedNames = append(linkedNames, defaultProductKnowledgeNames()...)
for _, name := range uniqueStrings(linkedNames) {
if len(result) >= 10 {
break
}
for _, chunk := range bySource[normalizeKnowledgeSourceKey(name+".md")] {
if len(result) >= 10 {
break
}
if seen[chunk.ID] || !isProductSummaryChunk(chunk, name) {
continue
}
chunk.Score = productExpansionScore(query, chunk)
result = append(result, chunk)
seen[chunk.ID] = true
break
}
}
sort.SliceStable(result, func(i, j int) bool {
return productHitRank(result[i]) < productHitRank(result[j])
})
return result
}
func isProductHubChunk(chunk KnowledgeChunk) bool {
text := chunk.Source + " " + chunk.Title + " " + chunk.Content
return strings.Contains(text, "产品矩阵") ||
strings.Contains(text, "AgentBox") ||
strings.Contains(text, "硬件载体") ||
strings.Contains(text, "模型引擎") ||
strings.Contains(text, "AI 应用")
}
func extractWikiLinkNames(text string) []string {
matches := wikiLinkPattern.FindAllStringSubmatch(text, -1)
names := make([]string, 0, len(matches))
for _, match := range matches {
if len(match) < 2 {
continue
}
name := strings.TrimSpace(match[1])
if name != "" {
names = append(names, name)
}
}
return names
}
func defaultProductKnowledgeNames() []string {
return []string{
"产品矩阵", "AgentBox", "VISION-S01", "PRO-S01", "PRO-Y01", "SUPER-S01",
"AWIN25", "数字员工", "万川智媒", "智雕工坊",
}
}
func isProductSummaryChunk(chunk KnowledgeChunk, name string) bool {
title := strings.TrimSpace(chunk.Title)
content := strings.TrimSpace(chunk.Content)
if title == name || strings.EqualFold(title, name) {
return true
}
if strings.HasPrefix(content, ">") {
return true
}
if strings.Contains(title, "核心定位") || strings.Contains(title, "定义") || strings.Contains(title, "关键能力") {
return true
}
return false
}
func productExpansionScore(query string, chunk KnowledgeChunk) float64 {
score := 0.82 + exactMatchBoost(query, chunk)
if strings.Contains(chunk.Source, "产品矩阵") {
score += 0.12
}
return score
}
func productHitRank(chunk KnowledgeChunk) int {
source := normalizeKnowledgeSourceKey(chunk.Source)
order := defaultProductKnowledgeNames()
for i, name := range order {
if source == normalizeKnowledgeSourceKey(name+".md") {
return i
}
}
return len(order) + 1
}
func normalizeKnowledgeSourceKey(source string) string {
source = strings.ToLower(strings.TrimSpace(filepath.ToSlash(source)))
source = strings.TrimSuffix(source, ".md")
source = strings.TrimSuffix(source, ".txt")
source = strings.TrimSuffix(source, ".csv")
source = strings.TrimSuffix(source, ".xlsx")
source = strings.TrimSuffix(source, ".docx")
source = strings.TrimSuffix(source, ".pdf")
return filepath.Base(source)
}
func uniqueStrings(values []string) []string {
seen := make(map[string]bool)
result := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
key := normalizeKnowledgeSourceKey(value)
if value == "" || seen[key] {
continue
}
seen[key] = true
result = append(result, value)
}
return result
}
func (e *AutoReplyEngine) searchKeywordKnowledge(query string, limit int) []KnowledgeChunk {
e.mu.Lock()
idx := e.index
e.mu.Unlock()
if idx == nil || len(idx.Chunks) == 0 {
return nil
}
queryTokens := tokenizeKnowledgeText(query)
if len(queryTokens) == 0 {
return nil
}
results := make([]KnowledgeChunk, 0, limit)
for _, chunk := range idx.Chunks {
score := scoreKnowledgeChunk(queryTokens, chunk)
score += exactMatchBoost(query, chunk)
if score <= 0 {
continue
}
c := chunk
c.Score = score
results = append(results, c)
}
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
return limitKnowledgeChunks(results, limit)
}
func (e *AutoReplyEngine) searchVectorKnowledge(query string, limit int) ([]KnowledgeChunk, error) {
cfg := e.getConfig()
e.mu.Lock()
idx := e.index
embeddingIndex := e.embeddingIndex
e.mu.Unlock()
if idx == nil || embeddingIndex == nil || len(embeddingIndex.Entries) == 0 {
return nil, fmt.Errorf("向量索引为空,请先重建知识库索引")
}
if strings.TrimSpace(cfg.AI.APIKey) == "" || strings.TrimSpace(cfg.AI.BaseURL) == "" {
return nil, fmt.Errorf("AI Base URL 或 API Key 未配置")
}
vectors, err := callDashScopeEmbeddings(cfg.AI, cfg.Retrieval, []string{query})
if err != nil {
return nil, err
}
if len(vectors) == 0 {
return nil, fmt.Errorf("Embedding返回空向量")
}
chunksByID := make(map[string]KnowledgeChunk, len(idx.Chunks))
for _, chunk := range idx.Chunks {
chunksByID[chunk.ID] = chunk
}
results := make([]KnowledgeChunk, 0, limit)
for chunkID, entry := range embeddingIndex.Entries {
chunk, ok := chunksByID[chunkID]
if !ok || len(entry.Embedding) == 0 {
continue
}
score := cosineSimilarity(vectors[0], entry.Embedding)
if score <= 0 {
continue
}
chunk.Score = (score + 1) / 2
results = append(results, chunk)
}
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
return limitKnowledgeChunks(results, limit), nil
}
func callDashScopeEmbeddings(aiCfg config.AIConfig, retrievalCfg config.RetrievalConfig, inputs []string) ([][]float64, error) {
if len(inputs) == 0 {
return nil, nil
}
url := strings.TrimRight(aiCfg.BaseURL, "/")
if !strings.HasSuffix(url, "/embeddings") {
url += "/embeddings"
}
payload := map[string]interface{}{
"model": retrievalCfg.EmbeddingModel,
"input": inputs,
"encoding_format": "float",
}
if retrievalCfg.EmbeddingDimensions > 0 {
payload["dimensions"] = retrievalCfg.EmbeddingDimensions
}
var response struct {
Data []struct {
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Error interface{} `json:"error"`
}
if err := doRetrievalJSONRequest(aiCfg, url, payload, &response); err != nil {
// 检测是否是模型配置错误
errMsg := err.Error()
if strings.Contains(strings.ToLower(errMsg), "unsupported model") &&
strings.Contains(strings.ToLower(errMsg), "rerank") {
return nil, fmt.Errorf("Embedding模型配置错误'%s' 是一个Rerank模型不是Embedding模型。请使用 text-embedding-v4 或 text-embedding-v3 等Embedding模型", retrievalCfg.EmbeddingModel)
}
return nil, err
}
if response.Error != nil {
return nil, fmt.Errorf("Embedding返回错误: %v", response.Error)
}
vectors := make([][]float64, len(response.Data))
for i, item := range response.Data {
target := i
if item.Index >= 0 && item.Index < len(response.Data) {
target = item.Index
}
vectors[target] = item.Embedding
}
return vectors, nil
}
func callDashScopeRerank(aiCfg config.AIConfig, retrievalCfg config.RetrievalConfig, query string, candidates []retrievalCandidate) ([]retrievalCandidate, error) {
if len(candidates) == 0 {
return nil, nil
}
documents := make([]string, 0, len(candidates))
for _, candidate := range candidates {
documents = append(documents, truncateTextForPrompt(buildRetrievalDocumentText(candidate.Chunk), 1200))
}
topN := retrievalCfg.FinalTopK
if topN <= 0 || topN > len(documents) {
topN = len(documents)
}
payload := map[string]interface{}{
"model": retrievalCfg.RerankModel,
"query": query,
"documents": documents,
"top_n": topN,
"instruct": "Given a customer support query, retrieve passages that directly answer the query about Lingze Wanchuan products, services, or after-sales support.",
}
var response struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Score float64 `json:"score"`
} `json:"results"`
Error interface{} `json:"error"`
}
var lastErr error
for _, url := range dashScopeRerankURLs(aiCfg) {
if err := doRetrievalJSONRequest(aiCfg, url, payload, &response); err != nil {
lastErr = err
continue
}
lastErr = nil
break
}
if lastErr != nil {
return nil, lastErr
}
if response.Error != nil {
return nil, fmt.Errorf("Rerank返回错误: %v", response.Error)
}
if len(response.Results) == 0 {
return nil, fmt.Errorf("Rerank返回空结果")
}
reranked := make([]retrievalCandidate, 0, len(response.Results))
for _, item := range response.Results {
if item.Index < 0 || item.Index >= len(candidates) {
continue
}
candidate := candidates[item.Index]
candidate.RerankScore = item.RelevanceScore
if candidate.RerankScore <= 0 {
candidate.RerankScore = item.Score
}
reranked = append(reranked, candidate)
}
return reranked, nil
}
func doRetrievalJSONRequest(aiCfg config.AIConfig, url string, payload interface{}, out interface{}) error {
timeout := time.Duration(aiCfg.TimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 20 * time.Second
}
body, err := json.Marshal(payload)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(aiCfg.APIKey))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("HTTP状态码错误: %d, body=%s", resp.StatusCode, truncateText(string(respBody), 240))
}
if err := json.Unmarshal(respBody, out); err != nil {
return fmt.Errorf("解析响应失败: %v, body=%s", err, truncateText(string(respBody), 240))
}
return nil
}
func dashScopeRerankURLs(aiCfg config.AIConfig) []string {
baseURL := strings.TrimRight(aiCfg.BaseURL, "/")
if strings.Contains(baseURL, "dashscope.aliyuncs.com") {
return []string{
"https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
"https://dashscope.aliyuncs.com/compatible-api/v1/rerank",
}
}
if strings.HasSuffix(baseURL, "/v1") {
prefix := strings.TrimSuffix(baseURL, "/v1") + "/v1"
return []string{prefix + "/reranks", prefix + "/rerank"}
}
return []string{baseURL + "/reranks", baseURL + "/rerank"}
}
func fuseRetrievalCandidates(keywordHits []KnowledgeChunk, vectorHits []KnowledgeChunk, query string) []retrievalCandidate {
candidates := make(map[string]*retrievalCandidate)
maxKeyword := topChunkScore(keywordHits)
maxVector := topChunkScore(vectorHits)
add := func(hit KnowledgeChunk) *retrievalCandidate {
candidate, ok := candidates[hit.ID]
if !ok {
candidate = &retrievalCandidate{Chunk: hit}
candidates[hit.ID] = candidate
}
return candidate
}
for i, hit := range keywordHits {
candidate := add(hit)
candidate.KeywordScore = hit.Score
candidate.KeywordRank = i + 1
}
for i, hit := range vectorHits {
candidate := add(hit)
candidate.VectorScore = hit.Score
candidate.VectorRank = i + 1
}
result := make([]retrievalCandidate, 0, len(candidates))
for _, candidate := range candidates {
keywordScore := normalizedScore(candidate.KeywordScore, maxKeyword)
vectorScore := normalizedScore(candidate.VectorScore, maxVector)
boost := exactMatchBoost(query, candidate.Chunk)
rrfScore := 0.0
if candidate.KeywordRank > 0 {
rrfScore += 1 / (defaultRRFK + float64(candidate.KeywordRank))
}
if candidate.VectorRank > 0 {
rrfScore += 1 / (defaultRRFK + float64(candidate.VectorRank))
}
candidate.FusionScore = keywordScore*0.45 + vectorScore*0.45 + math.Min(boost, 0.10) + rrfScore
result = append(result, *candidate)
}
sort.Slice(result, func(i, j int) bool {
return result[i].FusionScore > result[j].FusionScore
})
return result
}
func buildRetrievalDocumentText(chunk KnowledgeChunk) string {
var b strings.Builder
if strings.TrimSpace(chunk.Source) != "" {
b.WriteString("文件:")
b.WriteString(chunk.Source)
b.WriteString("\n")
}
if strings.TrimSpace(chunk.Title) != "" {
b.WriteString("标题:")
b.WriteString(chunk.Title)
b.WriteString("\n")
}
b.WriteString("内容:")
b.WriteString(chunk.Content)
return b.String()
}
func (e *AutoReplyEngine) expandKnowledgeNeighborHits(query string, hits []KnowledgeChunk) []KnowledgeChunk {
e.mu.Lock()
idx := e.index
e.mu.Unlock()
if idx == nil || len(idx.Chunks) == 0 || len(hits) == 0 {
return hits
}
bySource := make(map[string][]KnowledgeChunk)
for _, chunk := range idx.Chunks {
if isLowValueKnowledgeBlock(chunk.Title, chunk.Content) {
continue
}
sourceKey := normalizeKnowledgeSourceKey(chunk.Source)
bySource[sourceKey] = append(bySource[sourceKey], chunk)
}
seen := make(map[string]bool, len(hits))
result := make([]KnowledgeChunk, 0, len(hits)+4)
for _, hit := range hits {
if seen[hit.ID] {
continue
}
seen[hit.ID] = true
result = append(result, hit)
}
for _, hit := range hits {
sourceChunks := bySource[normalizeKnowledgeSourceKey(hit.Source)]
if len(sourceChunks) == 0 {
continue
}
for i, chunk := range sourceChunks {
if chunk.ID != hit.ID {
continue
}
for _, offset := range []int{-1, 1} {
pos := i + offset
if pos < 0 || pos >= len(sourceChunks) {
continue
}
neighbor := sourceChunks[pos]
if neighbor.ID == "" || seen[neighbor.ID] {
continue
}
neighbor.Score = hit.Score * 0.95
seen[neighbor.ID] = true
result = append(result, neighbor)
}
break
}
}
sort.SliceStable(result, func(i, j int) bool {
return result[i].Score > result[j].Score
})
if len(result) > 12 {
result = result[:12]
}
return result
}
func exactMatchBoost(query string, chunk KnowledgeChunk) float64 {
query = strings.ToLower(strings.TrimSpace(query))
if query == "" {
return 0
}
haystack := strings.ToLower(chunk.Source + " " + chunk.Title + " " + chunk.Content)
boost := 0.0
for _, token := range append(extractExactBoostTokens(query), extractKnowledgeReferenceTokens(query)...) {
if token == "" {
continue
}
if strings.Contains(strings.ToLower(chunk.Source+" "+chunk.Title), token) {
boost += 0.18
continue
}
if strings.Contains(haystack, token) {
boost += 0.08
}
}
for _, phrase := range extractChineseBoostPhrases(query) {
if phrase == "" {
continue
}
if strings.Contains(strings.ToLower(chunk.Source+" "+chunk.Title), phrase) {
boost += 0.22
continue
}
if strings.Contains(haystack, phrase) {
boost += 0.12
}
}
return boost
}
func extractExactBoostTokens(query string) []string {
parts := strings.FieldsFunc(query, func(r rune) bool {
return !(r == '-' || r == '_' || (r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z'))
})
tokens := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.ToLower(strings.Trim(part, ".,;:!?,。!?;:、"))
if len([]rune(part)) >= 3 || strings.Contains(part, "-") {
tokens = append(tokens, part)
}
}
for _, keyword := range []string{"产品", "产品线", "设备", "工作站", "模型", "数字员工", "agentbox", "awin25", "pro-s01", "pro-y01", "super-s01", "vision-s01"} {
if strings.Contains(query, keyword) {
tokens = append(tokens, keyword)
}
}
return tokens
}
func extractChineseBoostPhrases(query string) []string {
query = strings.TrimSpace(query)
if query == "" {
return nil
}
for _, suffix := range []string{"有哪些", "有啥", "是什么", "怎么", "如何", "哪些", "问题", "内容"} {
query = strings.TrimSpace(strings.ReplaceAll(query, suffix, ""))
}
runes := []rune(query)
if len(runes) < 2 {
return nil
}
phrases := make([]string, 0, 4)
phrases = append(phrases, query)
if len(runes) >= 3 {
phrases = append(phrases, string(runes[:2]))
phrases = append(phrases, string(runes[:3]))
}
return dedupeNonEmptyStrings(phrases)
}
func extractKnowledgeReferenceTokens(query string) []string {
query = strings.TrimSpace(query)
if query == "" {
return nil
}
candidates := make([]string, 0)
for _, match := range regexp.MustCompile(`[《<"“]?([^《》<>"“”\s]+?\.(?:xlsx|xls|docx|doc|pdf|md|txt|csv))[》>"”]?`).FindAllStringSubmatch(query, -1) {
if len(match) > 1 {
candidates = append(candidates, match[1])
}
}
for _, wrapped := range regexp.MustCompile(`[《"“]([^》"”]+)[》"”]`).FindAllStringSubmatch(query, -1) {
if len(wrapped) > 1 {
candidates = append(candidates, wrapped[1])
}
}
result := make([]string, 0, len(candidates)*2)
seen := make(map[string]bool)
for _, candidate := range candidates {
candidate = strings.ToLower(strings.TrimSpace(filepath.ToSlash(candidate)))
if candidate == "" {
continue
}
for _, token := range []string{candidate, normalizeKnowledgeSourceKey(candidate)} {
token = strings.TrimSpace(token)
if token != "" && !seen[token] {
seen[token] = true
result = append(result, token)
}
}
}
return result
}
func cosineSimilarity(a []float64, b []float64) float64 {
if len(a) == 0 || len(b) == 0 {
return 0
}
n := len(a)
if len(b) < n {
n = len(b)
}
var dot, normA, normB float64
for i := 0; i < n; i++ {
dot += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
if normA == 0 || normB == 0 {
return 0
}
return dot / (math.Sqrt(normA) * math.Sqrt(normB))
}
func candidatesToKnowledgeChunks(candidates []retrievalCandidate) []KnowledgeChunk {
chunks := make([]KnowledgeChunk, 0, len(candidates))
for _, candidate := range candidates {
chunk := candidate.Chunk
chunk.Score = candidateScore(candidate)
chunks = append(chunks, chunk)
}
return chunks
}
func candidateScore(candidate retrievalCandidate) float64 {
if candidate.RerankScore > 0 {
return candidate.RerankScore
}
if candidate.FusionScore > 0 {
return candidate.FusionScore
}
if candidate.KeywordScore > candidate.VectorScore {
return candidate.KeywordScore
}
return candidate.VectorScore
}
func topCandidateRerankScore(candidates []retrievalCandidate) float64 {
for _, candidate := range candidates {
if candidate.RerankScore > 0 {
return candidate.RerankScore
}
}
return 0
}
func topChunkScore(chunks []KnowledgeChunk) float64 {
if len(chunks) == 0 {
return 0
}
return chunks[0].Score
}
func normalizedScore(score float64, maxScore float64) float64 {
if score <= 0 || maxScore <= 0 {
return 0
}
return score / maxScore
}
func limitCandidates(candidates []retrievalCandidate, limit int) []retrievalCandidate {
if limit <= 0 || len(candidates) <= limit {
return candidates
}
return candidates[:limit]
}
func limitKnowledgeChunks(chunks []KnowledgeChunk, limit int) []KnowledgeChunk {
if limit <= 0 || len(chunks) <= limit {
return chunks
}
return chunks[:limit]
}
func knowledgeSources(chunks []KnowledgeChunk) []string {
seen := make(map[string]bool)
sources := make([]string, 0, len(chunks))
for _, chunk := range chunks {
source := strings.TrimSpace(chunk.Source)
if source == "" || seen[source] {
continue
}
seen[source] = true
sources = append(sources, source)
}
return sources
}
func maxInt(a int, b int) int {
if a > b {
return a
}
return b
}