This commit is contained in:
2023-05-14 07:00:24 +08:00
parent 2423213e9a
commit ee3b60eccc
18 changed files with 273 additions and 1221 deletions

View File

@@ -1,7 +1,6 @@
package configs package configs
import ( import (
"database/sql"
"log" "log"
"os" "os"
@@ -10,9 +9,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// 使用SQLite3初始化數據庫
func init() { func init() {
// 設置日誌顯示文件名和行號 // 設置日誌顯示文件名和行號
log.SetFlags(log.Lshortfile | log.LstdFlags) log.SetFlags(log.Lshortfile | log.LstdFlags)
@@ -20,91 +17,16 @@ func init() {
if _, err := os.Stat("data"); os.IsNotExist(err) { if _, err := os.Stat("data"); os.IsNotExist(err) {
os.Mkdir("data", os.ModePerm) os.Mkdir("data", os.ModePerm)
} }
// 初始化數據庫
db, err := sql.Open("sqlite3", "data/sqlite3.db")
if err != nil {
log.Fatal(err)
}
// 一次性創建多個數據表(自增主鍵)
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS images(
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
width INTEGER,
height INTEGER,
prompt TEXT,
negative_prompt TEXT,
num_inference_steps INTEGER,
guidance_scale REAL,
scheduler TEXT,
seed INTEGER,
from_image TEXT,
created_at TEXT,
updated_at TEXT,
user_id INTEGER
);
CREATE TABLE IF NOT EXISTS models(
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
type TEXT,
trigger_words TEXT,
base_model TEXT,
model_path TEXT,
status TEXT,
progress INTEGER,
image TEXT,
tags TEXT,
created_at TEXT,
updated_at TEXT,
user_id INTEGER
);
CREATE TABLE IF NOT EXISTS users(
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
email TEXT,
password TEXT,
slat TEXT,
created_at TEXT,
updated_at TEXT
);
CREATE TABLE IF NOT EXISTS tags(
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
created_at TEXT,
updated_at TEXT
);
CREATE TABLE IF NOT EXISTS tasks(
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
status TEXT,
progress INTEGER,
created_at TEXT,
updated_at TEXT,
user_id INTEGER
);
CREATE TABLE IF NOT EXISTS sessions(
id TEXT PRIMARY KEY,
user_id INTEGER,
created_at TEXT,
updated_at TEXT
);
`)
defer db.Close()
if err != nil {
log.Fatal(err)
}
} }
// GetDB 獲取數據庫連接 //// GetDB 獲取數據庫連接
func GetDB() (*sql.DB, error) { //func GetDB() (*sql.DB, error) {
db, err := sql.Open("sqlite3", "data/sqlite3.db") // db, err := sql.Open("sqlite3", "data/sqlite3.db")
if err != nil { // if err != nil {
return nil, err // return nil, err
} // }
return db, nil // return db, nil
} //}
// ORMDB 使用 GORM // ORMDB 使用 GORM
func ORMDB() (db *gorm.DB) { func ORMDB() (db *gorm.DB) {

View File

@@ -1,12 +1,11 @@
package models package models
import ( import (
"log"
"main/configs" "main/configs"
) )
type Image struct { type Image struct {
ID int `json:"id"` ID int `json:"id" gorm:"primary_key"`
Name string `json:"name"` Name string `json:"name"`
Width int `json:"width"` Width int `json:"width"`
Height int `json:"height"` Height int `json:"height"`
@@ -22,125 +21,6 @@ type Image struct {
UserID int `json:"user_id"` UserID int `json:"user_id"`
} }
func (image *Image) Create() error { func init() {
db, err := configs.GetDB() configs.ORMDB().AutoMigrate(&Image{})
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("INSERT INTO images(name, created_at, updated_at) values(?, ?, ?)")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
result, err := stmt.Exec(image.Name, image.CreatedAt, image.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
image.ID = int(id)
return nil
}
func (image *Image) Delete() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("DELETE FROM images WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(image.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (image *Image) Update() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("UPDATE images SET name = ?, updated_at = ? WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(image.Name, image.UpdatedAt, image.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (image *Image) Get() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
err = db.QueryRow("SELECT id, name, created_at, updated_at FROM images WHERE id = ?", image.ID).Scan(&image.ID, &image.Name, &image.CreatedAt, &image.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
return nil
}
func QueryImages(page int, pagesize int) (images []interface{}) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
rows, err := db.Query("SELECT id, name, created_at, updated_at FROM images LIMIT ?, ?", (page-1)*pagesize, pagesize)
if err != nil {
log.Println(err)
return
}
defer rows.Close()
for rows.Next() {
image := Image{}
err := rows.Scan(&image.ID, &image.Name, &image.CreatedAt, &image.UpdatedAt)
if err != nil {
log.Println(err)
return
}
images = append(images, image)
}
return
}
func CountImages() (count int) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
err = db.QueryRow("SELECT COUNT(*) FROM images").Scan(&count)
if err != nil {
log.Println(err)
return
}
return
} }

View File

@@ -9,7 +9,7 @@ import (
type ListView struct { type ListView struct {
Page int `json:"page"` Page int `json:"page"`
PageSize int `json:"page_size"` PageSize int `json:"page_size"`
Total int `json:"total"` Total int64 `json:"total"`
Next bool `json:"next"` Next bool `json:"next"`
List []interface{} `json:"list"` List []interface{} `json:"list"`
} }

View File

@@ -1,7 +1,6 @@
package models package models
import ( import (
"log"
"main/configs" "main/configs"
) )
@@ -25,25 +24,25 @@ func init() {
configs.ORMDB().AutoMigrate(&Model{}) configs.ORMDB().AutoMigrate(&Model{})
} }
func (model *Model) SendToTrain() error { //func (model *Model) SendToTrain() error {
db, err := configs.GetDB() // db, err := configs.GetDB()
if err != nil { // if err != nil {
log.Println(err) // log.Println(err)
return err // return err
} // }
defer db.Close() // defer db.Close()
stmt, err := db.Prepare("UPDATE models SET status = ?, progress = ?, updated_at = ? WHERE id = ?") // stmt, err := db.Prepare("UPDATE models SET status = ?, progress = ?, updated_at = ? WHERE id = ?")
if err != nil { // if err != nil {
log.Println(err) // log.Println(err)
return err // return err
} // }
defer stmt.Close() // defer stmt.Close()
_, err = stmt.Exec(model.Status, model.Progress, model.UpdatedAt, model.ID) // _, err = stmt.Exec(model.Status, model.Progress, model.UpdatedAt, model.ID)
if err != nil { // if err != nil {
log.Println(err) // log.Println(err)
return err // return err
} // }
// TODO: 創建一個新線程管理訓練任務 // // TODO: 創建一個新線程管理訓練任務
// 將任務放入隊列中, 將自動回調更新任務狀態 // // 將任務放入隊列中, 將自動回調更新任務狀態
return nil // return nil
} //}

View File

@@ -1,12 +1,11 @@
package models package models
import ( import (
"log"
"main/configs" "main/configs"
) )
type Server struct { type Server struct {
ID int `json:"id"` ID int `json:"id" gorm:"primary_key"`
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` // (訓練|推理) Type string `json:"type"` // (訓練|推理)
IP string `json:"ip"` IP string `json:"ip"`
@@ -17,125 +16,6 @@ type Server struct {
UpdatedAt string `json:"updated_at"` UpdatedAt string `json:"updated_at"`
} }
func (server *Server) Create() error { func init() {
db, err := configs.GetDB() configs.ORMDB().AutoMigrate(&Server{})
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("INSERT INTO servers(name, ip, port, username, password, created_at, updated_at) values(?, ?, ?, ?, ?, ?, ?)")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
result, err := stmt.Exec(server.Name, server.IP, server.Port, server.Username, server.Password, server.CreatedAt, server.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
server.ID = int(id)
return nil
}
func (server *Server) Delete() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("DELETE FROM servers WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(server.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (server *Server) Update() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("UPDATE servers SET name = ?, ip = ?, port = ?, username = ?, password = ?, updated_at = ? WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(server.Name, server.IP, server.Port, server.Username, server.Password, server.UpdatedAt, server.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (server *Server) Get() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
err = db.QueryRow("SELECT id, name, ip, port, username, password, created_at, updated_at FROM servers WHERE id = ?", server.ID).Scan(&server.ID, &server.Name, &server.IP, &server.Port, &server.Username, &server.Password, &server.CreatedAt, &server.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
return nil
}
func QueryServers(page int, pagesize int) (servers []interface{}) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
rows, err := db.Query("SELECT id, name, ip, port, username, password, created_at, updated_at FROM servers ORDER BY id DESC LIMIT ?, ?", page*pagesize, pagesize)
if err != nil {
log.Println(err)
return
}
defer rows.Close()
for rows.Next() {
server := Server{}
err := rows.Scan(&server.ID, &server.Name, &server.IP, &server.Port, &server.Username, &server.Password, &server.CreatedAt, &server.UpdatedAt)
if err != nil {
log.Println(err)
continue
}
servers = append(servers, server)
}
return
}
func CountServers() (count int) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
err = db.QueryRow("SELECT COUNT(*) FROM servers").Scan(&count)
if err != nil {
log.Println(err)
return
}
return
} }

View File

@@ -1,197 +1,16 @@
package models package models
import ( import (
"log"
"main/configs" "main/configs"
) )
type Tag struct { type Tag struct {
ID int `json:"id"` ID int `json:"id" gorm:"primary_key"`
Name string `json:"name"` Name string `json:"name"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"` UpdatedAt string `json:"updated_at"`
} }
func (tag *Tag) Create(name string) error { func init() {
db, err := configs.GetDB() configs.ORMDB().AutoMigrate(&Tag{})
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("INSERT INTO tags(name, created_at, updated_at) values(?, ?, ?)")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
result, err := stmt.Exec(name, tag.CreatedAt, tag.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
tag.ID = int(id)
return nil
}
func (tag *Tag) Delete() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("DELETE FROM tags WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(tag.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (tag *Tag) Update(name string) error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("UPDATE tags SET name = ?, updated_at = ? WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(name, tag.UpdatedAt, tag.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (tag *Tag) Get() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
err = db.QueryRow("SELECT * FROM tags WHERE id = ?", tag.ID).Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
return nil
}
func GetTags() ([]Tag, error) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return nil, err
}
defer db.Close()
rows, err := db.Query("SELECT * FROM tags")
if err != nil {
log.Println(err)
return nil, err
}
defer rows.Close()
var tags []Tag
for rows.Next() {
var tag Tag
err := rows.Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt)
if err != nil {
log.Println(err)
return nil, err
}
tags = append(tags, tag)
}
return tags, nil
}
func GetTag(id int) (*Tag, error) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return nil, err
}
defer db.Close()
row := db.QueryRow("SELECT * FROM tags WHERE id = ?", id)
var tag Tag
err = row.Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt)
if err != nil {
log.Println(err)
return nil, err
}
return &tag, nil
}
func GetTagByName(name string) (*Tag, error) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return nil, err
}
defer db.Close()
row := db.QueryRow("SELECT * FROM tags WHERE name = ?", name)
var tag Tag
err = row.Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt)
if err != nil {
log.Println(err)
return nil, err
}
return &tag, nil
}
func QueryTags(page, pagesize int) (list []interface{}) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return nil
}
defer db.Close()
rows, err := db.Query("SELECT * FROM tags LIMIT ?, ?", page, pagesize)
if err != nil {
log.Println(err)
return nil
}
defer rows.Close()
for rows.Next() {
var tag Tag
err := rows.Scan(&tag.ID, &tag.Name, &tag.CreatedAt, &tag.UpdatedAt)
if err != nil {
log.Println(err)
return nil
}
list = append(list, tag)
}
return list
}
func CountTags() (count int) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return 0
}
defer db.Close()
row := db.QueryRow("SELECT COUNT(*) FROM tags")
err = row.Scan(&count)
if err != nil {
log.Println(err)
return 0
}
return count
} }

View File

@@ -1,13 +1,5 @@
package models package models
import (
"log"
"main/configs"
"net/http"
"strconv"
"time"
)
type Task struct { type Task struct {
ID int `json:"id"` ID int `json:"id"`
Name string `json:"name"` Name string `json:"name"`
@@ -19,183 +11,45 @@ type Task struct {
UserID int `json:"user_id"` UserID int `json:"user_id"`
} }
// 推理任務 //// 推理任務
func startInferenceTask(task *Task) { //func startInferenceTask(task *Task) {
//
// 獲取一臺可用的 GPU 資源 // // 獲取一臺可用的 GPU 資源
// ... // // ...
//
// 執行推理任務 // // 執行推理任務
// ... // // ...
//
// 更新任務狀態 // // 更新任務狀態
task.Status = "running" // task.Status = "running"
task.Progress = 0 // task.Progress = 0
task.Update() // task.Update()
//
// 監聽任務狀態 // // 監聽任務狀態
for { // for {
// 延遲 1 秒 // // 延遲 1 秒
time.Sleep(1 * time.Second) // time.Sleep(1 * time.Second)
//
// 查詢任務狀態 // // 查詢任務狀態
resp, err := http.Get("http://localhost:5000/api/v1/tasks/" + strconv.Itoa(task.ID)) // resp, err := http.Get("http://localhost:5000/api/v1/tasks/" + strconv.Itoa(task.ID))
if err != nil { // if err != nil {
log.Println(err) // log.Println(err)
continue // continue
} // }
defer resp.Body.Close() // defer resp.Body.Close()
//
// 解析任務狀態 // // 解析任務狀態
// ... // // ...
//
// 更新任務狀態 // // 更新任務狀態
task.Progress = 100 // task.Progress = 100
task.Status = "success" // task.Status = "success"
task.Update() // task.Update()
//
// 任務結束判定 // // 任務結束判定
if task.Progress == 100 { // if task.Progress == 100 {
break // break
} // }
} // }
//
} //}
func (task *Task) Create() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("INSERT INTO tasks(name, type, created_at, updated_at) values(?, ?, ?, ?)")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
result, err := stmt.Exec(task.Name, task.Type, task.CreatedAt, task.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
task.ID = int(id)
return nil
}
func (task *Task) Delete() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("DELETE FROM tasks WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(task.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (task *Task) Update() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("UPDATE tasks SET name = ?, type = ?, updated_at = ? WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(task.Name, task.Type, task.UpdatedAt, task.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (task *Task) Get() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
err = db.QueryRow("SELECT name, type, created_at, updated_at FROM tasks WHERE id = ?", task.ID).Scan(&task.Name, &task.Type, &task.CreatedAt, &task.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
return nil
}
func QueryTask(id int) (task Task) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
err = db.QueryRow("SELECT id, name, type, created_at, updated_at FROM tasks WHERE id = ?", id).Scan(&task.ID, &task.Name, &task.Type, &task.CreatedAt, &task.UpdatedAt)
if err != nil {
log.Println(err)
return
}
return
}
func QueryTasks(page int, pagesize int) (tasks []interface{}) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
rows, err := db.Query("SELECT id, name, type, created_at, updated_at FROM tasks LIMIT ?, ?", (page-1)*pagesize, pagesize)
if err != nil {
log.Println(err)
return
}
defer rows.Close()
for rows.Next() {
task := Task{}
err := rows.Scan(&task.ID, &task.Name, &task.Type, &task.CreatedAt, &task.UpdatedAt)
if err != nil {
log.Println(err)
return
}
tasks = append(tasks, task)
}
return
}
func CountTasks() (count int) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
err = db.QueryRow("SELECT COUNT(*) FROM tasks").Scan(&count)
if err != nil {
log.Println(err)
return
}
return
}

View File

@@ -3,14 +3,11 @@ package models
import ( import (
"crypto/md5" "crypto/md5"
"fmt" "fmt"
"log"
"main/configs" "main/configs"
"main/utils"
"time"
) )
type User struct { type User struct {
ID int `json:"id"` ID int `json:"id" gorm:"primary_key"`
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
@@ -19,241 +16,11 @@ type User struct {
Slat string `json:"-"` Slat string `json:"-"`
} }
func (user *User) Create(name, email, password string) error { func init() {
configs.ORMDB().AutoMigrate(&User{})
if name == "" || email == "" || password == "" {
return fmt.Errorf("name, email and password can not be empty")
}
user.Slat = utils.RandomString(16)
user.Password = fmt.Sprintf("%x", md5.Sum([]byte(password+user.Slat)))
user.Name = name
user.Email = email
user.CreatedAt = time.Now().Format("2006-01-02 15:04:05")
user.UpdatedAt = user.CreatedAt
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("INSERT INTO users(name, email, password, slat, created_at, updated_at) values(?, ?, ?, ?, ?, ?)")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
result, err := stmt.Exec(user.Name, user.Email, user.Password, user.Slat, user.CreatedAt, user.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
user.ID = int(id)
return nil
}
func (user *User) Delete() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("DELETE FROM users WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(user.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (user *User) Update() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("UPDATE users SET name = ?, email = ?, updated_at = ? WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(user.Name, user.Email, user.UpdatedAt, user.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (user *User) RoadByID(id int) (err error) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE id = ?", user.ID).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (user *User) Get() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE id = ?", user.ID).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (user *User) GetAll() ([]User, error) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return nil, err
}
defer db.Close()
rows, err := db.Query("SELECT id, name, email, created_at, updated_at FROM users")
if err != nil {
log.Println(err)
return nil, err
}
defer rows.Close()
var users []User
for rows.Next() {
var user User
err := rows.Scan(&user.ID, &user.Name, &user.Email, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
log.Println(err)
return nil, err
}
users = append(users, user)
}
return users, nil
} }
// 驗證用戶密碼 // 驗證用戶密碼
func (user *User) CheckPassword(password string) bool { func (user *User) CheckPassword(password string) bool {
return user.Password == fmt.Sprintf("%x", md5.Sum([]byte(password+user.Slat))) return user.Password == fmt.Sprintf("%x", md5.Sum([]byte(password+user.Slat)))
} }
// 使用Email和Password驗證登錄
func (user *User) CheckLogin(email, password string) bool {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return false
}
defer db.Close()
err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE email = ?", email).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
log.Println(err)
return false
}
if user.ID == 0 {
fmt.Println("user not found")
return false
}
if user.Password == "" {
fmt.Println("password is empty")
return false
}
if user.Password == fmt.Sprintf("%x", md5.Sum([]byte(password+user.Slat))) {
return true
}
return false
}
// 獲取用戶實體
func GetUserByEmail(email string) (user User, err error) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
err = db.QueryRow("SELECT id, name, email, password, slat, created_at, updated_at FROM users WHERE email = ?", email).Scan(&user.ID, &user.Name, &user.Email, &user.Password, &user.Slat, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
log.Println(err)
return
}
return
}
func QueryUserByEmail(email string) (user User, err error) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
err = db.QueryRow("SELECT id, name, email, created_at, updated_at FROM users WHERE email = ?", email).Scan(&user.ID, &user.Name, &user.Email, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
log.Println(err)
return
}
return
}
func QueryUsers(page, pagesize int) (list []interface{}) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
rows, err := db.Query("SELECT id, name, email, created_at, updated_at FROM users LIMIT ?, ?", (page-1)*pagesize, pagesize)
if err != nil {
log.Println(err)
return
}
defer rows.Close()
for rows.Next() {
var user User
err := rows.Scan(&user.ID, &user.Name, &user.Email, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
log.Println(err)
return
}
list = append(list, user)
}
return
}
func CountUsers() (count int) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
if err != nil {
log.Println(err)
return
}
return
}

View File

@@ -1,6 +1,7 @@
package models package models
import ( import (
"main/configs"
"net/http" "net/http"
) )
@@ -24,7 +25,7 @@ func AccountRead(w http.ResponseWriter, r *http.Request, cb func(account *Accoun
// 獲取當前session // 獲取當前session
session := Session{ID: cookie.Value} session := Session{ID: cookie.Value}
if err := session.Get(); err != nil { if err := configs.ORMDB().Take(&session).Error; err != nil {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("401 - 會話已過期")) w.Write([]byte("401 - 會話已過期"))
return return
@@ -32,7 +33,7 @@ func AccountRead(w http.ResponseWriter, r *http.Request, cb func(account *Accoun
// 獲取當前用戶 // 獲取當前用戶
user := User{ID: session.UserID} user := User{ID: session.UserID}
if err := user.Get(); err != nil { if err := configs.ORMDB().Model(&user).Select("id, name, email, created_at, updated_at").Find(&user).Error; err != nil {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("401 - 用戶不存在")) w.Write([]byte("401 - 用戶不存在"))
return return

View File

@@ -1,151 +1,16 @@
package models package models
import ( import (
"log"
"main/configs" "main/configs"
"time"
) )
type Session struct { type Session struct {
ID string `json:"id"` ID string `json:"id" gorm:"primary_key"`
UserID int `json:"user_id"` UserID int `json:"user_id"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"` UpdatedAt string `json:"updated_at"`
} }
func (session *Session) Get() (err error) { func init() {
db, err := configs.GetDB() configs.ORMDB().AutoMigrate(&Session{})
if err != nil {
log.Println(err)
return
}
defer db.Close()
row := db.QueryRow("SELECT * FROM sessions WHERE id = ?", session.ID)
err = row.Scan(&session.ID, &session.UserID, &session.CreatedAt, &session.UpdatedAt)
if err != nil {
log.Println(err)
return
}
return
}
func (session *Session) Create() error {
session.CreatedAt = time.Now().Format("2006-01-02 15:04:05")
session.UpdatedAt = session.CreatedAt
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("INSERT INTO sessions (id, user_id, created_at, updated_at) VALUES (?, ?, ?, ?)")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(session.ID, session.UserID, session.CreatedAt, session.UpdatedAt)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (session *Session) Delete() error {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("DELETE FROM sessions WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(session.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func (session *Session) Update() error {
session.UpdatedAt = time.Now().Format("2006-01-02 15:04:05")
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return err
}
defer db.Close()
stmt, err := db.Prepare("UPDATE sessions SET user_id = ?, updated_at = ? WHERE id = ?")
if err != nil {
log.Println(err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(session.UpdatedAt, session.UserID, session.ID)
if err != nil {
log.Println(err)
return err
}
return nil
}
func GetSession(id int) (*Session, error) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return nil, err
}
defer db.Close()
row := db.QueryRow("SELECT id, user_id, created_at, updated_at FROM sessions WHERE id = ?", id)
var session Session
if err := row.Scan(&session.ID, &session.UserID, &session.CreatedAt, &session.UpdatedAt); err != nil {
log.Println(err)
return nil, err
}
return &session, nil
}
func QuerySessions(page, pagesize int) (list []interface{}) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
rows, err := db.Query("SELECT id, user_id, created_at, updated_at FROM sessions LIMIT ?, ?", (page-1)*pagesize, pagesize)
if err != nil {
log.Println(err)
return
}
defer rows.Close()
for rows.Next() {
var session Session
if err := rows.Scan(&session.ID, &session.UserID, &session.CreatedAt, &session.UpdatedAt); err != nil {
log.Println(err)
return
}
list = append(list, session)
}
return
}
func CountSessions() (count int) {
db, err := configs.GetDB()
if err != nil {
log.Println(err)
return
}
defer db.Close()
row := db.QueryRow("SELECT COUNT(*) FROM sessions")
if err := row.Scan(&count); err != nil {
log.Println(err)
return
}
return
} }

View File

@@ -2,6 +2,7 @@ package routers
import ( import (
"fmt" "fmt"
"main/configs"
"main/models" "main/models"
"main/utils" "main/utils"
"net/http" "net/http"
@@ -22,16 +23,23 @@ func AccountGet(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id") cookie, err := r.Cookie("session_id")
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("401 - 未登錄"))
return return
} }
// 獲取會話 // 獲取會話
session := models.Session{ID: cookie.Value} session := models.Session{ID: cookie.Value}
session.Get() if err := configs.ORMDB().Take(&session).Error; err != nil {
fmt.Println(err)
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("401 - 會話已過期"))
return
}
// 獲取用戶 // 獲取用戶
user := models.User{ID: session.UserID} user := models.User{ID: session.UserID}
user.Get() configs.ORMDB().Model(&user).Select("id, name, email, created_at, updated_at").Find(&user)
account.ID = user.ID account.ID = user.ID
account.Name = user.Name account.Name = user.Name
@@ -43,25 +51,3 @@ func AccountGet(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(account)) w.Write(utils.ToJSON(account))
} }
// 獲取當前賬戶, 並將其傳入回調函數
func get_account(w http.ResponseWriter, r *http.Request, callback func(*models.User)) (err error) {
// 獲取Cookie
cookie, err := r.Cookie("session_id")
if err != nil {
fmt.Println(err)
return nil
}
// 獲取會話
session := models.Session{ID: cookie.Value}
session.Get()
// 獲取用戶
user := models.User{ID: session.UserID}
user.Get()
callback(&user)
return nil
}

View File

@@ -4,9 +4,11 @@ import (
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"log" "log"
"main/configs"
"main/models" "main/models"
"main/utils" "main/utils"
"net/http" "net/http"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
@@ -15,9 +17,17 @@ func ImagesGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
listview.List = models.QueryImages(listview.Page, listview.PageSize)
listview.Total = models.CountImages() var image_list []models.Image
listview.Next = listview.Page*listview.PageSize < listview.Total db := configs.ORMDB()
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&image_list)
for _, image := range image_list {
listview.List = append(listview.List, image)
}
db.Model(&models.Image{}).Count(&listview.Total)
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w) listview.WriteJSON(w)
} }
@@ -33,14 +43,20 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
log.Println(err) log.Println(err)
return return
} }
image.Create() if err := configs.ORMDB().Create(&image).Error; err != nil {
log.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(image)) w.Write(utils.ToJSON(image))
} }
func ImagesItemGet(w http.ResponseWriter, r *http.Request) { func ImagesItemGet(w http.ResponseWriter, r *http.Request) {
image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
image.Get() if err := configs.ORMDB().First(&image).Error; err != nil {
log.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(image)) w.Write(utils.ToJSON(image))
} }
@@ -58,14 +74,24 @@ func ImagesItemPatch(w http.ResponseWriter, r *http.Request) {
return return
} }
image.ID = utils.ParamInt(mux.Vars(r)["id"], 0) image.ID = utils.ParamInt(mux.Vars(r)["id"], 0)
image.Update() image.UpdatedAt = time.Now().Format("2006-01-02 15:04:05")
if err := configs.ORMDB().Model(&image).Updates(image).Error; err != nil {
log.Println(err)
return
}
//image.ID = utils.ParamInt(mux.Vars(r)["id"], 0)
//image.Update()
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(image)) w.Write(utils.ToJSON(image))
} }
func ImagesItemDelete(w http.ResponseWriter, r *http.Request) { func ImagesItemDelete(w http.ResponseWriter, r *http.Request) {
image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
image.Delete() if err := configs.ORMDB().Delete(&image).Error; err != nil {
log.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(image)) w.Write(utils.ToJSON(image))
} }

View File

@@ -17,28 +17,24 @@ import (
var manager = models.NewWebSocketManager() var manager = models.NewWebSocketManager()
// 獲取模型列表
func ModelsGet(w http.ResponseWriter, r *http.Request) { func ModelsGet(w http.ResponseWriter, r *http.Request) {
// 初始化基本參數
var listview models.ListView var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
// 獲取模型列表
var model_list []models.Model var model_list []models.Model
db := configs.ORMDB() db := configs.ORMDB()
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list) db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list)
for _, model := range model_list { for _, model := range model_list {
listview.List = append(listview.List, model) listview.List = append(listview.List, model)
} }
// 獲取總數 db.Model(&models.Model{}).Count(&listview.Total)
var total int64 listview.Next = listview.Page*listview.PageSize < int(listview.Total)
db.Model(&models.Model{}).Count(&total)
listview.Total = int(total)
listview.Next = listview.Page*listview.PageSize < listview.Total
listview.WriteJSON(w) listview.WriteJSON(w)
} }
// 創建模型
func ModelsPost(w http.ResponseWriter, r *http.Request) { func ModelsPost(w http.ResponseWriter, r *http.Request) {
// 取得用戶
models.AccountRead(w, r, func(account *models.Account) { models.AccountRead(w, r, func(account *models.Account) {
fmt.Println(account) fmt.Println(account)
// TODO: 判斷權限(是否可以創建) // TODO: 判斷權限(是否可以創建)
@@ -138,7 +134,7 @@ func ModelItemPatch(w http.ResponseWriter, r *http.Request) {
model.Status = model_new.Status model.Status = model_new.Status
// 如果狀態被改變爲 ready, 將模型發送到訓練隊列 // 如果狀態被改變爲 ready, 將模型發送到訓練隊列
if model.Status == "ready" { if model.Status == "ready" {
model.SendToTrain() //model.SendToTrain()
} }
} }
if model_new.Image != "" && model_new.Image != model.Image { if model_new.Image != "" && model_new.Image != model.Image {

View File

@@ -1,6 +1,7 @@
package routers package routers
import ( import (
"main/configs"
"main/models" "main/models"
"main/utils" "main/utils"
"net/http" "net/http"
@@ -12,38 +13,42 @@ func ServersGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
listview.List = models.QueryServers(listview.Page, listview.PageSize) var server_list []models.Server
listview.Total = models.CountServers() db := configs.ORMDB()
listview.Next = listview.Page*listview.PageSize < listview.Total db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&server_list)
for _, server := range server_list {
listview.List = append(listview.List, server)
}
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w) listview.WriteJSON(w)
} }
func ServersPost(w http.ResponseWriter, r *http.Request) { func ServersPost(w http.ResponseWriter, r *http.Request) {
var server models.Server var server models.Server
server.Create() configs.ORMDB().Create(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server)) w.Write(utils.ToJSON(server))
} }
func ServersItemGet(w http.ResponseWriter, r *http.Request) { func ServersItemGet(w http.ResponseWriter, r *http.Request) {
server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
server.Get() configs.ORMDB().First(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server)) w.Write(utils.ToJSON(server))
} }
func ServersItemPatch(w http.ResponseWriter, r *http.Request) { func ServersItemPatch(w http.ResponseWriter, r *http.Request) {
server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
server.Get() configs.ORMDB().First(&server)
server.Update() // TODO: update server
configs.ORMDB().Save(&server)
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server)) w.Write(utils.ToJSON(server))
} }
func ServersItemDelete(w http.ResponseWriter, r *http.Request) { func ServersItemDelete(w http.ResponseWriter, r *http.Request) {
server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} server := models.Server{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
server.Get() configs.ORMDB().Delete(&server)
server.Delete()
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(server)) w.Write(utils.ToJSON(server))
} }

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"main/configs"
"main/models" "main/models"
"main/utils" "main/utils"
"net/http" "net/http"
@@ -17,9 +18,14 @@ func SessionsGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
listview.List = models.QuerySessions(listview.Page, listview.PageSize) var session_list []models.Session
listview.Total = models.CountSessions() db := configs.ORMDB()
listview.Next = listview.Page*listview.PageSize < listview.Total db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&session_list)
for _, session := range session_list {
listview.List = append(listview.List, session)
}
db.Model(&models.Session{}).Count(&listview.Total)
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w) listview.WriteJSON(w)
} }
@@ -55,8 +61,8 @@ func SessionsPost(w http.ResponseWriter, r *http.Request) {
} }
// 使用Email獲取用戶 // 使用Email獲取用戶
user, err := models.GetUserByEmail(form.Email) var user models.User
if err != nil { if err := configs.ORMDB().Where("email = ?", form.Email).First(&user).Error; err != nil {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("404 - User Not Found")) w.Write([]byte("404 - User Not Found"))
return return
@@ -71,7 +77,11 @@ func SessionsPost(w http.ResponseWriter, r *http.Request) {
// 創建會話(生成一個不重複的 uuid 作爲 sid) // 創建會話(生成一個不重複的 uuid 作爲 sid)
session := &models.Session{ID: uuid.New().String(), UserID: user.ID} session := &models.Session{ID: uuid.New().String(), UserID: user.ID}
session.Create() if err := configs.ORMDB().Create(session).Error; err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("500 - Internal Server Error"))
return
}
// 寫入Cookie // 寫入Cookie
cookie := http.Cookie{Name: "session_id", Value: session.ID, Path: "/", HttpOnly: true} cookie := http.Cookie{Name: "session_id", Value: session.ID, Path: "/", HttpOnly: true}
@@ -85,7 +95,11 @@ func SessionsPost(w http.ResponseWriter, r *http.Request) {
// 獲取會話 // 獲取會話
func SessionsItemGet(w http.ResponseWriter, r *http.Request) { func SessionsItemGet(w http.ResponseWriter, r *http.Request) {
session := models.Session{ID: mux.Vars(r)["session_id"]} session := models.Session{ID: mux.Vars(r)["session_id"]}
session.Get() if err := configs.ORMDB().Find(&session).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 - Not Found"))
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(session)) w.Write(utils.ToJSON(session))
} }
@@ -93,8 +107,11 @@ func SessionsItemGet(w http.ResponseWriter, r *http.Request) {
// 更新會話 // 更新會話
func SessionsItemPatch(w http.ResponseWriter, r *http.Request) { func SessionsItemPatch(w http.ResponseWriter, r *http.Request) {
session := models.Session{ID: mux.Vars(r)["session_id"]} session := models.Session{ID: mux.Vars(r)["session_id"]}
session.Get() if err := configs.ORMDB().Model(&session).Updates(GetForm(r)); err != nil {
session.Update() w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 - Not Found"))
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(session)) w.Write(utils.ToJSON(session))
} }
@@ -113,14 +130,22 @@ func SessionsItemDelete(w http.ResponseWriter, r *http.Request) {
// 獲取當前session // 獲取當前session
session := models.Session{ID: cookie.Value} session := models.Session{ID: cookie.Value}
session.Get() if err := configs.ORMDB().Find(&session).Error; err != nil {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("401 - 會話已過期"))
return
}
// 獲取當前用戶 // 獲取當前用戶
user := models.User{ID: session.UserID} user := models.User{ID: session.UserID}
user.Get() configs.ORMDB().Find(&user)
sessionx := models.Session{ID: mux.Vars(r)["session_id"]} sessionx := models.Session{ID: mux.Vars(r)["session_id"]}
sessionx.Get() if err := configs.ORMDB().Find(&sessionx).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 - Not Found"))
return
}
if user.ID != sessionx.UserID { if user.ID != sessionx.UserID {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
@@ -128,7 +153,11 @@ func SessionsItemDelete(w http.ResponseWriter, r *http.Request) {
return return
} }
sessionx.Delete() if err := configs.ORMDB().Delete(&sessionx).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 - Not Found"))
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(sessionx)) w.Write(utils.ToJSON(sessionx))
} }

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"main/configs"
"main/models" "main/models"
"main/utils" "main/utils"
"net/http" "net/http"
@@ -16,9 +17,14 @@ func TagsGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
listview.List = models.QueryTags(listview.Page, listview.PageSize) var tag_list []models.Tag
listview.Total = models.CountTags() db := configs.ORMDB()
listview.Next = listview.Page*listview.PageSize < listview.Total db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&tag_list)
for _, tag := range tag_list {
listview.List = append(listview.List, tag)
}
db.Model(&models.Tag{}).Count(&listview.Total)
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w) listview.WriteJSON(w)
} }
@@ -38,8 +44,8 @@ func TagsPost(w http.ResponseWriter, r *http.Request) {
return return
} }
// 創建標籤 // 創建標籤
var tag models.Tag var tag models.Tag = models.Tag{Name: form.Name}
if err := tag.Create(form.Name); err != nil { if err := configs.ORMDB().Create(&tag).Error; err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
@@ -50,13 +56,8 @@ func TagsPost(w http.ResponseWriter, r *http.Request) {
// 獲取標籤 // 獲取標籤
func TagsItemGet(w http.ResponseWriter, r *http.Request) { func TagsItemGet(w http.ResponseWriter, r *http.Request) {
var tag models.Tag var tag models.Tag = models.Tag{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
tag.ID = utils.ParamInt(mux.Vars(r)["id"], 0) if err := configs.ORMDB().First(&tag).Error; err != nil {
if tag.ID == 0 {
w.WriteHeader(http.StatusNotFound)
return
}
if err := tag.Get(); err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
@@ -66,16 +67,6 @@ func TagsItemGet(w http.ResponseWriter, r *http.Request) {
// 更新標籤 // 更新標籤
func TagsItemPatch(w http.ResponseWriter, r *http.Request) { func TagsItemPatch(w http.ResponseWriter, r *http.Request) {
var tag models.Tag
tag.ID = utils.ParamInt(mux.Vars(r)["id"], 0)
if tag.ID == 0 {
w.WriteHeader(http.StatusNotFound)
return
}
if err := tag.Get(); err != nil {
fmt.Println(err)
return
}
var form struct { var form struct {
Name string `json:"name"` Name string `json:"name"`
} }
@@ -89,7 +80,8 @@ func TagsItemPatch(w http.ResponseWriter, r *http.Request) {
fmt.Println(err) fmt.Println(err)
return return
} }
if err := tag.Update(form.Name); err != nil { var tag models.Tag = models.Tag{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Model(&tag).Update("name", form.Name).Error; err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
@@ -99,13 +91,8 @@ func TagsItemPatch(w http.ResponseWriter, r *http.Request) {
// 刪除標籤 // 刪除標籤
func TagsItemDelete(w http.ResponseWriter, r *http.Request) { func TagsItemDelete(w http.ResponseWriter, r *http.Request) {
var tag models.Tag var tag models.Tag = models.Tag{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
tag.ID = utils.ParamInt(mux.Vars(r)["id"], 0) if err := configs.ORMDB().Delete(&tag).Error; err != nil {
if tag.ID == 0 {
w.WriteHeader(http.StatusNotFound)
return
}
if err := tag.Delete(); err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"log" "log"
"main/configs"
"main/models" "main/models"
"main/utils" "main/utils"
"net/http" "net/http"
@@ -17,9 +18,14 @@ func TasksGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
listview.List = models.QueryTasks(listview.Page, listview.PageSize) var task_list []models.Task
listview.Total = models.CountTasks() db := configs.ORMDB()
listview.Next = listview.Page*listview.PageSize < listview.Total db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&task_list)
for _, task := range task_list {
listview.List = append(listview.List, task)
}
db.Model(&models.Task{}).Count(&listview.Total)
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w) listview.WriteJSON(w)
} }
@@ -35,7 +41,7 @@ func TasksPost(w http.ResponseWriter, r *http.Request) {
log.Println(err) log.Println(err)
return return
} }
task.Create() configs.ORMDB().Create(&task)
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(task)) w.Write(utils.ToJSON(task))
} }
@@ -44,8 +50,10 @@ func TasksItemGet(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" { if r.Header.Get("Upgrade") == "websocket" {
vars := mux.Vars(r) vars := mux.Vars(r)
id, _ := strconv.Atoi(vars["id"]) id, _ := strconv.Atoi(vars["id"])
task := models.QueryTask(id)
if task.ID == 0 { var task models.Task = models.Task{ID: id}
if err := configs.ORMDB().First(&task, id).Error; err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
return return
} }
@@ -63,12 +71,12 @@ func TasksItemGet(w http.ResponseWriter, r *http.Request) {
break break
} }
task.Status = string(message) task.Status = string(message)
task.Update() configs.ORMDB().Model(&task).Update("status", task.Status)
} }
return return
} }
task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
task.Get() configs.ORMDB().First(&task)
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(task)) w.Write(utils.ToJSON(task))
} }
@@ -86,18 +94,13 @@ func TasksItemPatch(w http.ResponseWriter, r *http.Request) {
return return
} }
task.ID = utils.ParamInt(mux.Vars(r)["id"], 0) task.ID = utils.ParamInt(mux.Vars(r)["id"], 0)
task.Update() configs.ORMDB().Model(&task).Updates(task)
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(task)) w.Write(utils.ToJSON(task))
} }
func TasksItemDelete(w http.ResponseWriter, r *http.Request) { func TasksItemDelete(w http.ResponseWriter, r *http.Request) {
task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
task.Delete() configs.ORMDB().Delete(&task)
if task.ID == 0 {
w.WriteHeader(http.StatusNotFound)
return
}
task.Delete()
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }

View File

@@ -1,13 +1,17 @@
package routers package routers
import ( import (
"crypto/md5"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"main/configs"
"main/models" "main/models"
"main/utils" "main/utils"
"net/http" "net/http"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
@@ -16,9 +20,14 @@ func UsersGet(w http.ResponseWriter, r *http.Request) {
var listview models.ListView var listview models.ListView
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
listview.List = models.QueryUsers(listview.Page, listview.PageSize) var user_list []models.User
listview.Total = models.CountUsers() db := configs.ORMDB()
listview.Next = listview.Page*listview.PageSize < listview.Total db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&user_list)
for _, user := range user_list {
listview.List = append(listview.List, user)
}
db.Model(&models.User{}).Count(&listview.Total)
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
listview.WriteJSON(w) listview.WriteJSON(w)
} }
@@ -40,9 +49,22 @@ func UsersPost(w http.ResponseWriter, r *http.Request) {
return return
} }
if form.Name == "" || form.Email == "" || form.Password == "" {
fmt.Println("name, email, password cannot be empty")
return
}
// 創建用戶 // 創建用戶
var user models.User var slat string = uuid.New().String()
if err := user.Create(form.Name, form.Email, form.Password); err != nil { var user models.User = models.User{
Name: form.Name,
Email: form.Email,
Password: fmt.Sprintf("%x", md5.Sum([]byte(form.Password+slat))),
Slat: slat,
CreatedAt: time.Now().Format("2006-01-02 15:04:05"),
UpdatedAt: time.Now().Format("2006-01-02 15:04:05"),
}
if err := configs.ORMDB().Create(&user).Error; err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
@@ -55,7 +77,7 @@ func UsersPost(w http.ResponseWriter, r *http.Request) {
// 獲取用戶 // 獲取用戶
func UsersItemGet(w http.ResponseWriter, r *http.Request) { func UsersItemGet(w http.ResponseWriter, r *http.Request) {
user := models.User{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} user := models.User{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
user.Get() configs.ORMDB().First(&user)
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(user)) w.Write(utils.ToJSON(user))
} }
@@ -63,8 +85,20 @@ func UsersItemGet(w http.ResponseWriter, r *http.Request) {
// 更新用戶 // 更新用戶
func UsersItemPatch(w http.ResponseWriter, r *http.Request) { func UsersItemPatch(w http.ResponseWriter, r *http.Request) {
user := models.User{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} user := models.User{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
user.Get() body, err := ioutil.ReadAll(r.Body)
user.Update() if err != nil {
fmt.Println(err)
return
}
defer r.Body.Close()
if err = json.Unmarshal(body, &user); err != nil {
fmt.Println(err)
return
}
if err := configs.ORMDB().Save(&user).Error; err != nil {
fmt.Println(err)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(user)) w.Write(utils.ToJSON(user))
} }
@@ -72,8 +106,7 @@ func UsersItemPatch(w http.ResponseWriter, r *http.Request) {
// 刪除用戶 // 刪除用戶
func UsersItemDelete(w http.ResponseWriter, r *http.Request) { func UsersItemDelete(w http.ResponseWriter, r *http.Request) {
user := models.User{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} user := models.User{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
user.Get() configs.ORMDB().Delete(&user)
user.Delete()
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(utils.ToJSON(user)) w.Write(utils.ToJSON(user))
} }