From b94b2d60c24f90c9966d2e80c7a043f984f274b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A7=89?= Date: Mon, 20 Nov 2023 04:48:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B5=8C=E5=85=A5=E6=96=87=E5=AD=97=E6=90=9C?= =?UTF-8?q?=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bin/main.go | 164 ++++++++++++---------------------------- models/elasticsearch.go | 16 ++-- models/mysql.go | 66 +++++++++++++++- 3 files changed, 125 insertions(+), 121 deletions(-) diff --git a/bin/main.go b/bin/main.go index 1d3ebfd..1f26460 100644 --- a/bin/main.go +++ b/bin/main.go @@ -55,34 +55,18 @@ func LogComponent(startTime int64, r *http.Request) { log.Println(method, url, endTime) } -type User struct { - Id int `json:"id"` - UserName string `json:"user_name"` - Avatar string `json:"avatar"` - CreateTime time.Time `json:"create_time"` - UpdateTime time.Time `json:"update_time"` -} - -type Article struct { - Id int `json:"id"` - Title string `json:"title"` - Tags string `json:"tags"` - CreateTime time.Time `json:"create_time"` - UpdateTime time.Time `json:"update_time"` -} - type Image struct { - Id int `json:"id"` - Width int `json:"width"` - Height int `json:"height"` - Content string `json:"content"` - ArticleCategoryTopId int `json:"article_category_top_id"` - PraiseCount int `json:"praise_count"` - CollectCount int `json:"collect_count"` - CreateTime time.Time `json:"createTime"` - UpdateTime time.Time `json:"updateTime"` - User User `json:"user"` - Article Article `json:"article"` + Id int `json:"id"` + Width int `json:"width"` + Height int `json:"height"` + Content string `json:"content"` + ArticleCategoryTopId int `json:"article_category_top_id"` + PraiseCount int `json:"praise_count"` + CollectCount int `json:"collect_count"` + CreateTime time.Time `json:"createTime"` + UpdateTime time.Time `json:"updateTime"` + User models.User `json:"user"` + Article models.Article `json:"article"` } type Tag struct { @@ -277,29 +261,7 @@ func main() { } return list } - - // 如果是查询 text, 直接从 Elasticsearch 返回结果 - if text := QueryConditions("text"); len(text) > 0 { - - rest := models.ElasticsearchSearch(strings.Join(text, " ")) - - // 获取图片列表 - // 是否有下一页 - //images.Next = images.Total > images.Page*images.PageSize - - // 将对象转换为有缩进的JSON输出 - data, err := json.MarshalIndent(rest, "", " ") - if err != nil { - log.Println("转换图片列表失败", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - w.Header().Set("Content-Type", "application/json; charset=UTF-8") - w.Write(data) - return - } - - // 拼接查询条件 + // 拼接基本查询条件 var addCondition = func(conditions *strings.Builder, key, column string) { if values := QueryConditions(key); len(values) > 0 { if conditions.Len() > 0 { @@ -311,10 +273,25 @@ func main() { } } var conditions strings.Builder - addCondition(&conditions, "authors", "author") - addCondition(&conditions, "tags", "tag") - addCondition(&conditions, "categories", "categorie") - addCondition(&conditions, "sets", "sets") + // 如果是查询 text, 直接从 Elasticsearch 返回结果 + var text_ids []int + if text := QueryConditions("text"); len(text) > 0 { + rest := models.ElasticsearchSearch(strings.Join(text, " ")) + for _, hit := range rest["hits"].(map[string]interface{})["hits"].([]interface{}) { + id, err := strconv.Atoi(hit.(map[string]interface{})["_id"].(string)) + if err != nil { + log.Println("strconv.Atoi failed:", err) + return + } + text_ids = append(text_ids, id) + } + conditions.WriteString(fmt.Sprintf(" WHERE id IN (%s)", strings.Trim(strings.Replace(fmt.Sprint(text_ids), " ", ",", -1), "[]"))) + } else { + addCondition(&conditions, "authors", "author") + addCondition(&conditions, "tags", "tag") + addCondition(&conditions, "categories", "categorie") + addCondition(&conditions, "sets", "sets") + } // 获取图片列表 var images ListView @@ -352,8 +329,6 @@ func main() { conditions.WriteString(fmt.Sprintf(" id IN (%s)", strings.Join(idsStr, ","))) // 拼接查询条件 } } - - // 开始查询 +" LIMIT ?, ?", (images.Page-1)*images.PageSize, images.PageSize rows, err := mysqlConnection.Database.Query(fmt.Sprintf("SELECT id, width, height, content, update_time, create_time, user_id, article_id, article_category_top_id, praise_count, collect_count FROM web_images %s", conditions.String())) if err != nil { log.Println("获取图片列表失败", err) @@ -383,34 +358,29 @@ func main() { image_list = image_list_sorted } - // 附加用户信息(第一步: 获取用户ID列表) - var user_ids []int - for _, image := range image_list { - user_ids = append(user_ids, image.User.Id) + // 如果使用了图像文字检索, 需要按照图像文字检索的相似度重新排序 text_ids + if len(text_ids) > 0 { + var image_list_sorted []Image + for _, id := range text_ids { + for _, image := range image_list { + if image.Id == int(id) { + image_list_sorted = append(image_list_sorted, image) + } + } + } + image_list = image_list_sorted } - // 附加用户信息(第二步: 获取用户信息) - var users []User - if len(user_ids) > 0 { - // 使用逗号分隔的用户ID列表查询用户信息 strings.Join(strings.Fields(fmt.Sprint(user_ids)), ",") - user_ids_str := strings.Trim(strings.Replace(fmt.Sprint(user_ids), " ", ",", -1), "[]") - rows, err := mysqlConnection.Database.Query("SELECT id, user_name, avatar, update_time, create_time FROM web_member WHERE id IN (" + user_ids_str + ")") - if err != nil { - log.Println("获取用户列表失败", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - defer rows.Close() - for rows.Next() { - var user User - rows.Scan(&user.Id, &user.UserName, &user.Avatar, &user.UpdateTime, &user.CreateTime) - user.UpdateTime = user.UpdateTime.UTC() - user.CreateTime = user.CreateTime.UTC() - users = append(users, user) - } + // 用户ID, 图集ID + var user_ids []int + var article_ids []int + for _, image := range image_list { + user_ids = append(user_ids, image.User.Id) + article_ids = append(article_ids, image.Article.Id) } // 附加用户信息(第三步: 将用户信息附加到图片信息中) + users := models.QueryUserList(user_ids) for i, image := range image_list { for _, user := range users { if image.User.Id == user.Id { @@ -419,34 +389,8 @@ func main() { } } - // 附加图片集信息(第一步: 获取图片集ID列表) - var article_ids []int - for _, image := range image_list { - article_ids = append(article_ids, image.Article.Id) - } - - // 附加图片集信息(第二步: 获取图片集信息) - var articles []Article - if len(article_ids) > 0 { - // 使用逗号分隔的图片集ID列表查询图片集信息 strings.Join(strings.Fields(fmt.Sprint(article_ids)), ",") - article_ids_str := strings.Trim(strings.Replace(fmt.Sprint(article_ids), " ", ",", -1), "[]") - rows, err := mysqlConnection.Database.Query("SELECT id, title, tags, update_time, create_time FROM web_article WHERE id IN (" + article_ids_str + ")") - if err != nil { - log.Println("获取图片集列表失败", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - defer rows.Close() - for rows.Next() { - var article Article - rows.Scan(&article.Id, &article.Title, &article.Tags, &article.UpdateTime, &article.CreateTime) - article.UpdateTime = article.UpdateTime.UTC() - article.CreateTime = article.CreateTime.UTC() - articles = append(articles, article) - } - } - // 附加图片集信息(第三步: 将图片集信息附加到图片信息中) + articles := models.QueryArticleList(article_ids) for i, image := range image_list { for _, article := range articles { if image.Article.Id == article.Id { @@ -473,14 +417,6 @@ func main() { } } - //// 获取总数 - //err = mysqlConnection.Database.QueryRow("SELECT COUNT(*) FROM web_images" + conditions.String()).Scan(&images.Total) - //if err != nil { - // log.Println("获取图片总数失败", err) - // http.Error(w, err.Error(), http.StatusBadRequest) - // return - //} - // 是否有下一页 images.Next = images.Total > images.Page*images.PageSize diff --git a/models/elasticsearch.go b/models/elasticsearch.go index 41e15fc..9f956a4 100644 --- a/models/elasticsearch.go +++ b/models/elasticsearch.go @@ -29,7 +29,7 @@ func elasticsearch_init() (es *elasticsearch.Client) { return es } -func ElasticsearchSearch(text string) interface{} { +func ElasticsearchSearch(text string) map[string]interface{} { var ( r map[string]interface{} ) @@ -44,7 +44,8 @@ func ElasticsearchSearch(text string) interface{} { }, } if err := json.NewEncoder(&buf).Encode(query); err != nil { - log.Fatalf("Error encoding query: %s", err) + log.Printf("Error encoding query: %s", err) + return nil } es := elasticsearch_init() @@ -52,24 +53,27 @@ func ElasticsearchSearch(text string) interface{} { // Perform the search request. res, err := es.Search( es.Search.WithContext(context.Background()), - es.Search.WithIndex("news"), + es.Search.WithIndex("my_index"), es.Search.WithBody(&buf), es.Search.WithTrackTotalHits(true), es.Search.WithPretty(), ) if err != nil { - log.Fatalf("Error getting response: %s", err) + log.Printf("Error getting response: %s", err) + return nil } defer res.Body.Close() // Check response status if res.IsError() { - log.Fatalf("Error: %s", res.String()) + log.Printf("Error: %s", res.String()) + return nil } // Deserialize the response into a map. if err := json.NewDecoder(res.Body).Decode(&r); err != nil { - log.Fatalf("Error parsing the response body: %s", err) + log.Printf("Error parsing the response body: %s", err) + return nil } // Print the response status, number of results, and request duration. diff --git a/models/mysql.go b/models/mysql.go index 2d03b3a..4ea1d34 100644 --- a/models/mysql.go +++ b/models/mysql.go @@ -3,8 +3,11 @@ package models import ( "database/sql" "errors" + "fmt" "log" "strconv" + "strings" + "time" _ "github.com/go-sql-driver/mysql" ) @@ -13,6 +16,8 @@ type MysqlConnection struct { Database *sql.DB } +var connection *sql.DB + // 初始化数据库连接 func (m *MysqlConnection) Init() (err error) { user := Viper.Get("mysql.user").(string) @@ -21,11 +26,12 @@ func (m *MysqlConnection) Init() (err error) { port := Viper.Get("mysql.port").(int) database := Viper.Get("mysql.database").(string) sqlconf := user + ":" + password + "@tcp(" + host + ":" + strconv.Itoa(port) + ")/" + database + "?charset=utf8mb4&parseTime=True&loc=Local" - m.Database, err = sql.Open("mysql", sqlconf) // 连接数据库 + connection, err = sql.Open("mysql", sqlconf) // 连接数据库 if err != nil { log.Println("连接数据库失败", err) return } + m.Database = connection return } @@ -75,3 +81,61 @@ func (m *MysqlConnection) GetImages(page int, size int) (images []byte, err erro images = append(images, []byte("]")...) return } + +type User struct { + Id int `json:"id"` + UserName string `json:"user_name"` + Avatar string `json:"avatar"` + CreateTime time.Time `json:"create_time"` + UpdateTime time.Time `json:"update_time"` +} + +// 获取一组用户信息 +func QueryUserList(id_list []int) (users []User) { + count := len(id_list) + if count == 0 { + return + } + idstr := strings.Trim(strings.Replace(fmt.Sprint(id_list), " ", ",", -1), "[]") + rows, err := connection.Query("SELECT id, user_name, avatar, update_time, create_time FROM web_member WHERE id IN (" + idstr + ") LIMIT " + strconv.Itoa(count)) + if err != nil { + log.Println("获取用户列表失败", err) + return + } + defer rows.Close() + for rows.Next() { + var user User + rows.Scan(&user.Id, &user.UserName, &user.Avatar, &user.UpdateTime, &user.CreateTime) + users = append(users, user) + } + return +} + +type Article struct { + Id int `json:"id"` + Title string `json:"title"` + Tags string `json:"tags"` + CreateTime time.Time `json:"create_time"` + UpdateTime time.Time `json:"update_time"` +} + +// 获取一组文章信息 +func QueryArticleList(id_list []int) (articles []Article) { + count := len(id_list) + if count == 0 { + return + } + idstr := strings.Trim(strings.Replace(fmt.Sprint(id_list), " ", ",", -1), "[]") + rows, err := connection.Query("SELECT id, title, tags, update_time, create_time FROM web_article WHERE id IN (" + idstr + ")") + if err != nil { + log.Println("获取文章列表失败", err) + return + } + defer rows.Close() + for rows.Next() { + var article Article + rows.Scan(&article.Id, &article.Title, &article.Tags, &article.UpdateTime, &article.CreateTime) + articles = append(articles, article) + } + return +}