支持相似图像查询
This commit is contained in:
		
							
								
								
									
										104
									
								
								models/milvus.go
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								models/milvus.go
									
									
									
									
									
								
							@@ -2,11 +2,17 @@ package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	"github.com/milvus-io/milvus-sdk-go/v2/client"
 | 
			
		||||
	"github.com/milvus-io/milvus-sdk-go/v2/entity"
 | 
			
		||||
	"github.com/spf13/viper"
 | 
			
		||||
 | 
			
		||||
	lru "github.com/hashicorp/golang-lru/v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MilvusConnection struct {
 | 
			
		||||
@@ -32,3 +38,101 @@ func (m *MilvusConnection) Init() (err error) {
 | 
			
		||||
	log.Println("Milvus connection success")
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var milvusConnection MilvusConnection
 | 
			
		||||
var lruCache, _ = lru.New[int, []int64](100000)
 | 
			
		||||
 | 
			
		||||
// 获取图像向量(从指定API)
 | 
			
		||||
func GetNetWorkEmbedding(id int) (embedding []float32) {
 | 
			
		||||
	host := viper.GetString("embedding.host")
 | 
			
		||||
	port := viper.GetInt("embedding.port")
 | 
			
		||||
	httpClient := &http.Client{}
 | 
			
		||||
	request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/api/default/%d", host, port, id), nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("请求失败1:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	response, err := httpClient.Do(request)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("请求失败2:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer response.Body.Close()
 | 
			
		||||
 | 
			
		||||
	var result struct {
 | 
			
		||||
		Code    int       `json:"code"`
 | 
			
		||||
		Message string    `json:"message"`
 | 
			
		||||
		Feature []float32 `json:"feature"`
 | 
			
		||||
	}
 | 
			
		||||
	err = json.NewDecoder(response.Body).Decode(&result)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("解析失败:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if result.Code != 0 {
 | 
			
		||||
		log.Println("请求失败3:", result.Message)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return result.Feature
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取相似图像ID列
 | 
			
		||||
func GetSimilarImagesIdList(id int, topK int) (ids []int64) {
 | 
			
		||||
	// 从缓存中获取
 | 
			
		||||
	if cache, ok := lruCache.Get(id); ok {
 | 
			
		||||
		return cache
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 从Milvus中获取向量
 | 
			
		||||
	var collection_name = "default"
 | 
			
		||||
	var embedding []float32
 | 
			
		||||
	var ctx = context.Background()
 | 
			
		||||
 | 
			
		||||
	// 先检查 milvusConnection 是否已经初始化
 | 
			
		||||
	if milvusConnection.Client == nil {
 | 
			
		||||
		err := milvusConnection.Init()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Println("Milvus 初始化失败:", err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", id), []string{"embedding"})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("查詢向量失敗:", err)
 | 
			
		||||
		embedding = GetNetWorkEmbedding(id)
 | 
			
		||||
	} else {
 | 
			
		||||
		for _, item := range result {
 | 
			
		||||
			if item.Name() == "embedding" {
 | 
			
		||||
				embedding = item.FieldData().GetVectors().GetFloatVector().Data
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 处理向量不存在的情况
 | 
			
		||||
	if len(embedding) == 0 {
 | 
			
		||||
		log.Println("向量不存在, 也未能重新生成")
 | 
			
		||||
		return ids
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 用向量查询相似图片
 | 
			
		||||
	topk := 200
 | 
			
		||||
	sp, _ := entity.NewIndexIvfFlatSearchParam(64)
 | 
			
		||||
	vectors := []entity.Vector{entity.FloatVector(embedding)}
 | 
			
		||||
	resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("搜索相似失敗:", err)
 | 
			
		||||
		return ids
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 输出结果
 | 
			
		||||
	for _, item := range resultx {
 | 
			
		||||
		ids = item.IDs.FieldData().GetScalars().GetLongData().GetData()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 将结果缓存到 LRU 中
 | 
			
		||||
	lruCache.Add(id, ids)
 | 
			
		||||
 | 
			
		||||
	return ids
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user