From ccd773ddad863cdbf0030e5dfad2a43e3e7d22ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A7=89?= Date: Thu, 1 Aug 2024 02:13:02 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=BF=E6=8D=A2=E5=85=B3=E8=81=94ID?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/graphql.go | 117 ++++++++++++++----------------------------------- 1 file changed, 34 insertions(+), 83 deletions(-) diff --git a/api/graphql.go b/api/graphql.go index c270205..c50211e 100644 --- a/api/graphql.go +++ b/api/graphql.go @@ -91,6 +91,30 @@ func NewSchema(config Config) (graphql.Schema, error) { }, }) + // 将 list 中的字段提取出来用于查询 + get_fields := func(requestedFields []ast.Selection) (fields []string) { + for _, field := range requestedFields { + fieldAST, ok := field.(*ast.Field) + if ok && fieldAST.Name.Value == "list" { + for _, field := range fieldAST.SelectionSet.Selections { + fieldAST, ok := field.(*ast.Field) + if ok { + if fieldAST.Name.Value == "user" { + fields = append(fields, "user_id") + continue + } + if fieldAST.Name.Value == "article" { + fields = append(fields, "article_id") + continue + } + fields = append(fields, fieldAST.Name.Value) + } + } + } + } + return fields + } + schema, err := graphql.NewSchema(graphql.SchemaConfig{Query: graphql.NewObject(graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{ "users": &graphql.Field{ Name: "users", @@ -117,29 +141,7 @@ func NewSchema(config Config) (graphql.Schema, error) { "before": &graphql.ArgumentConfig{Type: graphql.String, Description: "翻页参数(傳回清單中指定遊標之前的元素)"}, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - var fields []string - requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections - for _, field := range requestedFields { - fieldAST, ok := field.(*ast.Field) - if ok { - switch fieldAST.Name.Value { - case "list": - for _, field := range fieldAST.SelectionSet.Selections { - fieldAST, ok := field.(*ast.Field) - if ok { - fields = append(fields, fieldAST.Name.Value) - } - } - case "next": - fmt.Println("next") - case "text": - fmt.Println("text") - default: - fmt.Println(fieldAST.Name.Value) - } - } - } - fields_str := strings.Join(fields, ",") + fields := strings.Join(get_fields(p.Info.FieldASTs[0].SelectionSet.Selections), ",") var where []string if p.Args["id"] != nil { where = append(where, fmt.Sprintf("id=%d", p.Args["id"])) @@ -157,7 +159,7 @@ func NewSchema(config Config) (graphql.Schema, error) { var users []User var total int - query.WriteString(fmt.Sprintf("SELECT %s FROM web_member %s LIMIT %d OFFSET %d", fields_str, where_str, 10, 0)) + query.WriteString(fmt.Sprintf("SELECT %s FROM web_member %s LIMIT %d OFFSET %d", fields, where_str, 10, 0)) if err := connection.Select(&users, query.String()); err != nil { fmt.Println("获取用户列表失败", err) return nil, err @@ -226,37 +228,6 @@ func NewSchema(config Config) (graphql.Schema, error) { } mapstructure.Decode(p.Args, &args) - // 处理要求返回的字段 - var fields []string - requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections - for _, field := range requestedFields { - fieldAST, ok := field.(*ast.Field) - if ok { - switch fieldAST.Name.Value { - case "list": - for _, field := range fieldAST.SelectionSet.Selections { - fieldAST, ok := field.(*ast.Field) - if ok { - switch fieldAST.Name.Value { - case "user": - fields = append(fields, "user_id") - case "article": - fields = append(fields, "article_id") - default: - fields = append(fields, fieldAST.Name.Value) - } - } - } - case "total": - fmt.Println("total") - default: - fmt.Println(fieldAST.Name.Value) - } - } - } - - fields_str := strings.Join(fields, ",") - // 参数到 SQL 格式字符串的映射 var argToSQLFormat = map[string]string{ "id": "id=%d", @@ -361,8 +332,8 @@ func NewSchema(config Config) (graphql.Schema, error) { // 执行查询 var query strings.Builder - query.WriteString(fmt.Sprintf("SELECT %s FROM web_images %s", fields_str, where_str)) - log.Println("query:", query.String()) + fields := strings.Join(get_fields(p.Info.FieldASTs[0].SelectionSet.Selections), ",") + query.WriteString(fmt.Sprintf("SELECT %s FROM web_images %s", fields, where_str)) var images ImageList if err := connection.Select(&images, query.String()); err != nil { @@ -376,7 +347,7 @@ func NewSchema(config Config) (graphql.Schema, error) { } // 获取用户信息(如果图像列表不为空且请求字段中包含user) - if len(images) > 0 && strings.Contains(fields_str, "user") { + if len(images) > 0 && strings.Contains(fields, "user") { user_ids_str := images.ToAllUserID().ToString() var users []User if err := connection.Select(&users, fmt.Sprintf("SELECT id,user_name,avatar,rank,create_time,update_time FROM web_member WHERE id IN (%s)", user_ids_str)); err != nil { @@ -388,7 +359,7 @@ func NewSchema(config Config) (graphql.Schema, error) { } // 获取文章信息(如果图像列表不为空且请求字段中包含article) - if len(images) > 0 && strings.Contains(fields_str, "article") { + if len(images) > 0 && strings.Contains(fields, "article") { article_ids_str := images.ToAllArticleID().ToString() var articles []Article if err := connection.Select(&articles, fmt.Sprintf("SELECT id,title,tags,create_time,update_time FROM web_article WHERE id IN (%s)", article_ids_str)); err != nil { @@ -424,29 +395,9 @@ func NewSchema(config Config) (graphql.Schema, error) { "update_time": &graphql.ArgumentConfig{Type: graphql.DateTime, Description: "筛选文章中更新时间等于指定值的"}, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - var fields []string - requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections - for _, field := range requestedFields { - fieldAST, ok := field.(*ast.Field) - if ok { - switch fieldAST.Name.Value { - case "list": - for _, field := range fieldAST.SelectionSet.Selections { - fieldAST, ok := field.(*ast.Field) - if ok { - fields = append(fields, fieldAST.Name.Value) - } - } - case "total": - fmt.Println("total") - default: - fmt.Println(fieldAST.Name.Value) - } - } - } - first := p.Args["first"] - after := p.Args["after"] - fields_str := strings.Join(fields, ",") + first := 10 // p.Args["first"] + after := 0 // p.Args["after"] + fields := strings.Join(get_fields(p.Info.FieldASTs[0].SelectionSet.Selections), ",") var where []string if p.Args["id"] != nil { where = append(where, fmt.Sprintf("id=%d", p.Args["id"])) @@ -461,7 +412,7 @@ func NewSchema(config Config) (graphql.Schema, error) { } var query strings.Builder - query.WriteString(fmt.Sprintf("SELECT %s FROM web_article %s LIMIT %d OFFSET %s", fields_str, where_str, first, after)) + query.WriteString(fmt.Sprintf("SELECT %s FROM web_article %s LIMIT %d OFFSET %d", fields, where_str, first, after)) // 返回翻页信息 var articles []Article if err := connection.Select(&articles, query.String()); err != nil {