diff --git a/models/Model.go b/models/Model.go index 5f4fa89..fe6d106 100644 --- a/models/Model.go +++ b/models/Model.go @@ -1,6 +1,7 @@ package models import ( + "fmt" "main/configs" ) @@ -24,6 +25,26 @@ func init() { configs.ORMDB().AutoMigrate(&Model{}) } +func (model *Model) Train() (err error) { + if model.Type == "lora" { + fmt.Println("lora") + return + } + if model.Type == "ckp" { + fmt.Println("ckp") + return + } + if model.Type == "hyper" { + fmt.Println("hyper") + return + } + if model.Type == "ti" { + fmt.Println("ti") + return + } + return +} + //func (model *Model) SendToTrain() error { // db, err := configs.GetDB() // if err != nil { diff --git a/models/session.go b/models/session.go index 8b9e233..e986e84 100644 --- a/models/session.go +++ b/models/session.go @@ -6,7 +6,9 @@ import ( type Session struct { ID string `json:"id" gorm:"primary_key"` + IP string `json:"ip"` UserID int `json:"user_id"` + UserAgent string `json:"user_agent"` CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` } diff --git a/routers/models.go b/routers/models.go index 38274d7..e294823 100644 --- a/routers/models.go +++ b/routers/models.go @@ -10,6 +10,7 @@ import ( "main/utils" "net/http" "strconv" + "time" "github.com/gorilla/mux" "github.com/gorilla/websocket" @@ -134,19 +135,25 @@ func ModelItemPatch(w http.ResponseWriter, r *http.Request) { if model_new.Status != "" && model_new.Status != model.Status { model.Status = model_new.Status // 如果狀態被改變爲 ready, 將模型發送到訓練隊列 - //if model.Status == "ready" { - // model.SendToTrain() - //} + if model.Status == "ready" { + model.Status = "training" + go model.Train() + } } if model_new.Image != "" && model_new.Image != model.Image { model.Image = model_new.Image } + // 更新時間 + model.UpdatedAt = time.Now().Format("2006-01-02 15:04:05") + + // 執行更新 if err := configs.ORMDB().Save(&model).Error; err != nil { log.Println(err) return } + // 返回更新後的數據 w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(utils.ToJSON(model)) } diff --git a/routers/sessions.go b/routers/sessions.go index 21144c2..ce5f5ee 100644 --- a/routers/sessions.go +++ b/routers/sessions.go @@ -76,7 +76,7 @@ func SessionsPost(w http.ResponseWriter, r *http.Request) { } // 創建會話(生成一個不重複的 uuid 作爲 sid) - session := &models.Session{ID: uuid.New().String(), UserID: user.ID} + session := &models.Session{ID: uuid.New().String(), UserID: user.ID, UserAgent: r.UserAgent(), IP: r.RemoteAddr} if err := configs.ORMDB().Create(session).Error; err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("500 - Internal Server Error"))