903 lines
31 KiB
Go
903 lines
31 KiB
Go
package main
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"mime/multipart"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"qiweimanager/config"
|
||
)
|
||
|
||
type AIResult struct {
|
||
Answer string `json:"answer"`
|
||
RawSummary string `json:"rawSummary"`
|
||
DurationMS int64 `json:"durationMs"`
|
||
}
|
||
|
||
const (
|
||
aiPromptMaxHits = 8 // 增加到8个片段,提供更多上下文
|
||
aiPromptMaxChunkRunes = 1200 // 增加到1200字,保留更多细节
|
||
aiPromptMaxContextRune = 8000 // 增加到8000字,支持更长的知识库内容
|
||
defaultAudioModel = "qwen3-asr-flash"
|
||
audioModeAuto = "auto"
|
||
audioModeOpenAIChat = "openai_audio_chat"
|
||
audioModeParaformer = "dashscope_paraformer"
|
||
audioModeTranscription = "local_openai_transcription"
|
||
audioModeCustomHTTP = "custom_http"
|
||
)
|
||
|
||
func (e *AutoReplyEngine) getConfig() config.AutoReplyConfig {
|
||
e.mu.Lock()
|
||
defer e.mu.Unlock()
|
||
cfg := e.config
|
||
if cfg.AI.TimeoutSeconds <= 0 {
|
||
cfg.AI.TimeoutSeconds = 20
|
||
}
|
||
if cfg.AI.MaxTokens <= 0 {
|
||
cfg.AI.MaxTokens = 700
|
||
}
|
||
if strings.TrimSpace(cfg.AI.ReplyDetail) == "" {
|
||
cfg.AI.ReplyDetail = "detailed"
|
||
}
|
||
if cfg.Knowledge.TopK <= 0 {
|
||
cfg.Knowledge.TopK = 3
|
||
}
|
||
if cfg.Knowledge.MinScore <= 0 {
|
||
cfg.Knowledge.MinScore = 0.40
|
||
}
|
||
if cfg.ReplyPolicy.UnknownAnswerToken == "" {
|
||
cfg.ReplyPolicy.UnknownAnswerToken = "NO_ANSWER"
|
||
}
|
||
return cfg
|
||
}
|
||
|
||
func (e *AutoReplyEngine) askAI(question string, hits []KnowledgeChunk, msg autoReplyMessage) (*AIResult, error) {
|
||
cfg := e.getConfig()
|
||
if strings.TrimSpace(cfg.AI.BaseURL) == "" {
|
||
return nil, fmt.Errorf("AI Base URL未配置")
|
||
}
|
||
if strings.TrimSpace(cfg.AI.Model) == "" {
|
||
return nil, fmt.Errorf("AI模型未配置")
|
||
}
|
||
systemPrompt := buildAutoReplySystemPrompt(cfg)
|
||
msg.ContextText = e.recentContextPrompt(msg, 6)
|
||
userPrompt := buildAutoReplyUserPrompt(question, hits, msg, cfg.ReplyPolicy.UnknownAnswerToken)
|
||
switch strings.ToLower(strings.TrimSpace(cfg.AI.Provider)) {
|
||
case "local", "ollama":
|
||
return callOllamaChat(cfg.AI, systemPrompt, userPrompt)
|
||
default:
|
||
return callOpenAICompatibleChat(cfg.AI, systemPrompt, userPrompt)
|
||
}
|
||
}
|
||
|
||
func (e *AutoReplyEngine) askGeneralAI(question string, msg autoReplyMessage) (*AIResult, error) {
|
||
cfg := e.getConfig()
|
||
if strings.TrimSpace(cfg.AI.BaseURL) == "" {
|
||
return nil, fmt.Errorf("AI Base URL未配置")
|
||
}
|
||
if strings.TrimSpace(cfg.AI.Model) == "" {
|
||
return nil, fmt.Errorf("AI模型未配置")
|
||
}
|
||
systemPrompt := buildGeneralAutoReplySystemPrompt(cfg)
|
||
msg.ContextText = e.recentContextPrompt(msg, 6)
|
||
userPrompt := buildGeneralAutoReplyUserPrompt(question, msg)
|
||
switch strings.ToLower(strings.TrimSpace(cfg.AI.Provider)) {
|
||
case "local", "ollama":
|
||
return callOllamaChat(cfg.AI, systemPrompt, userPrompt)
|
||
default:
|
||
return callOpenAICompatibleChat(cfg.AI, systemPrompt, userPrompt)
|
||
}
|
||
}
|
||
|
||
func (e *AutoReplyEngine) askNonTextAI(msg autoReplyMessage) (*AIResult, error) {
|
||
cfg := e.getConfig()
|
||
if strings.TrimSpace(cfg.AI.BaseURL) == "" {
|
||
return nil, fmt.Errorf("AI Base URL未配置")
|
||
}
|
||
if strings.TrimSpace(cfg.AI.Model) == "" {
|
||
return nil, fmt.Errorf("AI模型未配置")
|
||
}
|
||
systemPrompt := buildNonTextAutoReplySystemPrompt(cfg)
|
||
userPrompt := buildNonTextAutoReplyUserPrompt(msg)
|
||
switch strings.ToLower(strings.TrimSpace(cfg.AI.Provider)) {
|
||
case "local", "ollama":
|
||
return callOllamaChat(cfg.AI, systemPrompt, userPrompt)
|
||
default:
|
||
if mediaURL := strings.TrimSpace(msg.MediaURL); mediaURL != "" {
|
||
return callOpenAICompatibleVisionChat(cfg.AI, systemPrompt, userPrompt, mediaURL)
|
||
}
|
||
return callOpenAICompatibleChat(cfg.AI, systemPrompt, userPrompt)
|
||
}
|
||
}
|
||
|
||
func (e *AutoReplyEngine) testAIConnection() (*AIResult, error) {
|
||
testMsg := autoReplyMessage{
|
||
FromNickName: "测试客户",
|
||
ConversationID: "test",
|
||
}
|
||
hits := []KnowledgeChunk{{
|
||
Source: "test.md",
|
||
Content: "测试知识:自动客服连接测试时,请回复“连接正常”。",
|
||
Score: 1,
|
||
}}
|
||
return e.askAI("请回复连接正常", hits, testMsg)
|
||
}
|
||
|
||
func buildAutoReplySystemPrompt(cfg config.AutoReplyConfig) string {
|
||
token := cfg.ReplyPolicy.UnknownAnswerToken
|
||
if token == "" {
|
||
token = "NO_ANSWER"
|
||
}
|
||
return prependAISystemPrompt(cfg, "你是企业微信售后客服助手。只能根据提供的知识库片段回答客户问题。"+replyDetailInstruction(cfg)+"知识库不足以确定答案时,只输出 "+token+"。不要编造政策、价格、承诺、库存或物流时效。客户要求人工、投诉、退款、合同、发票、赔偿或价格特殊审批时,也只输出 "+token+"。")
|
||
}
|
||
|
||
func buildGeneralAutoReplySystemPrompt(cfg config.AutoReplyConfig) string {
|
||
token := cfg.ReplyPolicy.UnknownAnswerToken
|
||
if token == "" {
|
||
token = "NO_ANSWER"
|
||
}
|
||
return prependAISystemPrompt(cfg, "你是企业微信智能客服助手。请用中文自然、和蔼地回答普通问候、身份介绍和日常沟通问题。"+replyDetailInstruction(cfg)+"不要冒充真人,不要编造产品参数、价格、政策、库存、物流、合同、发票或售后结论。遇到需要公司专有资料、知识库、人工审批或无法确认的信息时,不要硬编,可以温和说明会按资料核对或请客户补充具体问题。不要输出 "+token+",除非客户明确要求停止回复。")
|
||
}
|
||
|
||
func buildNonTextAutoReplySystemPrompt(cfg config.AutoReplyConfig) string {
|
||
return prependAISystemPrompt(cfg, "你是企业微信客服岗位助手。用户发来非文本消息时,请根据消息类型和文字描述判断是否属于客服岗位可处理范围。范围内包括产品咨询、订单、售后、方案资料、使用问题、客户服务沟通;可回复时要自然、和蔼。"+replyDetailInstruction(cfg)+"不要编造图片里不存在的信息。若无法判断图片/表情内容,礼貌请客户补充文字说明。若明显超出客服岗位范围,只能回复:抱歉,你这问题超出我的岗位认知了,回答不了。不要主动转人工,除非客户明确要求人工。")
|
||
}
|
||
|
||
func buildVisionRecognitionSystemPrompt(cfg config.AutoReplyConfig) string {
|
||
return prependAISystemPrompt(cfg, "你是企业微信客服岗位的图片识别助手。请识别客户发来的图片/表情/封面中与客服沟通有关的内容,输出一句简洁中文描述;如果明显不是客服岗位可处理的内容,也请说明其大概内容。不要编造看不见的信息。")
|
||
}
|
||
|
||
func prependAISystemPrompt(cfg config.AutoReplyConfig, base string) string {
|
||
identity := strings.TrimSpace(cfg.AI.SystemPrompt)
|
||
if identity == "" {
|
||
identity = "你是一名企业微信智能客服。"
|
||
}
|
||
return identity + "\n" + base
|
||
}
|
||
|
||
func replyDetailInstruction(cfg config.AutoReplyConfig) string {
|
||
switch strings.ToLower(strings.TrimSpace(cfg.AI.ReplyDetail)) {
|
||
case "concise":
|
||
return "回复保持简洁,通常1-2句,约80-140个中文字符;先回答结论,必要时补一句下一步建议。"
|
||
case "medium":
|
||
return "回复详细程度适中,通常2-4句,约160-280个中文字符;先回答结论,再说明关键原因或注意事项,最后给出下一步建议。"
|
||
default:
|
||
return "回复尽量详细但不要啰嗦,通常3-6句,约280-500个中文字符;先明确回答客户问题,再结合可用资料说明关键点、适用场景或限制,最后给出具体下一步建议。"
|
||
}
|
||
}
|
||
|
||
func effectiveReplyMaxTokens(cfg config.AIConfig) int {
|
||
maxTokens := cfg.MaxTokens
|
||
switch strings.ToLower(strings.TrimSpace(cfg.ReplyDetail)) {
|
||
case "concise":
|
||
if maxTokens < 220 {
|
||
return 220
|
||
}
|
||
case "medium":
|
||
if maxTokens < 450 {
|
||
return 450
|
||
}
|
||
default:
|
||
if maxTokens < 700 {
|
||
return 700
|
||
}
|
||
}
|
||
return maxTokens
|
||
}
|
||
|
||
func buildGeneralAutoReplyUserPrompt(question string, msg autoReplyMessage) string {
|
||
var b strings.Builder
|
||
b.WriteString("客户昵称:")
|
||
if msg.FromNickName != "" {
|
||
b.WriteString(msg.FromNickName)
|
||
} else {
|
||
b.WriteString("未知")
|
||
}
|
||
b.WriteString("\n客户问题:")
|
||
b.WriteString(question)
|
||
if contextText := strings.TrimSpace(msg.ContextText); contextText != "" {
|
||
b.WriteString("\n\n最近对话上下文:\n")
|
||
b.WriteString(contextText)
|
||
}
|
||
b.WriteString("\n请直接给客户一条友好、可发送的回复。")
|
||
return b.String()
|
||
}
|
||
|
||
func buildNonTextAutoReplyUserPrompt(msg autoReplyMessage) string {
|
||
var b strings.Builder
|
||
b.WriteString("客户昵称:")
|
||
if msg.FromNickName != "" {
|
||
b.WriteString(msg.FromNickName)
|
||
} else {
|
||
b.WriteString("未知")
|
||
}
|
||
b.WriteString("\n消息类型:")
|
||
b.WriteString(msg.MessageType)
|
||
b.WriteString("\n原始类型:")
|
||
b.WriteString(fmt.Sprintf("%d", msg.RawType))
|
||
b.WriteString("\n消息描述:")
|
||
if strings.TrimSpace(msg.Content) != "" {
|
||
b.WriteString(msg.Content)
|
||
} else {
|
||
b.WriteString("无文字描述")
|
||
}
|
||
if strings.TrimSpace(msg.MediaURL) != "" {
|
||
b.WriteString("\n媒体地址:")
|
||
b.WriteString(msg.MediaURL)
|
||
}
|
||
b.WriteString("\n请直接给客户一条可发送的回复。")
|
||
return b.String()
|
||
}
|
||
|
||
func buildAutoReplyUserPrompt(question string, hits []KnowledgeChunk, msg autoReplyMessage, noAnswerToken string) string {
|
||
noAnswerToken = strings.TrimSpace(noAnswerToken)
|
||
if noAnswerToken == "" {
|
||
noAnswerToken = "NO_ANSWER"
|
||
}
|
||
var b strings.Builder
|
||
b.WriteString("客户昵称:")
|
||
if msg.FromNickName != "" {
|
||
b.WriteString(msg.FromNickName)
|
||
} else {
|
||
b.WriteString("未知")
|
||
}
|
||
b.WriteString("\n客户问题:")
|
||
b.WriteString(question)
|
||
if contextText := strings.TrimSpace(msg.ContextText); contextText != "" {
|
||
b.WriteString("\n\n最近对话上下文:\n")
|
||
b.WriteString(contextText)
|
||
}
|
||
b.WriteString("\n\n知识库片段:\n")
|
||
for i, hit := range compactKnowledgeHitsForAI(hits) {
|
||
b.WriteString(fmt.Sprintf("[%d] 来源:%s 分数:%.3f\n%s\n\n", i+1, hit.Source, hit.Score, hit.Content))
|
||
}
|
||
b.WriteString("\u53ea\u80fd\u4f7f\u7528\u4e0a\u9762\u7247\u6bb5\u4e2d\u660e\u786e\u51fa\u73b0\u7684\u4e8b\u5b9e\u56de\u7b54\uff1b\u5982\u679c\u8be2\u95ee\u90e8\u95e8\u3001\u4f1a\u8bae\u65f6\u95f4\u3001\u6807\u51c6\u6216\u89c4\u5b9a\uff0c\u53ea\u80fd\u5217\u51fa\u7247\u6bb5\u91cc\u76f4\u63a5\u51fa\u73b0\u7684\u503c\uff0c\u4e0d\u5f97\u6839\u636e\u5e38\u8bc6\u8865\u5145\u5176\u4ed6\u90e8\u95e8\u6216\u65f6\u95f4\u3002\n")
|
||
if isGenericProductQuery(question) {
|
||
b.WriteString("客户在泛问产品时,请优先按知识库列出具体产品或型号,每项用一句话说明定位,最后询问客户更关注硬件、模型还是AI应用。不要只概括为几大类。无法确认时只输出 ")
|
||
} else {
|
||
b.WriteString("请基于知识库片段回答客户。无法确认时只输出 ")
|
||
}
|
||
b.WriteString(noAnswerToken)
|
||
b.WriteString("。")
|
||
return b.String()
|
||
}
|
||
|
||
func compactKnowledgeHitsForAI(hits []KnowledgeChunk) []KnowledgeChunk {
|
||
if len(hits) == 0 {
|
||
return nil
|
||
}
|
||
limit := aiPromptMaxHits
|
||
if len(hits) < limit {
|
||
limit = len(hits)
|
||
}
|
||
result := make([]KnowledgeChunk, 0, limit)
|
||
totalRunes := 0
|
||
for i := 0; i < limit; i++ {
|
||
hit := hits[i]
|
||
content := strings.TrimSpace(hit.Content)
|
||
if content == "" {
|
||
continue
|
||
}
|
||
content = truncateTextForPrompt(content, aiPromptMaxChunkRunes)
|
||
remaining := aiPromptMaxContextRune - totalRunes
|
||
if remaining <= 0 {
|
||
break
|
||
}
|
||
if len([]rune(content)) > remaining {
|
||
content = truncateTextForPrompt(content, remaining)
|
||
}
|
||
hit.Content = content
|
||
totalRunes += len([]rune(content))
|
||
result = append(result, hit)
|
||
}
|
||
return result
|
||
}
|
||
|
||
func truncateTextForPrompt(text string, max int) string {
|
||
if max <= 0 {
|
||
return ""
|
||
}
|
||
runes := []rune(text)
|
||
if len(runes) <= max {
|
||
return text
|
||
}
|
||
return string(runes[:max])
|
||
}
|
||
|
||
func callOpenAICompatibleChat(cfg config.AIConfig, systemPrompt string, userPrompt string) (*AIResult, error) {
|
||
url := strings.TrimRight(cfg.BaseURL, "/")
|
||
if !strings.HasSuffix(url, "/chat/completions") {
|
||
url += "/chat/completions"
|
||
}
|
||
payload := map[string]interface{}{
|
||
"model": cfg.Model,
|
||
"temperature": cfg.Temperature,
|
||
"max_tokens": effectiveReplyMaxTokens(cfg),
|
||
"enable_thinking": cfg.EnableThinking,
|
||
"messages": []map[string]string{
|
||
{"role": "system", "content": systemPrompt},
|
||
{"role": "user", "content": userPrompt},
|
||
},
|
||
}
|
||
var response struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
Error interface{} `json:"error"`
|
||
}
|
||
result, err := doAIJSONRequest(cfg, url, payload, &response)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if response.Error != nil {
|
||
return nil, fmt.Errorf("AI返回错误: %v", response.Error)
|
||
}
|
||
if len(response.Choices) == 0 {
|
||
return nil, fmt.Errorf("AI返回空choices")
|
||
}
|
||
answer := strings.TrimSpace(response.Choices[0].Message.Content)
|
||
result.Answer = answer
|
||
result.RawSummary = truncateText(answer, 160)
|
||
return result, nil
|
||
}
|
||
|
||
func callOpenAICompatibleVisionChat(cfg config.AIConfig, systemPrompt string, userPrompt string, imageURL string) (*AIResult, error) {
|
||
visionCfg := visionRequestConfig(cfg)
|
||
url := strings.TrimRight(visionCfg.BaseURL, "/")
|
||
if !strings.HasSuffix(url, "/chat/completions") {
|
||
url += "/chat/completions"
|
||
}
|
||
payload := map[string]interface{}{
|
||
"model": visionCfg.Model,
|
||
"temperature": visionCfg.Temperature,
|
||
"max_tokens": visionCfg.MaxTokens,
|
||
"enable_thinking": visionCfg.EnableThinking,
|
||
"messages": []map[string]interface{}{
|
||
{"role": "system", "content": systemPrompt},
|
||
{
|
||
"role": "user",
|
||
"content": []map[string]interface{}{
|
||
{"type": "text", "text": userPrompt},
|
||
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
var response struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
Error interface{} `json:"error"`
|
||
}
|
||
result, err := doAIJSONRequest(visionCfg, url, payload, &response)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if response.Error != nil {
|
||
return nil, fmt.Errorf("AI返回错误: %v", response.Error)
|
||
}
|
||
if len(response.Choices) == 0 {
|
||
return nil, fmt.Errorf("AI返回空choices")
|
||
}
|
||
answer := strings.TrimSpace(response.Choices[0].Message.Content)
|
||
result.Answer = answer
|
||
result.RawSummary = truncateText(answer, 160)
|
||
return result, nil
|
||
}
|
||
|
||
func visionRequestConfig(cfg config.AIConfig) config.AIConfig {
|
||
visionCfg := cfg
|
||
visionCfg.Model = fallbackString(cfg.VisionModel, cfg.Model)
|
||
if strings.TrimSpace(cfg.VisionBaseURL) != "" {
|
||
visionCfg.BaseURL = strings.TrimSpace(cfg.VisionBaseURL)
|
||
}
|
||
visionKey := strings.TrimSpace(cfg.VisionAPIKey)
|
||
if visionKey != "" && !looksLikeURL(visionKey) {
|
||
visionCfg.APIKey = visionKey
|
||
}
|
||
return visionCfg
|
||
}
|
||
|
||
func callOpenAICompatibleAudioChatTranscription(cfg config.AIConfig, audioPath string) (string, error) {
|
||
audioCfg := audioRequestConfig(cfg)
|
||
audioDataURL, err := audioDataURLFromFile(audioPath)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
url := strings.TrimRight(audioCfg.BaseURL, "/")
|
||
if !strings.HasSuffix(url, "/chat/completions") {
|
||
url += "/chat/completions"
|
||
}
|
||
model := fallbackString(audioCfg.Model, defaultAudioModel)
|
||
payload := map[string]interface{}{
|
||
"model": model,
|
||
"temperature": 0,
|
||
"max_tokens": audioCfg.MaxTokens,
|
||
"enable_thinking": false,
|
||
"messages": []map[string]interface{}{
|
||
{
|
||
"role": "user",
|
||
"content": audioChatContentForModel(model, audioDataURL),
|
||
},
|
||
},
|
||
}
|
||
var response struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
Error interface{} `json:"error"`
|
||
}
|
||
if _, err := doAIJSONRequest(audioCfg, url, payload, &response); err != nil {
|
||
return "", fmt.Errorf("audio chat transcription failed (model=%s endpoint=%s): %w", audioCfg.Model, url, err)
|
||
}
|
||
if response.Error != nil {
|
||
return "", fmt.Errorf("audio chat transcription failed (model=%s endpoint=%s): %v", audioCfg.Model, url, response.Error)
|
||
}
|
||
if len(response.Choices) == 0 {
|
||
return "", fmt.Errorf("audio chat transcription failed (model=%s endpoint=%s): empty choices", audioCfg.Model, url)
|
||
}
|
||
text := strings.TrimSpace(response.Choices[0].Message.Content)
|
||
if text == "" {
|
||
return "", fmt.Errorf("audio chat transcription failed (model=%s endpoint=%s): empty text", audioCfg.Model, url)
|
||
}
|
||
return text, nil
|
||
}
|
||
|
||
func audioChatContentForModel(model string, audioDataURL string) []map[string]interface{} {
|
||
if isQwenASRModel(model) {
|
||
return []map[string]interface{}{
|
||
{"type": "input_audio", "input_audio": audioDataURL},
|
||
}
|
||
}
|
||
return []map[string]interface{}{
|
||
{"type": "text", "text": "请把这段语音转写成简体中文文本,只输出转写内容,不要解释。"},
|
||
{"type": "input_audio", "input_audio": map[string]interface{}{"data": audioDataURL}},
|
||
}
|
||
}
|
||
|
||
func isQwenASRModel(model string) bool {
|
||
name := strings.ToLower(strings.TrimSpace(model))
|
||
return strings.HasPrefix(name, "qwen3-asr") || strings.HasPrefix(name, "qwen-asr")
|
||
}
|
||
|
||
func audioRequestConfig(cfg config.AIConfig) config.AIConfig {
|
||
audioCfg := cfg
|
||
audioCfg.Model = fallbackString(cfg.AudioModel, defaultAudioModel)
|
||
if strings.TrimSpace(cfg.AudioBaseURL) != "" {
|
||
audioCfg.BaseURL = strings.TrimSpace(cfg.AudioBaseURL)
|
||
}
|
||
audioKey := strings.TrimSpace(cfg.AudioAPIKey)
|
||
if audioKey != "" && !looksLikeURL(audioKey) {
|
||
audioCfg.APIKey = audioKey
|
||
}
|
||
audioCfg.EnableThinking = false
|
||
audioCfg.Temperature = 0
|
||
return audioCfg
|
||
}
|
||
|
||
func audioConfigWarning(cfg config.AIConfig) string {
|
||
if looksLikeURL(strings.TrimSpace(cfg.AudioAPIKey)) {
|
||
return "语音 API Key 误填为 URL,已忽略该值并复用主 API Key"
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func inferAudioMode(cfg config.AIConfig) string {
|
||
mode := normalizeAudioMode(cfg.AudioMode)
|
||
if mode != audioModeAuto {
|
||
return mode
|
||
}
|
||
provider := normalizeAudioMode(cfg.AudioProvider)
|
||
if provider != audioModeAuto {
|
||
return provider
|
||
}
|
||
model := strings.ToLower(strings.TrimSpace(cfg.AudioModel))
|
||
if strings.HasPrefix(model, "paraformer") {
|
||
return audioModeParaformer
|
||
}
|
||
if strings.Contains(model, "whisper") || strings.Contains(model, "transcribe") {
|
||
return audioModeTranscription
|
||
}
|
||
return audioModeOpenAIChat
|
||
}
|
||
|
||
func normalizeAudioMode(value string) string {
|
||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||
case "", audioModeAuto:
|
||
return audioModeAuto
|
||
case "openai", "openai_chat", "audio_chat", "qwen_audio", "qwen3_asr", audioModeOpenAIChat:
|
||
return audioModeOpenAIChat
|
||
case "dashscope", "paraformer", audioModeParaformer:
|
||
return audioModeParaformer
|
||
case "transcription", "openai_transcription", "local", "local_asr", audioModeTranscription:
|
||
return audioModeTranscription
|
||
case "custom", audioModeCustomHTTP:
|
||
return audioModeCustomHTTP
|
||
default:
|
||
return audioModeAuto
|
||
}
|
||
}
|
||
|
||
func looksLikeURL(value string) bool {
|
||
value = strings.TrimSpace(value)
|
||
return strings.HasPrefix(strings.ToLower(value), "http://") || strings.HasPrefix(strings.ToLower(value), "https://")
|
||
}
|
||
|
||
func supportsSilkDirectly(cfg config.AIConfig) bool {
|
||
model := strings.ToLower(strings.TrimSpace(cfg.AudioModel))
|
||
mode := inferAudioMode(cfg)
|
||
if mode == audioModeParaformer || mode == audioModeTranscription || mode == audioModeCustomHTTP {
|
||
return false
|
||
}
|
||
return strings.Contains(model, "silk")
|
||
}
|
||
|
||
func dashScopeAPIBaseURL(cfg config.AIConfig) string {
|
||
base := strings.TrimSpace(cfg.AudioBaseURL)
|
||
if base == "" {
|
||
base = strings.TrimSpace(cfg.BaseURL)
|
||
}
|
||
if base == "" || strings.Contains(base, "/compatible-mode/") {
|
||
return "https://dashscope.aliyuncs.com/api/v1"
|
||
}
|
||
base = strings.TrimRight(base, "/")
|
||
if strings.HasSuffix(base, "/services/audio/asr/transcription") {
|
||
return strings.TrimSuffix(base, "/services/audio/asr/transcription")
|
||
}
|
||
if strings.Contains(base, "/api/v1/") {
|
||
return strings.Split(base, "/api/v1/")[0] + "/api/v1"
|
||
}
|
||
if strings.HasSuffix(base, "/api/v1") {
|
||
return base
|
||
}
|
||
return base
|
||
}
|
||
|
||
func callOpenAICompatibleAudioTranscription(cfg config.AIConfig, audioPath string) (string, error) {
|
||
cfg = audioRequestConfig(cfg)
|
||
url := strings.TrimRight(cfg.BaseURL, "/")
|
||
if !strings.HasSuffix(url, "/audio/transcriptions") {
|
||
url += "/audio/transcriptions"
|
||
}
|
||
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 20 * time.Second
|
||
}
|
||
file, err := os.Open(audioPath)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer file.Close()
|
||
body := &bytes.Buffer{}
|
||
writer := multipart.NewWriter(body)
|
||
if err := writer.WriteField("model", cfg.Model); err != nil {
|
||
return "", err
|
||
}
|
||
part, err := writer.CreateFormFile("file", filepath.Base(audioPath))
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
if _, err := io.Copy(part, file); err != nil {
|
||
return "", err
|
||
}
|
||
if err := writer.Close(); err != nil {
|
||
return "", err
|
||
}
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
req, err := http.NewRequestWithContext(ctx, "POST", url, body)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||
if strings.TrimSpace(cfg.APIKey) != "" {
|
||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(cfg.APIKey))
|
||
}
|
||
resp, err := (&http.Client{Timeout: timeout}).Do(req)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer resp.Body.Close()
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
return "", fmt.Errorf("audio transcription failed (model=%s endpoint=%s): HTTP status %d, body=%s", cfg.Model, url, resp.StatusCode, truncateText(string(respBody), 240))
|
||
}
|
||
var parsed struct {
|
||
Text string `json:"text"`
|
||
Error interface{} `json:"error"`
|
||
}
|
||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||
return "", fmt.Errorf("parse audio transcription failed (model=%s endpoint=%s): %v, body=%s", cfg.Model, url, err, truncateText(string(respBody), 240))
|
||
}
|
||
if parsed.Error != nil {
|
||
return "", fmt.Errorf("audio transcription failed (model=%s endpoint=%s): %v", cfg.Model, url, parsed.Error)
|
||
}
|
||
text := strings.TrimSpace(parsed.Text)
|
||
if text == "" {
|
||
return "", fmt.Errorf("audio transcription failed (model=%s endpoint=%s): empty text", cfg.Model, url)
|
||
}
|
||
return text, nil
|
||
}
|
||
|
||
func callDashScopeParaformerTranscription(cfg config.AIConfig, fileURL string) (string, error) {
|
||
cfg = audioRequestConfig(cfg)
|
||
fileURL = strings.TrimSpace(fileURL)
|
||
if fileURL == "" {
|
||
return "", fmt.Errorf("paraformer transcription failed (model=%s): 需要公网可访问的音频 URL,本地文件不能直接提交给 Paraformer RESTful 接口", cfg.Model)
|
||
}
|
||
parsedURL, err := url.Parse(fileURL)
|
||
if err != nil || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https" && parsedURL.Scheme != "oss") {
|
||
return "", fmt.Errorf("paraformer transcription failed (model=%s): 音频 URL 无效", cfg.Model)
|
||
}
|
||
base := dashScopeAPIBaseURL(cfg)
|
||
submitURL := strings.TrimRight(base, "/") + "/services/audio/asr/transcription"
|
||
payload := map[string]interface{}{
|
||
"model": fallbackString(cfg.Model, "paraformer-v2"),
|
||
"input": map[string]interface{}{
|
||
"file_urls": []string{fileURL},
|
||
},
|
||
"parameters": map[string]interface{}{
|
||
"channel_id": []int{0},
|
||
"language_hints": []string{"zh", "en"},
|
||
},
|
||
}
|
||
var submitResp struct {
|
||
Output struct {
|
||
TaskID string `json:"task_id"`
|
||
TaskStatus string `json:"task_status"`
|
||
} `json:"output"`
|
||
Code string `json:"code"`
|
||
Message string `json:"message"`
|
||
}
|
||
if err := doDashScopeJSONRequest(cfg, submitURL, "POST", payload, true, &submitResp); err != nil {
|
||
return "", fmt.Errorf("paraformer transcription submit failed (model=%s endpoint=%s): %w", cfg.Model, submitURL, err)
|
||
}
|
||
if submitResp.Code != "" || submitResp.Message != "" {
|
||
return "", fmt.Errorf("paraformer transcription submit failed (model=%s endpoint=%s): %s %s", cfg.Model, submitURL, submitResp.Code, submitResp.Message)
|
||
}
|
||
taskID := strings.TrimSpace(submitResp.Output.TaskID)
|
||
if taskID == "" {
|
||
return "", fmt.Errorf("paraformer transcription submit failed (model=%s endpoint=%s): empty task_id", cfg.Model, submitURL)
|
||
}
|
||
return waitDashScopeParaformerTask(cfg, base, taskID)
|
||
}
|
||
|
||
func waitDashScopeParaformerTask(cfg config.AIConfig, base string, taskID string) (string, error) {
|
||
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 20 * time.Second
|
||
}
|
||
deadline := time.Now().Add(timeout)
|
||
queryURL := strings.TrimRight(base, "/") + "/tasks/" + url.PathEscape(taskID)
|
||
var lastStatus string
|
||
for time.Now().Before(deadline) {
|
||
var queryResp struct {
|
||
Output struct {
|
||
TaskStatus string `json:"task_status"`
|
||
Results []struct {
|
||
FileURL string `json:"file_url"`
|
||
TranscriptionURL string `json:"transcription_url"`
|
||
SubtaskStatus string `json:"subtask_status"`
|
||
Code string `json:"code"`
|
||
Message string `json:"message"`
|
||
} `json:"results"`
|
||
} `json:"output"`
|
||
Code string `json:"code"`
|
||
Message string `json:"message"`
|
||
}
|
||
if err := doDashScopeJSONRequest(cfg, queryURL, "GET", nil, false, &queryResp); err != nil {
|
||
return "", fmt.Errorf("paraformer transcription query failed (model=%s endpoint=%s task=%s): %w", cfg.Model, queryURL, taskID, err)
|
||
}
|
||
if queryResp.Code != "" || queryResp.Message != "" {
|
||
return "", fmt.Errorf("paraformer transcription query failed (model=%s endpoint=%s task=%s): %s %s", cfg.Model, queryURL, taskID, queryResp.Code, queryResp.Message)
|
||
}
|
||
lastStatus = strings.ToUpper(strings.TrimSpace(queryResp.Output.TaskStatus))
|
||
switch lastStatus {
|
||
case "SUCCEEDED":
|
||
for _, result := range queryResp.Output.Results {
|
||
if strings.EqualFold(result.SubtaskStatus, "SUCCEEDED") && strings.TrimSpace(result.TranscriptionURL) != "" {
|
||
return downloadDashScopeTranscriptionResult(cfg, result.TranscriptionURL)
|
||
}
|
||
if result.Code != "" || result.Message != "" {
|
||
return "", fmt.Errorf("paraformer transcription subtask failed (model=%s task=%s): %s %s", cfg.Model, taskID, result.Code, result.Message)
|
||
}
|
||
}
|
||
return "", fmt.Errorf("paraformer transcription finished without usable result (model=%s task=%s)", cfg.Model, taskID)
|
||
case "FAILED", "CANCELED", "UNKNOWN":
|
||
return "", fmt.Errorf("paraformer transcription task failed (model=%s task=%s status=%s)", cfg.Model, taskID, lastStatus)
|
||
}
|
||
time.Sleep(500 * time.Millisecond)
|
||
}
|
||
return "", fmt.Errorf("paraformer transcription timed out (model=%s task=%s last_status=%s)", cfg.Model, taskID, lastStatus)
|
||
}
|
||
|
||
func downloadDashScopeTranscriptionResult(cfg config.AIConfig, resultURL string) (string, error) {
|
||
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 20 * time.Second
|
||
}
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
req, err := http.NewRequestWithContext(ctx, "GET", resultURL, nil)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
resp, err := (&http.Client{Timeout: timeout}).Do(req)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer resp.Body.Close()
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
return "", fmt.Errorf("download paraformer result failed: HTTP status %d, body=%s", resp.StatusCode, truncateText(string(respBody), 240))
|
||
}
|
||
var parsed struct {
|
||
Transcripts []struct {
|
||
Text string `json:"text"`
|
||
} `json:"transcripts"`
|
||
}
|
||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||
return "", fmt.Errorf("parse paraformer result failed: %v, body=%s", err, truncateText(string(respBody), 240))
|
||
}
|
||
parts := make([]string, 0, len(parsed.Transcripts))
|
||
for _, transcript := range parsed.Transcripts {
|
||
if text := strings.TrimSpace(transcript.Text); text != "" {
|
||
parts = append(parts, text)
|
||
}
|
||
}
|
||
text := strings.TrimSpace(strings.Join(parts, "\n"))
|
||
if text == "" {
|
||
return "", fmt.Errorf("paraformer result returned empty text")
|
||
}
|
||
return text, nil
|
||
}
|
||
|
||
func doDashScopeJSONRequest(cfg config.AIConfig, endpoint string, method string, payload interface{}, async bool, out interface{}) error {
|
||
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 20 * time.Second
|
||
}
|
||
var body io.Reader
|
||
if payload != nil {
|
||
data, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
body = bytes.NewBuffer(data)
|
||
}
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
req, err := http.NewRequestWithContext(ctx, method, endpoint, body)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if payload != nil {
|
||
req.Header.Set("Content-Type", "application/json")
|
||
}
|
||
if async {
|
||
req.Header.Set("X-DashScope-Async", "enable")
|
||
}
|
||
if strings.TrimSpace(cfg.APIKey) != "" {
|
||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(cfg.APIKey))
|
||
}
|
||
resp, err := (&http.Client{Timeout: timeout}).Do(req)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
return fmt.Errorf("HTTP status %d, body=%s", resp.StatusCode, truncateText(string(respBody), 240))
|
||
}
|
||
if err := json.Unmarshal(respBody, out); err != nil {
|
||
return fmt.Errorf("parse response failed: %v, body=%s", err, truncateText(string(respBody), 240))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func callOllamaChat(cfg config.AIConfig, systemPrompt string, userPrompt string) (*AIResult, error) {
|
||
url := strings.TrimRight(cfg.BaseURL, "/")
|
||
if !strings.HasSuffix(url, "/api/chat") {
|
||
url += "/api/chat"
|
||
}
|
||
payload := map[string]interface{}{
|
||
"model": cfg.Model,
|
||
"stream": false,
|
||
"messages": []map[string]string{
|
||
{"role": "system", "content": systemPrompt},
|
||
{"role": "user", "content": userPrompt},
|
||
},
|
||
"options": map[string]interface{}{
|
||
"temperature": cfg.Temperature,
|
||
"num_predict": effectiveReplyMaxTokens(cfg),
|
||
},
|
||
}
|
||
var response struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
Response string `json:"response"`
|
||
Error string `json:"error"`
|
||
}
|
||
result, err := doAIJSONRequest(cfg, url, payload, &response)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if response.Error != "" {
|
||
return nil, fmt.Errorf("本地模型返回错误: %s", response.Error)
|
||
}
|
||
answer := strings.TrimSpace(response.Message.Content)
|
||
if answer == "" {
|
||
answer = strings.TrimSpace(response.Response)
|
||
}
|
||
if answer == "" {
|
||
return nil, fmt.Errorf("本地模型返回空内容")
|
||
}
|
||
result.Answer = answer
|
||
result.RawSummary = truncateText(answer, 160)
|
||
return result, nil
|
||
}
|
||
|
||
func doAIJSONRequest(cfg config.AIConfig, url string, payload interface{}, out interface{}) (*AIResult, error) {
|
||
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 20 * time.Second
|
||
}
|
||
start := time.Now()
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
body, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
if strings.TrimSpace(cfg.APIKey) != "" {
|
||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(cfg.APIKey))
|
||
}
|
||
client := &http.Client{Timeout: timeout}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
return nil, fmt.Errorf("AI HTTP状态码错误: %d, body=%s", resp.StatusCode, truncateText(string(respBody), 240))
|
||
}
|
||
if err := json.Unmarshal(respBody, out); err != nil {
|
||
return nil, fmt.Errorf("解析AI响应失败: %v, body=%s", err, truncateText(string(respBody), 240))
|
||
}
|
||
return &AIResult{DurationMS: time.Since(start).Milliseconds()}, nil
|
||
}
|