diff --git a/routers/models.go b/routers/models.go index 08882cc..bbd229a 100644 --- a/routers/models.go +++ b/routers/models.go @@ -242,71 +242,129 @@ func ModelItemPatch(w http.ResponseWriter, r *http.Request) { return } - // 取出更新数据 - var model_new models.Model - body, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println(err) - return - } - defer r.Body.Close() - if err = json.Unmarshal(body, &model_new); err != nil { - log.Println(err) - return - } - - // 字段不爲空且不等於原始數據時更新 - if model_new.Name != "" && model_new.Name != model.Name { - model.Name = model_new.Name - } - if model_new.Info != "" && model_new.Info != model.Info { - model.Info = model_new.Info - } - if model_new.Type != "" && model_new.Type != model.Type { - model.Type = model_new.Type - } - if model_new.Status != "" && model_new.Status != model.Status { - model.Status = model_new.Status - // 如果狀態被改變爲 ready, 將模型發送到訓練隊列 - if model.Status == "ready" { - model.Status = "training" - //go model.Train() + // 判断数据类型是否二进制文件 + if r.Header.Get("Content-Type") == "multipart/form-data" { + // 解析表单取出图片文件 (32MB) + if err := r.ParseMultipartForm(32 << 20); err != nil { + log.Println(err) + return } - } - if model_new.Preview != "" && model_new.Preview != model.Preview { - model.Preview = model_new.Preview - } - if model_new.TriggerWords != "" && model_new.TriggerWords != model.TriggerWords { - model.TriggerWords = model_new.TriggerWords - } - if model_new.BaseModel != "" && model_new.BaseModel != model.BaseModel { - model.BaseModel = model_new.BaseModel - } - if model_new.ModelPath != "" && model_new.ModelPath != model.ModelPath { - model.ModelPath = model_new.ModelPath - } - if model_new.Hash != "" && model_new.Hash != model.Hash { - model.Hash = model_new.Hash - } - if model_new.Epochs != 0 && model_new.Epochs != model.Epochs { - model.Epochs = model_new.Epochs - } - if model_new.Progress != 0 && model_new.Progress != model.Progress { - model.Progress = model_new.Progress - } - if model_new.Tags != nil && len(model_new.Tags) != len(model.Tags) { - model.Tags = model_new.Tags - } - // 執行更新 - if err := configs.ORMDB().Save(&model).Error; err != nil { - log.Println(err) + // 检查文件目录是否存在 + os.MkdirAll(fmt.Sprintf("data/models/%d", model.ID), 0777) + + // 上传文件 + for _, headers := range r.MultipartForm.File { + for _, header := range headers { + // 打开本地文件 + file, err := os.Create(fmt.Sprintf("data/models/%d/%s", model.ID, header.Filename)) + if err != nil { + log.Println(err) + return + } + defer file.Close() + + // 打开上传文件 + f, err := header.Open() + if err != nil { + log.Println(err) + return + } + + // 拷贝文件到本地 + _, err = io.Copy(file, f) + if err != nil { + log.Println(err) + return + } + + // 更新模型 + model.Preview = fmt.Sprintf("data/models/%d/%s", model.ID, header.Filename) + if err := configs.ORMDB().Save(&model).Error; err != nil { + log.Println(err) + return + } + } + } + + // 返回更新后的数据 + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(model)) + + return + } + + // 判断数据类型是否JSON + if r.Header.Get("Content-Type") == "application/json" { + + // 取出更新数据 + var model_new models.Model + body, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println(err) + return + } + defer r.Body.Close() + if err = json.Unmarshal(body, &model_new); err != nil { + log.Println(err) + return + } + + // 字段不爲空且不等於原始數據時更新 + if model_new.Name != "" && model_new.Name != model.Name { + model.Name = model_new.Name + } + if model_new.Info != "" && model_new.Info != model.Info { + model.Info = model_new.Info + } + if model_new.Type != "" && model_new.Type != model.Type { + model.Type = model_new.Type + } + if model_new.Status != "" && model_new.Status != model.Status { + model.Status = model_new.Status + // 如果狀態被改變爲 ready, 將模型發送到訓練隊列 + if model.Status == "ready" { + model.Status = "training" + //go model.Train() + } + } + if model_new.Preview != "" && model_new.Preview != model.Preview { + model.Preview = model_new.Preview + } + if model_new.TriggerWords != "" && model_new.TriggerWords != model.TriggerWords { + model.TriggerWords = model_new.TriggerWords + } + if model_new.BaseModel != "" && model_new.BaseModel != model.BaseModel { + model.BaseModel = model_new.BaseModel + } + if model_new.ModelPath != "" && model_new.ModelPath != model.ModelPath { + model.ModelPath = model_new.ModelPath + } + if model_new.Hash != "" && model_new.Hash != model.Hash { + model.Hash = model_new.Hash + } + if model_new.Epochs != 0 && model_new.Epochs != model.Epochs { + model.Epochs = model_new.Epochs + } + if model_new.Progress != 0 && model_new.Progress != model.Progress { + model.Progress = model_new.Progress + } + if model_new.Tags != nil && len(model_new.Tags) != len(model.Tags) { + model.Tags = model_new.Tags + } + + // 執行更新 + if err := configs.ORMDB().Save(&model).Error; err != nil { + log.Println(err) + return + } + + // 返回更新後的數據 + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(model)) return } - // 返回更新後的數據 - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(utils.ToJSON(model)) } // 刪除模型