下載圖像

This commit is contained in:
2023-05-27 16:40:04 +08:00
parent 788f166909
commit 0966a3c83e
7 changed files with 203 additions and 62 deletions

21
models/Dataset.go Normal file
View File

@@ -0,0 +1,21 @@
package models
import (
"main/configs"
"time"
)
// 數據集模型
type Dataset struct {
ID int `json:"id" gorm:"primary_key"`
Name string `json:"name"`
Info string `json:"info"`
Images ImageList `json:"images"`
UserID int `json:"user_id"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
func init() {
configs.ORMDB().AutoMigrate(&Dataset{})
}

View File

@@ -1,6 +1,8 @@
package models
import (
"database/sql/driver"
"encoding/json"
"main/configs"
"time"
)
@@ -25,3 +27,13 @@ type Image struct {
func init() {
configs.ORMDB().AutoMigrate(&Image{})
}
type ImageList []string
func (list *ImageList) Scan(value interface{}) error {
return json.Unmarshal(value.([]byte), list)
}
func (list ImageList) Value() (driver.Value, error) {
return json.Marshal(list)
}

View File

@@ -1,8 +1,15 @@
package models
import (
"crypto/md5"
"encoding/json"
"fmt"
"io/ioutil"
"main/configs"
"net/http"
"os"
"os/exec"
"path/filepath"
"time"
)
@@ -21,6 +28,7 @@ type Model struct {
Tags TagList `json:"tags"`
UserID int `json:"user_id"`
DatasetID int `json:"dataset_id"`
ServerID int `json:"server_id"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
@@ -30,21 +38,130 @@ func init() {
}
func (model *Model) Train() (err error) {
if model.Type == "lora" {
fmt.Println("lora")
// 獲取一臺空閒的訓練機
var server Server
if err = configs.ORMDB().Where("status = ?", "閒置").First(&server).Error; err != nil {
fmt.Println(err)
return
}
if model.Type == "ckp" {
fmt.Println("ckp")
// 獲取數據集
var dataset Dataset = Dataset{ID: model.DatasetID}
if err = configs.ORMDB().First(&dataset).Error; err != nil {
fmt.Println(err)
return
}
if model.Type == "hyper" {
fmt.Println("hyper")
// 更新模型狀態
model.ServerID = server.ID
model.Status = "training"
if err = configs.ORMDB().Save(&model).Error; err != nil {
fmt.Println(err)
return
}
if model.Type == "ti" {
fmt.Println("ti")
return
// 創建數據集目錄
dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images")
if err := os.MkdirAll(dirPath, 0755); err != nil {
fmt.Println(err)
return err
}
return
// 將數據下載到本地
for index, url := range dataset.Images {
fmt.Println("下載數據到本地:", index, url)
// 下載到臨時目錄
resp, err := http.Get(url)
if err != nil {
fmt.Println("下載失敗:", err)
continue
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Println("保存失敗:", err)
continue
}
// 保存文件到本地目錄下(自動創建目錄,文件名為url的md5值)
filename := fmt.Sprintf("%x", md5.Sum([]byte(url)))
filePath := filepath.Join(dirPath, filename)
if err := os.MkdirAll(dirPath, 0755); err != nil {
fmt.Println(err)
continue
}
if err := ioutil.WriteFile(filePath, data, 0644); err != nil {
fmt.Println(err)
continue
}
}
fmt.Println("數據下載完成")
// 檢查目錄下是否有文件, 如果沒有文件則返回錯誤
files, err := ioutil.ReadDir(dirPath)
if err != nil {
fmt.Println(err)
return err
}
if len(files) == 0 {
fmt.Println("目錄下沒有文件")
return fmt.Errorf("目錄下沒有文件")
}
// 將文件全部上傳到訓練機(使用scp命令)
err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
if err != nil {
fmt.Println(err)
return err
}
// 刪除本地臨時目錄
if err := os.RemoveAll(dirPath); err != nil {
fmt.Println(err)
return err
}
// 将基础模型上传到训练机(使用scp命令)
baseModelPath := filepath.Join("data/models", model.BaseModel)
fmt.Println("baseModelPath:", baseModelPath)
err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
if err != nil {
fmt.Println(err)
return err
}
// 進行訓練(訓練機上調用訓練webapi接口:參數)
resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
if err != nil {
fmt.Println(err)
return err
}
defer resp.Body.Close()
// 循環監聽訓練進度
for {
// 訓練機上調用訓練webapi接口:獲取訓練進度
resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
if err != nil {
fmt.Println(err)
return err
}
defer resp.Body.Close()
// 更新本地訓練進度
var progress int
if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
fmt.Println(err)
return err
}
}
// TODO: 訓練完成後將模型下載到本地
}

View File

@@ -11,7 +11,7 @@ type Server struct {
Type string `json:"type"` // (訓練|推理)
IP string `json:"ip"`
Port int `json:"port"`
Status string `json:"status"` // (異常|初始化|就緒|工作中|關閉中)
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
Username string `json:"username"`
Password string `json:"password"`
Models []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
@@ -21,4 +21,13 @@ type Server struct {
func init() {
configs.ORMDB().AutoMigrate(&Server{})
// 添加一個預設的訓練機
configs.ORMDB().Create(&Server{
Name: "GPU T4",
Type: "train",
IP: "106.15.192.42",
Port: 7860,
Status: "閒置",
})
}