diff --git a/README.md b/README.md index 3110a30..c8bb470 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ TEST: - [x] GET [/api/images](/api/images) 圖片列表 - [x] GET [/api/tasks](/api/tasks) 任務列表 - [x] GET [/api/tags](/api/tags) 標籤列表 +- [x] GET [/api/params](/api/params) 參數列表 +- [x] GET [/api/account](/api/account) 賬戶信息 TEST: diff --git a/main.go b/main.go index 86d1bf4..6379d41 100644 --- a/main.go +++ b/main.go @@ -97,6 +97,7 @@ func main() { r.HandleFunc("/api/servers/{id}", routers.ServersItemDelete).Methods("DELETE") r.HandleFunc("/api/params/model", routers.ParamsModelsGet).Methods("GET") + r.HandleFunc("/api/account", routers.AccountGet).Methods("GET") log.Println("Web Server is running on http://localhost:8080") http.ListenAndServe(":8080", r) diff --git a/models/User.go b/models/User.go index 4b60e36..a654241 100644 --- a/models/User.go +++ b/models/User.go @@ -38,7 +38,6 @@ func (user *User) Create(name, email, password string) error { return err } defer db.Close() - fmt.Println(user) stmt, err := db.Prepare("INSERT INTO users(name, email, password, slat, created_at, updated_at) values(?, ?, ?, ?, ?, ?)") if err != nil { log.Println(err) @@ -100,6 +99,21 @@ func (user *User) Update() error { return nil } +func (user *User) RoadByID(id int) (err error) { + db, err := configs.GetDB() + if err != nil { + log.Println(err) + return err + } + defer db.Close() + err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE id = ?", user.ID).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt) + if err != nil { + log.Println(err) + return err + } + return nil +} + func (user *User) Get() error { db, err := configs.GetDB() if err != nil { @@ -107,7 +121,7 @@ func (user *User) Get() error { return err } defer db.Close() - err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE email = ?", user.ID).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt) + err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE id = ?", user.ID).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt) if err != nil { log.Println(err) return err diff --git a/models/session.go b/models/session.go index 7cd1e13..903f145 100644 --- a/models/session.go +++ b/models/session.go @@ -13,7 +13,7 @@ type Session struct { UpdatedAt string `json:"updated_at"` } -func (session *Session) Get() { +func (session *Session) Get() (err error) { db, err := configs.GetDB() if err != nil { log.Println(err) @@ -26,6 +26,7 @@ func (session *Session) Get() { log.Println(err) return } + return } func (session *Session) Create() error { @@ -94,31 +95,6 @@ func (session *Session) Update() error { return nil } -//func GetSessions() ([]Session, error) { -// db, err := configs.GetDB() -// if err != nil { -// log.Println(err) -// return nil, err -// } -// defer db.Close() -// rows, err := db.Query("SELECT id, name FROM sessions") -// if err != nil { -// log.Println(err) -// return nil, err -// } -// defer rows.Close() -// sessions := []Session{} -// for rows.Next() { -// var session Session -// if err := rows.Scan(&session.ID, &session.Name); err != nil { -// log.Println(err) -// return nil, err -// } -// sessions = append(sessions, session) -// } -// return sessions, nil -//} - func GetSession(id int) (*Session, error) { db, err := configs.GetDB() if err != nil { diff --git a/routers/account.go b/routers/account.go new file mode 100644 index 0000000..7b4cfc7 --- /dev/null +++ b/routers/account.go @@ -0,0 +1,45 @@ +package routers + +import ( + "fmt" + "main/models" + "main/utils" + "net/http" +) + +// 獲取當前賬戶信息(重寫, 爲輸出增加sid字段) +func AccountGet(w http.ResponseWriter, r *http.Request) { + var account struct { + ID int `json:"id"` + Name string `json:"name"` + Email string `json:"email"` + SessionID string `json:"session_id"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + } + + // 獲取Cookie + cookie, err := r.Cookie("session_id") + if err != nil { + fmt.Println(err) + return + } + + // 獲取會話 + session := models.Session{ID: cookie.Value} + session.Get() + + // 獲取用戶 + user := models.User{ID: session.UserID} + user.Get() + + account.ID = user.ID + account.Name = user.Name + account.Email = user.Email + account.SessionID = session.ID + account.CreatedAt = user.CreatedAt + account.UpdatedAt = user.UpdatedAt + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(account)) +} diff --git a/routers/sessions.go b/routers/sessions.go index a7a260e..9ea2879 100644 --- a/routers/sessions.go +++ b/routers/sessions.go @@ -84,7 +84,7 @@ func SessionsPost(w http.ResponseWriter, r *http.Request) { // 獲取會話 func SessionsItemGet(w http.ResponseWriter, r *http.Request) { - session := models.Session{ID: mux.Vars(r)["id"]} + session := models.Session{ID: mux.Vars(r)["session_id"]} session.Get() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(session)) @@ -92,7 +92,7 @@ func SessionsItemGet(w http.ResponseWriter, r *http.Request) { // 更新會話 func SessionsItemPatch(w http.ResponseWriter, r *http.Request) { - session := models.Session{ID: mux.Vars(r)["id"]} + session := models.Session{ID: mux.Vars(r)["session_id"]} session.Get() session.Update() w.Header().Set("Content-Type", "application/json; charset=utf-8")