139 lines
3.3 KiB
Go
139 lines
3.3 KiB
Go
package models
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
|
|
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
|
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
|
"github.com/spf13/viper"
|
|
|
|
lru "github.com/hashicorp/golang-lru/v2"
|
|
)
|
|
|
|
type MilvusConnection struct {
|
|
Client client.Client
|
|
}
|
|
|
|
func (m *MilvusConnection) GetClient() client.Client {
|
|
return m.Client
|
|
}
|
|
|
|
func (m *MilvusConnection) Init() (err error) {
|
|
log.Println("Milvus connection init")
|
|
os.Setenv("NO_PROXY", config.GetString("milvus.host"))
|
|
m.Client, err = client.NewGrpcClient(context.Background(), fmt.Sprintf(
|
|
"%s:%d",
|
|
config.GetString("milvus.host"),
|
|
config.GetInt("milvus.port"),
|
|
))
|
|
if err != nil {
|
|
log.Println("Milvus connection failed:", err)
|
|
return
|
|
}
|
|
log.Println("Milvus connection success")
|
|
return
|
|
}
|
|
|
|
var milvusConnection MilvusConnection
|
|
var lruCache, _ = lru.New[int, []int64](100000)
|
|
|
|
// 获取图像向量(从指定API)
|
|
func GetNetWorkEmbedding(id int) (embedding []float32) {
|
|
host := viper.GetString("embedding.host")
|
|
port := viper.GetInt("embedding.port")
|
|
httpClient := &http.Client{}
|
|
request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/api/default/%d", host, port, id), nil)
|
|
if err != nil {
|
|
log.Println("请求失败1:", err)
|
|
return
|
|
}
|
|
response, err := httpClient.Do(request)
|
|
if err != nil {
|
|
log.Println("请求失败2:", err)
|
|
return
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
var result struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Feature []float32 `json:"feature"`
|
|
}
|
|
err = json.NewDecoder(response.Body).Decode(&result)
|
|
if err != nil {
|
|
log.Println("解析失败:", err)
|
|
return
|
|
}
|
|
if result.Code != 0 {
|
|
log.Println("请求失败3:", result.Message)
|
|
return
|
|
}
|
|
return result.Feature
|
|
}
|
|
|
|
// 获取相似图像ID列
|
|
func GetSimilarImagesIdList(id int, topK int) (ids []int64) {
|
|
// 从缓存中获取
|
|
if cache, ok := lruCache.Get(id); ok {
|
|
return cache
|
|
}
|
|
|
|
// 从Milvus中获取向量
|
|
var collection_name = "default"
|
|
var embedding []float32
|
|
var ctx = context.Background()
|
|
|
|
// 先检查 milvusConnection 是否已经初始化
|
|
if milvusConnection.Client == nil {
|
|
err := milvusConnection.Init()
|
|
if err != nil {
|
|
log.Println("Milvus 初始化失败:", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", id), []string{"embedding"})
|
|
if err != nil {
|
|
log.Println("查詢向量失敗:", err)
|
|
embedding = GetNetWorkEmbedding(id)
|
|
} else {
|
|
for _, item := range result {
|
|
if item.Name() == "embedding" {
|
|
embedding = item.FieldData().GetVectors().GetFloatVector().Data
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
// 处理向量不存在的情况
|
|
if len(embedding) == 0 {
|
|
log.Println("向量不存在, 也未能重新生成")
|
|
return ids
|
|
}
|
|
|
|
// 用向量查询相似图片
|
|
topk := 200
|
|
sp, _ := entity.NewIndexIvfFlatSearchParam(64)
|
|
vectors := []entity.Vector{entity.FloatVector(embedding)}
|
|
resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp)
|
|
if err != nil {
|
|
log.Println("搜索相似失敗:", err)
|
|
return ids
|
|
}
|
|
|
|
// 输出结果
|
|
for _, item := range resultx {
|
|
ids = item.IDs.FieldData().GetScalars().GetLongData().GetData()
|
|
}
|
|
|
|
// 将结果缓存到 LRU 中
|
|
lruCache.Add(id, ids)
|
|
|
|
return ids
|
|
}
|