diff --git a/configs/sqlite3.go b/configs/sqlite3.go index 3915353..51e7fdb 100644 --- a/configs/sqlite3.go +++ b/configs/sqlite3.go @@ -1,7 +1,6 @@ package configs import ( - "database/sql" "log" "os" @@ -10,9 +9,7 @@ import ( "gorm.io/gorm" ) -// 使用SQLite3初始化數據庫 func init() { - // 設置日誌顯示文件名和行號 log.SetFlags(log.Lshortfile | log.LstdFlags) @@ -20,91 +17,16 @@ func init() { if _, err := os.Stat("data"); os.IsNotExist(err) { os.Mkdir("data", os.ModePerm) } - - // 初始化數據庫 - db, err := sql.Open("sqlite3", "data/sqlite3.db") - if err != nil { - log.Fatal(err) - } - - // 一次性創建多個數據表(自增主鍵) - _, err = db.Exec(` - CREATE TABLE IF NOT EXISTS images( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT, - width INTEGER, - height INTEGER, - prompt TEXT, - negative_prompt TEXT, - num_inference_steps INTEGER, - guidance_scale REAL, - scheduler TEXT, - seed INTEGER, - from_image TEXT, - created_at TEXT, - updated_at TEXT, - user_id INTEGER - ); - CREATE TABLE IF NOT EXISTS models( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT, - type TEXT, - trigger_words TEXT, - base_model TEXT, - model_path TEXT, - status TEXT, - progress INTEGER, - image TEXT, - tags TEXT, - created_at TEXT, - updated_at TEXT, - user_id INTEGER - ); - CREATE TABLE IF NOT EXISTS users( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT, - email TEXT, - password TEXT, - slat TEXT, - created_at TEXT, - updated_at TEXT - ); - CREATE TABLE IF NOT EXISTS tags( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT, - created_at TEXT, - updated_at TEXT - ); - CREATE TABLE IF NOT EXISTS tasks( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT, - status TEXT, - progress INTEGER, - created_at TEXT, - updated_at TEXT, - user_id INTEGER - ); - CREATE TABLE IF NOT EXISTS sessions( - id TEXT PRIMARY KEY, - user_id INTEGER, - created_at TEXT, - updated_at TEXT - ); - `) - defer db.Close() - if err != nil { - log.Fatal(err) - } } -// GetDB 獲取數據庫連接 -func GetDB() (*sql.DB, error) { - db, err := sql.Open("sqlite3", "data/sqlite3.db") - if err != nil { - return nil, err - } - return db, nil -} +//// GetDB 獲取數據庫連接 +//func GetDB() (*sql.DB, error) { +// db, err := sql.Open("sqlite3", "data/sqlite3.db") +// if err != nil { +// return nil, err +// } +// return db, nil +//} // ORMDB 使用 GORM func ORMDB() (db *gorm.DB) { diff --git a/models/Image.go b/models/Image.go index d23fc4b..14c3b14 100644 --- a/models/Image.go +++ b/models/Image.go @@ -1,12 +1,11 @@ package models import ( - "log" "main/configs" ) type Image struct { - ID int `json:"id"` + ID int `json:"id" gorm:"primary_key"` Name string `json:"name"` Width int `json:"width"` Height int `json:"height"` @@ -22,125 +21,6 @@ type Image struct { UserID int `json:"user_id"` } -func (image *Image) Create() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("INSERT INTO images(name, created_at, updated_at) values(?, ?, ?)") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - result, err := stmt.Exec(image.Name, image.CreatedAt, image.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - id, err := result.LastInsertId() - if err != nil { - return err - } - image.ID = int(id) - return nil -} - -func (image *Image) Delete() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("DELETE FROM images WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(image.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (image *Image) Update() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("UPDATE images SET name = ?, updated_at = ? WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(image.Name, image.UpdatedAt, image.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (image *Image) Get() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - err = db.QueryRow("SELECT id, name, created_at, updated_at FROM images WHERE id = ?", image.ID).Scan(&image.ID, &image.Name, &image.CreatedAt, &image.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func QueryImages(page int, pagesize int) (images []interface{}) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - rows, err := db.Query("SELECT id, name, created_at, updated_at FROM images LIMIT ?, ?", (page-1)*pagesize, pagesize) - if err != nil { - log.Println(err) - return - } - defer rows.Close() - for rows.Next() { - image := Image{} - err := rows.Scan(&image.ID, &image.Name, &image.CreatedAt, &image.UpdatedAt) - if err != nil { - log.Println(err) - return - } - images = append(images, image) - } - return -} - -func CountImages() (count int) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - err = db.QueryRow("SELECT COUNT(*) FROM images").Scan(&count) - if err != nil { - log.Println(err) - return - } - return +func init() { + configs.ORMDB().AutoMigrate(&Image{}) } diff --git a/models/ListView.go b/models/ListView.go index 57437a1..e11ebf3 100644 --- a/models/ListView.go +++ b/models/ListView.go @@ -9,7 +9,7 @@ import ( type ListView struct { Page int `json:"page"` PageSize int `json:"page_size"` - Total int `json:"total"` + Total int64 `json:"total"` Next bool `json:"next"` List []interface{} `json:"list"` } diff --git a/models/Model.go b/models/Model.go index 68fb612..5f4fa89 100644 --- a/models/Model.go +++ b/models/Model.go @@ -1,7 +1,6 @@ package models import ( - "log" "main/configs" ) @@ -25,25 +24,25 @@ func init() { configs.ORMDB().AutoMigrate(&Model{}) } -func (model *Model) SendToTrain() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("UPDATE models SET status = ?, progress = ?, updated_at = ? WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(model.Status, model.Progress, model.UpdatedAt, model.ID) - if err != nil { - log.Println(err) - return err - } - // TODO: 創建一個新線程管理訓練任務 - // 將任務放入隊列中, 將自動回調更新任務狀態 - return nil -} +//func (model *Model) SendToTrain() error { +// db, err := configs.GetDB() +// if err != nil { +// log.Println(err) +// return err +// } +// defer db.Close() +// stmt, err := db.Prepare("UPDATE models SET status = ?, progress = ?, updated_at = ? WHERE id = ?") +// if err != nil { +// log.Println(err) +// return err +// } +// defer stmt.Close() +// _, err = stmt.Exec(model.Status, model.Progress, model.UpdatedAt, model.ID) +// if err != nil { +// log.Println(err) +// return err +// } +// // TODO: 創建一個新線程管理訓練任務 +// // 將任務放入隊列中, 將自動回調更新任務狀態 +// return nil +//} diff --git a/models/Server.go b/models/Server.go index 8eea7ab..2ec8e56 100644 --- a/models/Server.go +++ b/models/Server.go @@ -1,12 +1,11 @@ package models import ( - "log" "main/configs" ) type Server struct { - ID int `json:"id"` + ID int `json:"id" gorm:"primary_key"` Name string `json:"name"` Type string `json:"type"` // (訓練|推理) IP string `json:"ip"` @@ -17,125 +16,6 @@ type Server struct { UpdatedAt string `json:"updated_at"` } -func (server *Server) Create() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("INSERT INTO servers(name, ip, port, username, password, created_at, updated_at) values(?, ?, ?, ?, ?, ?, ?)") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - result, err := stmt.Exec(server.Name, server.IP, server.Port, server.Username, server.Password, server.CreatedAt, server.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - id, err := result.LastInsertId() - if err != nil { - return err - } - server.ID = int(id) - return nil -} - -func (server *Server) Delete() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("DELETE FROM servers WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(server.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (server *Server) Update() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("UPDATE servers SET name = ?, ip = ?, port = ?, username = ?, password = ?, updated_at = ? WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(server.Name, server.IP, server.Port, server.Username, server.Password, server.UpdatedAt, server.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (server *Server) Get() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - err = db.QueryRow("SELECT id, name, ip, port, username, password, created_at, updated_at FROM servers WHERE id = ?", server.ID).Scan(&server.ID, &server.Name, &server.IP, &server.Port, &server.Username, &server.Password, &server.CreatedAt, &server.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func QueryServers(page int, pagesize int) (servers []interface{}) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - rows, err := db.Query("SELECT id, name, ip, port, username, password, created_at, updated_at FROM servers ORDER BY id DESC LIMIT ?, ?", page*pagesize, pagesize) - if err != nil { - log.Println(err) - return - } - defer rows.Close() - for rows.Next() { - server := Server{} - err := rows.Scan(&server.ID, &server.Name, &server.IP, &server.Port, &server.Username, &server.Password, &server.CreatedAt, &server.UpdatedAt) - if err != nil { - log.Println(err) - continue - } - servers = append(servers, server) - } - return -} - -func CountServers() (count int) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - err = db.QueryRow("SELECT COUNT(*) FROM servers").Scan(&count) - if err != nil { - log.Println(err) - return - } - return +func init() { + configs.ORMDB().AutoMigrate(&Server{}) } diff --git a/models/Tag.go b/models/Tag.go index 49e6435..f3c6ffa 100644 --- a/models/Tag.go +++ b/models/Tag.go @@ -1,197 +1,16 @@ package models import ( - "log" "main/configs" ) type Tag struct { - ID int `json:"id"` + ID int `json:"id" gorm:"primary_key"` Name string `json:"name"` CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` } -func (tag *Tag) Create(name string) error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("INSERT INTO tags(name, created_at, updated_at) values(?, ?, ?)") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - result, err := stmt.Exec(name, tag.CreatedAt, tag.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - id, err := result.LastInsertId() - if err != nil { - return err - } - tag.ID = int(id) - return nil -} - -func (tag *Tag) Delete() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("DELETE FROM tags WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(tag.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (tag *Tag) Update(name string) error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("UPDATE tags SET name = ?, updated_at = ? WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(name, tag.UpdatedAt, tag.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (tag *Tag) Get() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - err = db.QueryRow("SELECT * FROM tags WHERE id = ?", tag.ID).Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func GetTags() ([]Tag, error) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return nil, err - } - defer db.Close() - rows, err := db.Query("SELECT * FROM tags") - if err != nil { - log.Println(err) - return nil, err - } - defer rows.Close() - var tags []Tag - for rows.Next() { - var tag Tag - err := rows.Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt) - if err != nil { - log.Println(err) - return nil, err - } - tags = append(tags, tag) - } - return tags, nil -} - -func GetTag(id int) (*Tag, error) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return nil, err - } - defer db.Close() - row := db.QueryRow("SELECT * FROM tags WHERE id = ?", id) - var tag Tag - err = row.Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt) - if err != nil { - log.Println(err) - return nil, err - } - return &tag, nil -} - -func GetTagByName(name string) (*Tag, error) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return nil, err - } - defer db.Close() - row := db.QueryRow("SELECT * FROM tags WHERE name = ?", name) - var tag Tag - err = row.Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt) - if err != nil { - log.Println(err) - return nil, err - } - return &tag, nil -} - -func QueryTags(page, pagesize int) (list []interface{}) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return nil - } - defer db.Close() - rows, err := db.Query("SELECT * FROM tags LIMIT ?, ?", page, pagesize) - if err != nil { - log.Println(err) - return nil - } - defer rows.Close() - for rows.Next() { - var tag Tag - err := rows.Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt) - if err != nil { - log.Println(err) - return nil - } - list = append(list, tag) - } - return list -} - -func CountTags() (count int) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return 0 - } - defer db.Close() - row := db.QueryRow("SELECT COUNT(*) FROM tags") - err = row.Scan(&count) - if err != nil { - log.Println(err) - return 0 - } - return count +func init() { + configs.ORMDB().AutoMigrate(&Tag{}) } diff --git a/models/Task.go b/models/Task.go index bcdc2d6..ecc1848 100644 --- a/models/Task.go +++ b/models/Task.go @@ -1,13 +1,5 @@ package models -import ( - "log" - "main/configs" - "net/http" - "strconv" - "time" -) - type Task struct { ID int `json:"id"` Name string `json:"name"` @@ -19,183 +11,45 @@ type Task struct { UserID int `json:"user_id"` } -// 推理任務 -func startInferenceTask(task *Task) { - - // 獲取一臺可用的 GPU 資源 - // ... - - // 執行推理任務 - // ... - - // 更新任務狀態 - task.Status = "running" - task.Progress = 0 - task.Update() - - // 監聽任務狀態 - for { - // 延遲 1 秒 - time.Sleep(1 * time.Second) - - // 查詢任務狀態 - resp, err := http.Get("http://localhost:5000/api/v1/tasks/" + strconv.Itoa(task.ID)) - if err != nil { - log.Println(err) - continue - } - defer resp.Body.Close() - - // 解析任務狀態 - // ... - - // 更新任務狀態 - task.Progress = 100 - task.Status = "success" - task.Update() - - // 任務結束判定 - if task.Progress == 100 { - break - } - } - -} - -func (task *Task) Create() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("INSERT INTO tasks(name, type, created_at, updated_at) values(?, ?, ?, ?)") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - result, err := stmt.Exec(task.Name, task.Type, task.CreatedAt, task.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - id, err := result.LastInsertId() - if err != nil { - return err - } - task.ID = int(id) - return nil -} - -func (task *Task) Delete() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("DELETE FROM tasks WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(task.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (task *Task) Update() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("UPDATE tasks SET name = ?, type = ?, updated_at = ? WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(task.Name, task.Type, task.UpdatedAt, task.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (task *Task) Get() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - err = db.QueryRow("SELECT name, type, created_at, updated_at FROM tasks WHERE id = ?", task.ID).Scan(&task.Name, &task.Type, &task.CreatedAt, &task.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func QueryTask(id int) (task Task) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - err = db.QueryRow("SELECT id, name, type, created_at, updated_at FROM tasks WHERE id = ?", id).Scan(&task.ID, &task.Name, &task.Type, &task.CreatedAt, &task.UpdatedAt) - if err != nil { - log.Println(err) - return - } - return -} - -func QueryTasks(page int, pagesize int) (tasks []interface{}) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - rows, err := db.Query("SELECT id, name, type, created_at, updated_at FROM tasks LIMIT ?, ?", (page-1)*pagesize, pagesize) - if err != nil { - log.Println(err) - return - } - defer rows.Close() - for rows.Next() { - task := Task{} - err := rows.Scan(&task.ID, &task.Name, &task.Type, &task.CreatedAt, &task.UpdatedAt) - if err != nil { - log.Println(err) - return - } - tasks = append(tasks, task) - } - return -} - -func CountTasks() (count int) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - err = db.QueryRow("SELECT COUNT(*) FROM tasks").Scan(&count) - if err != nil { - log.Println(err) - return - } - return -} +//// 推理任務 +//func startInferenceTask(task *Task) { +// +// // 獲取一臺可用的 GPU 資源 +// // ... +// +// // 執行推理任務 +// // ... +// +// // 更新任務狀態 +// task.Status = "running" +// task.Progress = 0 +// task.Update() +// +// // 監聽任務狀態 +// for { +// // 延遲 1 秒 +// time.Sleep(1 * time.Second) +// +// // 查詢任務狀態 +// resp, err := http.Get("http://localhost:5000/api/v1/tasks/" + strconv.Itoa(task.ID)) +// if err != nil { +// log.Println(err) +// continue +// } +// defer resp.Body.Close() +// +// // 解析任務狀態 +// // ... +// +// // 更新任務狀態 +// task.Progress = 100 +// task.Status = "success" +// task.Update() +// +// // 任務結束判定 +// if task.Progress == 100 { +// break +// } +// } +// +//} diff --git a/models/User.go b/models/User.go index a654241..180a88b 100644 --- a/models/User.go +++ b/models/User.go @@ -3,14 +3,11 @@ package models import ( "crypto/md5" "fmt" - "log" "main/configs" - "main/utils" - "time" ) type User struct { - ID int `json:"id"` + ID int `json:"id" gorm:"primary_key"` Name string `json:"name"` Email string `json:"email"` CreatedAt string `json:"created_at"` @@ -19,241 +16,11 @@ type User struct { Slat string `json:"-"` } -func (user *User) Create(name, email, password string) error { - - if name == "" || email == "" || password == "" { - return fmt.Errorf("name, email and password can not be empty") - } - - user.Slat = utils.RandomString(16) - user.Password = fmt.Sprintf("%x", md5.Sum([]byte(password+user.Slat))) - user.Name = name - user.Email = email - user.CreatedAt = time.Now().Format("2006-01-02 15:04:05") - user.UpdatedAt = user.CreatedAt - - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("INSERT INTO users(name, email, password, slat, created_at, updated_at) values(?, ?, ?, ?, ?, ?)") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - result, err := stmt.Exec(user.Name, user.Email, user.Password, user.Slat, user.CreatedAt, user.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - id, err := result.LastInsertId() - if err != nil { - return err - } - user.ID = int(id) - return nil -} - -func (user *User) Delete() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("DELETE FROM users WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(user.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (user *User) Update() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("UPDATE users SET name = ?, email = ?, updated_at = ? WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(user.Name, user.Email, user.UpdatedAt, user.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (user *User) RoadByID(id int) (err error) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE id = ?", user.ID).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (user *User) Get() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE id = ?", user.ID).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (user *User) GetAll() ([]User, error) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return nil, err - } - defer db.Close() - rows, err := db.Query("SELECT id, name, email, created_at, updated_at FROM users") - if err != nil { - log.Println(err) - return nil, err - } - defer rows.Close() - var users []User - for rows.Next() { - var user User - err := rows.Scan(&user.ID, &user.Name, &user.Email, &user.CreatedAt, &user.UpdatedAt) - if err != nil { - log.Println(err) - return nil, err - } - users = append(users, user) - } - return users, nil +func init() { + configs.ORMDB().AutoMigrate(&User{}) } // 驗證用戶密碼 func (user *User) CheckPassword(password string) bool { return user.Password == fmt.Sprintf("%x", md5.Sum([]byte(password+user.Slat))) } - -// 使用Email和Password驗證登錄 -func (user *User) CheckLogin(email, password string) bool { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return false - } - defer db.Close() - err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE email = ?", email).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt) - if err != nil { - log.Println(err) - return false - } - if user.ID == 0 { - fmt.Println("user not found") - return false - } - if user.Password == "" { - fmt.Println("password is empty") - return false - } - if user.Password == fmt.Sprintf("%x", md5.Sum([]byte(password+user.Slat))) { - return true - } - return false -} - -// 獲取用戶實體 -func GetUserByEmail(email string) (user User, err error) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE email = ?", email).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt) - if err != nil { - log.Println(err) - return - } - return -} - -func QueryUserByEmail(email string) (user User, err error) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - err = db.QueryRow("SELECT id, name, email, created_at, updated_at FROM users WHERE email = ?", email).Scan(&user.ID, &user.Name, &user.Email, &user.CreatedAt, &user.UpdatedAt) - if err != nil { - log.Println(err) - return - } - return -} - -func QueryUsers(page, pagesize int) (list []interface{}) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - rows, err := db.Query("SELECT id, name, email, created_at, updated_at FROM users LIMIT ?, ?", (page-1)*pagesize, pagesize) - if err != nil { - log.Println(err) - return - } - defer rows.Close() - for rows.Next() { - var user User - err := rows.Scan(&user.ID, &user.Name, &user.Email, &user.CreatedAt, &user.UpdatedAt) - if err != nil { - log.Println(err) - return - } - list = append(list, user) - } - return -} - -func CountUsers() (count int) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) - if err != nil { - log.Println(err) - return - } - return -} diff --git a/models/account.go b/models/account.go index ced6d1a..40b04d0 100644 --- a/models/account.go +++ b/models/account.go @@ -1,6 +1,7 @@ package models import ( + "main/configs" "net/http" ) @@ -24,7 +25,7 @@ func AccountRead(w http.ResponseWriter, r *http.Request, cb func(account *Accoun // 獲取當前session session := Session{ID: cookie.Value} - if err := session.Get(); err != nil { + if err := configs.ORMDB().Take(&session).Error; err != nil { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("401 - 會話已過期")) return @@ -32,7 +33,7 @@ func AccountRead(w http.ResponseWriter, r *http.Request, cb func(account *Accoun // 獲取當前用戶 user := User{ID: session.UserID} - if err := user.Get(); err != nil { + if err := configs.ORMDB().Model(&user).Select("id, name, email, created_at, updated_at").Find(&user).Error; err != nil { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("401 - 用戶不存在")) return diff --git a/models/session.go b/models/session.go index 903f145..8b9e233 100644 --- a/models/session.go +++ b/models/session.go @@ -1,151 +1,16 @@ package models import ( - "log" "main/configs" - "time" ) type Session struct { - ID string `json:"id"` + ID string `json:"id" gorm:"primary_key"` UserID int `json:"user_id"` CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` } -func (session *Session) Get() (err error) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - row := db.QueryRow("SELECT * FROM sessions WHERE id = ?", session.ID) - err = row.Scan(&session.ID, &session.UserID, &session.CreatedAt, &session.UpdatedAt) - if err != nil { - log.Println(err) - return - } - return -} - -func (session *Session) Create() error { - session.CreatedAt = time.Now().Format("2006-01-02 15:04:05") - session.UpdatedAt = session.CreatedAt - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("INSERT INTO sessions (id, user_id, created_at, updated_at) VALUES (?, ?, ?, ?)") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(session.ID, session.UserID, session.CreatedAt, session.UpdatedAt) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (session *Session) Delete() error { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("DELETE FROM sessions WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(session.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func (session *Session) Update() error { - session.UpdatedAt = time.Now().Format("2006-01-02 15:04:05") - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return err - } - defer db.Close() - stmt, err := db.Prepare("UPDATE sessions SET user_id = ?, updated_at = ? WHERE id = ?") - if err != nil { - log.Println(err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(session.UpdatedAt, session.UserID, session.ID) - if err != nil { - log.Println(err) - return err - } - return nil -} - -func GetSession(id int) (*Session, error) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return nil, err - } - defer db.Close() - row := db.QueryRow("SELECT id, user_id, created_at, updated_at FROM sessions WHERE id = ?", id) - var session Session - if err := row.Scan(&session.ID, &session.UserID, &session.CreatedAt, &session.UpdatedAt); err != nil { - log.Println(err) - return nil, err - } - return &session, nil -} - -func QuerySessions(page, pagesize int) (list []interface{}) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - rows, err := db.Query("SELECT id, user_id, created_at, updated_at FROM sessions LIMIT ?, ?", (page-1)*pagesize, pagesize) - if err != nil { - log.Println(err) - return - } - defer rows.Close() - for rows.Next() { - var session Session - if err := rows.Scan(&session.ID, &session.UserID, &session.CreatedAt, &session.UpdatedAt); err != nil { - log.Println(err) - return - } - list = append(list, session) - } - return -} - -func CountSessions() (count int) { - db, err := configs.GetDB() - if err != nil { - log.Println(err) - return - } - defer db.Close() - row := db.QueryRow("SELECT COUNT(*) FROM sessions") - if err := row.Scan(&count); err != nil { - log.Println(err) - return - } - return +func init() { + configs.ORMDB().AutoMigrate(&Session{}) } diff --git a/routers/account.go b/routers/account.go index e7b72c0..0ba5b20 100644 --- a/routers/account.go +++ b/routers/account.go @@ -2,6 +2,7 @@ package routers import ( "fmt" + "main/configs" "main/models" "main/utils" "net/http" @@ -22,16 +23,23 @@ func AccountGet(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_id") if err != nil { fmt.Println(err) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("401 - 未登錄")) return } // 獲取會話 session := models.Session{ID: cookie.Value} - session.Get() + if err := configs.ORMDB().Take(&session).Error; err != nil { + fmt.Println(err) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("401 - 會話已過期")) + return + } // 獲取用戶 user := models.User{ID: session.UserID} - user.Get() + configs.ORMDB().Model(&user).Select("id, name, email, created_at, updated_at").Find(&user) account.ID = user.ID account.Name = user.Name @@ -43,25 +51,3 @@ func AccountGet(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(account)) } - -// 獲取當前賬戶, 並將其傳入回調函數 -func get_account(w http.ResponseWriter, r *http.Request, callback func(*models.User)) (err error) { - // 獲取Cookie - cookie, err := r.Cookie("session_id") - if err != nil { - fmt.Println(err) - return nil - } - - // 獲取會話 - session := models.Session{ID: cookie.Value} - session.Get() - - // 獲取用戶 - user := models.User{ID: session.UserID} - user.Get() - - callback(&user) - - return nil -} diff --git a/routers/images.go b/routers/images.go index ad877a8..7d24b47 100644 --- a/routers/images.go +++ b/routers/images.go @@ -4,9 +4,11 @@ import ( "encoding/json" "io/ioutil" "log" + "main/configs" "main/models" "main/utils" "net/http" + "time" "github.com/gorilla/mux" ) @@ -15,9 +17,17 @@ func ImagesGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) - listview.List = models.QueryImages(listview.Page, listview.PageSize) - listview.Total = models.CountImages() - listview.Next = listview.Page*listview.PageSize < listview.Total + + var image_list []models.Image + db := configs.ORMDB() + db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&image_list) + for _, image := range image_list { + listview.List = append(listview.List, image) + } + + db.Model(&models.Image{}).Count(&listview.Total) + + listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } @@ -33,14 +43,20 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) { log.Println(err) return } - image.Create() + if err := configs.ORMDB().Create(&image).Error; err != nil { + log.Println(err) + return + } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(image)) } func ImagesItemGet(w http.ResponseWriter, r *http.Request) { image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - image.Get() + if err := configs.ORMDB().First(&image).Error; err != nil { + log.Println(err) + return + } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(image)) } @@ -58,14 +74,24 @@ func ImagesItemPatch(w http.ResponseWriter, r *http.Request) { return } image.ID = utils.ParamInt(mux.Vars(r)["id"], 0) - image.Update() + image.UpdatedAt = time.Now().Format("2006-01-02 15:04:05") + if err := configs.ORMDB().Model(&image).Updates(image).Error; err != nil { + log.Println(err) + return + } + + //image.ID = utils.ParamInt(mux.Vars(r)["id"], 0) + //image.Update() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(image)) } func ImagesItemDelete(w http.ResponseWriter, r *http.Request) { image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - image.Delete() + if err := configs.ORMDB().Delete(&image).Error; err != nil { + log.Println(err) + return + } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(image)) } diff --git a/routers/models.go b/routers/models.go index bebee3d..21608f3 100644 --- a/routers/models.go +++ b/routers/models.go @@ -17,28 +17,24 @@ import ( var manager = models.NewWebSocketManager() +// 獲取模型列表 func ModelsGet(w http.ResponseWriter, r *http.Request) { - // 初始化基本參數 var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) - // 獲取模型列表 var model_list []models.Model db := configs.ORMDB() db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list) for _, model := range model_list { listview.List = append(listview.List, model) } - // 獲取總數 - var total int64 - db.Model(&models.Model{}).Count(&total) - listview.Total = int(total) - listview.Next = listview.Page*listview.PageSize < listview.Total + db.Model(&models.Model{}).Count(&listview.Total) + listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } +// 創建模型 func ModelsPost(w http.ResponseWriter, r *http.Request) { - // 取得用戶 models.AccountRead(w, r, func(account *models.Account) { fmt.Println(account) // TODO: 判斷權限(是否可以創建) @@ -138,7 +134,7 @@ func ModelItemPatch(w http.ResponseWriter, r *http.Request) { model.Status = model_new.Status // 如果狀態被改變爲 ready, 將模型發送到訓練隊列 if model.Status == "ready" { - model.SendToTrain() + //model.SendToTrain() } } if model_new.Image != "" && model_new.Image != model.Image { diff --git a/routers/servers.go b/routers/servers.go index 85b4fea..ed15c1e 100644 --- a/routers/servers.go +++ b/routers/servers.go @@ -1,6 +1,7 @@ package routers import ( + "main/configs" "main/models" "main/utils" "net/http" @@ -12,38 +13,42 @@ func ServersGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) - listview.List = models.QueryServers(listview.Page, listview.PageSize) - listview.Total = models.CountServers() - listview.Next = listview.Page*listview.PageSize < listview.Total + var server_list []models.Server + db := configs.ORMDB() + db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&server_list) + for _, server := range server_list { + listview.List = append(listview.List, server) + } + listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } func ServersPost(w http.ResponseWriter, r *http.Request) { var server models.Server - server.Create() + configs.ORMDB().Create(&server) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(server)) } func ServersItemGet(w http.ResponseWriter, r *http.Request) { server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - server.Get() + configs.ORMDB().First(&server) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(server)) } func ServersItemPatch(w http.ResponseWriter, r *http.Request) { server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - server.Get() - server.Update() + configs.ORMDB().First(&server) + // TODO: update server + configs.ORMDB().Save(&server) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(server)) } func ServersItemDelete(w http.ResponseWriter, r *http.Request) { server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - server.Get() - server.Delete() + configs.ORMDB().Delete(&server) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(server)) } diff --git a/routers/sessions.go b/routers/sessions.go index 8dbcd2d..c839db4 100644 --- a/routers/sessions.go +++ b/routers/sessions.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "main/configs" "main/models" "main/utils" "net/http" @@ -17,9 +18,14 @@ func SessionsGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) - listview.List = models.QuerySessions(listview.Page, listview.PageSize) - listview.Total = models.CountSessions() - listview.Next = listview.Page*listview.PageSize < listview.Total + var session_list []models.Session + db := configs.ORMDB() + db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&session_list) + for _, session := range session_list { + listview.List = append(listview.List, session) + } + db.Model(&models.Session{}).Count(&listview.Total) + listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } @@ -55,8 +61,8 @@ func SessionsPost(w http.ResponseWriter, r *http.Request) { } // 使用Email獲取用戶 - user, err := models.GetUserByEmail(form.Email) - if err != nil { + var user models.User + if err := configs.ORMDB().Where("email = ?", form.Email).First(&user).Error; err != nil { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("404 - User Not Found")) return @@ -71,7 +77,11 @@ func SessionsPost(w http.ResponseWriter, r *http.Request) { // 創建會話(生成一個不重複的 uuid 作爲 sid) session := &models.Session{ID: uuid.New().String(), UserID: user.ID} - session.Create() + if err := configs.ORMDB().Create(session).Error; err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Internal Server Error")) + return + } // 寫入Cookie cookie := http.Cookie{Name: "session_id", Value: session.ID, Path: "/", HttpOnly: true} @@ -85,7 +95,11 @@ func SessionsPost(w http.ResponseWriter, r *http.Request) { // 獲取會話 func SessionsItemGet(w http.ResponseWriter, r *http.Request) { session := models.Session{ID: mux.Vars(r)["session_id"]} - session.Get() + if err := configs.ORMDB().Find(&session).Error; err != nil { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("404 - Not Found")) + return + } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(session)) } @@ -93,8 +107,11 @@ func SessionsItemGet(w http.ResponseWriter, r *http.Request) { // 更新會話 func SessionsItemPatch(w http.ResponseWriter, r *http.Request) { session := models.Session{ID: mux.Vars(r)["session_id"]} - session.Get() - session.Update() + if err := configs.ORMDB().Model(&session).Updates(GetForm(r)); err != nil { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("404 - Not Found")) + return + } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(session)) } @@ -113,14 +130,22 @@ func SessionsItemDelete(w http.ResponseWriter, r *http.Request) { // 獲取當前session session := models.Session{ID: cookie.Value} - session.Get() + if err := configs.ORMDB().Find(&session).Error; err != nil { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("401 - 會話已過期")) + return + } // 獲取當前用戶 user := models.User{ID: session.UserID} - user.Get() + configs.ORMDB().Find(&user) sessionx := models.Session{ID: mux.Vars(r)["session_id"]} - sessionx.Get() + if err := configs.ORMDB().Find(&sessionx).Error; err != nil { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("404 - Not Found")) + return + } if user.ID != sessionx.UserID { w.WriteHeader(http.StatusUnauthorized) @@ -128,7 +153,11 @@ func SessionsItemDelete(w http.ResponseWriter, r *http.Request) { return } - sessionx.Delete() + if err := configs.ORMDB().Delete(&sessionx).Error; err != nil { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("404 - Not Found")) + return + } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(sessionx)) } diff --git a/routers/tags.go b/routers/tags.go index 3db2001..b3de16a 100644 --- a/routers/tags.go +++ b/routers/tags.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "main/configs" "main/models" "main/utils" "net/http" @@ -16,9 +17,14 @@ func TagsGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) - listview.List = models.QueryTags(listview.Page, listview.PageSize) - listview.Total = models.CountTags() - listview.Next = listview.Page*listview.PageSize < listview.Total + var tag_list []models.Tag + db := configs.ORMDB() + db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&tag_list) + for _, tag := range tag_list { + listview.List = append(listview.List, tag) + } + db.Model(&models.Tag{}).Count(&listview.Total) + listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } @@ -38,8 +44,8 @@ func TagsPost(w http.ResponseWriter, r *http.Request) { return } // 創建標籤 - var tag models.Tag - if err := tag.Create(form.Name); err != nil { + var tag models.Tag = models.Tag{Name: form.Name} + if err := configs.ORMDB().Create(&tag).Error; err != nil { fmt.Println(err) return } @@ -50,13 +56,8 @@ func TagsPost(w http.ResponseWriter, r *http.Request) { // 獲取標籤 func TagsItemGet(w http.ResponseWriter, r *http.Request) { - var tag models.Tag - tag.ID = utils.ParamInt(mux.Vars(r)["id"], 0) - if tag.ID == 0 { - w.WriteHeader(http.StatusNotFound) - return - } - if err := tag.Get(); err != nil { + var tag models.Tag = models.Tag{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + if err := configs.ORMDB().First(&tag).Error; err != nil { fmt.Println(err) return } @@ -66,16 +67,6 @@ func TagsItemGet(w http.ResponseWriter, r *http.Request) { // 更新標籤 func TagsItemPatch(w http.ResponseWriter, r *http.Request) { - var tag models.Tag - tag.ID = utils.ParamInt(mux.Vars(r)["id"], 0) - if tag.ID == 0 { - w.WriteHeader(http.StatusNotFound) - return - } - if err := tag.Get(); err != nil { - fmt.Println(err) - return - } var form struct { Name string `json:"name"` } @@ -89,7 +80,8 @@ func TagsItemPatch(w http.ResponseWriter, r *http.Request) { fmt.Println(err) return } - if err := tag.Update(form.Name); err != nil { + var tag models.Tag = models.Tag{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + if err := configs.ORMDB().Model(&tag).Update("name", form.Name).Error; err != nil { fmt.Println(err) return } @@ -99,13 +91,8 @@ func TagsItemPatch(w http.ResponseWriter, r *http.Request) { // 刪除標籤 func TagsItemDelete(w http.ResponseWriter, r *http.Request) { - var tag models.Tag - tag.ID = utils.ParamInt(mux.Vars(r)["id"], 0) - if tag.ID == 0 { - w.WriteHeader(http.StatusNotFound) - return - } - if err := tag.Delete(); err != nil { + var tag models.Tag = models.Tag{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + if err := configs.ORMDB().Delete(&tag).Error; err != nil { fmt.Println(err) return } diff --git a/routers/tasks.go b/routers/tasks.go index c97c866..bd847a3 100644 --- a/routers/tasks.go +++ b/routers/tasks.go @@ -4,6 +4,7 @@ import ( "encoding/json" "io/ioutil" "log" + "main/configs" "main/models" "main/utils" "net/http" @@ -17,9 +18,14 @@ func TasksGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) - listview.List = models.QueryTasks(listview.Page, listview.PageSize) - listview.Total = models.CountTasks() - listview.Next = listview.Page*listview.PageSize < listview.Total + var task_list []models.Task + db := configs.ORMDB() + db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&task_list) + for _, task := range task_list { + listview.List = append(listview.List, task) + } + db.Model(&models.Task{}).Count(&listview.Total) + listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } @@ -35,7 +41,7 @@ func TasksPost(w http.ResponseWriter, r *http.Request) { log.Println(err) return } - task.Create() + configs.ORMDB().Create(&task) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(task)) } @@ -44,8 +50,10 @@ func TasksItemGet(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") == "websocket" { vars := mux.Vars(r) id, _ := strconv.Atoi(vars["id"]) - task := models.QueryTask(id) - if task.ID == 0 { + + var task models.Task = models.Task{ID: id} + if err := configs.ORMDB().First(&task, id).Error; err != nil { + log.Println(err) w.WriteHeader(http.StatusNotFound) return } @@ -63,12 +71,12 @@ func TasksItemGet(w http.ResponseWriter, r *http.Request) { break } task.Status = string(message) - task.Update() + configs.ORMDB().Model(&task).Update("status", task.Status) } return } task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - task.Get() + configs.ORMDB().First(&task) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(task)) } @@ -86,18 +94,13 @@ func TasksItemPatch(w http.ResponseWriter, r *http.Request) { return } task.ID = utils.ParamInt(mux.Vars(r)["id"], 0) - task.Update() + configs.ORMDB().Model(&task).Updates(task) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(task)) } func TasksItemDelete(w http.ResponseWriter, r *http.Request) { task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - task.Delete() - if task.ID == 0 { - w.WriteHeader(http.StatusNotFound) - return - } - task.Delete() + configs.ORMDB().Delete(&task) w.WriteHeader(http.StatusNoContent) } diff --git a/routers/users.go b/routers/users.go index db4f20a..70f936d 100644 --- a/routers/users.go +++ b/routers/users.go @@ -1,13 +1,17 @@ package routers import ( + "crypto/md5" "encoding/json" "fmt" "io/ioutil" + "main/configs" "main/models" "main/utils" "net/http" + "time" + "github.com/google/uuid" "github.com/gorilla/mux" ) @@ -16,9 +20,14 @@ func UsersGet(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) - listview.List = models.QueryUsers(listview.Page, listview.PageSize) - listview.Total = models.CountUsers() - listview.Next = listview.Page*listview.PageSize < listview.Total + var user_list []models.User + db := configs.ORMDB() + db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&user_list) + for _, user := range user_list { + listview.List = append(listview.List, user) + } + db.Model(&models.User{}).Count(&listview.Total) + listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) } @@ -40,9 +49,22 @@ func UsersPost(w http.ResponseWriter, r *http.Request) { return } + if form.Name == "" || form.Email == "" || form.Password == "" { + fmt.Println("name, email, password cannot be empty") + return + } + // 創建用戶 - var user models.User - if err := user.Create(form.Name, form.Email, form.Password); err != nil { + var slat string = uuid.New().String() + var user models.User = models.User{ + Name: form.Name, + Email: form.Email, + Password: fmt.Sprintf("%x", md5.Sum([]byte(form.Password+slat))), + Slat: slat, + CreatedAt: time.Now().Format("2006-01-02 15:04:05"), + UpdatedAt: time.Now().Format("2006-01-02 15:04:05"), + } + if err := configs.ORMDB().Create(&user).Error; err != nil { fmt.Println(err) return } @@ -55,7 +77,7 @@ func UsersPost(w http.ResponseWriter, r *http.Request) { // 獲取用戶 func UsersItemGet(w http.ResponseWriter, r *http.Request) { user := models.User{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - user.Get() + configs.ORMDB().First(&user) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(user)) } @@ -63,8 +85,20 @@ func UsersItemGet(w http.ResponseWriter, r *http.Request) { // 更新用戶 func UsersItemPatch(w http.ResponseWriter, r *http.Request) { user := models.User{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - user.Get() - user.Update() + body, err := ioutil.ReadAll(r.Body) + if err != nil { + fmt.Println(err) + return + } + defer r.Body.Close() + if err = json.Unmarshal(body, &user); err != nil { + fmt.Println(err) + return + } + if err := configs.ORMDB().Save(&user).Error; err != nil { + fmt.Println(err) + return + } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(user)) } @@ -72,8 +106,7 @@ func UsersItemPatch(w http.ResponseWriter, r *http.Request) { // 刪除用戶 func UsersItemDelete(w http.ResponseWriter, r *http.Request) { user := models.User{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - user.Get() - user.Delete() + configs.ORMDB().Delete(&user) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(user)) }