diff --git a/README.md b/README.md index d70a1df..5389b65 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # ai 繪圖 +- [ ] 注册用户前验证手机号或者邮箱是否已经存在 +- [ ] 使用验证码登录功能同时创建账户 + + 用戶: ```go diff --git a/main.go b/main.go index a6810d9..6828a8f 100644 --- a/main.go +++ b/main.go @@ -12,84 +12,6 @@ import ( "github.com/gorilla/mux" ) -//// 獲取查詢參數(int64 類型) -//func ParamInt(value string, defaultValue int64) int64 { -// if value == "" { -// return defaultValue -// } -// if v, err := strconv.ParseInt(value, 10, 64); err == nil { -// return v -// } -// return defaultValue -//} -// -//func getlist(w http.ResponseWriter, r *http.Request) { -// listview := struct { -// Name string `json:"name"` -// Page int64 `json:"page"` -// PageSize int64 `json:"pageSize"` -// Total int64 `json:"total"` -// Next bool `json:"next"` -// List interface{} `json:"list"` -// }{ -// Name: mux.Vars(r)["name"], -// Page: ParamInt(r.URL.Query().Get("page"), 1), -// PageSize: ParamInt(r.URL.Query().Get("pageSize"), 10), -// Total: 0, -// Next: false, -// List: make([]interface{}, 0), -// } -// -// // 选择对应的模型 -// switch listview.Name { -// case "sessions": -// listview.List = []models.Session{} -// case "users": -// listview.List = []models.User{} -// case "models": -// listview.List = []models.Model{} -// case "images": -// listview.List = []models.Image{} -// case "tags": -// listview.List = []models.Tag{} -// case "servers": -// listview.List = []models.Server{} -// case "datasets": -// listview.List = []models.Dataset{} -// default: -// fmt.Fprintf(w, "404") -// return -// } -// -// // 从数据库中获取数据 -// db := configs.ORMDB() -// //if task := r.URL.Query().Get("task"); task != "" { -// // db = db.Where("task = ?", task) -// //} -// //if status := r.URL.Query().Get("status"); status != "" { -// // db = db.Where("status = ?", status) -// //} -// //if user_id := r.URL.Query().Get("user_id"); user_id != "" { -// // db = db.Where("user_id = ?", user_id) -// //} -// //if model_id := r.URL.Query().Get("model_id"); model_id != "" { -// // db = db.Where("model_id = ?", model_id) -// //} -// //// 获取指定用户喜欢的对象(图像) -// //if like := r.URL.Query().Get("like"); like != "" { -// // list, _ := models.LikeImage.GetA(like) -// // db = db.Where("id in (?)", list) -// //} -// // 分页数据 -// db = db.Offset(int((listview.Page - 1) * listview.PageSize)).Limit(int(listview.PageSize)) -// db.Find(&listview.List).Count(&listview.Total) -// -// // 轉換爲JSON並返回 -// data, _ := json.MarshalIndent(listview, "", " ") -// w.Header().Set("Content-Type", "application/json; charset=utf-8") -// w.Write(data) -//} - func main() { runtime.GOMAXPROCS(runtime.NumCPU()) log.SetFlags(log.Lshortfile | log.LstdFlags) @@ -114,7 +36,6 @@ func main() { // 設定路由 r.HandleFunc("/api", routers.GetDocs).Methods("GET") - //r.HandleFunc("/api/{name}", getlist).Methods("GET") r.HandleFunc("/api/sessions", routers.SessionsGet).Methods("GET") r.HandleFunc("/api/sessions", routers.SessionsPost).Methods("POST") diff --git a/models/User.go b/models/User.go index 0abff8b..a4e74ba 100644 --- a/models/User.go +++ b/models/User.go @@ -12,6 +12,7 @@ type User struct { Gold int `json:"gold"` Name string `json:"name"` Email string `json:"email" gorm:"unique;not null"` + Mobile string `json:"mobile" gorm:"unique;not null"` Password string `json:"-"` Slat string `json:"-"` Admin bool `json:"admin"` diff --git a/models/code.go b/models/code.go new file mode 100644 index 0000000..e858985 --- /dev/null +++ b/models/code.go @@ -0,0 +1,48 @@ +package models + +import ( + "fmt" + "main/configs" + "math/rand" + "time" +) + +type Code struct { + Email string `json:"email"` + Mobile string `json:"mobile"` + Code string `json:"code"` + Expire time.Time `json:"expire"` +} + +func init() { + configs.ORMDB().AutoMigrate(&Code{}) +} + +func CodeCreate(email, mobile string) string { + code := fmt.Sprintf("%06v", rand.New(rand.NewSource(time.Now().UnixNano())).Int31n(1000000)) + configs.ORMDB().Create(&Code{ + Email: email, + Mobile: mobile, + Code: code, + Expire: time.Now().Add(time.Minute * 5), + }) + return code +} + +func EmailCheck(email, code string) (err error) { + var data Code + configs.ORMDB().Where("email = ?", email).First(&data) + if data.Code == code && data.Expire.After(time.Now()) { + return nil + } + return fmt.Errorf("验证码错误") +} + +func MobileCheck(mobile, code string) (err error) { + var data Code + configs.ORMDB().Where("mobile = ?", mobile).First(&data) + if data.Code == code && data.Expire.After(time.Now()) { + return nil + } + return fmt.Errorf("验证码错误") +} diff --git a/routers/users.go b/routers/users.go index 0d78ea2..6d5abe9 100644 --- a/routers/users.go +++ b/routers/users.go @@ -2,6 +2,7 @@ package routers import ( "crypto/md5" + "encoding/json" "fmt" "main/configs" "main/models" @@ -13,7 +14,7 @@ import ( "github.com/gorilla/mux" ) -// 用戶列表 +// 获取用戶列表 func UsersGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) @@ -28,34 +29,59 @@ func UsersGet(w http.ResponseWriter, r *http.Request) { // 創建用戶 func UsersPost(w http.ResponseWriter, r *http.Request) { - var form map[string]interface{} = utils.BodyRead(r) - if form["name"] == nil || form["email"] == nil || form["password"] == nil { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("400 - name, email, password cannot be empty")) + var data struct { + Name string `json:"name"` + Email string `json:"email"` + Mobile string `json:"mobile"` + Password string `json:"password"` + Code string `json:"code"` + } + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) return } - // 創建用戶 - var slat string = uuid.New().String() - var user models.User = models.User{ - Name: form["name"].(string), - Email: form["email"].(string), - Password: fmt.Sprintf("%x", md5.Sum([]byte(form["password"].(string)+slat))), - Slat: slat, - } - // 檢查郵箱是否已經存在, 郵箱不能重複 + var user models.User var count int64 - configs.ORMDB().Model(&models.User{}).Where("email = ?", user.Email).Count(&count) - if count > 0 { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("400 - email already exists")) - return + // 如果是帐号密码注册 + if data.Name != "" && data.Password != "" { + user.Name = data.Name + user.Slat = uuid.New().String() + user.Password = fmt.Sprintf("%x", md5.Sum([]byte(data.Password+user.Slat))) + configs.ORMDB().Model(&models.User{}).Where("name = ?", user.Name).Count(&count) + if count > 0 { + http.Error(w, "用户名已存在", http.StatusBadRequest) + return + } } - // 檢查用戶名是否已經存在, 用戶名不能重複 - configs.ORMDB().Model(&models.User{}).Where("name = ?", user.Name).Count(&count) - if count > 0 { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("400 - name already exists")) - return + // 如果是邮箱验证码注册 + if data.Email != "" && data.Code != "" { + // 检查验证码是否正确 + if err := models.EmailCheck(data.Email, data.Code); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + user.Email = data.Email + user.Name = fmt.Sprintf("user_%s", uuid.New().String()) // 设置一个随机用户名 + configs.ORMDB().Model(&models.User{}).Where("email = ?", user.Email).Count(&count) + if count > 0 { + http.Error(w, "邮箱已存在", http.StatusBadRequest) + return + } + } + // 如果是短信验证码注册 + if data.Mobile != "" && data.Code != "" { + // 检查验证码是否正确 + if err := models.MobileCheck(data.Mobile, data.Code); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + user.Mobile = data.Mobile + user.Name = fmt.Sprintf("user_%s", uuid.New().String()) // 设置一个随机用户名 + configs.ORMDB().Model(&models.User{}).Where("mobile = ?", user.Mobile).Count(&count) + if count > 0 { + http.Error(w, "手机号已存在", http.StatusBadRequest) + return + } } // 寫入數據庫 if err := configs.ORMDB().Create(&user).Error; err != nil { diff --git a/utils/params.go b/utils/params.go index 7876e7c..6ae4ff7 100644 --- a/utils/params.go +++ b/utils/params.go @@ -3,7 +3,7 @@ package utils import ( "encoding/json" "fmt" - "io/ioutil" + "io" "log" "math/rand" "net/http" @@ -12,7 +12,7 @@ import ( ) func BodyRead(r *http.Request) (form map[string]interface{}) { - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { log.Println(err) return