diff --git a/api/graphql.go b/api/graphql.go index dbcb7df..858800f 100644 --- a/api/graphql.go +++ b/api/graphql.go @@ -173,6 +173,22 @@ func NewSchema(config Config) (graphql.Schema, error) { "create_time": &graphql.Field{Type: graphql.DateTime, Description: "图像创建时间"}, "update_time": &graphql.Field{Type: graphql.DateTime, Description: "图像更新时间"}, "article": &graphql.Field{Type: article, Description: "图像所属文章"}, + "demo": &graphql.Field{Type: graphql.String, Description: "demo", Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return "ok", nil + }}, + "praise": &graphql.Field{Type: graphql.Boolean, Description: "当前用户是否点赞", Resolve: func(p graphql.ResolveParams) (interface{}, error) { + var user_id = p.Context.Value("user_id").(int) + if user_id != 0 { + var praise int64 + if err := db.Table("web_praise").Where("user_id = ?", user_id).Where("image_id = ?", p.Source.(Image).ID).Count(&praise); err != nil { + return false, nil + } + if praise > 0 { + return true, nil + } + } + return false, nil + }}, "text": &graphql.Field{ Type: graphql.NewList(graphql.NewObject(graphql.ObjectConfig{ Name: "Text", @@ -409,13 +425,9 @@ func NewSchema(config Config) (graphql.Schema, error) { Orientation string Sort string Order string - CreateTime string `json:"create_time"` - UpdateTime string } mapstructure.Decode(p.Args, &args) - fmt.Println("args.CreateTime:", args.CreateTime) - // 限制长度防止全表扫描 var limit = 10 if args.First != 0 { diff --git a/bin/main.go b/bin/main.go index d3225ee..99e1f07 100644 --- a/bin/main.go +++ b/bin/main.go @@ -63,7 +63,29 @@ func LogComponent(startTime int64, r *http.Request) { func LogRequest(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer LogComponent(time.Now().UnixNano(), r) // 最后打印日志 - next.ServeHTTP(w, r) + + var user_id int + if token := r.Header.Get("token"); token != "" { + fmt.Println("token:", token) + rows, err := mysqlConnection.Database.Query("SELECT user_id FROM web_auth WHERE token = ? LIMIT 1", token) + if err != nil { + log.Println("查询失败:", err) + return + } + defer rows.Close() + if rows.Next() { + err = rows.Scan(&user_id) + if err != nil { + log.Println("扫描失败:", err) + return + } + } + } + + ctx := context.WithValue(r.Context(), "user_id", user_id) + fmt.Println("user_id:", user_id) + + next.ServeHTTP(w, r.WithContext(ctx)) }) }