diff --git a/api/graphql.go b/api/graphql.go index 4eb0eb9..2ebaf62 100644 --- a/api/graphql.go +++ b/api/graphql.go @@ -83,6 +83,16 @@ func generateGraphQLType(model interface{}) (*graphql.Object, error) { }), nil } +// 判断指定字段是否存在 +func existField(selections []ast.Selection, name string) bool { + for _, field := range selections { + if f, ok := field.(*ast.Field); ok && f.Name.Value == name { + return true + } + } + return false +} + func NewSchema(config Config) (graphql.Schema, error) { if db, err = gorm.Open(mysql.Open(fmt.Sprintf( @@ -571,11 +581,9 @@ func NewSchema(config Config) (graphql.Schema, error) { } // 如果查询了 total 字段 - if p.Info.FieldASTs[0].SelectionSet.Selections[1] != nil { + if existField(p.Info.FieldASTs[0].SelectionSet.Selections, "total") { sql, _, _ := query.ToSQL() - fmt.Println(sql) sql = strings.Replace(sql, "SELECT *", "SELECT COUNT(*)", 1) - fmt.Println(sql) if err := db.Raw(sql).Scan(&total).Error; err != nil { return nil, err } @@ -1181,15 +1189,7 @@ func NewSchema(config Config) (graphql.Schema, error) { } } - existField := func(selections []ast.Selection, name string) bool { - for _, field := range selections { - if f, ok := field.(*ast.Field); ok && f.Name.Value == name { - return true - } - } - return false - } - + // 如果查询了 total 字段 if existField(p.Info.FieldASTs[0].SelectionSet.Selections, "total") { sql, _, _ := query.ToSQL() sql = strings.Replace(sql, "SELECT *", "SELECT COUNT(*)", 1)