嵌入文字搜索

This commit is contained in:
2023-11-20 04:48:55 +08:00
parent fdb3edfb61
commit b94b2d60c2
3 changed files with 125 additions and 121 deletions

View File

@@ -55,34 +55,18 @@ func LogComponent(startTime int64, r *http.Request) {
log.Println(method, url, endTime) log.Println(method, url, endTime)
} }
type User struct {
Id int `json:"id"`
UserName string `json:"user_name"`
Avatar string `json:"avatar"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
}
type Article struct {
Id int `json:"id"`
Title string `json:"title"`
Tags string `json:"tags"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
}
type Image struct { type Image struct {
Id int `json:"id"` Id int `json:"id"`
Width int `json:"width"` Width int `json:"width"`
Height int `json:"height"` Height int `json:"height"`
Content string `json:"content"` Content string `json:"content"`
ArticleCategoryTopId int `json:"article_category_top_id"` ArticleCategoryTopId int `json:"article_category_top_id"`
PraiseCount int `json:"praise_count"` PraiseCount int `json:"praise_count"`
CollectCount int `json:"collect_count"` CollectCount int `json:"collect_count"`
CreateTime time.Time `json:"createTime"` CreateTime time.Time `json:"createTime"`
UpdateTime time.Time `json:"updateTime"` UpdateTime time.Time `json:"updateTime"`
User User `json:"user"` User models.User `json:"user"`
Article Article `json:"article"` Article models.Article `json:"article"`
} }
type Tag struct { type Tag struct {
@@ -277,29 +261,7 @@ func main() {
} }
return list return list
} }
// 拼接基本查询条件
// 如果是查询 text, 直接从 Elasticsearch 返回结果
if text := QueryConditions("text"); len(text) > 0 {
rest := models.ElasticsearchSearch(strings.Join(text, " "))
// 获取图片列表
// 是否有下一页
//images.Next = images.Total > images.Page*images.PageSize
// 将对象转换为有缩进的JSON输出
data, err := json.MarshalIndent(rest, "", " ")
if err != nil {
log.Println("转换图片列表失败", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.Write(data)
return
}
// 拼接查询条件
var addCondition = func(conditions *strings.Builder, key, column string) { var addCondition = func(conditions *strings.Builder, key, column string) {
if values := QueryConditions(key); len(values) > 0 { if values := QueryConditions(key); len(values) > 0 {
if conditions.Len() > 0 { if conditions.Len() > 0 {
@@ -311,10 +273,25 @@ func main() {
} }
} }
var conditions strings.Builder var conditions strings.Builder
addCondition(&conditions, "authors", "author") // 如果是查询 text, 直接从 Elasticsearch 返回结果
addCondition(&conditions, "tags", "tag") var text_ids []int
addCondition(&conditions, "categories", "categorie") if text := QueryConditions("text"); len(text) > 0 {
addCondition(&conditions, "sets", "sets") rest := models.ElasticsearchSearch(strings.Join(text, " "))
for _, hit := range rest["hits"].(map[string]interface{})["hits"].([]interface{}) {
id, err := strconv.Atoi(hit.(map[string]interface{})["_id"].(string))
if err != nil {
log.Println("strconv.Atoi failed:", err)
return
}
text_ids = append(text_ids, id)
}
conditions.WriteString(fmt.Sprintf(" WHERE id IN (%s)", strings.Trim(strings.Replace(fmt.Sprint(text_ids), " ", ",", -1), "[]")))
} else {
addCondition(&conditions, "authors", "author")
addCondition(&conditions, "tags", "tag")
addCondition(&conditions, "categories", "categorie")
addCondition(&conditions, "sets", "sets")
}
// 获取图片列表 // 获取图片列表
var images ListView var images ListView
@@ -352,8 +329,6 @@ func main() {
conditions.WriteString(fmt.Sprintf(" id IN (%s)", strings.Join(idsStr, ","))) // 拼接查询条件 conditions.WriteString(fmt.Sprintf(" id IN (%s)", strings.Join(idsStr, ","))) // 拼接查询条件
} }
} }
// 开始查询 +" LIMIT ?, ?", (images.Page-1)*images.PageSize, images.PageSize
rows, err := mysqlConnection.Database.Query(fmt.Sprintf("SELECT id, width, height, content, update_time, create_time, user_id, article_id, article_category_top_id, praise_count, collect_count FROM web_images %s", conditions.String())) rows, err := mysqlConnection.Database.Query(fmt.Sprintf("SELECT id, width, height, content, update_time, create_time, user_id, article_id, article_category_top_id, praise_count, collect_count FROM web_images %s", conditions.String()))
if err != nil { if err != nil {
log.Println("获取图片列表失败", err) log.Println("获取图片列表失败", err)
@@ -383,34 +358,29 @@ func main() {
image_list = image_list_sorted image_list = image_list_sorted
} }
// 附加用户信息(第一步: 获取用户ID列表) // 如果使用了图像文字检索, 需要按照图像文字检索的相似度重新排序 text_ids
var user_ids []int if len(text_ids) > 0 {
for _, image := range image_list { var image_list_sorted []Image
user_ids = append(user_ids, image.User.Id) for _, id := range text_ids {
for _, image := range image_list {
if image.Id == int(id) {
image_list_sorted = append(image_list_sorted, image)
}
}
}
image_list = image_list_sorted
} }
// 附加用户信息(第二步: 获取用户信息) // 用户ID, 图集ID
var users []User var user_ids []int
if len(user_ids) > 0 { var article_ids []int
// 使用逗号分隔的用户ID列表查询用户信息 strings.Join(strings.Fields(fmt.Sprint(user_ids)), ",") for _, image := range image_list {
user_ids_str := strings.Trim(strings.Replace(fmt.Sprint(user_ids), " ", ",", -1), "[]") user_ids = append(user_ids, image.User.Id)
rows, err := mysqlConnection.Database.Query("SELECT id, user_name, avatar, update_time, create_time FROM web_member WHERE id IN (" + user_ids_str + ")") article_ids = append(article_ids, image.Article.Id)
if err != nil {
log.Println("获取用户列表失败", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer rows.Close()
for rows.Next() {
var user User
rows.Scan(&user.Id, &user.UserName, &user.Avatar, &user.UpdateTime, &user.CreateTime)
user.UpdateTime = user.UpdateTime.UTC()
user.CreateTime = user.CreateTime.UTC()
users = append(users, user)
}
} }
// 附加用户信息(第三步: 将用户信息附加到图片信息中) // 附加用户信息(第三步: 将用户信息附加到图片信息中)
users := models.QueryUserList(user_ids)
for i, image := range image_list { for i, image := range image_list {
for _, user := range users { for _, user := range users {
if image.User.Id == user.Id { if image.User.Id == user.Id {
@@ -419,34 +389,8 @@ func main() {
} }
} }
// 附加图片集信息(第一步: 获取图片集ID列表)
var article_ids []int
for _, image := range image_list {
article_ids = append(article_ids, image.Article.Id)
}
// 附加图片集信息(第二步: 获取图片集信息)
var articles []Article
if len(article_ids) > 0 {
// 使用逗号分隔的图片集ID列表查询图片集信息 strings.Join(strings.Fields(fmt.Sprint(article_ids)), ",")
article_ids_str := strings.Trim(strings.Replace(fmt.Sprint(article_ids), " ", ",", -1), "[]")
rows, err := mysqlConnection.Database.Query("SELECT id, title, tags, update_time, create_time FROM web_article WHERE id IN (" + article_ids_str + ")")
if err != nil {
log.Println("获取图片集列表失败", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer rows.Close()
for rows.Next() {
var article Article
rows.Scan(&article.Id, &article.Title, &article.Tags, &article.UpdateTime, &article.CreateTime)
article.UpdateTime = article.UpdateTime.UTC()
article.CreateTime = article.CreateTime.UTC()
articles = append(articles, article)
}
}
// 附加图片集信息(第三步: 将图片集信息附加到图片信息中) // 附加图片集信息(第三步: 将图片集信息附加到图片信息中)
articles := models.QueryArticleList(article_ids)
for i, image := range image_list { for i, image := range image_list {
for _, article := range articles { for _, article := range articles {
if image.Article.Id == article.Id { if image.Article.Id == article.Id {
@@ -473,14 +417,6 @@ func main() {
} }
} }
//// 获取总数
//err = mysqlConnection.Database.QueryRow("SELECT COUNT(*) FROM web_images" + conditions.String()).Scan(&images.Total)
//if err != nil {
// log.Println("获取图片总数失败", err)
// http.Error(w, err.Error(), http.StatusBadRequest)
// return
//}
// 是否有下一页 // 是否有下一页
images.Next = images.Total > images.Page*images.PageSize images.Next = images.Total > images.Page*images.PageSize

View File

@@ -29,7 +29,7 @@ func elasticsearch_init() (es *elasticsearch.Client) {
return es return es
} }
func ElasticsearchSearch(text string) interface{} { func ElasticsearchSearch(text string) map[string]interface{} {
var ( var (
r map[string]interface{} r map[string]interface{}
) )
@@ -44,7 +44,8 @@ func ElasticsearchSearch(text string) interface{} {
}, },
} }
if err := json.NewEncoder(&buf).Encode(query); err != nil { if err := json.NewEncoder(&buf).Encode(query); err != nil {
log.Fatalf("Error encoding query: %s", err) log.Printf("Error encoding query: %s", err)
return nil
} }
es := elasticsearch_init() es := elasticsearch_init()
@@ -52,24 +53,27 @@ func ElasticsearchSearch(text string) interface{} {
// Perform the search request. // Perform the search request.
res, err := es.Search( res, err := es.Search(
es.Search.WithContext(context.Background()), es.Search.WithContext(context.Background()),
es.Search.WithIndex("news"), es.Search.WithIndex("my_index"),
es.Search.WithBody(&buf), es.Search.WithBody(&buf),
es.Search.WithTrackTotalHits(true), es.Search.WithTrackTotalHits(true),
es.Search.WithPretty(), es.Search.WithPretty(),
) )
if err != nil { if err != nil {
log.Fatalf("Error getting response: %s", err) log.Printf("Error getting response: %s", err)
return nil
} }
defer res.Body.Close() defer res.Body.Close()
// Check response status // Check response status
if res.IsError() { if res.IsError() {
log.Fatalf("Error: %s", res.String()) log.Printf("Error: %s", res.String())
return nil
} }
// Deserialize the response into a map. // Deserialize the response into a map.
if err := json.NewDecoder(res.Body).Decode(&r); err != nil { if err := json.NewDecoder(res.Body).Decode(&r); err != nil {
log.Fatalf("Error parsing the response body: %s", err) log.Printf("Error parsing the response body: %s", err)
return nil
} }
// Print the response status, number of results, and request duration. // Print the response status, number of results, and request duration.

View File

@@ -3,8 +3,11 @@ package models
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"log" "log"
"strconv" "strconv"
"strings"
"time"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
) )
@@ -13,6 +16,8 @@ type MysqlConnection struct {
Database *sql.DB Database *sql.DB
} }
var connection *sql.DB
// 初始化数据库连接 // 初始化数据库连接
func (m *MysqlConnection) Init() (err error) { func (m *MysqlConnection) Init() (err error) {
user := Viper.Get("mysql.user").(string) user := Viper.Get("mysql.user").(string)
@@ -21,11 +26,12 @@ func (m *MysqlConnection) Init() (err error) {
port := Viper.Get("mysql.port").(int) port := Viper.Get("mysql.port").(int)
database := Viper.Get("mysql.database").(string) database := Viper.Get("mysql.database").(string)
sqlconf := user + ":" + password + "@tcp(" + host + ":" + strconv.Itoa(port) + ")/" + database + "?charset=utf8mb4&parseTime=True&loc=Local" sqlconf := user + ":" + password + "@tcp(" + host + ":" + strconv.Itoa(port) + ")/" + database + "?charset=utf8mb4&parseTime=True&loc=Local"
m.Database, err = sql.Open("mysql", sqlconf) // 连接数据库 connection, err = sql.Open("mysql", sqlconf) // 连接数据库
if err != nil { if err != nil {
log.Println("连接数据库失败", err) log.Println("连接数据库失败", err)
return return
} }
m.Database = connection
return return
} }
@@ -75,3 +81,61 @@ func (m *MysqlConnection) GetImages(page int, size int) (images []byte, err erro
images = append(images, []byte("]")...) images = append(images, []byte("]")...)
return return
} }
type User struct {
Id int `json:"id"`
UserName string `json:"user_name"`
Avatar string `json:"avatar"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
}
// 获取一组用户信息
func QueryUserList(id_list []int) (users []User) {
count := len(id_list)
if count == 0 {
return
}
idstr := strings.Trim(strings.Replace(fmt.Sprint(id_list), " ", ",", -1), "[]")
rows, err := connection.Query("SELECT id, user_name, avatar, update_time, create_time FROM web_member WHERE id IN (" + idstr + ") LIMIT " + strconv.Itoa(count))
if err != nil {
log.Println("获取用户列表失败", err)
return
}
defer rows.Close()
for rows.Next() {
var user User
rows.Scan(&user.Id, &user.UserName, &user.Avatar, &user.UpdateTime, &user.CreateTime)
users = append(users, user)
}
return
}
type Article struct {
Id int `json:"id"`
Title string `json:"title"`
Tags string `json:"tags"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
}
// 获取一组文章信息
func QueryArticleList(id_list []int) (articles []Article) {
count := len(id_list)
if count == 0 {
return
}
idstr := strings.Trim(strings.Replace(fmt.Sprint(id_list), " ", ",", -1), "[]")
rows, err := connection.Query("SELECT id, title, tags, update_time, create_time FROM web_article WHERE id IN (" + idstr + ")")
if err != nil {
log.Println("获取文章列表失败", err)
return
}
defer rows.Close()
for rows.Next() {
var article Article
rows.Scan(&article.Id, &article.Title, &article.Tags, &article.UpdateTime, &article.CreateTime)
articles = append(articles, article)
}
return
}