From 95107e7bdcabf0f730900020f11c04fa2d338d57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A1=9C=E8=8F=AF?= Date: Fri, 28 Apr 2023 16:53:23 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B7=AF=E7=94=B1=E5=88=86=E6=8B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 6 +- go.mod | 1 + go.sum | 2 + main.go | 363 ++------------------------------------ models/WebSocketMnager.go | 61 +++++++ routers/images.go | 72 ++++++++ routers/models.go | 105 +++++++++++ routers/params.go | 15 ++ routers/tasks.go | 104 +++++++++++ utils/params.go | 73 ++++++++ 10 files changed, 457 insertions(+), 345 deletions(-) create mode 100644 models/WebSocketMnager.go create mode 100644 routers/images.go create mode 100644 routers/models.go create mode 100644 routers/params.go create mode 100644 routers/tasks.go create mode 100644 utils/params.go diff --git a/README.md b/README.md index 5d06145..6ef660e 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,11 @@ ai 繪圖服務端(快速重構) - [ ] /api/tags [#標籤詳情](#標籤列表) - [ ] /api/users [#用戶詳情](#用戶列表) - [ ] /api/tasks [#任務詳情](#任務列表) -- [ ] /api/models [#模型列表](#模型列表) +- [x] /api/models [#模型列表](#模型列表) + - [x] GET /api/models/{id} + - [x] PATCH /api/models/{id} + - [x] DELETE /api/models/{id} + - [ ] WebSocket /api/models/{id} - [ ] /api/images [#圖片列表](#圖片列表) - [ ] /api/params [#參數列表](#參數列表) diff --git a/go.mod b/go.mod index 7ab8a13..0e5c960 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.18 require ( github.com/go-sql-driver/mysql v1.7.1 + github.com/google/uuid v1.3.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 7cb9a38..f7da4ca 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= diff --git a/main.go b/main.go index 746714a..a24ca23 100644 --- a/main.go +++ b/main.go @@ -1,20 +1,16 @@ package main import ( - "encoding/json" - "fmt" - "io/ioutil" "log" "net/http" "runtime" - "strconv" "text/template" "time" - "main/models" + "main/routers" + "main/utils" "github.com/gorilla/mux" - "github.com/gorilla/websocket" ) func main() { @@ -25,7 +21,7 @@ func main() { // 設定中間件 r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer LogComponent(time.Now().UnixNano(), r) // 最后打印日志 + defer utils.LogComponent(time.Now().UnixNano(), r) // 最后打印日志 w.Header().Set("Access-Control-Allow-Origin", "*") // 處理跨域請求 w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") @@ -42,347 +38,26 @@ func main() { t, _ := template.ParseFiles("templates/index.html") t.Execute(w, nil) }) - r.HandleFunc("/api/models", models_get).Methods("GET") - r.HandleFunc("/api/models", models_post).Methods("POST") - r.HandleFunc("/api/models/{id}", models_item_get).Methods("GET") - r.HandleFunc("/api/models/{id}", models_item_patch).Methods("PATCH") - r.HandleFunc("/api/models/{id}", models_item_delete).Methods("DELETE") + r.HandleFunc("/api/models", routers.ModelsGet).Methods("GET") + r.HandleFunc("/api/models", routers.ModelsPost).Methods("POST") + r.HandleFunc("/api/models/{id}", routers.ModelItemGet).Methods("GET") + r.HandleFunc("/api/models/{id}", routers.ModelItemPatch).Methods("PATCH") + r.HandleFunc("/api/models/{id}", routers.ModelItemDelete).Methods("DELETE") - r.HandleFunc("/api/images", images_get).Methods("GET") - r.HandleFunc("/api/images", images_post).Methods("POST") - r.HandleFunc("/api/images/{id}", images_item_get).Methods("GET") - r.HandleFunc("/api/images/{id}", images_item_patch).Methods("PATCH") - r.HandleFunc("/api/images/{id}", images_item_delete).Methods("DELETE") + r.HandleFunc("/api/images", routers.ImagesGet).Methods("GET") + r.HandleFunc("/api/images", routers.ImagesPost).Methods("POST") + r.HandleFunc("/api/images/{id}", routers.ImagesItemGet).Methods("GET") + r.HandleFunc("/api/images/{id}", routers.ImagesItemPatch).Methods("PATCH") + r.HandleFunc("/api/images/{id}", routers.ImagesItemDelete).Methods("DELETE") - r.HandleFunc("/api/tasks", tasks_get).Methods("GET") - r.HandleFunc("/api/tasks", tasks_post).Methods("POST") - r.HandleFunc("/api/tasks/{id}", tasks_item_get).Methods("GET") - r.HandleFunc("/api/tasks/{id}", tasks_item_patch).Methods("PATCH") - r.HandleFunc("/api/tasks/{id}", tasks_item_delete).Methods("DELETE") + r.HandleFunc("/api/tasks", routers.TasksGet).Methods("GET") + r.HandleFunc("/api/tasks", routers.TasksPost).Methods("POST") + r.HandleFunc("/api/tasks/{id}", routers.TasksItemGet).Methods("GET") + r.HandleFunc("/api/tasks/{id}", routers.TasksItemPatch).Methods("PATCH") + r.HandleFunc("/api/tasks/{id}", routers.TasksItemDelete).Methods("DELETE") - r.HandleFunc("/api/params/model", models_params_get).Methods("GET") + r.HandleFunc("/api/params/model", routers.ParamsModelsGet).Methods("GET") log.Println("Web Server is running on http://localhost:8080") http.ListenAndServe(":8080", r) } - -func models_params_get(w http.ResponseWriter, r *http.Request) { - params := make(map[string]interface{}) - params["type"] = []string{"lora", "ckp", "hyper", "ti"} - params["status"] = []string{"pending", "running", "finished", "failed"} - params["base_model"] = []string{"SD1.5", "SD2"} - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(params)) -} - -func models_get(w http.ResponseWriter, r *http.Request) { - var listview models.ListView - listview.Page = ParamInt(r.URL.Query().Get("page"), 1) - listview.PageSize = ParamInt(r.URL.Query().Get("pageSize"), 10) - listview.List = models.QueryModels(listview.Page, listview.PageSize) - listview.Total = models.CountModels() - listview.Next = listview.Page*listview.PageSize < listview.Total - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(listview.ToJSON()) -} - -func models_post(w http.ResponseWriter, r *http.Request) { - var model models.Model - body, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println(err) - return - } - defer r.Body.Close() - if err = json.Unmarshal(body, &model); err != nil { - log.Println(err) - return - } - model.Create() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(model)) -} - -func models_item_get_ws(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - id, _ := strconv.Atoi(vars["id"]) - model := models.QueryModel(id) - if model.ID == 0 { - w.WriteHeader(http.StatusNotFound) - return - } - upgrader := websocket.Upgrader{} - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Println(err) - return - } - defer conn.Close() - for { - time.Sleep(1 * time.Second) - model = models.QueryModel(id) - if model.ID == 0 { - w.WriteHeader(http.StatusNotFound) - return - } - if model.Status == "success" || model.Status == "error" { - break - } - err = conn.WriteJSON(model) - if err != nil { - log.Println(err) - return - } - } -} - -func models_item_get(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Upgrade") == "websocket" { - models_item_get_ws(w, r) - return - } - - model := models.Model{ID: ParamInt(mux.Vars(r)["id"], 0)} - model.Get() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(model)) -} - -func models_item_patch(w http.ResponseWriter, r *http.Request) { - model := models.Model{} - body, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println(err) - return - } - defer r.Body.Close() - if err = json.Unmarshal(body, &model); err != nil { - log.Println(err) - return - } - model.ID = ParamInt(mux.Vars(r)["id"], 0) - model.Update() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(model)) -} - -func models_item_delete(w http.ResponseWriter, r *http.Request) { - model := models.Model{ID: ParamInt(mux.Vars(r)["id"], 0)} - model.Delete() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(model)) -} - -func images_get(w http.ResponseWriter, r *http.Request) { - var listview models.ListView - listview.Page = ParamInt(r.URL.Query().Get("page"), 1) - listview.PageSize = ParamInt(r.URL.Query().Get("pageSize"), 10) - listview.List = models.QueryImages(listview.Page, listview.PageSize) - listview.Total = models.CountImages() - listview.Next = listview.Page*listview.PageSize < listview.Total - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(listview.ToJSON()) -} - -func images_post(w http.ResponseWriter, r *http.Request) { - var image models.Image - body, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println(err) - return - } - defer r.Body.Close() - if err = json.Unmarshal(body, &image); err != nil { - log.Println(err) - return - } - image.Create() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(image)) -} - -func images_item_get(w http.ResponseWriter, r *http.Request) { - image := models.Image{ID: ParamInt(mux.Vars(r)["id"], 0)} - image.Get() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(image)) -} - -func images_item_patch(w http.ResponseWriter, r *http.Request) { - image := models.Image{} - body, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println(err) - return - } - defer r.Body.Close() - if err = json.Unmarshal(body, &image); err != nil { - log.Println(err) - return - } - image.ID = ParamInt(mux.Vars(r)["id"], 0) - image.Update() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(image)) -} - -func images_item_delete(w http.ResponseWriter, r *http.Request) { - image := models.Image{ID: ParamInt(mux.Vars(r)["id"], 0)} - image.Delete() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(image)) -} - -func tasks_get(w http.ResponseWriter, r *http.Request) { - var listview models.ListView - listview.Page = ParamInt(r.URL.Query().Get("page"), 1) - listview.PageSize = ParamInt(r.URL.Query().Get("pageSize"), 10) - listview.List = models.QueryTasks(listview.Page, listview.PageSize) - listview.Total = models.CountTasks() - listview.Next = listview.Page*listview.PageSize < listview.Total - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(listview)) -} - -func tasks_post(w http.ResponseWriter, r *http.Request) { - var task models.Task - body, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println(err) - return - } - defer r.Body.Close() - if err = json.Unmarshal(body, &task); err != nil { - log.Println(err) - return - } - task.Create() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(task)) -} - -func tasks_item_get_ws(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - id, _ := strconv.Atoi(vars["id"]) - task := models.QueryTask(id) - if task.ID == 0 { - w.WriteHeader(http.StatusNotFound) - return - } - upgrader := websocket.Upgrader{} - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Println(err) - return - } - defer ws.Close() - for { - _, message, err := ws.ReadMessage() - if err != nil { - log.Println(err) - break - } - task.Status = string(message) - task.Update() - } -} - -func tasks_item_get(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Upgrade") == "websocket" { - tasks_item_get_ws(w, r) - return - } - - task := models.Task{ID: ParamInt(mux.Vars(r)["id"], 0)} - task.Get() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(task)) -} - -func tasks_item_patch(w http.ResponseWriter, r *http.Request) { - task := models.Task{} - body, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println(err) - return - } - defer r.Body.Close() - if err = json.Unmarshal(body, &task); err != nil { - log.Println(err) - return - } - task.ID = ParamInt(mux.Vars(r)["id"], 0) - task.Update() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(task)) -} - -func tasks_item_delete(w http.ResponseWriter, r *http.Request) { - task := models.Task{ID: ParamInt(mux.Vars(r)["id"], 0)} - task.Delete() - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(ToJSON(task)) -} - -func ToJSON(object interface{}) []byte { - json, err := json.MarshalIndent(object, "", " ") - if err != nil { - log.Println(err) - return []byte{} - } - return json -} - -func LogComponent(startTime int64, r *http.Request) { - ms := (time.Now().UnixNano() - startTime) / 1000000 - color := "\033[1;32m%d\033[0m" - if ms > 800 { - color = "\033[1;31m%dms\033[0m" // 紅色加重 - } else if ms > 500 { - color = "\033[1;33m%dms\033[0m" // 黃色加重 - } else if ms > 300 { - color = "\033[1;32m%dms\033[0m" // 綠色加重 - } else if ms > 200 { - color = "\033[1;34m%dms\033[0m" // 藍色加重 - } else if ms > 100 { - color = "\033[1;35m%dms\033[0m" // 紫色加重 - } else { - color = "\033[1;36m%dms\033[0m" // 黑色加重 - } - endTime := fmt.Sprintf(color, ms) - method := fmt.Sprintf("\033[1;32m%s\033[0m", r.Method) // 綠色加重 - url := fmt.Sprintf("\033[1;34m%s\033[0m", r.URL) // 藍色加重 - log.Println(method, url, endTime) -} - -// 獲取查詢參數(int 類型) -func ParamInt(value string, defaultValue int) int { - if value == "" { - return defaultValue - } - result, err := strconv.Atoi(value) - if err != nil { - return defaultValue - } - return result -} - -// 獲取查詢參數(string 類型) -func ParamString(value string, defaultValue string) string { - if value == "" { - return defaultValue - } - return value -} - -// 獲取查詢參數(bool 類型) -func ParamBool(value string, defaultValue bool) bool { - if value == "" { - return defaultValue - } - result, err := strconv.ParseBool(value) - if err != nil { - return defaultValue - } - return result -} diff --git a/models/WebSocketMnager.go b/models/WebSocketMnager.go new file mode 100644 index 0000000..2152047 --- /dev/null +++ b/models/WebSocketMnager.go @@ -0,0 +1,61 @@ +package models + +import ( + "sync" + + "github.com/google/uuid" + "github.com/gorilla/websocket" +) + +type WebSocketManager struct { + connections map[string]*websocket.Conn + listeners map[string]map[chan struct{}]struct{} + mutex sync.RWMutex +} + +func NewWebSocketManager() *WebSocketManager { + return &WebSocketManager{ + connections: make(map[string]*websocket.Conn), + mutex: sync.RWMutex{}, + } +} + +func (mgr *WebSocketManager) AddConnection(conn *websocket.Conn) string { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + id := uuid.New().String() // 为每个连接生成一个唯一的 ID + mgr.connections[id] = conn + + return id +} + +func (mgr *WebSocketManager) RemoveConnection(id string) { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + delete(mgr.connections, id) +} + +func (mgr *WebSocketManager) ListenForChanges(target string, callback func()) { + notifications := make(chan struct{}) + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + if _, ok := mgr.listeners[target]; !ok { + mgr.listeners[target] = make(map[chan struct{}]struct{}) + } + mgr.listeners[target][notifications] = struct{}{} + + go func() { + for { + callback() + for listener := range mgr.listeners[target] { + select { + case listener <- struct{}{}: + default: + delete(mgr.listeners[target], listener) + } + } + } + }() +} diff --git a/routers/images.go b/routers/images.go new file mode 100644 index 0000000..55c0433 --- /dev/null +++ b/routers/images.go @@ -0,0 +1,72 @@ +package routers + +import ( + "encoding/json" + "io/ioutil" + "log" + "main/models" + "main/utils" + "net/http" + + "github.com/gorilla/mux" +) + +func ImagesGet(w http.ResponseWriter, r *http.Request) { + var listview models.ListView + listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) + listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) + listview.List = models.QueryImages(listview.Page, listview.PageSize) + listview.Total = models.CountImages() + listview.Next = listview.Page*listview.PageSize < listview.Total + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(listview.ToJSON()) +} + +func ImagesPost(w http.ResponseWriter, r *http.Request) { + var image models.Image + body, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println(err) + return + } + defer r.Body.Close() + if err = json.Unmarshal(body, &image); err != nil { + log.Println(err) + return + } + image.Create() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(image)) +} + +func ImagesItemGet(w http.ResponseWriter, r *http.Request) { + image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + image.Get() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(image)) +} + +func ImagesItemPatch(w http.ResponseWriter, r *http.Request) { + image := models.Image{} + body, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println(err) + return + } + defer r.Body.Close() + if err = json.Unmarshal(body, &image); err != nil { + log.Println(err) + return + } + image.ID = utils.ParamInt(mux.Vars(r)["id"], 0) + image.Update() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(image)) +} + +func ImagesItemDelete(w http.ResponseWriter, r *http.Request) { + image := models.Image{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + image.Delete() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(image)) +} diff --git a/routers/models.go b/routers/models.go new file mode 100644 index 0000000..cd427f5 --- /dev/null +++ b/routers/models.go @@ -0,0 +1,105 @@ +package routers + +import ( + "encoding/json" + "io/ioutil" + "log" + "main/models" + "main/utils" + "net/http" + "strconv" + + "github.com/gorilla/mux" + "github.com/gorilla/websocket" +) + +var manager = models.NewWebSocketManager() + +func ModelsGet(w http.ResponseWriter, r *http.Request) { + var listview models.ListView + listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) + listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) + listview.List = models.QueryModels(listview.Page, listview.PageSize) + listview.Total = models.CountModels() + listview.Next = listview.Page*listview.PageSize < listview.Total + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(listview.ToJSON()) +} + +func ModelsPost(w http.ResponseWriter, r *http.Request) { + var model models.Model + body, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println(err) + return + } + defer r.Body.Close() + if err = json.Unmarshal(body, &model); err != nil { + log.Println(err) + return + } + model.Create() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(model)) +} + +func ModelItemGet(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + vars := mux.Vars(r) + id, _ := strconv.Atoi(vars["id"]) + model := models.QueryModel(id) + if model.ID == 0 { + w.WriteHeader(http.StatusNotFound) + return + } + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + defer conn.Close() + wsid := manager.AddConnection(conn) + defer manager.RemoveConnection(wsid) + for { + _, msg, err := conn.ReadMessage() + if err != nil { + log.Println(err) + return + } + log.Println(string(msg)) + if string(msg) == "close" { + break + } + } + return + } + model := models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + model.Get() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(model)) +} + +func ModelItemPatch(w http.ResponseWriter, r *http.Request) { + var model models.Model + body, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println(err) + return + } + defer r.Body.Close() + if err = json.Unmarshal(body, &model); err != nil { + log.Println(err) + return + } + model.ID = utils.ParamInt(mux.Vars(r)["id"], 0) + model.Update() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(model)) +} + +func ModelItemDelete(w http.ResponseWriter, r *http.Request) { + model := models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + model.Delete() + w.WriteHeader(http.StatusNoContent) +} diff --git a/routers/params.go b/routers/params.go new file mode 100644 index 0000000..9d48438 --- /dev/null +++ b/routers/params.go @@ -0,0 +1,15 @@ +package routers + +import ( + "main/utils" + "net/http" +) + +func ParamsModelsGet(w http.ResponseWriter, r *http.Request) { + params := make(map[string]interface{}) + params["type"] = []string{"lora", "ckp", "hyper", "ti"} + params["status"] = []string{"pending", "running", "finished", "failed"} + params["base_model"] = []string{"SD1.5", "SD2"} + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(params)) +} diff --git a/routers/tasks.go b/routers/tasks.go new file mode 100644 index 0000000..58a8d02 --- /dev/null +++ b/routers/tasks.go @@ -0,0 +1,104 @@ +package routers + +import ( + "encoding/json" + "io/ioutil" + "log" + "main/models" + "main/utils" + "net/http" + "strconv" + + "github.com/gorilla/mux" + "github.com/gorilla/websocket" +) + +func TasksGet(w http.ResponseWriter, r *http.Request) { + var listview models.ListView + listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) + listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) + listview.List = models.QueryTasks(listview.Page, listview.PageSize) + listview.Total = models.CountTasks() + listview.Next = listview.Page*listview.PageSize < listview.Total + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(listview.ToJSON()) +} + +func TasksPost(w http.ResponseWriter, r *http.Request) { + var task models.Task + body, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println(err) + return + } + defer r.Body.Close() + if err = json.Unmarshal(body, &task); err != nil { + log.Println(err) + return + } + task.Create() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(task)) +} + +func TasksItemGet(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + vars := mux.Vars(r) + id, _ := strconv.Atoi(vars["id"]) + task := models.QueryTask(id) + if task.ID == 0 { + w.WriteHeader(http.StatusNotFound) + return + } + upgrader := websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + defer ws.Close() + for { + _, message, err := ws.ReadMessage() + if err != nil { + log.Println(err) + break + } + task.Status = string(message) + task.Update() + } + return + } + task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + task.Get() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(task)) +} + +func TasksItemPatch(w http.ResponseWriter, r *http.Request) { + var task models.Task + body, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println(err) + return + } + defer r.Body.Close() + if err = json.Unmarshal(body, &task); err != nil { + log.Println(err) + return + } + task.ID = utils.ParamInt(mux.Vars(r)["id"], 0) + task.Update() + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(task)) +} + +func TasksItemDelete(w http.ResponseWriter, r *http.Request) { + task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} + task.Delete() + if task.ID == 0 { + w.WriteHeader(http.StatusNotFound) + return + } + task.Delete() + w.WriteHeader(http.StatusNoContent) +} diff --git a/utils/params.go b/utils/params.go new file mode 100644 index 0000000..f387e1c --- /dev/null +++ b/utils/params.go @@ -0,0 +1,73 @@ +package utils + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "strconv" + "time" +) + +// 獲取查詢參數(int 類型) +func ParamInt(value string, defaultValue int) int { + if value == "" { + return defaultValue + } + result, err := strconv.Atoi(value) + if err != nil { + return defaultValue + } + return result +} + +// 獲取查詢參數(string 類型) +func ParamString(value string, defaultValue string) string { + if value == "" { + return defaultValue + } + return value +} + +// 獲取查詢參數(bool 類型) +func ParamBool(value string, defaultValue bool) bool { + if value == "" { + return defaultValue + } + result, err := strconv.ParseBool(value) + if err != nil { + return defaultValue + } + return result +} + +func ToJSON(object interface{}) []byte { + json, err := json.MarshalIndent(object, "", " ") + if err != nil { + log.Println(err) + return []byte{} + } + return json +} + +func LogComponent(startTime int64, r *http.Request) { + ms := (time.Now().UnixNano() - startTime) / 1000000 + color := "\033[1;32m%d\033[0m" + if ms > 800 { + color = "\033[1;31m%dms\033[0m" // 紅色加重 + } else if ms > 500 { + color = "\033[1;33m%dms\033[0m" // 黃色加重 + } else if ms > 300 { + color = "\033[1;32m%dms\033[0m" // 綠色加重 + } else if ms > 200 { + color = "\033[1;34m%dms\033[0m" // 藍色加重 + } else if ms > 100 { + color = "\033[1;35m%dms\033[0m" // 紫色加重 + } else { + color = "\033[1;36m%dms\033[0m" // 黑色加重 + } + endTime := fmt.Sprintf(color, ms) + method := fmt.Sprintf("\033[1;32m%s\033[0m", r.Method) // 綠色加重 + url := fmt.Sprintf("\033[1;34m%s\033[0m", r.URL) // 藍色加重 + log.Println(method, url, endTime) +}