下載圖像
This commit is contained in:
		
							
								
								
									
										21
									
								
								models/Dataset.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								models/Dataset.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
			
		||||
package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"main/configs"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 數據集模型
 | 
			
		||||
type Dataset struct {
 | 
			
		||||
	ID        int       `json:"id" gorm:"primary_key"`
 | 
			
		||||
	Name      string    `json:"name"`
 | 
			
		||||
	Info      string    `json:"info"`
 | 
			
		||||
	Images    ImageList `json:"images"`
 | 
			
		||||
	UserID    int       `json:"user_id"`
 | 
			
		||||
	CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
			
		||||
	UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	configs.ORMDB().AutoMigrate(&Dataset{})
 | 
			
		||||
}
 | 
			
		||||
@@ -1,6 +1,8 @@
 | 
			
		||||
package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"main/configs"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
@@ -25,3 +27,13 @@ type Image struct {
 | 
			
		||||
func init() {
 | 
			
		||||
	configs.ORMDB().AutoMigrate(&Image{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageList []string
 | 
			
		||||
 | 
			
		||||
func (list *ImageList) Scan(value interface{}) error {
 | 
			
		||||
	return json.Unmarshal(value.([]byte), list)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (list ImageList) Value() (driver.Value, error) {
 | 
			
		||||
	return json.Marshal(list)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										137
									
								
								models/Model.go
									
									
									
									
									
								
							
							
						
						
									
										137
									
								
								models/Model.go
									
									
									
									
									
								
							@@ -1,8 +1,15 @@
 | 
			
		||||
package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/md5"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"main/configs"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -21,6 +28,7 @@ type Model struct {
 | 
			
		||||
	Tags         TagList   `json:"tags"`
 | 
			
		||||
	UserID       int       `json:"user_id"`
 | 
			
		||||
	DatasetID    int       `json:"dataset_id"`
 | 
			
		||||
	ServerID     int       `json:"server_id"`
 | 
			
		||||
	CreatedAt    time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
			
		||||
	UpdatedAt    time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
			
		||||
}
 | 
			
		||||
@@ -30,21 +38,130 @@ func init() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (model *Model) Train() (err error) {
 | 
			
		||||
	if model.Type == "lora" {
 | 
			
		||||
		fmt.Println("lora")
 | 
			
		||||
 | 
			
		||||
	// 獲取一臺空閒的訓練機
 | 
			
		||||
	var server Server
 | 
			
		||||
	if err = configs.ORMDB().Where("status = ?", "閒置").First(&server).Error; err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if model.Type == "ckp" {
 | 
			
		||||
		fmt.Println("ckp")
 | 
			
		||||
 | 
			
		||||
	// 獲取數據集
 | 
			
		||||
	var dataset Dataset = Dataset{ID: model.DatasetID}
 | 
			
		||||
	if err = configs.ORMDB().First(&dataset).Error; err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if model.Type == "hyper" {
 | 
			
		||||
		fmt.Println("hyper")
 | 
			
		||||
 | 
			
		||||
	// 更新模型狀態
 | 
			
		||||
	model.ServerID = server.ID
 | 
			
		||||
	model.Status = "training"
 | 
			
		||||
	if err = configs.ORMDB().Save(&model).Error; err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if model.Type == "ti" {
 | 
			
		||||
		fmt.Println("ti")
 | 
			
		||||
		return
 | 
			
		||||
 | 
			
		||||
	// 創建數據集目錄
 | 
			
		||||
	dirPath := filepath.Join("data/datasets", fmt.Sprint(dataset.ID), "images")
 | 
			
		||||
	if err := os.MkdirAll(dirPath, 0755); err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
 | 
			
		||||
	// 將數據下載到本地
 | 
			
		||||
	for index, url := range dataset.Images {
 | 
			
		||||
		fmt.Println("下載數據到本地:", index, url)
 | 
			
		||||
 | 
			
		||||
		// 下載到臨時目錄
 | 
			
		||||
		resp, err := http.Get(url)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println("下載失敗:", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
		data, err := ioutil.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println("保存失敗:", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 保存文件到本地目錄下(自動創建目錄,文件名為url的md5值)
 | 
			
		||||
		filename := fmt.Sprintf("%x", md5.Sum([]byte(url)))
 | 
			
		||||
		filePath := filepath.Join(dirPath, filename)
 | 
			
		||||
		if err := os.MkdirAll(dirPath, 0755); err != nil {
 | 
			
		||||
			fmt.Println(err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err := ioutil.WriteFile(filePath, data, 0644); err != nil {
 | 
			
		||||
			fmt.Println(err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fmt.Println("數據下載完成")
 | 
			
		||||
 | 
			
		||||
	// 檢查目錄下是否有文件, 如果沒有文件則返回錯誤
 | 
			
		||||
	files, err := ioutil.ReadDir(dirPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(files) == 0 {
 | 
			
		||||
		fmt.Println("目錄下沒有文件")
 | 
			
		||||
		return fmt.Errorf("目錄下沒有文件")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 將文件全部上傳到訓練機(使用scp命令)
 | 
			
		||||
	err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 刪除本地臨時目錄
 | 
			
		||||
	if err := os.RemoveAll(dirPath); err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 将基础模型上传到训练机(使用scp命令)
 | 
			
		||||
	baseModelPath := filepath.Join("data/models", model.BaseModel)
 | 
			
		||||
	fmt.Println("baseModelPath:", baseModelPath)
 | 
			
		||||
	err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 進行訓練(訓練機上調用訓練webapi接口:參數)
 | 
			
		||||
	resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	// 循環監聽訓練進度
 | 
			
		||||
	for {
 | 
			
		||||
		// 訓練機上調用訓練webapi接口:獲取訓練進度
 | 
			
		||||
		resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println(err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
		// 更新本地訓練進度
 | 
			
		||||
		var progress int
 | 
			
		||||
		if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
 | 
			
		||||
			fmt.Println(err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: 訓練完成後將模型下載到本地
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -11,7 +11,7 @@ type Server struct {
 | 
			
		||||
	Type      string                   `json:"type"` // (訓練|推理)
 | 
			
		||||
	IP        string                   `json:"ip"`
 | 
			
		||||
	Port      int                      `json:"port"`
 | 
			
		||||
	Status    string                   `json:"status"` // (異常|初始化|就緒|工作中|關閉中)
 | 
			
		||||
	Status    string                   `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
 | 
			
		||||
	Username  string                   `json:"username"`
 | 
			
		||||
	Password  string                   `json:"password"`
 | 
			
		||||
	Models    []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
 | 
			
		||||
@@ -21,4 +21,13 @@ type Server struct {
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	configs.ORMDB().AutoMigrate(&Server{})
 | 
			
		||||
 | 
			
		||||
	// 添加一個預設的訓練機
 | 
			
		||||
	configs.ORMDB().Create(&Server{
 | 
			
		||||
		Name:   "GPU T4",
 | 
			
		||||
		Type:   "train",
 | 
			
		||||
		IP:     "106.15.192.42",
 | 
			
		||||
		Port:   7860,
 | 
			
		||||
		Status: "閒置",
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user