package main import ( "encoding/json" "fmt" "io/ioutil" "log" "net/http" "runtime" "strconv" "text/template" "time" "main/models" "github.com/gorilla/mux" "github.com/gorilla/websocket" ) func main() { runtime.GOMAXPROCS(runtime.NumCPU()) r := mux.NewRouter() r.Use(middleware) r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { t, _ := template.ParseFiles("templates/index.html") t.Execute(w, nil) }) r.HandleFunc("/api/models", models_get).Methods("GET") r.HandleFunc("/api/models", models_post).Methods("POST") r.HandleFunc("/api/models/{id}", models_item_get).Methods("GET") r.HandleFunc("/api/models/{id}", models_item_patch).Methods("PATCH") r.HandleFunc("/api/models/{id}", models_item_delete).Methods("DELETE") r.HandleFunc("/api/images", images_get).Methods("GET") r.HandleFunc("/api/images", images_post).Methods("POST") r.HandleFunc("/api/images/{id}", images_item_get).Methods("GET") r.HandleFunc("/api/images/{id}", images_item_patch).Methods("PATCH") r.HandleFunc("/api/images/{id}", images_item_delete).Methods("DELETE") r.HandleFunc("/api/tasks", tasks_get).Methods("GET") r.HandleFunc("/api/tasks", tasks_post).Methods("POST") r.HandleFunc("/api/tasks/{id}", tasks_item_get).Methods("GET") r.HandleFunc("/api/tasks/{id}", tasks_item_patch).Methods("PATCH") r.HandleFunc("/api/tasks/{id}", tasks_item_delete).Methods("DELETE") r.HandleFunc("/api/params/model", models_params_get).Methods("GET") log.Println("Web Server is running on http://localhost:8080") http.ListenAndServe(":8080", r) } func models_params_get(w http.ResponseWriter, r *http.Request) { params := make(map[string]interface{}) params["type"] = []string{"lora", "ckp", "hyper", "ti"} params["status"] = []string{"pending", "running", "finished", "failed"} params["base_model"] = []string{"SD1.5", "SD2"} w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(params)) } func models_get(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = ParamInt(r.URL.Query().Get("pageSize"), 10) listview.List = models.QueryModels(listview.Page, listview.PageSize) listview.Total = models.CountModels() listview.Next = listview.Page*listview.PageSize < listview.Total w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(listview.ToJSON()) } func models_post(w http.ResponseWriter, r *http.Request) { var model models.Model body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &model); err != nil { log.Println(err) return } model.Create() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(model)) } func models_item_get_ws(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.Atoi(vars["id"]) model := models.QueryModel(id) if model.ID == 0 { w.WriteHeader(http.StatusNotFound) return } upgrader := websocket.Upgrader{} conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println(err) return } defer conn.Close() for { time.Sleep(1 * time.Second) model = models.QueryModel(id) if model.ID == 0 { w.WriteHeader(http.StatusNotFound) return } if model.Status == "success" || model.Status == "error" { break } err = conn.WriteJSON(model) if err != nil { log.Println(err) return } } } func models_item_get(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") == "websocket" { models_item_get_ws(w, r) return } model := models.Model{ID: ParamInt(mux.Vars(r)["id"], 0)} model.Get() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(model)) } func models_item_patch(w http.ResponseWriter, r *http.Request) { model := models.Model{} body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &model); err != nil { log.Println(err) return } model.ID = ParamInt(mux.Vars(r)["id"], 0) model.Update() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(model)) } func models_item_delete(w http.ResponseWriter, r *http.Request) { model := models.Model{ID: ParamInt(mux.Vars(r)["id"], 0)} model.Delete() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(model)) } func images_get(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = ParamInt(r.URL.Query().Get("pageSize"), 10) listview.List = models.QueryImages(listview.Page, listview.PageSize) listview.Total = models.CountImages() listview.Next = listview.Page*listview.PageSize < listview.Total w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(listview.ToJSON()) } func images_post(w http.ResponseWriter, r *http.Request) { var image models.Image body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &image); err != nil { log.Println(err) return } image.Create() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(image)) } func images_item_get(w http.ResponseWriter, r *http.Request) { image := models.Image{ID: ParamInt(mux.Vars(r)["id"], 0)} image.Get() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(image)) } func images_item_patch(w http.ResponseWriter, r *http.Request) { image := models.Image{} body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &image); err != nil { log.Println(err) return } image.ID = ParamInt(mux.Vars(r)["id"], 0) image.Update() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(image)) } func images_item_delete(w http.ResponseWriter, r *http.Request) { image := models.Image{ID: ParamInt(mux.Vars(r)["id"], 0)} image.Delete() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(image)) } func tasks_get(w http.ResponseWriter, r *http.Request) { var listview models.ListView listview.Page = ParamInt(r.URL.Query().Get("page"), 1) listview.PageSize = ParamInt(r.URL.Query().Get("pageSize"), 10) listview.List = models.QueryTasks(listview.Page, listview.PageSize) listview.Total = models.CountTasks() listview.Next = listview.Page*listview.PageSize < listview.Total w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(listview)) } func tasks_post(w http.ResponseWriter, r *http.Request) { var task models.Task body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &task); err != nil { log.Println(err) return } task.Create() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(task)) } func tasks_item_get_ws(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.Atoi(vars["id"]) task := models.QueryTask(id) if task.ID == 0 { w.WriteHeader(http.StatusNotFound) return } upgrader := websocket.Upgrader{} ws, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println(err) return } defer ws.Close() for { _, message, err := ws.ReadMessage() if err != nil { log.Println(err) break } task.Status = string(message) task.Update() } } func tasks_item_get(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") == "websocket" { tasks_item_get_ws(w, r) return } task := models.Task{ID: ParamInt(mux.Vars(r)["id"], 0)} task.Get() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(task)) } func tasks_item_patch(w http.ResponseWriter, r *http.Request) { task := models.Task{} body, err := ioutil.ReadAll(r.Body) if err != nil { log.Println(err) return } defer r.Body.Close() if err = json.Unmarshal(body, &task); err != nil { log.Println(err) return } task.ID = ParamInt(mux.Vars(r)["id"], 0) task.Update() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(task)) } func tasks_item_delete(w http.ResponseWriter, r *http.Request) { task := models.Task{ID: ParamInt(mux.Vars(r)["id"], 0)} task.Delete() w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(ToJSON(task)) } // 中間件, 通用 func middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer LogComponent(time.Now().UnixNano(), r) // 最后打印日志 // 處理跨域請求 w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") // 處理OPTIONS請求 if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } func ToJSON(object interface{}) []byte { json, err := json.MarshalIndent(object, "", " ") if err != nil { log.Println(err) return []byte{} } return json } func LogComponent(startTime int64, r *http.Request) { ms := (time.Now().UnixNano() - startTime) / 1000000 color := "\033[1;32m%d\033[0m" if ms > 800 { color = "\033[1;31m%dms\033[0m" // 紅色加重 } else if ms > 500 { color = "\033[1;33m%dms\033[0m" // 黃色加重 } else if ms > 300 { color = "\033[1;32m%dms\033[0m" // 綠色加重 } else if ms > 200 { color = "\033[1;34m%dms\033[0m" // 藍色加重 } else if ms > 100 { color = "\033[1;35m%dms\033[0m" // 紫色加重 } else { color = "\033[1;36m%dms\033[0m" // 黑色加重 } endTime := fmt.Sprintf(color, ms) method := fmt.Sprintf("\033[1;32m%s\033[0m", r.Method) // 綠色加重 url := fmt.Sprintf("\033[1;34m%s\033[0m", r.URL) // 藍色加重 log.Println(method, url, endTime) } // 獲取查詢參數(int 類型) func ParamInt(value string, defaultValue int) int { if value == "" { return defaultValue } result, err := strconv.Atoi(value) if err != nil { return defaultValue } return result } // 獲取查詢參數(string 類型) func ParamString(value string, defaultValue string) string { if value == "" { return defaultValue } return value } // 獲取查詢參數(bool 類型) func ParamBool(value string, defaultValue bool) bool { if value == "" { return defaultValue } result, err := strconv.ParseBool(value) if err != nil { return defaultValue } return result }