diff --git a/bin/main.go b/bin/main.go index 25898df..341e101 100644 --- a/bin/main.go +++ b/bin/main.go @@ -164,7 +164,7 @@ func (image *Image) GetSimilarImagesIdList(collection_name string) (ids []int64) } // 用向量查询相似图片 - topk := 1000 + topk := 200 sp, _ := entity.NewIndexIvfFlatSearchParam(64) vectors := []entity.Vector{entity.FloatVector(embedding)} resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp) @@ -278,7 +278,7 @@ func main() { // 排序 // 分页 - // 获取查询条件(忽略空值), 超级简洁写法 + // 获取查询条件(忽略空值) QueryConditions := func(key string) (list []string) { for _, item := range strings.Split(r.URL.Query().Get(key), ",") { if item != "" { @@ -288,7 +288,7 @@ func main() { return list } - // 拼接查询条件, 超级简洁写法 + // 拼接查询条件 var addCondition = func(conditions *strings.Builder, key, column string) { if values := QueryConditions(key); len(values) > 0 { if conditions.Len() > 0 { @@ -305,6 +305,11 @@ func main() { addCondition(&conditions, "categories", "categorie") addCondition(&conditions, "sets", "sets") + // 获取图片列表 + var images ListView + var image_list []Image + images.Page, images.PageSize = stringToInt(r.URL.Query().Get("page"), 1), stringToInt(r.URL.Query().Get("pageSize"), 20) + var ids []int64 if similar := QueryConditions("similar"); len(similar) > 0 { id, err := strconv.Atoi(strings.Trim(similar[0], "'")) @@ -314,6 +319,15 @@ func main() { } // 如果指定以某个图片为基准的相似图片列表范围, 获取相似图片ID的列表 ids = (&Image{Id: id}).GetSimilarImagesIdList("default") + images.Total = len(ids) + + // 按照分页取相应的图片ID + if len(ids) > images.Page*images.PageSize { + ids = ids[(images.Page-1)*images.PageSize : images.Page*images.PageSize] + } else { + ids = ids[(images.Page-1)*images.PageSize:] + } + idsStr := make([]string, len(ids)) for i, v := range ids { idsStr[i] = strconv.FormatInt(v, 10) @@ -328,11 +342,8 @@ func main() { } } - // 获取图片列表 - var images ListView - var image_list []Image - images.Page, images.PageSize = stringToInt(r.URL.Query().Get("page"), 1), stringToInt(r.URL.Query().Get("pageSize"), 20) - rows, err := mysqlConnection.Database.Query("SELECT id, width, height, content, update_time, create_time, user_id, article_id, article_category_top_id, praise_count, collect_count FROM web_images"+conditions.String()+" LIMIT ?, ?", (images.Page-1)*images.PageSize, images.PageSize) + // 开始查询 +" LIMIT ?, ?", (images.Page-1)*images.PageSize, images.PageSize + rows, err := mysqlConnection.Database.Query(fmt.Sprintf("SELECT id, width, height, content, update_time, create_time, user_id, article_id, article_category_top_id, praise_count, collect_count FROM web_images %s", conditions.String())) if err != nil { log.Println("获取图片列表失败", err) http.Error(w, err.Error(), http.StatusBadRequest) @@ -439,14 +450,26 @@ func main() { images.List[i] = v } - // 获取总数 - err = mysqlConnection.Database.QueryRow("SELECT COUNT(*) FROM web_images" + conditions.String()).Scan(&images.Total) - if err != nil { - log.Println("获取图片总数失败", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return + if similar := QueryConditions("similar"); len(similar) > 0 { + // 总数不变 + } else { + // 获取总数 + err = mysqlConnection.Database.QueryRow("SELECT COUNT(*) FROM web_images" + conditions.String()).Scan(&images.Total) + if err != nil { + log.Println("获取图片总数失败", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } } + //// 获取总数 + //err = mysqlConnection.Database.QueryRow("SELECT COUNT(*) FROM web_images" + conditions.String()).Scan(&images.Total) + //if err != nil { + // log.Println("获取图片总数失败", err) + // http.Error(w, err.Error(), http.StatusBadRequest) + // return + //} + // 是否有下一页 images.Next = images.Total > images.Page*images.PageSize