模型訓練
This commit is contained in:
		
							
								
								
									
										117
									
								
								models/Model.go
									
									
									
									
									
								
							
							
						
						
									
										117
									
								
								models/Model.go
									
									
									
									
									
								
							@@ -2,13 +2,12 @@ package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/md5"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"main/configs"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
@@ -115,53 +114,91 @@ func (model *Model) Train() (err error) {
 | 
			
		||||
		return fmt.Errorf("目錄下沒有文件")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 將文件全部上傳到訓練機(使用scp命令)
 | 
			
		||||
	err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	// 按類型執行訓練任務
 | 
			
		||||
	if model.Type == "dreambooth" {
 | 
			
		||||
		// 創建數據庫模型
 | 
			
		||||
		fmt.Println("創建數據庫模型 ======================================")
 | 
			
		||||
		resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println("創建訓練任務失敗:", err.Error())
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
		// 打印返回的結果
 | 
			
		||||
		body, err := ioutil.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println("解碼任務數據失敗:", err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		fmt.Println("預覽:", string(body))
 | 
			
		||||
 | 
			
		||||
		// 上傳數據到訓練機
 | 
			
		||||
 | 
			
		||||
		// 執行訓練命令
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 刪除本地臨時目錄
 | 
			
		||||
	if err := os.RemoveAll(dirPath); err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 将基础模型上传到训练机(使用scp命令)
 | 
			
		||||
	baseModelPath := filepath.Join("data/models", model.BaseModel)
 | 
			
		||||
	fmt.Println("baseModelPath:", baseModelPath)
 | 
			
		||||
	err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 進行訓練(訓練機上調用訓練webapi接口:參數)
 | 
			
		||||
	resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	// 循環監聽訓練進度
 | 
			
		||||
	for {
 | 
			
		||||
		// 訓練機上調用訓練webapi接口:獲取訓練進度
 | 
			
		||||
		resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
 | 
			
		||||
	if model.Type == "lora" {
 | 
			
		||||
		// 創建數據庫模型
 | 
			
		||||
		formData := url.Values{}
 | 
			
		||||
		formData.Set("name", model.Name)
 | 
			
		||||
		formData.Set("type", model.Type)
 | 
			
		||||
		resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println(err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
		// 更新本地訓練進度
 | 
			
		||||
		var progress int
 | 
			
		||||
		if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
 | 
			
		||||
			fmt.Println(err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		// 上傳數據到訓練機
 | 
			
		||||
 | 
			
		||||
		// 執行訓練命令
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄
 | 
			
		||||
	//err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run()
 | 
			
		||||
	//if err != nil {
 | 
			
		||||
	//	fmt.Println(err)
 | 
			
		||||
	//	return err
 | 
			
		||||
	//}
 | 
			
		||||
	//// 刪除本地臨時目錄
 | 
			
		||||
	//if err := os.RemoveAll(dirPath); err != nil {
 | 
			
		||||
	//	fmt.Println(err)
 | 
			
		||||
	//	return err
 | 
			
		||||
	//}
 | 
			
		||||
	//// 将基础模型上传到训练机(使用scp命令)
 | 
			
		||||
	//baseModelPath := filepath.Join("data/models", model.BaseModel)
 | 
			
		||||
	//fmt.Println("baseModelPath:", baseModelPath)
 | 
			
		||||
	//err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
			
		||||
	//if err != nil {
 | 
			
		||||
	//	fmt.Println(err)
 | 
			
		||||
	//	return err
 | 
			
		||||
	//}
 | 
			
		||||
	//// 進行訓練(訓練機上調用訓練webapi接口:參數)
 | 
			
		||||
	//resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
 | 
			
		||||
	//if err != nil {
 | 
			
		||||
	//	fmt.Println(err)
 | 
			
		||||
	//	return err
 | 
			
		||||
	//}
 | 
			
		||||
	//defer resp.Body.Close()
 | 
			
		||||
	//// 循環監聽訓練進度
 | 
			
		||||
	//for i := 0; i < 5; i++ {
 | 
			
		||||
	//	// 訓練機上調用訓練webapi接口:獲取訓練進度
 | 
			
		||||
	//	resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
 | 
			
		||||
	//	if err != nil {
 | 
			
		||||
	//		fmt.Println(err)
 | 
			
		||||
	//		return err
 | 
			
		||||
	//	}
 | 
			
		||||
	//	defer resp.Body.Close()
 | 
			
		||||
	//// 更新本地訓練進度
 | 
			
		||||
	//	var progress int
 | 
			
		||||
	//	if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
 | 
			
		||||
	//		fmt.Println(err)
 | 
			
		||||
	//		return err
 | 
			
		||||
	//	}
 | 
			
		||||
	//}
 | 
			
		||||
	//
 | 
			
		||||
	// TODO: 訓練完成後將模型下載到本地
 | 
			
		||||
	return nil
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ type Server struct {
 | 
			
		||||
	IP        string                   `json:"ip"`
 | 
			
		||||
	Port      int                      `json:"port"`
 | 
			
		||||
	Status    string                   `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
 | 
			
		||||
	Username  string                   `json:"username"`
 | 
			
		||||
	UserName  string                   `json:"username"`
 | 
			
		||||
	Password  string                   `json:"password"`
 | 
			
		||||
	Models    []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
 | 
			
		||||
	CreatedAt time.Time                `json:"created_at" gorm:"autoCreateTime"`
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user