diff --git a/api/graphql.go b/api/graphql.go index fcfa8d4..07793b6 100644 --- a/api/graphql.go +++ b/api/graphql.go @@ -181,7 +181,8 @@ func NewSchema(config Config) (graphql.Schema, error) { }, }), Args: graphql.FieldConfigArgument{ - "id": &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定ID的"}, + "similar": &graphql.ArgumentConfig{Type: graphql.Int, Description: "获取与指定ID图像相似的图像"}, + "id": &graphql.ArgumentConfig{Type: graphql.Int, Description: "获取指定ID的图像"}, "width": &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定宽度的"}, "height": &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定高度的"}, "content": &graphql.ArgumentConfig{Type: graphql.String, Description: "筛选图像中含有指定内容的"}, @@ -205,15 +206,16 @@ func NewSchema(config Config) (graphql.Schema, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) { // 定义参数结构体 var args struct { - First int - Last int - After string - Before string - Text string + First int + Last int + After string + Before string + Text string + Similar int } mapstructure.Decode(p.Args, &args) - // 返回字段 + // 处理要求返回的字段 var fields []string requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections for _, field := range requestedFields { @@ -271,8 +273,20 @@ func NewSchema(config Config) (graphql.Schema, error) { } } - // 特殊处理 text 参数 var id_list []string + // 特殊处理 similar 参数 + if args.Similar != 0 { + fmt.Println("similar:", args.Similar) + id_list := models.GetSimilarImagesIdList(args.Similar, 200) + fmt.Println("ids:", id_list) + ids_str := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(id_list)), ","), "[]") + if ids_str == "" { + return map[string]interface{}{"list": []Image{}, "total": 0}, nil + } + where = append(where, fmt.Sprintf("id IN (%s) LIMIT %d", ids_str, len(id_list))) + } + + // 特殊处理 text 参数 if args.Text != "" { resp, err := models.ZincSearch(map[string]interface{}{ "query": map[string]interface{}{ diff --git a/models/milvus.go b/models/milvus.go index 2620a37..18784a5 100644 --- a/models/milvus.go +++ b/models/milvus.go @@ -2,11 +2,17 @@ package models import ( "context" + "encoding/json" "fmt" "log" + "net/http" "os" "github.com/milvus-io/milvus-sdk-go/v2/client" + "github.com/milvus-io/milvus-sdk-go/v2/entity" + "github.com/spf13/viper" + + lru "github.com/hashicorp/golang-lru/v2" ) type MilvusConnection struct { @@ -32,3 +38,101 @@ func (m *MilvusConnection) Init() (err error) { log.Println("Milvus connection success") return } + +var milvusConnection MilvusConnection +var lruCache, _ = lru.New[int, []int64](100000) + +// 获取图像向量(从指定API) +func GetNetWorkEmbedding(id int) (embedding []float32) { + host := viper.GetString("embedding.host") + port := viper.GetInt("embedding.port") + httpClient := &http.Client{} + request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/api/default/%d", host, port, id), nil) + if err != nil { + log.Println("请求失败1:", err) + return + } + response, err := httpClient.Do(request) + if err != nil { + log.Println("请求失败2:", err) + return + } + defer response.Body.Close() + + var result struct { + Code int `json:"code"` + Message string `json:"message"` + Feature []float32 `json:"feature"` + } + err = json.NewDecoder(response.Body).Decode(&result) + if err != nil { + log.Println("解析失败:", err) + return + } + if result.Code != 0 { + log.Println("请求失败3:", result.Message) + return + } + return result.Feature +} + +// 获取相似图像ID列 +func GetSimilarImagesIdList(id int, topK int) (ids []int64) { + // 从缓存中获取 + if cache, ok := lruCache.Get(id); ok { + return cache + } + + // 从Milvus中获取向量 + var collection_name = "default" + var embedding []float32 + var ctx = context.Background() + + // 先检查 milvusConnection 是否已经初始化 + if milvusConnection.Client == nil { + err := milvusConnection.Init() + if err != nil { + log.Println("Milvus 初始化失败:", err) + return + } + } + + result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", id), []string{"embedding"}) + if err != nil { + log.Println("查詢向量失敗:", err) + embedding = GetNetWorkEmbedding(id) + } else { + for _, item := range result { + if item.Name() == "embedding" { + embedding = item.FieldData().GetVectors().GetFloatVector().Data + continue + } + } + } + + // 处理向量不存在的情况 + if len(embedding) == 0 { + log.Println("向量不存在, 也未能重新生成") + return ids + } + + // 用向量查询相似图片 + topk := 200 + sp, _ := entity.NewIndexIvfFlatSearchParam(64) + vectors := []entity.Vector{entity.FloatVector(embedding)} + resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp) + if err != nil { + log.Println("搜索相似失敗:", err) + return ids + } + + // 输出结果 + for _, item := range resultx { + ids = item.IDs.FieldData().GetScalars().GetLongData().GetData() + } + + // 将结果缓存到 LRU 中 + lruCache.Add(id, ids) + + return ids +}