diff --git a/api/graphql.go b/api/graphql.go index 7d0fb6a..7911d8f 100644 --- a/api/graphql.go +++ b/api/graphql.go @@ -17,6 +17,30 @@ import ( "gorm.io/gorm" ) +var db *gorm.DB +var err error + +func LoadItem(requestedFields []ast.Selection) (data []string) { + var items = []string{"user", "article"} + for _, field := range requestedFields { + fieldAST, _ := field.(*ast.Field) + if funk.Contains(items, fieldAST.Name.Value) { + data = append(data, fieldAST.Name.Value) + for _, str := range LoadItem(fieldAST.SelectionSet.Selections) { + str = strings.ToUpper(string(str[0])) + str[1:] + data = append(data, fieldAST.Name.Value+"."+str) + } + } + if fieldAST.Name.Value == "list" { + for _, str := range LoadItem(fieldAST.SelectionSet.Selections) { + str = strings.ToUpper(string(str[0])) + str[1:] + data = append(data, str) + } + } + } + return data +} + // 自动生成 GraphQL 类型的函数 func generateGraphQLType(model interface{}) (*graphql.Object, error) { modelType := reflect.TypeOf(model) @@ -52,7 +76,7 @@ func generateGraphQLType(model interface{}) (*graphql.Object, error) { func NewSchema(config Config) (graphql.Schema, error) { - db, err := gorm.Open(mysql.Open(fmt.Sprintf( + db, err = gorm.Open(mysql.Open(fmt.Sprintf( "%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.Mysql.UserName, config.Mysql.Password, @@ -88,7 +112,7 @@ func NewSchema(config Config) (graphql.Schema, error) { "user_name": &graphql.Field{Type: graphql.String, Description: "用户名"}, "avatar": &graphql.Field{Type: graphql.String, Description: "用户头像"}, "rank": &graphql.Field{Type: graphql.String, Description: "用户等级"}, - "price": &graphql.Field{Type: graphql.Float, Description: "用户金币"}, + "price": &graphql.Field{Type: graphql.Int, Description: "用户金币"}, "create_time": &graphql.Field{Type: graphql.DateTime, Description: "用户创建时间"}, "update_time": &graphql.Field{Type: graphql.DateTime, Description: "用户更新时间"}, }, @@ -190,13 +214,11 @@ func NewSchema(config Config) (graphql.Schema, error) { // 将 list 中的字段提取出来用于查询 get_fields := func(requestedFields []ast.Selection) (fields []string) { for _, field := range requestedFields { - fieldAST, ok := field.(*ast.Field) - if ok && fieldAST.Name.Value == "list" { + fieldAST, _ := field.(*ast.Field) + if fieldAST.Name.Value == "list" { for _, field := range fieldAST.SelectionSet.Selections { - fieldAST, ok := field.(*ast.Field) - if ok { - fields = append(fields, fieldAST.Name.Value) - } + fieldAST, _ := field.(*ast.Field) + fields = append(fields, fieldAST.Name.Value) } } } @@ -380,7 +402,6 @@ func NewSchema(config Config) (graphql.Schema, error) { var total int var images []Image - var fields = get_fields(p.Info.FieldASTs[0].SelectionSet.Selections) var query = db.Limit(limit) // 参数映射 @@ -518,16 +539,12 @@ func NewSchema(config Config) (graphql.Schema, error) { query = query.Where("id IN ?", list) } + for index, item := range LoadItem(p.Info.FieldASTs[0].SelectionSet.Selections) { + fmt.Println(index, item) + query = query.Preload(item) + } + // 输出 - if funk.Contains(fields, "user") { - query = query.Preload("User") - fmt.Println("加载 user") - } - - if funk.Contains(fields, "article") { - query = query.Preload("Article") - } - if err := query.Find(&images).Error; err != nil { fmt.Println("获取图像列表失败", err) return nil, err diff --git a/api/struct.go b/api/struct.go index 7195ecc..9bdb8a5 100644 --- a/api/struct.go +++ b/api/struct.go @@ -92,9 +92,10 @@ func (a *TextList) Scan(value interface{}) error { type User struct { ID int `json:"id" db:"id" gorm:"primaryKey"` - UserName *string `json:"user_name" db:"user_name"` - Avatar *string `json:"avatar" db:"avatar"` - Rank *string `json:"rank" db:"rank"` + UserName string `json:"user_name" db:"user_name"` + Avatar string `json:"avatar" db:"avatar"` + Rank string `json:"rank" db:"rank"` + Price int `json:"price" db:"price"` CreateTime time.Time `json:"create_time" db:"create_time"` UpdateTime time.Time `json:"update_time" db:"update_time"` } @@ -108,6 +109,8 @@ type Article struct { Title string `json:"title" db:"title"` Orientation string `json:"orientation" db:"orientation"` Tags string `json:"tags" db:"tags"` + UserId int `json:"user_id" db:"user_id"` + User User `json:"user" gorm:"foreignKey:UserId"` CreateTime time.Time `json:"create_time" db:"create_time"` UpdateTime time.Time `json:"update_time" db:"update_time"` }