From 050637dcd87fe1e1133a381b7d7e9f849bb61828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A1=9C=E8=8F=AF?= Date: Tue, 11 Jul 2023 14:37:03 +0800 Subject: [PATCH] =?UTF-8?q?dataset=20=20=E4=B8=8A=E4=BC=A0=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/Model.go | 6 ++++ routers/datasets.go | 78 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/models/Model.go b/models/Model.go index 7a5d35e..dc324ca 100644 --- a/models/Model.go +++ b/models/Model.go @@ -63,6 +63,12 @@ func (model *Model) Load() { configs.ORMDB().First(&model) } +// 从数据库加载指定的模型 +func ModelLoad(id int) (model Model, err error) { + err = configs.ORMDB().First(&model, id).Error + return +} + // 推理模型 func (model *Model) Inference(image_list []Image, callback func(Image)) { var server Server diff --git a/routers/datasets.go b/routers/datasets.go index 5b88fac..6b203f8 100644 --- a/routers/datasets.go +++ b/routers/datasets.go @@ -2,11 +2,14 @@ package routers import ( "encoding/json" + "fmt" + "io" "io/ioutil" "main/configs" "main/models" "main/utils" "net/http" + "os" "github.com/gorilla/mux" ) @@ -60,6 +63,81 @@ func DatasetsPost(w http.ResponseWriter, r *http.Request) { }) } +// 上传图片文件 +func DatasetsUpload(w http.ResponseWriter, r *http.Request) { + models.AccountRead(w, r, func(account *models.Account) { + // 獲取數據集 + dataset := models.Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)} + if err := configs.ORMDB().Find(&dataset).Error; err != nil { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("404 - Not Found")) + return + } + // 只能修改自己的數據集, 除非是管理員 + if dataset.UserID != account.ID && !account.Admin { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("403 - Forbidden")) + return + } + // 解析 HTTP 请求中的多个文件 (限制上传文件的大小为 32MB) + if err := r.ParseMultipartForm(32 << 20); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + // 遍历所有上传的文件 + for _, fileHeaders := range r.MultipartForm.File { + for _, fileHeader := range fileHeaders { + // 打开上传的文件 + file, err := fileHeader.Open() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer file.Close() + + // 在本地文件夹中创建一个新文件 + localFile, err := os.Create(fmt.Sprintf("data/dataset/%d/%s", dataset.ID, fileHeader.Filename)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer localFile.Close() + + // 将上传文件的内容复制到本地文件 + _, err = io.Copy(localFile, file) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // 将文件名添加到數據集中 + dataset.Images = append(dataset.Images, fileHeader.Filename) + } + } + + // 去除重复项 + uniqueImages := make(map[string]bool) + for _, image := range dataset.Images { + uniqueImages[image] = true + } + + // 转换为切片 + dataset.Images = []string{} + for image := range uniqueImages { + dataset.Images = append(dataset.Images, image) + } + + // 保存數據集 + if err := configs.ORMDB().Save(&dataset).Error; err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Internal Server Error")) + return + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(utils.ToJSON(dataset)) + }) +} + // 獲取數據集 func DatasetsItemGet(w http.ResponseWriter, r *http.Request) { dataset := models.Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)}