diff --git a/go.mod b/go.mod index c5a03d5..67144e4 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,6 @@ go 1.18 require ( github.com/gorilla/mux v1.8.0 + github.com/gorilla/websocket v1.5.0 github.com/mattn/go-sqlite3 v1.14.16 ) diff --git a/go.sum b/go.sum index 4f6d3fb..08efe41 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= diff --git a/main.go b/main.go index a198a18..b932173 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "main/models" "github.com/gorilla/mux" + "github.com/gorilla/websocket" ) func main() { @@ -75,7 +76,45 @@ func models_post(w http.ResponseWriter, r *http.Request) { w.Write(ToJSON(model)) } +func models_item_get_ws(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + id, _ := strconv.Atoi(vars["id"]) + model := models.QueryModel(id) + if model.ID == 0 { + w.WriteHeader(http.StatusNotFound) + return + } + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + defer conn.Close() + for { + time.Sleep(1 * time.Second) + model = models.QueryModel(id) + if model.ID == 0 { + w.WriteHeader(http.StatusNotFound) + return + } + if model.Status == "success" || model.Status == "error" { + break + } + err = conn.WriteJSON(model) + if err != nil { + log.Println(err) + return + } + } +} + func models_item_get(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + models_item_get_ws(w, r) + return + } + model := models.Model{ID: ParamInt(mux.Vars(r)["id"], 0)} model.Get() w.Header().Set("Content-Type", "application/json; charset=utf-8") @@ -195,7 +234,38 @@ func tasks_post(w http.ResponseWriter, r *http.Request) { w.Write(ToJSON(task)) } +func tasks_item_get_ws(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + id, _ := strconv.Atoi(vars["id"]) + task := models.QueryTask(id) + if task.ID == 0 { + w.WriteHeader(http.StatusNotFound) + return + } + upgrader := websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + defer ws.Close() + for { + _, message, err := ws.ReadMessage() + if err != nil { + log.Println(err) + break + } + task.Status = string(message) + task.Update() + } +} + func tasks_item_get(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + tasks_item_get_ws(w, r) + return + } + task := models.Task{ID: ParamInt(mux.Vars(r)["id"], 0)} task.Get() w.Header().Set("Content-Type", "application/json; charset=utf-8") diff --git a/models/Model.go b/models/Model.go index 854ad1e..1fc5222 100644 --- a/models/Model.go +++ b/models/Model.go @@ -104,6 +104,21 @@ func (model *Model) Get() error { return nil } +func QueryModel(id int) (model Model) { + db, err := configs.GetDB() + if err != nil { + log.Println(err) + return + } + defer db.Close() + err = db.QueryRow("SELECT id, name, type, trigger_words, base_model, model_path, status, progress, tags, created_at, updated_at, user_id FROM models WHERE id = ?", id).Scan(&model.ID, &model.Name, &model.Type, &model.TriggerWords, &model.BaseModel, &model.ModelPath, &model.Status, &model.Progress, &model.Tags, &model.CreatedAt, &model.UpdatedAt, &model.UserID) + if err != nil { + log.Println(err) + return + } + return +} + func QueryModels(page int, pagesize int) (models []interface{}) { db, err := configs.GetDB() if err != nil { diff --git a/models/Task.go b/models/Task.go index 6e996ce..380459f 100644 --- a/models/Task.go +++ b/models/Task.go @@ -99,6 +99,21 @@ func (task *Task) Get() error { return nil } +func QueryTask(id int) (task Task) { + db, err := configs.GetDB() + if err != nil { + log.Println(err) + return + } + defer db.Close() + err = db.QueryRow("SELECT id, name, type, created_at, updated_at FROM tasks WHERE id = ?", id).Scan(&task.ID, &task.Name, &task.Type, &task.CreatedAt, &task.UpdatedAt) + if err != nil { + log.Println(err) + return + } + return +} + func QueryTasks(page int, pagesize int) (tasks []interface{}) { db, err := configs.GetDB() if err != nil {