替换关联ID

This commit is contained in:
2024-08-01 02:13:02 +08:00
parent afa866e6fe
commit ccd773ddad

View File

@@ -91,6 +91,30 @@ func NewSchema(config Config) (graphql.Schema, error) {
}, },
}) })
// 将 list 中的字段提取出来用于查询
get_fields := func(requestedFields []ast.Selection) (fields []string) {
for _, field := range requestedFields {
fieldAST, ok := field.(*ast.Field)
if ok && fieldAST.Name.Value == "list" {
for _, field := range fieldAST.SelectionSet.Selections {
fieldAST, ok := field.(*ast.Field)
if ok {
if fieldAST.Name.Value == "user" {
fields = append(fields, "user_id")
continue
}
if fieldAST.Name.Value == "article" {
fields = append(fields, "article_id")
continue
}
fields = append(fields, fieldAST.Name.Value)
}
}
}
}
return fields
}
schema, err := graphql.NewSchema(graphql.SchemaConfig{Query: graphql.NewObject(graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{ schema, err := graphql.NewSchema(graphql.SchemaConfig{Query: graphql.NewObject(graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{
"users": &graphql.Field{ "users": &graphql.Field{
Name: "users", Name: "users",
@@ -117,29 +141,7 @@ func NewSchema(config Config) (graphql.Schema, error) {
"before": &graphql.ArgumentConfig{Type: graphql.String, Description: "翻页参数(傳回清單中指定遊標之前的元素)"}, "before": &graphql.ArgumentConfig{Type: graphql.String, Description: "翻页参数(傳回清單中指定遊標之前的元素)"},
}, },
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
var fields []string fields := strings.Join(get_fields(p.Info.FieldASTs[0].SelectionSet.Selections), ",")
requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections
for _, field := range requestedFields {
fieldAST, ok := field.(*ast.Field)
if ok {
switch fieldAST.Name.Value {
case "list":
for _, field := range fieldAST.SelectionSet.Selections {
fieldAST, ok := field.(*ast.Field)
if ok {
fields = append(fields, fieldAST.Name.Value)
}
}
case "next":
fmt.Println("next")
case "text":
fmt.Println("text")
default:
fmt.Println(fieldAST.Name.Value)
}
}
}
fields_str := strings.Join(fields, ",")
var where []string var where []string
if p.Args["id"] != nil { if p.Args["id"] != nil {
where = append(where, fmt.Sprintf("id=%d", p.Args["id"])) where = append(where, fmt.Sprintf("id=%d", p.Args["id"]))
@@ -157,7 +159,7 @@ func NewSchema(config Config) (graphql.Schema, error) {
var users []User var users []User
var total int var total int
query.WriteString(fmt.Sprintf("SELECT %s FROM web_member %s LIMIT %d OFFSET %d", fields_str, where_str, 10, 0)) query.WriteString(fmt.Sprintf("SELECT %s FROM web_member %s LIMIT %d OFFSET %d", fields, where_str, 10, 0))
if err := connection.Select(&users, query.String()); err != nil { if err := connection.Select(&users, query.String()); err != nil {
fmt.Println("获取用户列表失败", err) fmt.Println("获取用户列表失败", err)
return nil, err return nil, err
@@ -226,37 +228,6 @@ func NewSchema(config Config) (graphql.Schema, error) {
} }
mapstructure.Decode(p.Args, &args) mapstructure.Decode(p.Args, &args)
// 处理要求返回的字段
var fields []string
requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections
for _, field := range requestedFields {
fieldAST, ok := field.(*ast.Field)
if ok {
switch fieldAST.Name.Value {
case "list":
for _, field := range fieldAST.SelectionSet.Selections {
fieldAST, ok := field.(*ast.Field)
if ok {
switch fieldAST.Name.Value {
case "user":
fields = append(fields, "user_id")
case "article":
fields = append(fields, "article_id")
default:
fields = append(fields, fieldAST.Name.Value)
}
}
}
case "total":
fmt.Println("total")
default:
fmt.Println(fieldAST.Name.Value)
}
}
}
fields_str := strings.Join(fields, ",")
// 参数到 SQL 格式字符串的映射 // 参数到 SQL 格式字符串的映射
var argToSQLFormat = map[string]string{ var argToSQLFormat = map[string]string{
"id": "id=%d", "id": "id=%d",
@@ -361,8 +332,8 @@ func NewSchema(config Config) (graphql.Schema, error) {
// 执行查询 // 执行查询
var query strings.Builder var query strings.Builder
query.WriteString(fmt.Sprintf("SELECT %s FROM web_images %s", fields_str, where_str)) fields := strings.Join(get_fields(p.Info.FieldASTs[0].SelectionSet.Selections), ",")
log.Println("query:", query.String()) query.WriteString(fmt.Sprintf("SELECT %s FROM web_images %s", fields, where_str))
var images ImageList var images ImageList
if err := connection.Select(&images, query.String()); err != nil { if err := connection.Select(&images, query.String()); err != nil {
@@ -376,7 +347,7 @@ func NewSchema(config Config) (graphql.Schema, error) {
} }
// 获取用户信息(如果图像列表不为空且请求字段中包含user) // 获取用户信息(如果图像列表不为空且请求字段中包含user)
if len(images) > 0 && strings.Contains(fields_str, "user") { if len(images) > 0 && strings.Contains(fields, "user") {
user_ids_str := images.ToAllUserID().ToString() user_ids_str := images.ToAllUserID().ToString()
var users []User var users []User
if err := connection.Select(&users, fmt.Sprintf("SELECT id,user_name,avatar,rank,create_time,update_time FROM web_member WHERE id IN (%s)", user_ids_str)); err != nil { if err := connection.Select(&users, fmt.Sprintf("SELECT id,user_name,avatar,rank,create_time,update_time FROM web_member WHERE id IN (%s)", user_ids_str)); err != nil {
@@ -388,7 +359,7 @@ func NewSchema(config Config) (graphql.Schema, error) {
} }
// 获取文章信息(如果图像列表不为空且请求字段中包含article) // 获取文章信息(如果图像列表不为空且请求字段中包含article)
if len(images) > 0 && strings.Contains(fields_str, "article") { if len(images) > 0 && strings.Contains(fields, "article") {
article_ids_str := images.ToAllArticleID().ToString() article_ids_str := images.ToAllArticleID().ToString()
var articles []Article var articles []Article
if err := connection.Select(&articles, fmt.Sprintf("SELECT id,title,tags,create_time,update_time FROM web_article WHERE id IN (%s)", article_ids_str)); err != nil { if err := connection.Select(&articles, fmt.Sprintf("SELECT id,title,tags,create_time,update_time FROM web_article WHERE id IN (%s)", article_ids_str)); err != nil {
@@ -424,29 +395,9 @@ func NewSchema(config Config) (graphql.Schema, error) {
"update_time": &graphql.ArgumentConfig{Type: graphql.DateTime, Description: "筛选文章中更新时间等于指定值的"}, "update_time": &graphql.ArgumentConfig{Type: graphql.DateTime, Description: "筛选文章中更新时间等于指定值的"},
}, },
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
var fields []string first := 10 // p.Args["first"]
requestedFields := p.Info.FieldASTs[0].SelectionSet.Selections after := 0 // p.Args["after"]
for _, field := range requestedFields { fields := strings.Join(get_fields(p.Info.FieldASTs[0].SelectionSet.Selections), ",")
fieldAST, ok := field.(*ast.Field)
if ok {
switch fieldAST.Name.Value {
case "list":
for _, field := range fieldAST.SelectionSet.Selections {
fieldAST, ok := field.(*ast.Field)
if ok {
fields = append(fields, fieldAST.Name.Value)
}
}
case "total":
fmt.Println("total")
default:
fmt.Println(fieldAST.Name.Value)
}
}
}
first := p.Args["first"]
after := p.Args["after"]
fields_str := strings.Join(fields, ",")
var where []string var where []string
if p.Args["id"] != nil { if p.Args["id"] != nil {
where = append(where, fmt.Sprintf("id=%d", p.Args["id"])) where = append(where, fmt.Sprintf("id=%d", p.Args["id"]))
@@ -461,7 +412,7 @@ func NewSchema(config Config) (graphql.Schema, error) {
} }
var query strings.Builder var query strings.Builder
query.WriteString(fmt.Sprintf("SELECT %s FROM web_article %s LIMIT %d OFFSET %s", fields_str, where_str, first, after)) query.WriteString(fmt.Sprintf("SELECT %s FROM web_article %s LIMIT %d OFFSET %d", fields, where_str, first, after))
// 返回翻页信息 // 返回翻页信息
var articles []Article var articles []Article
if err := connection.Select(&articles, query.String()); err != nil { if err := connection.Select(&articles, query.String()); err != nil {