支持相似图像查询
This commit is contained in:
@@ -181,7 +181,8 @@ func NewSchema(config Config) (graphql.Schema, error) {
|
||||
},
|
||||
}),
|
||||
Args: graphql.FieldConfigArgument{
|
||||
"id": &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定ID的"},
|
||||
"similar": &graphql.ArgumentConfig{Type: graphql.Int, Description: "获取与指定ID图像相似的图像"},
|
||||
"id": &graphql.ArgumentConfig{Type: graphql.Int, Description: "获取指定ID的图像"},
|
||||
"width": &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定宽度的"},
|
||||
"height": &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定高度的"},
|
||||
"content": &graphql.ArgumentConfig{Type: graphql.String, Description: "筛选图像中含有指定内容的"},
|
||||
@@ -210,10 +211,11 @@ func NewSchema(config Config) (graphql.Schema, error) {
|
||||
After string
|
||||
Before string
|
||||
Text string
|
||||
Similar int
|
||||
}
|
||||
mapstructure.Decode(p.Args, &args)
|
||||
|
||||
// 返回字段
|
||||
// 处理要求返回的字段
|
||||
var fields []string
|
||||
requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections
|
||||
for _, field := range requestedFields {
|
||||
@@ -271,8 +273,20 @@ func NewSchema(config Config) (graphql.Schema, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 特殊处理 text 参数
|
||||
var id_list []string
|
||||
// 特殊处理 similar 参数
|
||||
if args.Similar != 0 {
|
||||
fmt.Println("similar:", args.Similar)
|
||||
id_list := models.GetSimilarImagesIdList(args.Similar, 200)
|
||||
fmt.Println("ids:", id_list)
|
||||
ids_str := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(id_list)), ","), "[]")
|
||||
if ids_str == "" {
|
||||
return map[string]interface{}{"list": []Image{}, "total": 0}, nil
|
||||
}
|
||||
where = append(where, fmt.Sprintf("id IN (%s) LIMIT %d", ids_str, len(id_list)))
|
||||
}
|
||||
|
||||
// 特殊处理 text 参数
|
||||
if args.Text != "" {
|
||||
resp, err := models.ZincSearch(map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
|
104
models/milvus.go
104
models/milvus.go
@@ -2,11 +2,17 @@ package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
)
|
||||
|
||||
type MilvusConnection struct {
|
||||
@@ -32,3 +38,101 @@ func (m *MilvusConnection) Init() (err error) {
|
||||
log.Println("Milvus connection success")
|
||||
return
|
||||
}
|
||||
|
||||
var milvusConnection MilvusConnection
|
||||
var lruCache, _ = lru.New[int, []int64](100000)
|
||||
|
||||
// 获取图像向量(从指定API)
|
||||
func GetNetWorkEmbedding(id int) (embedding []float32) {
|
||||
host := viper.GetString("embedding.host")
|
||||
port := viper.GetInt("embedding.port")
|
||||
httpClient := &http.Client{}
|
||||
request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/api/default/%d", host, port, id), nil)
|
||||
if err != nil {
|
||||
log.Println("请求失败1:", err)
|
||||
return
|
||||
}
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
log.Println("请求失败2:", err)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Feature []float32 `json:"feature"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&result)
|
||||
if err != nil {
|
||||
log.Println("解析失败:", err)
|
||||
return
|
||||
}
|
||||
if result.Code != 0 {
|
||||
log.Println("请求失败3:", result.Message)
|
||||
return
|
||||
}
|
||||
return result.Feature
|
||||
}
|
||||
|
||||
// 获取相似图像ID列
|
||||
func GetSimilarImagesIdList(id int, topK int) (ids []int64) {
|
||||
// 从缓存中获取
|
||||
if cache, ok := lruCache.Get(id); ok {
|
||||
return cache
|
||||
}
|
||||
|
||||
// 从Milvus中获取向量
|
||||
var collection_name = "default"
|
||||
var embedding []float32
|
||||
var ctx = context.Background()
|
||||
|
||||
// 先检查 milvusConnection 是否已经初始化
|
||||
if milvusConnection.Client == nil {
|
||||
err := milvusConnection.Init()
|
||||
if err != nil {
|
||||
log.Println("Milvus 初始化失败:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", id), []string{"embedding"})
|
||||
if err != nil {
|
||||
log.Println("查詢向量失敗:", err)
|
||||
embedding = GetNetWorkEmbedding(id)
|
||||
} else {
|
||||
for _, item := range result {
|
||||
if item.Name() == "embedding" {
|
||||
embedding = item.FieldData().GetVectors().GetFloatVector().Data
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理向量不存在的情况
|
||||
if len(embedding) == 0 {
|
||||
log.Println("向量不存在, 也未能重新生成")
|
||||
return ids
|
||||
}
|
||||
|
||||
// 用向量查询相似图片
|
||||
topk := 200
|
||||
sp, _ := entity.NewIndexIvfFlatSearchParam(64)
|
||||
vectors := []entity.Vector{entity.FloatVector(embedding)}
|
||||
resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp)
|
||||
if err != nil {
|
||||
log.Println("搜索相似失敗:", err)
|
||||
return ids
|
||||
}
|
||||
|
||||
// 输出结果
|
||||
for _, item := range resultx {
|
||||
ids = item.IDs.FieldData().GetScalars().GetLongData().GetData()
|
||||
}
|
||||
|
||||
// 将结果缓存到 LRU 中
|
||||
lruCache.Add(id, ids)
|
||||
|
||||
return ids
|
||||
}
|
||||
|
Reference in New Issue
Block a user