支持相似图像查询
This commit is contained in:
		@@ -181,7 +181,8 @@ func NewSchema(config Config) (graphql.Schema, error) {
 | 
				
			|||||||
				},
 | 
									},
 | 
				
			||||||
			}),
 | 
								}),
 | 
				
			||||||
			Args: graphql.FieldConfigArgument{
 | 
								Args: graphql.FieldConfigArgument{
 | 
				
			||||||
				"id":            &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定ID的"},
 | 
									"similar":       &graphql.ArgumentConfig{Type: graphql.Int, Description: "获取与指定ID图像相似的图像"},
 | 
				
			||||||
 | 
									"id":            &graphql.ArgumentConfig{Type: graphql.Int, Description: "获取指定ID的图像"},
 | 
				
			||||||
				"width":         &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定宽度的"},
 | 
									"width":         &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定宽度的"},
 | 
				
			||||||
				"height":        &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定高度的"},
 | 
									"height":        &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定高度的"},
 | 
				
			||||||
				"content":       &graphql.ArgumentConfig{Type: graphql.String, Description: "筛选图像中含有指定内容的"},
 | 
									"content":       &graphql.ArgumentConfig{Type: graphql.String, Description: "筛选图像中含有指定内容的"},
 | 
				
			||||||
@@ -205,15 +206,16 @@ func NewSchema(config Config) (graphql.Schema, error) {
 | 
				
			|||||||
			Resolve: func(p graphql.ResolveParams) (interface{}, error) {
 | 
								Resolve: func(p graphql.ResolveParams) (interface{}, error) {
 | 
				
			||||||
				// 定义参数结构体
 | 
									// 定义参数结构体
 | 
				
			||||||
				var args struct {
 | 
									var args struct {
 | 
				
			||||||
					First  int
 | 
										First   int
 | 
				
			||||||
					Last   int
 | 
										Last    int
 | 
				
			||||||
					After  string
 | 
										After   string
 | 
				
			||||||
					Before string
 | 
										Before  string
 | 
				
			||||||
					Text   string
 | 
										Text    string
 | 
				
			||||||
 | 
										Similar int
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				mapstructure.Decode(p.Args, &args)
 | 
									mapstructure.Decode(p.Args, &args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// 返回字段
 | 
									// 处理要求返回的字段
 | 
				
			||||||
				var fields []string
 | 
									var fields []string
 | 
				
			||||||
				requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections
 | 
									requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections
 | 
				
			||||||
				for _, field := range requestedFields {
 | 
									for _, field := range requestedFields {
 | 
				
			||||||
@@ -271,8 +273,20 @@ func NewSchema(config Config) (graphql.Schema, error) {
 | 
				
			|||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// 特殊处理 text 参数
 | 
					 | 
				
			||||||
				var id_list []string
 | 
									var id_list []string
 | 
				
			||||||
 | 
									// 特殊处理 similar 参数
 | 
				
			||||||
 | 
									if args.Similar != 0 {
 | 
				
			||||||
 | 
										fmt.Println("similar:", args.Similar)
 | 
				
			||||||
 | 
										id_list := models.GetSimilarImagesIdList(args.Similar, 200)
 | 
				
			||||||
 | 
										fmt.Println("ids:", id_list)
 | 
				
			||||||
 | 
										ids_str := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(id_list)), ","), "[]")
 | 
				
			||||||
 | 
										if ids_str == "" {
 | 
				
			||||||
 | 
											return map[string]interface{}{"list": []Image{}, "total": 0}, nil
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										where = append(where, fmt.Sprintf("id IN (%s) LIMIT %d", ids_str, len(id_list)))
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// 特殊处理 text 参数
 | 
				
			||||||
				if args.Text != "" {
 | 
									if args.Text != "" {
 | 
				
			||||||
					resp, err := models.ZincSearch(map[string]interface{}{
 | 
										resp, err := models.ZincSearch(map[string]interface{}{
 | 
				
			||||||
						"query": map[string]interface{}{
 | 
											"query": map[string]interface{}{
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										104
									
								
								models/milvus.go
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								models/milvus.go
									
									
									
									
									
								
							@@ -2,11 +2,17 @@ package models
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/milvus-io/milvus-sdk-go/v2/client"
 | 
						"github.com/milvus-io/milvus-sdk-go/v2/client"
 | 
				
			||||||
 | 
						"github.com/milvus-io/milvus-sdk-go/v2/entity"
 | 
				
			||||||
 | 
						"github.com/spf13/viper"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						lru "github.com/hashicorp/golang-lru/v2"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type MilvusConnection struct {
 | 
					type MilvusConnection struct {
 | 
				
			||||||
@@ -32,3 +38,101 @@ func (m *MilvusConnection) Init() (err error) {
 | 
				
			|||||||
	log.Println("Milvus connection success")
 | 
						log.Println("Milvus connection success")
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var milvusConnection MilvusConnection
 | 
				
			||||||
 | 
					var lruCache, _ = lru.New[int, []int64](100000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 获取图像向量(从指定API)
 | 
				
			||||||
 | 
					func GetNetWorkEmbedding(id int) (embedding []float32) {
 | 
				
			||||||
 | 
						host := viper.GetString("embedding.host")
 | 
				
			||||||
 | 
						port := viper.GetInt("embedding.port")
 | 
				
			||||||
 | 
						httpClient := &http.Client{}
 | 
				
			||||||
 | 
						request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/api/default/%d", host, port, id), nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println("请求失败1:", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						response, err := httpClient.Do(request)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println("请求失败2:", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer response.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var result struct {
 | 
				
			||||||
 | 
							Code    int       `json:"code"`
 | 
				
			||||||
 | 
							Message string    `json:"message"`
 | 
				
			||||||
 | 
							Feature []float32 `json:"feature"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = json.NewDecoder(response.Body).Decode(&result)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println("解析失败:", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if result.Code != 0 {
 | 
				
			||||||
 | 
							log.Println("请求失败3:", result.Message)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return result.Feature
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 获取相似图像ID列
 | 
				
			||||||
 | 
					func GetSimilarImagesIdList(id int, topK int) (ids []int64) {
 | 
				
			||||||
 | 
						// 从缓存中获取
 | 
				
			||||||
 | 
						if cache, ok := lruCache.Get(id); ok {
 | 
				
			||||||
 | 
							return cache
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 从Milvus中获取向量
 | 
				
			||||||
 | 
						var collection_name = "default"
 | 
				
			||||||
 | 
						var embedding []float32
 | 
				
			||||||
 | 
						var ctx = context.Background()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 先检查 milvusConnection 是否已经初始化
 | 
				
			||||||
 | 
						if milvusConnection.Client == nil {
 | 
				
			||||||
 | 
							err := milvusConnection.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Println("Milvus 初始化失败:", err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", id), []string{"embedding"})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println("查詢向量失敗:", err)
 | 
				
			||||||
 | 
							embedding = GetNetWorkEmbedding(id)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							for _, item := range result {
 | 
				
			||||||
 | 
								if item.Name() == "embedding" {
 | 
				
			||||||
 | 
									embedding = item.FieldData().GetVectors().GetFloatVector().Data
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 处理向量不存在的情况
 | 
				
			||||||
 | 
						if len(embedding) == 0 {
 | 
				
			||||||
 | 
							log.Println("向量不存在, 也未能重新生成")
 | 
				
			||||||
 | 
							return ids
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 用向量查询相似图片
 | 
				
			||||||
 | 
						topk := 200
 | 
				
			||||||
 | 
						sp, _ := entity.NewIndexIvfFlatSearchParam(64)
 | 
				
			||||||
 | 
						vectors := []entity.Vector{entity.FloatVector(embedding)}
 | 
				
			||||||
 | 
						resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println("搜索相似失敗:", err)
 | 
				
			||||||
 | 
							return ids
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 输出结果
 | 
				
			||||||
 | 
						for _, item := range resultx {
 | 
				
			||||||
 | 
							ids = item.IDs.FieldData().GetScalars().GetLongData().GetData()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 将结果缓存到 LRU 中
 | 
				
			||||||
 | 
						lruCache.Add(id, ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return ids
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user