diff --git a/Makefile b/Makefile index d91e6c5..e0480e6 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,9 @@ dev: go run github.com/air-verse/air@latest --build.cmd "go build -o ./data/ bin/main.go" --build.bin "./data/main"; \ wait +link: + ssh -NCPf main -L 3306:localhost:3306 -L 19530:localhost:19530 + # 编译项目 build: go mod tidy @@ -23,7 +26,7 @@ update: build ssh $host "systemctl restart webp" ssh $host "rm ~/webp/main_old" -# 设为系统服务和日志轮转 +# 设为系统服务和设置日志轮转 service: sudo cp webp.service /etc/systemd/system/webp.service sudo systemctl enable webp @@ -37,3 +40,7 @@ gorse: export GORSE_DASHBOARD_USER="gorse" export GORSE_DASHBOARD_PASS="gorse" curl -fsSL https://gorse.io/playground | bash + +# 安装搜图服务 python embedding +reverse: + git clone \ No newline at end of file diff --git a/README.md b/README.md index 3dd8e2e..8ceb8f6 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # webp +- [x] 相似图像推荐(迁移) +- [ ] 以图搜图(迁移) +- [ ] 标签筛选(补充筛选条件) + - [ ] 原始图像 - [ ] 缩略缓存 diff --git a/api/graphql.go b/api/graphql.go index 0229d57..2ecebc6 100644 --- a/api/graphql.go +++ b/api/graphql.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "reflect" + "strconv" "strings" "git.satori.love/gameui/webp/models" @@ -11,6 +12,7 @@ import ( "github.com/graphql-go/graphql/language/ast" "github.com/jmoiron/sqlx" "github.com/mitchellh/mapstructure" + "github.com/thoas/go-funk" "gorm.io/driver/mysql" "gorm.io/gorm" ) @@ -97,12 +99,13 @@ func NewSchema(config Config) (graphql.Schema, error) { Name: "Article", Description: "文章", Fields: graphql.Fields{ - "id": &graphql.Field{Type: graphql.Int, Description: "文章ID"}, - "title": &graphql.Field{Type: graphql.String, Description: "文章标题"}, - "tags": &graphql.Field{Type: graphql.String, Description: "文章标签"}, - "user": &graphql.Field{Type: user, Description: "文章所属用户"}, - "create_time": &graphql.Field{Type: graphql.DateTime, Description: "文章创建时间"}, - "update_time": &graphql.Field{Type: graphql.DateTime, Description: "文章更新时间"}, + "id": &graphql.Field{Type: graphql.Int, Description: "ID"}, + "title": &graphql.Field{Type: graphql.String, Description: "标题"}, + "orientation": &graphql.Field{Type: graphql.String, Description: "方向"}, + "tags": &graphql.Field{Type: graphql.String, Description: "标签"}, + "user": &graphql.Field{Type: user, Description: "所属用户"}, + "create_time": &graphql.Field{Type: graphql.DateTime, Description: "创建时间"}, + "update_time": &graphql.Field{Type: graphql.DateTime, Description: "更新时间"}, }, }) @@ -167,7 +170,6 @@ func NewSchema(config Config) (graphql.Schema, error) { if p.Args["text"] != nil { var texts TextList for _, text := range p.Source.(Image).Text { - fmt.Println("san", text.Text) if strings.Contains(text.Text, p.Args["text"].(string)) { texts = append(texts, text) } @@ -193,18 +195,6 @@ func NewSchema(config Config) (graphql.Schema, error) { for _, field := range fieldAST.SelectionSet.Selections { fieldAST, ok := field.(*ast.Field) if ok { - if fieldAST.Name.Value == "user" { - fields = append(fields, "user_id") - continue - } - if fieldAST.Name.Value == "article" { - fields = append(fields, "article_id") - continue - } - if fieldAST.Name.Value == "similars" { - // 跳过自定义字段 - continue - } fields = append(fields, fieldAST.Name.Value) } } @@ -344,7 +334,7 @@ func NewSchema(config Config) (graphql.Schema, error) { }, }), Args: graphql.FieldConfigArgument{ - "preference": &graphql.ArgumentConfig{Type: graphql.String, Description: "使用浏览记录获取的偏好推荐图像"}, + "interest": &graphql.ArgumentConfig{Type: graphql.Int, Description: "获取指定ID用户的兴趣推荐图像"}, "similar": &graphql.ArgumentConfig{Type: graphql.Int, Description: "获取与指定ID图像相似的图像"}, "id": &graphql.ArgumentConfig{Type: graphql.Int, Description: "获取指定ID的图像"}, "width": &graphql.ArgumentConfig{Type: graphql.Int, Description: "筛选图像中指定宽度的"}, @@ -364,72 +354,49 @@ func NewSchema(config Config) (graphql.Schema, error) { "update_time": &graphql.ArgumentConfig{Type: graphql.DateTime, Description: "筛选图像中更新时间等于指定值的"}, "first": &graphql.ArgumentConfig{Type: graphql.Int, Description: "翻页参数(傳回清單中的前n個元素)"}, "last": &graphql.ArgumentConfig{Type: graphql.Int, Description: "翻页参数(傳回清單中的最後n個元素)"}, - "after": &graphql.ArgumentConfig{Type: graphql.String, Description: "翻页参数(傳回清單中指定遊標之後的元素)"}, - "before": &graphql.ArgumentConfig{Type: graphql.String, Description: "翻页参数(傳回清單中指定遊標之前的元素)"}, + "after": &graphql.ArgumentConfig{Type: graphql.Int, Description: "翻页参数(傳回清單中指定遊標之後的元素)"}, + "before": &graphql.ArgumentConfig{Type: graphql.Int, Description: "翻页参数(傳回清單中指定遊標之前的元素)"}, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { // 定义参数结构体 var args struct { - First int - Last int - After string - Before string - Text string - Preference string - Similar int + First int + Last int + After int + Before int + Text string + Interest int + Similar int } mapstructure.Decode(p.Args, &args) - // 参数到 SQL 格式字符串的映射 - var argToSQLFormat = map[string]string{ - "id": "id=%d", - "width": "width=%d", - "height": "height=%d", - "content": "content='%s'", - "remark": "remark='%s'", - "description": "description='%s'", - "tags": "tags='%s'", - "rank": "rank='%s'", - "comment_num": "comment_num=%d", - "praise_count": "praise_count=%d", - "collect_count": "collect_count=%d", - "article_id": "article_id=%d", - "user_id": "user_id=%d", - "create_time": "create_time='%s'", - "update_time": "update_time='%s'", + // 限制长度防止全表扫描 + var limit = 10 + if args.First != 0 { + limit = args.First + } else if args.Last != 0 { + limit = args.Last } + var total int + var images []Image + var fields = get_fields(p.Info.FieldASTs[0].SelectionSet.Selections) + var query = db.Limit(limit) + + // 参数映射 + var argFormat = []string{"id", "width", "height", "content", "remark", "description", "tags", "rank", "comment_num", "praise_count", "collect_count", "article_id", "user_id", "create_time", "update_time"} + // 筛选条件 - var where []string - var order []string - for arg, format := range argToSQLFormat { - if p.Args[arg] != nil { - where = append(where, fmt.Sprintf(format, p.Args[arg])) + for _, format := range argFormat { + if p.Args[format] != nil { + query = query.Where(fmt.Sprintf(format, " = ?"), p.Args[format]) } } - var id_list []string + var list []int + var id_list [][]int - // 特殊处理 preference 参数 - if args.Preference != "" { - // 去除空格并拆分以逗号分割的ID - id_list = strings.Split(strings.ReplaceAll(args.Preference, " ", ""), ",") - // 使用这一组 id 推荐 - fmt.Println("preference:", args.Preference) - } - - // 特殊处理 similar 参数 - if args.Similar != 0 { - id_list := models.GetSimilarImagesIdList(args.Similar, 200) - ids_str := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(id_list)), ","), "[]") - if ids_str == "" { - return map[string]interface{}{"list": []Image{}, "total": 0}, nil - } - where = append(where, fmt.Sprintf("id IN (%s)", ids_str)) - order = append(order, fmt.Sprintf("ORDER BY FIELD(id,%s)", ids_str)) - } - - // 特殊处理 text 参数 + // 筛选:提取文字 if args.Text != "" { resp, err := models.ZincSearch(map[string]interface{}{ "query": map[string]interface{}{ @@ -447,84 +414,123 @@ func NewSchema(config Config) (graphql.Schema, error) { "from": 0, "size": 200, }) + if err != nil { fmt.Println("ZincSearch 获取图像列表失败", err) return nil, err } - id_list = resp.ToIDList(args.First, args.Last, args.After, args.Before) - id_list_str := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(id_list)), ","), "[]") - if id_list_str == "" { - return map[string]interface{}{ - "list": []Image{}, - "total": 0, - }, nil + + var item []int + + for _, hit := range resp.Hits.Hits { + num, _ := strconv.Atoi(hit.ID) + item = append(item, num) + } + + id_list = append(id_list, item) + + if len(id_list) == 0 { + return map[string]interface{}{"list": []Image{}, "total": 0}, nil } - where = append(where, fmt.Sprintf("id IN (%s)", id_list_str)) } - where_str := strings.Join(where, " AND ") - order_str := strings.Join(order, "") + // 筛选:相似图像 + if args.Similar != 0 { + var item []int + for _, id := range models.GetSimilarImagesIdList(args.Similar, 200) { + item = append(item, int(id)) + } + id_list = append(id_list, item) - if where_str != "" { - where_str = "WHERE " + where_str + if len(id_list) == 0 { + return map[string]interface{}{"list": []Image{}, "total": 0}, nil + } } - // 处理翻页参数 - var limit, offset int - if args.First == 0 && args.Last == 0 { - limit = 10 - offset = 0 + // 筛选:兴趣推荐 + if args.Interest != 0 { + fmt.Println("Interest:", args.Interest) + } + + // 排序 + + // 截取:取交集 + if len(id_list) != 0 { + list = id_list[0] + if len(id_list) > 1 { + for _, slice := range id_list[1:] { + list = funk.Join(list, slice, funk.InnerJoin).([]int) + } + } + if len(list) == 0 { + return map[string]interface{}{"list": []Image{}, "total": 0}, nil + } + } + + total = len(list) + + // 截取: 分页 + if args.After != 0 { + index := -1 + for i, id := range list { + if id == args.After { + index = i + break + } + } + if index != -1 { + list = list[index+1:] + } + } + + if args.Before != 0 { + index := -1 + for i, id := range list { + if id == args.Before { + index = i + break + } + } + if index != -1 { + list = list[:index] + } } if args.First != 0 { - limit = args.First - offset = 0 + list = list[:args.First] } if args.Last != 0 { - limit = args.Last - offset = len(id_list) - limit + list = list[len(list)-args.Last:] } - // 执行查询 - var query strings.Builder - fields := strings.Join(get_fields(p.Info.FieldASTs[0].SelectionSet.Selections), ",") - query.WriteString(fmt.Sprintf("SELECT %s FROM web_images %s %s LIMIT %d OFFSET %d", fields, where_str, order_str, limit, offset)) + if args.First == 0 && args.Last == 0 { + list = list[:10] + } - var images ImageList - var q = query.String() - if err := connection.Select(&images, q); err != nil { + // 存在外部筛选条件 + if len(id_list) > 0 && len(list) > 0 { + query = query.Where("id IN ?", list) + } + + // 输出 + if funk.Contains(fields, "user") { + query = query.Preload("User") + fmt.Println("加载 user") + } + + if funk.Contains(fields, "article") { + query = query.Preload("Article") + } + + if err := query.Find(&images).Error; err != nil { fmt.Println("获取图像列表失败", err) return nil, err } - // 获取用户信息(如果图像列表不为空且请求字段中包含user) - if len(images) > 0 && strings.Contains(fields, "user") { - user_ids_str := images.ToAllUserID().ToString() - var users []User - if err := connection.Select(&users, fmt.Sprintf("SELECT id,user_name,avatar,rank,create_time,update_time FROM web_member WHERE id IN (%s)", user_ids_str)); err != nil { - fmt.Println("获取用户列表失败", err) - return nil, err - } - // 将用户信息与图像信息关联 - images.SetUser(users) - } - - // 获取文章信息(如果图像列表不为空且请求字段中包含article) - if len(images) > 0 && strings.Contains(fields, "article") { - article_ids_str := images.ToAllArticleID().ToString() - var articles []Article - if err := connection.Select(&articles, fmt.Sprintf("SELECT id,title,tags,create_time,update_time FROM web_article WHERE id IN (%s)", article_ids_str)); err != nil { - fmt.Println("获取文章列表失败", err) - return nil, err - } - // 将文章信息与图像信息关联 - images.SetArticle(articles) - } - return map[string]interface{}{ "list": images, - "total": 0, + "total": total, }, nil }, }, diff --git a/api/struct.go b/api/struct.go index 5c57328..7195ecc 100644 --- a/api/struct.go +++ b/api/struct.go @@ -55,30 +55,8 @@ func (images *ImageList) ToAllUserID() (uniqueIds IDS) { return uniqueIds } -// 为每个图像设置用户信息 -func (images *ImageList) SetUser(userList []User) { - for i, image := range *images { - for _, user := range userList { - if image.UserID == user.ID { - (*images)[i].User = user - } - } - } -} - -// 为每个图像设置文章信息 -func (images *ImageList) SetArticle(articleList []Article) { - for i, image := range *images { - for _, article := range articleList { - if image.ArticleID == article.ID { - (*images)[i].Article = article - } - } - } -} - type Image struct { - ID int `json:"id" db:"id"` + ID int `json:"id" db:"id" gorm:"primaryKey"` Width int `json:"width" db:"width"` Height int `json:"height" db:"height"` Content string `json:"content" db:"content"` @@ -93,9 +71,13 @@ type Image struct { UserID int `json:"user_id" db:"user_id"` CreateTime time.Time `json:"create_time" db:"create_time"` UpdateTime time.Time `json:"update_time" db:"update_time"` - Text TextList `json:"text" db:"text"` - User User `json:"user" db:"-"` - Article Article `json:"article" db:"-"` + Text TextList `json:"text" db:"text" gorm:"type:json"` + User User `json:"user" gorm:"foreignKey:UserID"` + Article Article `json:"article" gorm:"foreignKey:ArticleID"` +} + +func (Image) TableName() string { + return "web_images" } type TextList []struct { @@ -109,7 +91,7 @@ func (a *TextList) Scan(value interface{}) error { } type User struct { - ID int `json:"id" db:"id"` + ID int `json:"id" db:"id" gorm:"primaryKey"` UserName *string `json:"user_name" db:"user_name"` Avatar *string `json:"avatar" db:"avatar"` Rank *string `json:"rank" db:"rank"` @@ -117,12 +99,21 @@ type User struct { UpdateTime time.Time `json:"update_time" db:"update_time"` } +func (User) TableName() string { + return "web_member" +} + type Article struct { - ID int `json:"id" db:"id"` - Title string `json:"title" db:"title"` - Tags string `json:"tags" db:"tags"` - CreateTime time.Time `json:"create_time" db:"create_time"` - UpdateTime time.Time `json:"update_time" db:"update_time"` + ID int `json:"id" db:"id" gorm:"primaryKey"` + Title string `json:"title" db:"title"` + Orientation string `json:"orientation" db:"orientation"` + Tags string `json:"tags" db:"tags"` + CreateTime time.Time `json:"create_time" db:"create_time"` + UpdateTime time.Time `json:"update_time" db:"update_time"` +} + +func (Article) TableName() string { + return "web_article" } type Category struct { diff --git a/go.mod b/go.mod index f3c15fa..3d876bf 100644 --- a/go.mod +++ b/go.mod @@ -51,7 +51,9 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.4.2 // indirect + github.com/thoas/go-funk v0.9.3 // indirect github.com/tjfoc/gmsm v1.3.2 // indirect + github.com/yalue/onnxruntime_go v1.12.1 // indirect golang.org/x/net v0.28.0 // indirect golang.org/x/sys v0.25.0 // indirect golang.org/x/text v0.18.0 // indirect diff --git a/go.sum b/go.sum index 40b4529..6002e65 100644 --- a/go.sum +++ b/go.sum @@ -167,8 +167,12 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8= github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= +github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw= +github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/tjfoc/gmsm v1.3.2 h1:7JVkAn5bvUJ7HtU08iW6UiD+UTmJTIToHCfeFzkcCxM= github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w= +github.com/yalue/onnxruntime_go v1.12.1 h1:joCCmBnNjHy04jK9EMP/UV6oPPqySXlRgf3gcUcnI/g= +github.com/yalue/onnxruntime_go v1.12.1/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.30/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/models/zincsearch.go b/models/zincsearch.go index 2cb19db..901b3e0 100644 --- a/models/zincsearch.go +++ b/models/zincsearch.go @@ -9,6 +9,7 @@ import ( "log" "net/http" "net/url" + "strconv" ) var ( @@ -63,13 +64,14 @@ type Response struct { } `json:"hits"` } -func (res Response) ToIDList(first, last int, after, before string) (id_list []string) { +func (res Response) ToIDList(first, last int, after, before int) (id_list []int) { for _, hit := range res.Hits.Hits { - id_list = append(id_list, hit.ID) + num, _ := strconv.Atoi(hit.ID) + id_list = append(id_list, num) } // 如果 after 不为 0, 从这个ID开始向后取切片 - if after != "" { + if after != 0 { for i, id := range id_list { if id == after { id_list = id_list[i+1:] @@ -79,7 +81,7 @@ func (res Response) ToIDList(first, last int, after, before string) (id_list []s } // 如果 before 不为 0, 从这个ID开始向前取切片 - if before != "" { + if before != 0 { for i, id := range id_list { if id == before { id_list = id_list[:i]