模型訓練
This commit is contained in:
		
							
								
								
									
										103
									
								
								models/Model.go
									
									
									
									
									
								
							
							
						
						
									
										103
									
								
								models/Model.go
									
									
									
									
									
								
							@@ -2,13 +2,12 @@ package models
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"crypto/md5"
 | 
						"crypto/md5"
 | 
				
			||||||
	"encoding/json"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
	"main/configs"
 | 
						"main/configs"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"os/exec"
 | 
					 | 
				
			||||||
	"path/filepath"
 | 
						"path/filepath"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -115,53 +114,91 @@ func (model *Model) Train() (err error) {
 | 
				
			|||||||
		return fmt.Errorf("目錄下沒有文件")
 | 
							return fmt.Errorf("目錄下沒有文件")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 將文件全部上傳到訓練機(使用scp命令)
 | 
						// 按類型執行訓練任務
 | 
				
			||||||
	err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
						if model.Type == "dreambooth" {
 | 
				
			||||||
 | 
							// 創建數據庫模型
 | 
				
			||||||
 | 
							fmt.Println("創建數據庫模型 ======================================")
 | 
				
			||||||
 | 
							resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/dreambooth/createModel?new_model_name=%s&new_model_src=%s", server.IP, server.Port, model.Name, model.ModelPath), nil)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
		fmt.Println(err)
 | 
								fmt.Println("創建訓練任務失敗:", err.Error())
 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							defer resp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 刪除本地臨時目錄
 | 
							// 打印返回的結果
 | 
				
			||||||
	if err := os.RemoveAll(dirPath); err != nil {
 | 
							body, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
		fmt.Println(err)
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// 将基础模型上传到训练机(使用scp命令)
 | 
					 | 
				
			||||||
	baseModelPath := filepath.Join("data/models", model.BaseModel)
 | 
					 | 
				
			||||||
	fmt.Println("baseModelPath:", baseModelPath)
 | 
					 | 
				
			||||||
	err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
		fmt.Println(err)
 | 
								fmt.Println("解碼任務數據失敗:", err)
 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							fmt.Println("預覽:", string(body))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 進行訓練(訓練機上調用訓練webapi接口:參數)
 | 
							// 上傳數據到訓練機
 | 
				
			||||||
	resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
 | 
					
 | 
				
			||||||
 | 
							// 執行訓練命令
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if model.Type == "lora" {
 | 
				
			||||||
 | 
							// 創建數據庫模型
 | 
				
			||||||
 | 
							formData := url.Values{}
 | 
				
			||||||
 | 
							formData.Set("name", model.Name)
 | 
				
			||||||
 | 
							formData.Set("type", model.Type)
 | 
				
			||||||
 | 
							resp, err := http.PostForm(fmt.Sprintf("http://%s:%d/lora/createModel", server.IP, server.Port), formData)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			fmt.Println(err)
 | 
								fmt.Println(err)
 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		defer resp.Body.Close()
 | 
							defer resp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 循環監聽訓練進度
 | 
							// 上傳數據到訓練機
 | 
				
			||||||
	for {
 | 
					 | 
				
			||||||
		// 訓練機上調用訓練webapi接口:獲取訓練進度
 | 
					 | 
				
			||||||
		resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			fmt.Println(err)
 | 
					 | 
				
			||||||
			return err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		defer resp.Body.Close()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 更新本地訓練進度
 | 
							// 執行訓練命令
 | 
				
			||||||
		var progress int
 | 
					 | 
				
			||||||
		if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
 | 
					 | 
				
			||||||
			fmt.Println(err)
 | 
					 | 
				
			||||||
			return err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						//// 將文件全部上傳到訓練機, 使用scp命令,自動使用密碼登錄
 | 
				
			||||||
 | 
						//err = exec.Command("scp", "-r", dirPath, fmt.Sprintf("%s@%s:~/dataset_%d", server.UserName, server.IP, model.ID)).Run()
 | 
				
			||||||
 | 
						//if err != nil {
 | 
				
			||||||
 | 
						//	fmt.Println(err)
 | 
				
			||||||
 | 
						//	return err
 | 
				
			||||||
 | 
						//}
 | 
				
			||||||
 | 
						//// 刪除本地臨時目錄
 | 
				
			||||||
 | 
						//if err := os.RemoveAll(dirPath); err != nil {
 | 
				
			||||||
 | 
						//	fmt.Println(err)
 | 
				
			||||||
 | 
						//	return err
 | 
				
			||||||
 | 
						//}
 | 
				
			||||||
 | 
						//// 将基础模型上传到训练机(使用scp命令)
 | 
				
			||||||
 | 
						//baseModelPath := filepath.Join("data/models", model.BaseModel)
 | 
				
			||||||
 | 
						//fmt.Println("baseModelPath:", baseModelPath)
 | 
				
			||||||
 | 
						//err = exec.Command("scp", baseModelPath, fmt.Sprintf("%s:%s", server.IP, filepath.Dir(model.ModelPath))).Run()
 | 
				
			||||||
 | 
						//if err != nil {
 | 
				
			||||||
 | 
						//	fmt.Println(err)
 | 
				
			||||||
 | 
						//	return err
 | 
				
			||||||
 | 
						//}
 | 
				
			||||||
 | 
						//// 進行訓練(訓練機上調用訓練webapi接口:參數)
 | 
				
			||||||
 | 
						//resp, err := http.Post(fmt.Sprintf("http://%s:5000/train", server.IP), "application/json", nil)
 | 
				
			||||||
 | 
						//if err != nil {
 | 
				
			||||||
 | 
						//	fmt.Println(err)
 | 
				
			||||||
 | 
						//	return err
 | 
				
			||||||
 | 
						//}
 | 
				
			||||||
 | 
						//defer resp.Body.Close()
 | 
				
			||||||
 | 
						//// 循環監聽訓練進度
 | 
				
			||||||
 | 
						//for i := 0; i < 5; i++ {
 | 
				
			||||||
 | 
						//	// 訓練機上調用訓練webapi接口:獲取訓練進度
 | 
				
			||||||
 | 
						//	resp, err := http.Get(fmt.Sprintf("http://%s:5000/progress", server.IP))
 | 
				
			||||||
 | 
						//	if err != nil {
 | 
				
			||||||
 | 
						//		fmt.Println(err)
 | 
				
			||||||
 | 
						//		return err
 | 
				
			||||||
 | 
						//	}
 | 
				
			||||||
 | 
						//	defer resp.Body.Close()
 | 
				
			||||||
 | 
						//// 更新本地訓練進度
 | 
				
			||||||
 | 
						//	var progress int
 | 
				
			||||||
 | 
						//	if err := json.NewDecoder(resp.Body).Decode(&progress); err != nil {
 | 
				
			||||||
 | 
						//		fmt.Println(err)
 | 
				
			||||||
 | 
						//		return err
 | 
				
			||||||
 | 
						//	}
 | 
				
			||||||
 | 
						//}
 | 
				
			||||||
 | 
						//
 | 
				
			||||||
	// TODO: 訓練完成後將模型下載到本地
 | 
						// TODO: 訓練完成後將模型下載到本地
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,7 @@ type Server struct {
 | 
				
			|||||||
	IP        string                   `json:"ip"`
 | 
						IP        string                   `json:"ip"`
 | 
				
			||||||
	Port      int                      `json:"port"`
 | 
						Port      int                      `json:"port"`
 | 
				
			||||||
	Status    string                   `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
 | 
						Status    string                   `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
 | 
				
			||||||
	Username  string                   `json:"username"`
 | 
						UserName  string                   `json:"username"`
 | 
				
			||||||
	Password  string                   `json:"password"`
 | 
						Password  string                   `json:"password"`
 | 
				
			||||||
	Models    []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
 | 
						Models    []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
 | 
				
			||||||
	CreatedAt time.Time                `json:"created_at" gorm:"autoCreateTime"`
 | 
						CreatedAt time.Time                `json:"created_at" gorm:"autoCreateTime"`
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,7 @@ import (
 | 
				
			|||||||
	"github.com/gorilla/mux"
 | 
						"github.com/gorilla/mux"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 獲取用戶列表
 | 
					// 用戶列表
 | 
				
			||||||
func UsersGet(w http.ResponseWriter, r *http.Request) {
 | 
					func UsersGet(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	var listview models.ListView
 | 
						var listview models.ListView
 | 
				
			||||||
	listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
 | 
						listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										24
									
								
								test.sh
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								test.sh
									
									
									
									
									
								
							@@ -53,19 +53,19 @@ response=$(curl -X PATCH -H "Content-Type: application/json" -d '{"images":["htt
 | 
				
			|||||||
[[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}"
 | 
					[[ ${response: -3} -eq 200 ]] && { echo "修改數據集成功: ${response%???}"; } || exit_service "修改數據集失敗: ${response%???}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 訓練模型 (POST /api/models)
 | 
					## 訓練模型 (POST /api/models)
 | 
				
			||||||
response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"lora","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
 | 
					#response=$(curl -X POST -H "Content-Type: application/json" -d '{"name":"test","type":"dreambooth","trigger_words":"miao~","base_model":"sd1.5","epochs":20,"description":"test","dataset_id":'$dataset_id'}' -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
 | 
				
			||||||
[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}"
 | 
					#[[ ${response: -3} -eq 200 ]] && { echo "訓練模型任務已創建: ${response%???}"; } || exit_service "訓練模型任務創建失敗: ${response%???}"
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					## 取模型id的值, 值爲 int
 | 
				
			||||||
 | 
					#model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}')
 | 
				
			||||||
 | 
					#echo "model_id: $model_id"
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 模型列表 (GET /api/models)
 | 
				
			||||||
# 取模型id的值, 值爲 int
 | 
					 response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
 | 
				
			||||||
model_id=$(echo "${response%???}" | grep -o '"id": [0-9]*' | awk '{print $2}')
 | 
					 [[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}"
 | 
				
			||||||
echo "model_id: $model_id"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## 模型列表 (GET /api/models)
 | 
					 | 
				
			||||||
# response=$(curl -X GET -H "Content-Type: application/json" -b "session_id=$session_id" -s -w "%{http_code}" http://localhost:8080/api/models)
 | 
					 | 
				
			||||||
# [[ ${response: -3} -eq 200 ]] && { echo "獲取模型列表成功: ${response%???}"; } || exit_service "獲取模型列表失敗: ${response%???}"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 獲取模型訓練進度 (GET /api/models/:id)
 | 
					## 獲取模型訓練進度 (GET /api/models/:id)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ go build -o data/gameui-ai-server main.go
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# 上传文件
 | 
					# 上传文件
 | 
				
			||||||
scp ./data/gameui-ai-server root@47.103.40.152:~/gameui-ai-server_new
 | 
					scp ./data/gameui-ai-server root@47.103.40.152:~/gameui-ai-server_new
 | 
				
			||||||
 | 
					rm -rf ./data/gameui-ai-server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 重啓服務
 | 
					# 重啓服務
 | 
				
			||||||
ssh root@47.103.40.152 '''
 | 
					ssh root@47.103.40.152 '''
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user