初始化默认服务器
This commit is contained in:
		@@ -11,10 +11,10 @@ import (
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"image/png"
 | 
			
		||||
 | 
			
		||||
	"github.com/chai2010/webp"
 | 
			
		||||
@@ -22,26 +22,27 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Model struct {
 | 
			
		||||
	ID           int       `json:"id" gorm:"primary_key"`    // 模型ID
 | 
			
		||||
	Name         string    `json:"name"`                     // 模型名稱
 | 
			
		||||
	Info         string    `json:"info"`                     // 模型描述
 | 
			
		||||
	Type         string    `json:"type"`                     // 模型類型(lora|ckp|hyper|ti)
 | 
			
		||||
	TriggerWords string    `json:"trigger_words"`            // 觸發詞
 | 
			
		||||
	BaseModel    string    `json:"base_model"`               // 基礎模型(SD1.5|SD2)
 | 
			
		||||
	ModelPath    string    `json:"model_path"`               // 模型路徑(實際存放在服務器上的路徑)
 | 
			
		||||
	Status       string    `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error)
 | 
			
		||||
	Progress     int       `json:"progress"`                 // (0-100)
 | 
			
		||||
	Image        string    `json:"image"`                    // 封面圖片實際地址
 | 
			
		||||
	Hash         string    `json:"hash"`                     // 模型哈希值
 | 
			
		||||
	Epochs       int       `json:"epochs"`                   // 訓練步數
 | 
			
		||||
	LearningRate float32   `json:"learning_rate"`            // 學習率(0.000005)
 | 
			
		||||
	Tags         TagList   `json:"tags"`                     // 模型標籤(標籤名數組)
 | 
			
		||||
	UserID       int       `json:"user_id"`                  // 模型的所有者
 | 
			
		||||
	DatasetID    int       `json:"dataset_id"`               // 模型所使用的數據集ID
 | 
			
		||||
	ServerID     string    `json:"server_id"`                // 模型所在服務器(訓練機或推理機)
 | 
			
		||||
	Stars        StarList  `json:"stars"`                    // 模型的收藏者
 | 
			
		||||
	CreatedAt    time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
			
		||||
	UpdatedAt    time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
			
		||||
	ID              int       `json:"id" gorm:"primary_key"`    // 模型ID
 | 
			
		||||
	Name            string    `json:"name"`                     // 模型名稱
 | 
			
		||||
	ModelCheckpoint string    `json:"model_checkpoint"`         // 模型檢查點
 | 
			
		||||
	Info            string    `json:"info"`                     // 模型描述
 | 
			
		||||
	Type            string    `json:"type"`                     // 模型類型(lora|ckp|hyper|ti)
 | 
			
		||||
	TriggerWords    string    `json:"trigger_words"`            // 觸發詞
 | 
			
		||||
	BaseModel       string    `json:"base_model"`               // 基礎模型(SD1.5|SD2)
 | 
			
		||||
	ModelPath       string    `json:"model_path"`               // 模型路徑(實際存放在服務器上的路徑)
 | 
			
		||||
	Status          string    `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error)
 | 
			
		||||
	Progress        int       `json:"progress"`                 // (0-100)
 | 
			
		||||
	Image           string    `json:"image"`                    // 封面圖片實際地址
 | 
			
		||||
	Hash            string    `json:"hash"`                     // 模型哈希值(sha256)
 | 
			
		||||
	Epochs          int       `json:"epochs"`                   // 訓練步數
 | 
			
		||||
	LearningRate    float32   `json:"learning_rate"`            // 學習率(0.000005)
 | 
			
		||||
	Tags            TagList   `json:"tags"`                     // 模型標籤(標籤名數組)
 | 
			
		||||
	UserID          int       `json:"user_id"`                  // 模型的所有者
 | 
			
		||||
	DatasetID       int       `json:"dataset_id"`               // 模型所使用的數據集ID
 | 
			
		||||
	ServerID        string    `json:"server_id"`                // 模型所在服務器(訓練機或推理機)
 | 
			
		||||
	Stars           StarList  `json:"stars"`                    // 模型的收藏者
 | 
			
		||||
	CreatedAt       time.Time `json:"created_at" gorm:"autoCreateTime"`
 | 
			
		||||
	UpdatedAt       time.Time `json:"updated_at" gorm:"autoUpdateTime"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
@@ -51,6 +52,10 @@ func init() {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// 清除所有hash长度小于32的模型
 | 
			
		||||
	configs.ORMDB().Where("length(hash) < 32").Delete(&Model{})
 | 
			
		||||
	// 清除所有type为空的模型
 | 
			
		||||
	configs.ORMDB().Where("type = ?", "").Delete(&Model{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 从数据库加载
 | 
			
		||||
@@ -66,17 +71,16 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
 | 
			
		||||
	if model.ServerID == "" {
 | 
			
		||||
		log.Println("模型未部署到推理機, 开始部署模型")
 | 
			
		||||
 | 
			
		||||
		// 寻找一台就绪的推理机, 且已部署模型目标模型
 | 
			
		||||
		if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("models LIKE ?", "%"+strconv.Itoa(model.ID)+"%").First(&server).Error; err != nil {
 | 
			
		||||
			// 寻找一台就绪的推理机, 且模型位置仍有空余
 | 
			
		||||
			if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("length(models) < ?", 5).First(&server).Error; err != nil {
 | 
			
		||||
				log.Println("创建一台新的推理机: 当前禁止创建新服务器")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			// 上传目标模型到推理机
 | 
			
		||||
			log.Println("上传模型到推理机: 当前禁止上传模型")
 | 
			
		||||
		// 寻找一台就绪的且模型位置仍有空余的推理机
 | 
			
		||||
		if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("length(models) < ?", 5).First(&server).Error; err != nil {
 | 
			
		||||
			log.Println("创建一台新的推理机: 当前禁止创建新服务器")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 打印为格式化的json
 | 
			
		||||
		data, _ := json.MarshalIndent(server, "", "    ")
 | 
			
		||||
		fmt.Println(string(data))
 | 
			
		||||
 | 
			
		||||
		//var form = struct {
 | 
			
		||||
		//	Components []struct {
 | 
			
		||||
		//		ID    int    `json:"id"`
 | 
			
		||||
@@ -106,7 +110,7 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
 | 
			
		||||
		//}
 | 
			
		||||
 | 
			
		||||
		// 记录到推理机
 | 
			
		||||
		server.Models = append(server.Models, strconv.Itoa(model.ID))
 | 
			
		||||
		server.Models = append(server.Models, model.ID)
 | 
			
		||||
		configs.ORMDB().Save(&server)
 | 
			
		||||
 | 
			
		||||
		// 记录到模型
 | 
			
		||||
 
 | 
			
		||||
@@ -18,7 +18,7 @@ import (
 | 
			
		||||
	cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ModelList []string
 | 
			
		||||
type ModelList []int
 | 
			
		||||
 | 
			
		||||
func (list *ModelList) Scan(value interface{}) error {
 | 
			
		||||
	return json.Unmarshal(value.([]byte), list)
 | 
			
		||||
@@ -75,9 +75,62 @@ func init() {
 | 
			
		||||
 | 
			
		||||
// 检查默认服务器是否存在, 不存在则添加
 | 
			
		||||
func InitDefaultServer() (err error) {
 | 
			
		||||
	if err = configs.ORMDB().Where("id = ?", "default").First(&Server{}).Error; err != nil {
 | 
			
		||||
		server := Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
 | 
			
		||||
		err = configs.ORMDB().Create(&server).Error
 | 
			
		||||
	var server Server
 | 
			
		||||
	if err = configs.ORMDB().Where("id = ?", "default").First(&server).Error; err != nil {
 | 
			
		||||
		server = Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
 | 
			
		||||
		if err = configs.ORMDB().Create(&server).Error; err != nil {
 | 
			
		||||
			return fmt.Errorf("创建默认服务器失败: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// 初始化服务器中的模型列表
 | 
			
		||||
	if err = server.InitModels(); err != nil {
 | 
			
		||||
		return fmt.Errorf("初始化服务器中的模型列表失败: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 初始化服务器中的模型列表
 | 
			
		||||
func (server *Server) InitModels() (err error) {
 | 
			
		||||
	// 获取服务器中的模型列表
 | 
			
		||||
	resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	// 解码JSON (数组)
 | 
			
		||||
	var data []map[string]interface{}
 | 
			
		||||
	if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 从数据库检查此模型hash是否存在
 | 
			
		||||
	for _, item := range data {
 | 
			
		||||
		var model Model
 | 
			
		||||
		if err := configs.ORMDB().Where("hash = ?", item["sha256"].(string)).First(&model).Error; err != nil {
 | 
			
		||||
			// 不存在则添加
 | 
			
		||||
			model = Model{
 | 
			
		||||
				Name:            item["model_name"].(string),
 | 
			
		||||
				Hash:            item["sha256"].(string),
 | 
			
		||||
				ModelCheckpoint: item["title"].(string),
 | 
			
		||||
				ModelPath:       item["filename"].(string), // TODO: 下载到本地
 | 
			
		||||
				ServerID:        server.ID,
 | 
			
		||||
				Type:            "ckp",
 | 
			
		||||
			}
 | 
			
		||||
			// TODO: 下载到本地
 | 
			
		||||
			// 添加到数据库
 | 
			
		||||
			if err := configs.ORMDB().Create(&model).Error; err != nil {
 | 
			
		||||
				return fmt.Errorf("添加模型到数据库失败: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			// 添加到模型列表
 | 
			
		||||
			server.Models = append(server.Models, model.ID)
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新数据库
 | 
			
		||||
	if err := configs.ORMDB().Save(&server).Error; err != nil {
 | 
			
		||||
		return fmt.Errorf("更新数据库失败: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user