下載圖像
This commit is contained in:
		
							
								
								
									
										21
									
								
								models/Dataset.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								models/Dataset.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"main/configs"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 數據集模型
 | 
				
			||||||
 | 
					type Dataset struct {
 | 
				
			||||||
 | 
						ID        int       `json:"id" gorm:"primary_key"`
 | 
				
			||||||
 | 
						Name      string    `json:"name"`
 | 
				
			||||||
 | 
						Info      string    `json:"info"`
 | 
				
			||||||
 | 
						Images    ImageList `json:"images"`
 | 
				
			||||||
 | 
						UserID    int       `json:"user_id"`
 | 
				
			||||||
 | 
						CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
				
			||||||
 | 
						UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						configs.ORMDB().AutoMigrate(&Dataset{})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,6 +1,8 @@
 | 
				
			|||||||
package models
 | 
					package models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql/driver"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
	"main/configs"
 | 
						"main/configs"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -25,3 +27,13 @@ type Image struct {
 | 
				
			|||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	configs.ORMDB().AutoMigrate(&Image{})
 | 
						configs.ORMDB().AutoMigrate(&Image{})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ImageList []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (list *ImageList) Scan(value interface{}) error {
 | 
				
			||||||
 | 
						return json.Unmarshal(value.([]byte), list)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (list ImageList) Value() (driver.Value, error) {
 | 
				
			||||||
 | 
						return json.Marshal(list)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										137
									
								
								models/Model.go
									
									
									
									
									
								
							
							
						
						
									
										137
									
								
								models/Model.go
									
									
									
									
									
								
							@@ -1,8 +1,15 @@
 | 
				
			|||||||
package models
 | 
					package models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/md5"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
	"main/configs"
 | 
						"main/configs"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"os/exec"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -21,6 +28,7 @@ type Model struct {
 | 
				
			|||||||
	Tags         TagList   `json:"tags"`
 | 
						Tags         TagList   `json:"tags"`
 | 
				
			||||||
	UserID       int       `json:"user_id"`
 | 
						UserID       int       `json:"user_id"`
 | 
				
			||||||
	DatasetID    int       `json:"dataset_id"`
 | 
						DatasetID    int       `json:"dataset_id"`
 | 
				
			||||||
 | 
						ServerID     int       `json:"server_id"`
 | 
				
			||||||
	CreatedAt    time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
						CreatedAt    time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
				
			||||||
	UpdatedAt    time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
						UpdatedAt    time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -30,21 +38,130 @@ func init() {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (model *Model) Train() (err error) {
 | 
					func (model *Model) Train() (err error) {
 | 
				
			||||||
	if model.Type == "lora" {
 | 
					
 | 
				
			||||||
		fmt.Println("lora")
 | 
						// 獲取一臺空閒的訓練機
 | 
				
			||||||
 | 
						var server Server
 | 
				
			||||||
 | 
						if err = configs.ORMDB().Where("status = ?", "閒置").First(&server).Error; err != nil {
 | 
				
			||||||
 | 
							fmt.Println(err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if model.Type == "ckp" {
 | 
					
 | 
				
			||||||
		fmt.Println("ckp")
 | 
						// 獲取數據集
 | 
				
			||||||
 | 
						var dataset Dataset = Dataset{ID: model.DatasetID}
 | 
				
			||||||
 | 
						if err = configs.ORMDB().First(&dataset).Error; err != nil {
 | 
				
			||||||
 | 
							fmt.Println(err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if model.Type == "hyper" {
 | 
					
 | 
				
			||||||
		fmt.Println("hyper")
 | 
						// 更新模型狀態
 | 
				
			||||||
 | 
						model.ServerID = server.ID
 | 
				
			||||||
 | 
						model.Status = "training"
 | 
				
			||||||
 | 
						if err = configs.ORMDB().Save(&model).Error; err != nil {
 | 
				
			||||||
 | 
							fmt.Println(err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if model.Type == "ti" {
 | 
					
 | 
				
			||||||
		fmt.Println("ti")
 | 
						// 創建數據集目錄
 | 
				
			||||||
		return
 | 
						dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images")
 | 
				
			||||||
 | 
						if err := os.MkdirAll(dirPath, 0755); err != nil {
 | 
				
			||||||
 | 
							fmt.Println(err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return
 | 
					
 | 
				
			||||||
 | 
						// 將數據下載到本地
 | 
				
			||||||
 | 
						for index, url := range dataset.Images {
 | 
				
			||||||
 | 
							fmt.Println("下載數據到本地:", index, url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 下載到臨時目錄
 | 
				
			||||||
 | 
							resp, err := http.Get(url)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								fmt.Println("下載失敗:", err)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer resp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							data, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								fmt.Println("保存失敗:", err)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 保存文件到本地目錄下(自動創建目錄,文件名為url的md5值)
 | 
				
			||||||
 | 
							filename := fmt.Sprintf("%x", md5.Sum([]byte(url)))
 | 
				
			||||||
 | 
							filePath := filepath.Join(dirPath, filename)
 | 
				
			||||||
 | 
							if err := os.MkdirAll(dirPath, 0755); err != nil {
 | 
				
			||||||
 | 
								fmt.Println(err)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if err := ioutil.WriteFile(filePath, data, 0644); err != nil {
 | 
				
			||||||
 | 
								fmt.Println(err)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fmt.Println("數據下載完成")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 檢查目錄下是否有文件, 如果沒有文件則返回錯誤
 | 
				
			||||||
 | 
						files, err := ioutil.ReadDir(dirPath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Println(err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(files) == 0 {
 | 
				
			||||||
 | 
							fmt.Println("目錄下沒有文件")
 | 
				
			||||||
 | 
							return fmt.Errorf("目錄下沒有文件")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 將文件全部上傳到訓練機(使用scp命令)
 | 
				
			||||||
 | 
						err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Println(err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 刪除本地臨時目錄
 | 
				
			||||||
 | 
						if err := os.RemoveAll(dirPath); err != nil {
 | 
				
			||||||
 | 
							fmt.Println(err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 将基础模型上传到训练机(使用scp命令)
 | 
				
			||||||
 | 
						baseModelPath := filepath.Join("data/models", model.BaseModel)
 | 
				
			||||||
 | 
						fmt.Println("baseModelPath:", baseModelPath)
 | 
				
			||||||
 | 
						err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Println(err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 進行訓練(訓練機上調用訓練webapi接口:參數)
 | 
				
			||||||
 | 
						resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Println(err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer resp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 循環監聽訓練進度
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							// 訓練機上調用訓練webapi接口:獲取訓練進度
 | 
				
			||||||
 | 
							resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								fmt.Println(err)
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer resp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 更新本地訓練進度
 | 
				
			||||||
 | 
							var progress int
 | 
				
			||||||
 | 
							if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
 | 
				
			||||||
 | 
								fmt.Println(err)
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO: 訓練完成後將模型下載到本地
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,7 +11,7 @@ type Server struct {
 | 
				
			|||||||
	Type      string                   `json:"type"` // (訓練|推理)
 | 
						Type      string                   `json:"type"` // (訓練|推理)
 | 
				
			||||||
	IP        string                   `json:"ip"`
 | 
						IP        string                   `json:"ip"`
 | 
				
			||||||
	Port      int                      `json:"port"`
 | 
						Port      int                      `json:"port"`
 | 
				
			||||||
	Status    string                   `json:"status"` // (異常|初始化|就緒|工作中|關閉中)
 | 
						Status    string                   `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
 | 
				
			||||||
	Username  string                   `json:"username"`
 | 
						Username  string                   `json:"username"`
 | 
				
			||||||
	Password  string                   `json:"password"`
 | 
						Password  string                   `json:"password"`
 | 
				
			||||||
	Models    []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
 | 
						Models    []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
 | 
				
			||||||
@@ -21,4 +21,13 @@ type Server struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	configs.ORMDB().AutoMigrate(&Server{})
 | 
						configs.ORMDB().AutoMigrate(&Server{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 添加一個預設的訓練機
 | 
				
			||||||
 | 
						configs.ORMDB().Create(&Server{
 | 
				
			||||||
 | 
							Name:   "GPU T4",
 | 
				
			||||||
 | 
							Type:   "train",
 | 
				
			||||||
 | 
							IP:     "106.15.192.42",
 | 
				
			||||||
 | 
							Port:   7860,
 | 
				
			||||||
 | 
							Status: "閒置",
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,57 +1,31 @@
 | 
				
			|||||||
package routers
 | 
					package routers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"database/sql/driver"
 | 
					 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
	"main/configs"
 | 
						"main/configs"
 | 
				
			||||||
	"main/models"
 | 
						"main/models"
 | 
				
			||||||
	"main/utils"
 | 
						"main/utils"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gorilla/mux"
 | 
						"github.com/gorilla/mux"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ImageList []string
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (list *ImageList) Scan(value interface{}) error {
 | 
					 | 
				
			||||||
	return json.Unmarshal(value.([]byte), list)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
func (list ImageList) Value() (driver.Value, error) {
 | 
					 | 
				
			||||||
	return json.Marshal(list)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// 數據集模型
 | 
					 | 
				
			||||||
type Dataset struct {
 | 
					 | 
				
			||||||
	ID        int       `json:"id" gorm:"primary_key"`
 | 
					 | 
				
			||||||
	Name      string    `json:"name"`
 | 
					 | 
				
			||||||
	Info      string    `json:"info"`
 | 
					 | 
				
			||||||
	Images    ImageList `json:"images"`
 | 
					 | 
				
			||||||
	UserID    int       `json:"user_id"`
 | 
					 | 
				
			||||||
	CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
					 | 
				
			||||||
	UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func init() {
 | 
					 | 
				
			||||||
	configs.ORMDB().AutoMigrate(&Dataset{})
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// 獲取數據集列表
 | 
					// 獲取數據集列表
 | 
				
			||||||
func DatasetsGet(w http.ResponseWriter, r *http.Request) {
 | 
					func DatasetsGet(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	var listview models.ListView
 | 
						var listview models.ListView
 | 
				
			||||||
	listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
 | 
						listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
 | 
				
			||||||
	listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
 | 
						listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
 | 
				
			||||||
	var dataset_list []Dataset
 | 
						var dataset_list []models.Dataset
 | 
				
			||||||
	db := configs.ORMDB()
 | 
						db := configs.ORMDB()
 | 
				
			||||||
	db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&dataset_list)
 | 
						db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&dataset_list)
 | 
				
			||||||
	for _, dataset := range dataset_list {
 | 
						for _, dataset := range dataset_list {
 | 
				
			||||||
		if dataset.Images == nil {
 | 
							if dataset.Images == nil {
 | 
				
			||||||
			dataset.Images = ImageList{}
 | 
								dataset.Images = models.ImageList{}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		listview.List = append(listview.List, dataset)
 | 
							listview.List = append(listview.List, dataset)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	db.Model(&Dataset{}).Count(&listview.Total)
 | 
						db.Model(&models.Dataset{}).Count(&listview.Total)
 | 
				
			||||||
	listview.Next = listview.Page*listview.PageSize < int(listview.Total)
 | 
						listview.Next = listview.Page*listview.PageSize < int(listview.Total)
 | 
				
			||||||
	listview.WriteJSON(w)
 | 
						listview.WriteJSON(w)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -59,7 +33,7 @@ func DatasetsGet(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
// 創建數據集
 | 
					// 創建數據集
 | 
				
			||||||
func DatasetsPost(w http.ResponseWriter, r *http.Request) {
 | 
					func DatasetsPost(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	models.AccountRead(w, r, func(account *models.Account) {
 | 
						models.AccountRead(w, r, func(account *models.Account) {
 | 
				
			||||||
		var dataset Dataset
 | 
							var dataset models.Dataset
 | 
				
			||||||
		body, err := ioutil.ReadAll(r.Body)
 | 
							body, err := ioutil.ReadAll(r.Body)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			w.WriteHeader(http.StatusInternalServerError)
 | 
								w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
@@ -73,7 +47,7 @@ func DatasetsPost(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if dataset.Images == nil {
 | 
							if dataset.Images == nil {
 | 
				
			||||||
			dataset.Images = ImageList{}
 | 
								dataset.Images = models.ImageList{}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		dataset.UserID = account.ID
 | 
							dataset.UserID = account.ID
 | 
				
			||||||
		if err := configs.ORMDB().Create(&dataset).Error; err != nil {
 | 
							if err := configs.ORMDB().Create(&dataset).Error; err != nil {
 | 
				
			||||||
@@ -88,7 +62,7 @@ func DatasetsPost(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// 獲取數據集
 | 
					// 獲取數據集
 | 
				
			||||||
func DatasetsItemGet(w http.ResponseWriter, r *http.Request) {
 | 
					func DatasetsItemGet(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	dataset := Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)}
 | 
						dataset := models.Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)}
 | 
				
			||||||
	if err := configs.ORMDB().Find(&dataset).Error; err != nil {
 | 
						if err := configs.ORMDB().Find(&dataset).Error; err != nil {
 | 
				
			||||||
		w.WriteHeader(http.StatusNotFound)
 | 
							w.WriteHeader(http.StatusNotFound)
 | 
				
			||||||
		w.Write([]byte("404 - Not Found"))
 | 
							w.Write([]byte("404 - Not Found"))
 | 
				
			||||||
@@ -101,7 +75,7 @@ func DatasetsItemGet(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
// 修改數據集
 | 
					// 修改數據集
 | 
				
			||||||
func DatasetsItemPatch(w http.ResponseWriter, r *http.Request) {
 | 
					func DatasetsItemPatch(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	models.AccountRead(w, r, func(account *models.Account) {
 | 
						models.AccountRead(w, r, func(account *models.Account) {
 | 
				
			||||||
		var dataset Dataset = Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)}
 | 
							var dataset models.Dataset = models.Dataset{ID: utils.ParamInt(mux.Vars(r)["dataset_id"], 0)}
 | 
				
			||||||
		if err := configs.ORMDB().Find(&dataset).Error; err != nil {
 | 
							if err := configs.ORMDB().Find(&dataset).Error; err != nil {
 | 
				
			||||||
			w.WriteHeader(http.StatusNotFound)
 | 
								w.WriteHeader(http.StatusNotFound)
 | 
				
			||||||
			w.Write([]byte("404 - Not Found"))
 | 
								w.Write([]byte("404 - Not Found"))
 | 
				
			||||||
@@ -120,7 +94,7 @@ func DatasetsItemPatch(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
			dataset.Info = info
 | 
								dataset.Info = info
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if images, ok := form["images"].([]interface{}); ok {
 | 
							if images, ok := form["images"].([]interface{}); ok {
 | 
				
			||||||
			var image_list ImageList
 | 
								var image_list models.ImageList
 | 
				
			||||||
			for _, image := range images {
 | 
								for _, image := range images {
 | 
				
			||||||
				if image, ok := image.(string); ok {
 | 
									if image, ok := image.(string); ok {
 | 
				
			||||||
					image_list = append(image_list, image)
 | 
										image_list = append(image_list, image)
 | 
				
			||||||
@@ -142,7 +116,7 @@ func DatasetsItemPatch(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
func DatasetsItemDelete(w http.ResponseWriter, r *http.Request) {
 | 
					func DatasetsItemDelete(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	models.AccountRead(w, r, func(account *models.Account) {
 | 
						models.AccountRead(w, r, func(account *models.Account) {
 | 
				
			||||||
		// 獲取數據集
 | 
							// 獲取數據集
 | 
				
			||||||
		dataset := Dataset{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
 | 
							dataset := models.Dataset{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
 | 
				
			||||||
		if err := configs.ORMDB().Find(&dataset).Error; err != nil {
 | 
							if err := configs.ORMDB().Find(&dataset).Error; err != nil {
 | 
				
			||||||
			w.WriteHeader(http.StatusNotFound)
 | 
								w.WriteHeader(http.StatusNotFound)
 | 
				
			||||||
			w.Write([]byte("404 - Not Found"))
 | 
								w.Write([]byte("404 - Not Found"))
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -92,8 +92,10 @@ func ModelsPost(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
			log.Println(err)
 | 
								log.Println(err)
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 直接提交訓練任務
 | 
							// 直接提交訓練任務
 | 
				
			||||||
		go model.Train()
 | 
							go model.Train()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 返回創建的模型
 | 
							// 返回創建的模型
 | 
				
			||||||
		w.Header().Set("Content-Type", "application/json; charset=utf-8")
 | 
							w.Header().Set("Content-Type", "application/json; charset=utf-8")
 | 
				
			||||||
		w.Write(utils.ToJSON(model))
 | 
							w.Write(utils.ToJSON(model))
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										38
									
								
								test.sh
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								test.sh
									
									
									
									
									
								
							@@ -2,6 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# 記錄開始時間戳
 | 
					# 記錄開始時間戳
 | 
				
			||||||
start_time=$(date +%s)
 | 
					start_time=$(date +%s)
 | 
				
			||||||
 | 
					rm -f data/sqlite3.db
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 流程測試, 啓動服務, 設定進程名 go_test
 | 
					# 流程測試, 啓動服務, 設定進程名 go_test
 | 
				
			||||||
go run main.go -procname go_test &
 | 
					go run main.go -procname go_test &
 | 
				
			||||||
@@ -48,7 +49,7 @@ echo "dataset_id: $dataset_id"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 修改數據集, images 中增加 url (PATCH /api/datasets/:id)
 | 
					# 修改數據集, images 中增加 url (PATCH /api/datasets/:id)
 | 
				
			||||||
response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"]}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets/$dataset_id)
 | 
					response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["https://img.gameui.net/article-7258-1677745322000@1x456.webp","https://img.gameui.net/article-6477-1682109454000@1x456.webp"]}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/datasets/$dataset_id)
 | 
				
			||||||
[[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}"
 | 
					[[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -72,22 +73,27 @@ echo "model_id: $model_id"
 | 
				
			|||||||
#[[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}"
 | 
					#[[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 循環獲取模型訓練進度, 直到訓練完成
 | 
					## 循環獲取模型訓練進度, 直到訓練完成
 | 
				
			||||||
while true; do
 | 
					#while true; do
 | 
				
			||||||
    # 獲取模型訓練進度 (GET /api/models/:id)
 | 
					#    # 獲取模型訓練進度 (GET /api/models/:id)
 | 
				
			||||||
    response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id)
 | 
					#    response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models/$model_id)
 | 
				
			||||||
    [[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}"
 | 
					#    [[ ${response: -3} -eq 200 ]] && { echo "獲取模型訓練進度成功: ${response%???}"; } || exit_service "獲取模型訓練進度失敗: ${response%???}"
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    # 取出進度字段的值, 值爲 int
 | 
				
			||||||
 | 
					#    progress=$(echo "${response%???}" | grep -o '"progress": [0-9]*' | awk '{print $2}')
 | 
				
			||||||
 | 
					#    echo "progress: $progress"
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    # 如果進度爲 100, 訓練完成, 跳出循環
 | 
				
			||||||
 | 
					#    [[ $progress -eq 100 ]] && { echo "訓練完成"; break; }
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    # 休眠 5 秒
 | 
				
			||||||
 | 
					#    sleep 5
 | 
				
			||||||
 | 
					#done
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 取出進度字段的值, 值爲 int
 | 
					# 服務器列表
 | 
				
			||||||
    progress=$(echo "${response%???}" | grep -o '"progress": [0-9]*' | awk '{print $2}')
 | 
					response=$(curl -X GET -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/servers)
 | 
				
			||||||
    echo "progress: $progress"
 | 
					[[ ${response: -3} -eq 200 ]] && { echo "獲取服務器列表成功: ${response%???}"; } || exit_service "獲取服務器列表失敗: ${response%???}"
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # 如果進度爲 100, 訓練完成, 跳出循環
 | 
					 | 
				
			||||||
    [[ $progress -eq 100 ]] && { echo "訓練完成"; break; }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # 休眠 5 秒
 | 
					 | 
				
			||||||
    sleep 5
 | 
					 | 
				
			||||||
done
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					sleep 10
 | 
				
			||||||
 | 
					
 | 
				
			||||||
exit_service "測試結束, 全部通過, 用費時間: $(($(date +%s) - $start_time)) 秒"
 | 
					exit_service "測試結束, 全部通過, 用費時間: $(($(date +%s) - $start_time)) 秒"
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user