websocket
This commit is contained in:
		
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							@@ -4,5 +4,6 @@ go 1.18
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
	github.com/gorilla/mux v1.8.0
 | 
						github.com/gorilla/mux v1.8.0
 | 
				
			||||||
 | 
						github.com/gorilla/websocket v1.5.0
 | 
				
			||||||
	github.com/mattn/go-sqlite3 v1.14.16
 | 
						github.com/mattn/go-sqlite3 v1.14.16
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							@@ -1,4 +1,6 @@
 | 
				
			|||||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
 | 
					github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
 | 
				
			||||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
 | 
					github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
 | 
				
			||||||
 | 
					github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 | 
				
			||||||
 | 
					github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 | 
				
			||||||
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
 | 
					github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
 | 
				
			||||||
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
 | 
					github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										70
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										70
									
								
								main.go
									
									
									
									
									
								
							@@ -14,6 +14,7 @@ import (
 | 
				
			|||||||
	"main/models"
 | 
						"main/models"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gorilla/mux"
 | 
						"github.com/gorilla/mux"
 | 
				
			||||||
 | 
						"github.com/gorilla/websocket"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func main() {
 | 
					func main() {
 | 
				
			||||||
@@ -75,7 +76,45 @@ func models_post(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	w.Write(ToJSON(model))
 | 
						w.Write(ToJSON(model))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func models_item_get_ws(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
						vars := mux.Vars(r)
 | 
				
			||||||
 | 
						id, _ := strconv.Atoi(vars["id"])
 | 
				
			||||||
 | 
						model := models.QueryModel(id)
 | 
				
			||||||
 | 
						if model.ID == 0 {
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusNotFound)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						upgrader := websocket.Upgrader{}
 | 
				
			||||||
 | 
						conn, err := upgrader.Upgrade(w, r, nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							time.Sleep(1 * time.Second)
 | 
				
			||||||
 | 
							model = models.QueryModel(id)
 | 
				
			||||||
 | 
							if model.ID == 0 {
 | 
				
			||||||
 | 
								w.WriteHeader(http.StatusNotFound)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if model.Status == "success" || model.Status == "error" {
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							err = conn.WriteJSON(model)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Println(err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func models_item_get(w http.ResponseWriter, r *http.Request) {
 | 
					func models_item_get(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
						if r.Header.Get("Upgrade") == "websocket" {
 | 
				
			||||||
 | 
							models_item_get_ws(w, r)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	model := models.Model{ID: ParamInt(mux.Vars(r)["id"], 0)}
 | 
						model := models.Model{ID: ParamInt(mux.Vars(r)["id"], 0)}
 | 
				
			||||||
	model.Get()
 | 
						model.Get()
 | 
				
			||||||
	w.Header().Set("Content-Type", "application/json; charset=utf-8")
 | 
						w.Header().Set("Content-Type", "application/json; charset=utf-8")
 | 
				
			||||||
@@ -195,7 +234,38 @@ func tasks_post(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	w.Write(ToJSON(task))
 | 
						w.Write(ToJSON(task))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func tasks_item_get_ws(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
						vars := mux.Vars(r)
 | 
				
			||||||
 | 
						id, _ := strconv.Atoi(vars["id"])
 | 
				
			||||||
 | 
						task := models.QueryTask(id)
 | 
				
			||||||
 | 
						if task.ID == 0 {
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusNotFound)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						upgrader := websocket.Upgrader{}
 | 
				
			||||||
 | 
						ws, err := upgrader.Upgrade(w, r, nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer ws.Close()
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							_, message, err := ws.ReadMessage()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Println(err)
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							task.Status = string(message)
 | 
				
			||||||
 | 
							task.Update()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func tasks_item_get(w http.ResponseWriter, r *http.Request) {
 | 
					func tasks_item_get(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
						if r.Header.Get("Upgrade") == "websocket" {
 | 
				
			||||||
 | 
							tasks_item_get_ws(w, r)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	task := models.Task{ID: ParamInt(mux.Vars(r)["id"], 0)}
 | 
						task := models.Task{ID: ParamInt(mux.Vars(r)["id"], 0)}
 | 
				
			||||||
	task.Get()
 | 
						task.Get()
 | 
				
			||||||
	w.Header().Set("Content-Type", "application/json; charset=utf-8")
 | 
						w.Header().Set("Content-Type", "application/json; charset=utf-8")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -104,6 +104,21 @@ func (model *Model) Get() error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func QueryModel(id int) (model Model) {
 | 
				
			||||||
 | 
						db, err := configs.GetDB()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer db.Close()
 | 
				
			||||||
 | 
						err = db.QueryRow("SELECT id, name, type, trigger_words, base_model, model_path, status, progress, tags, created_at, updated_at, user_id FROM models WHERE id = ?", id).Scan(&model.ID, &model.Name, &model.Type, &model.TriggerWords, &model.BaseModel, &model.ModelPath, &model.Status, &model.Progress, &model.Tags, &model.CreatedAt, &model.UpdatedAt, &model.UserID)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func QueryModels(page int, pagesize int) (models []interface{}) {
 | 
					func QueryModels(page int, pagesize int) (models []interface{}) {
 | 
				
			||||||
	db, err := configs.GetDB()
 | 
						db, err := configs.GetDB()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -99,6 +99,21 @@ func (task *Task) Get() error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func QueryTask(id int) (task Task) {
 | 
				
			||||||
 | 
						db, err := configs.GetDB()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer db.Close()
 | 
				
			||||||
 | 
						err = db.QueryRow("SELECT id, name, type, created_at, updated_at FROM tasks WHERE id = ?", id).Scan(&task.ID, &task.Name, &task.Type, &task.CreatedAt, &task.UpdatedAt)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Println(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func QueryTasks(page int, pagesize int) (tasks []interface{}) {
 | 
					func QueryTasks(page int, pagesize int) (tasks []interface{}) {
 | 
				
			||||||
	db, err := configs.GetDB()
 | 
						db, err := configs.GetDB()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user