Files
webp/models/milvus.go
2024-07-31 18:18:03 +08:00

139 lines
3.3 KiB
Go

package models
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/spf13/viper"
lru "github.com/hashicorp/golang-lru/v2"
)
type MilvusConnection struct {
Client client.Client
}
func (m *MilvusConnection) GetClient() client.Client {
return m.Client
}
func (m *MilvusConnection) Init() (err error) {
log.Println("Milvus connection init")
os.Setenv("NO_PROXY", config.GetString("milvus.host"))
m.Client, err = client.NewGrpcClient(context.Background(), fmt.Sprintf(
"%s:%d",
config.GetString("milvus.host"),
config.GetInt("milvus.port"),
))
if err != nil {
log.Println("Milvus connection failed:", err)
return
}
log.Println("Milvus connection success")
return
}
var milvusConnection MilvusConnection
var lruCache, _ = lru.New[int, []int64](100000)
// 获取图像向量(从指定API)
func GetNetWorkEmbedding(id int) (embedding []float32) {
host := viper.GetString("embedding.host")
port := viper.GetInt("embedding.port")
httpClient := &http.Client{}
request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/api/default/%d", host, port, id), nil)
if err != nil {
log.Println("请求失败1:", err)
return
}
response, err := httpClient.Do(request)
if err != nil {
log.Println("请求失败2:", err)
return
}
defer response.Body.Close()
var result struct {
Code int `json:"code"`
Message string `json:"message"`
Feature []float32 `json:"feature"`
}
err = json.NewDecoder(response.Body).Decode(&result)
if err != nil {
log.Println("解析失败:", err)
return
}
if result.Code != 0 {
log.Println("请求失败3:", result.Message)
return
}
return result.Feature
}
// 获取相似图像ID列
func GetSimilarImagesIdList(id int, topK int) (ids []int64) {
// 从缓存中获取
if cache, ok := lruCache.Get(id); ok {
return cache
}
// 从Milvus中获取向量
var collection_name = "default"
var embedding []float32
var ctx = context.Background()
// 先检查 milvusConnection 是否已经初始化
if milvusConnection.Client == nil {
err := milvusConnection.Init()
if err != nil {
log.Println("Milvus 初始化失败:", err)
return
}
}
result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", id), []string{"embedding"})
if err != nil {
log.Println("查詢向量失敗:", err)
embedding = GetNetWorkEmbedding(id)
} else {
for _, item := range result {
if item.Name() == "embedding" {
embedding = item.FieldData().GetVectors().GetFloatVector().Data
continue
}
}
}
// 处理向量不存在的情况
if len(embedding) == 0 {
log.Println("向量不存在, 也未能重新生成")
return ids
}
// 用向量查询相似图片
topk := 200
sp, _ := entity.NewIndexIvfFlatSearchParam(64)
vectors := []entity.Vector{entity.FloatVector(embedding)}
resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp)
if err != nil {
log.Println("搜索相似失敗:", err)
return ids
}
// 输出结果
for _, item := range resultx {
ids = item.IDs.FieldData().GetScalars().GetLongData().GetData()
}
// 将结果缓存到 LRU 中
lruCache.Add(id, ids)
return ids
}