diff --git a/api/graphql.go b/api/graphql.go index 2b1651f..cb02e18 100644 --- a/api/graphql.go +++ b/api/graphql.go @@ -290,70 +290,34 @@ func NewSchema(config Config) (graphql.Schema, error) { var query strings.Builder query.WriteString(fmt.Sprintf("SELECT %s FROM web_images WHERE %s LIMIT %d OFFSET %s", fields_str, where_str, first, after)) - var images []Image + var images ImageList if err := connection.Select(&images, query.String()); err != nil { fmt.Println("获取图像列表失败", err) return nil, err } // 获取用户信息(如果图像列表不为空且请求字段中包含user) - if len(images) > 0 && strings.Contains(fields_str, "user_id") { - // 取到所有的用户ID, 去除重复 - user_ids := make(map[int]bool) - for _, image := range images { - user_ids[image.UserID] = true - } - // map 转换为数组 - uniqueIds := make([]int, 0, len(user_ids)) - for id := range user_ids { - uniqueIds = append(uniqueIds, id) - } - // 合并为以逗号分隔的字符串 - user_ids_str := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(uniqueIds)), ","), "[]") - // 查询用户信息 + if len(images) > 0 && strings.Contains(fields_str, "user") { + user_ids_str := images.ToAllUserID().ToString() var users []User if err := connection.Select(&users, fmt.Sprintf("SELECT id,user_name,avatar,rank,create_time,update_time FROM web_member WHERE id IN (%s)", user_ids_str)); err != nil { fmt.Println("获取用户列表失败", err) return nil, err } // 将用户信息与图像信息关联 - for i, image := range images { - for _, user := range users { - if image.UserID == user.ID { - images[i].User = user - } - } - } + images.SetUser(users) } // 获取文章信息(如果图像列表不为空且请求字段中包含article) - if len(images) > 0 && strings.Contains(fields_str, "article_id") { - // 取到所有的文章ID, 去除重复 - article_ids := make(map[int]bool) - for _, image := range images { - article_ids[image.ArticleID] = true - } - // map 转换为数组 - uniqueIds := make([]int, 0, len(article_ids)) - for id := range article_ids { - uniqueIds = append(uniqueIds, id) - } - // 合并为以逗号分隔的字符串 - article_ids_str := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(uniqueIds)), ","), "[]") - // 查询文章信息 + if len(images) > 0 && strings.Contains(fields_str, "article") { + article_ids_str := images.ToAllArticleID().ToString() var articles []Article if err := connection.Select(&articles, fmt.Sprintf("SELECT id,title,tags,create_time,update_time FROM web_article WHERE id IN (%s)", article_ids_str)); err != nil { fmt.Println("获取文章列表失败", err) return nil, err } // 将文章信息与图像信息关联 - for i, image := range images { - for _, article := range articles { - if image.ArticleID == article.ID { - images[i].Article = article - } - } - } + images.SetArticle(articles) } return map[string]interface{}{ diff --git a/api/struct.go b/api/struct.go index 0ac851e..f603b43 100644 --- a/api/struct.go +++ b/api/struct.go @@ -2,9 +2,66 @@ package api import ( "encoding/json" + "fmt" + "strings" "time" ) +type IDS []int + +// 合并为以逗号分隔的字符串 +func (ids IDS) ToString() (str string) { + return strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids)), ","), "[]") +} + +type ImageList []Image + +// 取到所有的文章ID, 去除重复 +func (images *ImageList) ToAllArticleID() (uniqueIds IDS) { + article_ids := make(map[int]bool) + for _, image := range *images { + article_ids[image.ArticleID] = true + } + for id := range article_ids { + uniqueIds = append(uniqueIds, id) + } + return uniqueIds +} + +// 取到所有的用户ID, 去除重复 +func (images *ImageList) ToAllUserID() (uniqueIds IDS) { + user_ids := make(map[int]bool) + for _, image := range *images { + user_ids[image.UserID] = true + } + for id := range user_ids { + uniqueIds = append(uniqueIds, id) + } + return uniqueIds +} + +// 为每个图像设置用户信息 +func (images *ImageList) SetUser(userList []User) { + for i, image := range *images { + for _, user := range userList { + if image.UserID == user.ID { + (*images)[i].User = user + } + } + } +} + +// 为每个图像设置文章信息 +func (images *ImageList) SetArticle(articleList []Article) { + for i, image := range *images { + for _, article := range articleList { + if image.ArticleID == article.ID { + (*images)[i].Article = article + } + } + } +} + type Image struct { ID int `json:"id" db:"id"` Width int `json:"width" db:"width"`