初始化默认服务器
This commit is contained in:
@@ -11,10 +11,10 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"image/png"
|
||||
|
||||
"github.com/chai2010/webp"
|
||||
@@ -22,26 +22,27 @@ import (
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ID int `json:"id" gorm:"primary_key"` // 模型ID
|
||||
Name string `json:"name"` // 模型名稱
|
||||
Info string `json:"info"` // 模型描述
|
||||
Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti)
|
||||
TriggerWords string `json:"trigger_words"` // 觸發詞
|
||||
BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2)
|
||||
ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑)
|
||||
Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error)
|
||||
Progress int `json:"progress"` // (0-100)
|
||||
Image string `json:"image"` // 封面圖片實際地址
|
||||
Hash string `json:"hash"` // 模型哈希值
|
||||
Epochs int `json:"epochs"` // 訓練步數
|
||||
LearningRate float32 `json:"learning_rate"` // 學習率(0.000005)
|
||||
Tags TagList `json:"tags"` // 模型標籤(標籤名數組)
|
||||
UserID int `json:"user_id"` // 模型的所有者
|
||||
DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID
|
||||
ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機)
|
||||
Stars StarList `json:"stars"` // 模型的收藏者
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
ID int `json:"id" gorm:"primary_key"` // 模型ID
|
||||
Name string `json:"name"` // 模型名稱
|
||||
ModelCheckpoint string `json:"model_checkpoint"` // 模型檢查點
|
||||
Info string `json:"info"` // 模型描述
|
||||
Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti)
|
||||
TriggerWords string `json:"trigger_words"` // 觸發詞
|
||||
BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2)
|
||||
ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑)
|
||||
Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error)
|
||||
Progress int `json:"progress"` // (0-100)
|
||||
Image string `json:"image"` // 封面圖片實際地址
|
||||
Hash string `json:"hash"` // 模型哈希值(sha256)
|
||||
Epochs int `json:"epochs"` // 訓練步數
|
||||
LearningRate float32 `json:"learning_rate"` // 學習率(0.000005)
|
||||
Tags TagList `json:"tags"` // 模型標籤(標籤名數組)
|
||||
UserID int `json:"user_id"` // 模型的所有者
|
||||
DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID
|
||||
ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機)
|
||||
Stars StarList `json:"stars"` // 模型的收藏者
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -51,6 +52,10 @@ func init() {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
// 清除所有hash长度小于32的模型
|
||||
configs.ORMDB().Where("length(hash) < 32").Delete(&Model{})
|
||||
// 清除所有type为空的模型
|
||||
configs.ORMDB().Where("type = ?", "").Delete(&Model{})
|
||||
}
|
||||
|
||||
// 从数据库加载
|
||||
@@ -66,17 +71,16 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
|
||||
if model.ServerID == "" {
|
||||
log.Println("模型未部署到推理機, 开始部署模型")
|
||||
|
||||
// 寻找一台就绪的推理机, 且已部署模型目标模型
|
||||
if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("models LIKE ?", "%"+strconv.Itoa(model.ID)+"%").First(&server).Error; err != nil {
|
||||
// 寻找一台就绪的推理机, 且模型位置仍有空余
|
||||
if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("length(models) < ?", 5).First(&server).Error; err != nil {
|
||||
log.Println("创建一台新的推理机: 当前禁止创建新服务器")
|
||||
return
|
||||
}
|
||||
// 上传目标模型到推理机
|
||||
log.Println("上传模型到推理机: 当前禁止上传模型")
|
||||
// 寻找一台就绪的且模型位置仍有空余的推理机
|
||||
if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("length(models) < ?", 5).First(&server).Error; err != nil {
|
||||
log.Println("创建一台新的推理机: 当前禁止创建新服务器")
|
||||
return
|
||||
}
|
||||
|
||||
// 打印为格式化的json
|
||||
data, _ := json.MarshalIndent(server, "", " ")
|
||||
fmt.Println(string(data))
|
||||
|
||||
//var form = struct {
|
||||
// Components []struct {
|
||||
// ID int `json:"id"`
|
||||
@@ -106,7 +110,7 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
|
||||
//}
|
||||
|
||||
// 记录到推理机
|
||||
server.Models = append(server.Models, strconv.Itoa(model.ID))
|
||||
server.Models = append(server.Models, model.ID)
|
||||
configs.ORMDB().Save(&server)
|
||||
|
||||
// 记录到模型
|
||||
|
@@ -18,7 +18,7 @@ import (
|
||||
cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312"
|
||||
)
|
||||
|
||||
type ModelList []string
|
||||
type ModelList []int
|
||||
|
||||
func (list *ModelList) Scan(value interface{}) error {
|
||||
return json.Unmarshal(value.([]byte), list)
|
||||
@@ -75,9 +75,62 @@ func init() {
|
||||
|
||||
// 检查默认服务器是否存在, 不存在则添加
|
||||
func InitDefaultServer() (err error) {
|
||||
if err = configs.ORMDB().Where("id = ?", "default").First(&Server{}).Error; err != nil {
|
||||
server := Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
|
||||
err = configs.ORMDB().Create(&server).Error
|
||||
var server Server
|
||||
if err = configs.ORMDB().Where("id = ?", "default").First(&server).Error; err != nil {
|
||||
server = Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
|
||||
if err = configs.ORMDB().Create(&server).Error; err != nil {
|
||||
return fmt.Errorf("创建默认服务器失败: %v", err)
|
||||
}
|
||||
}
|
||||
// 初始化服务器中的模型列表
|
||||
if err = server.InitModels(); err != nil {
|
||||
return fmt.Errorf("初始化服务器中的模型列表失败: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化服务器中的模型列表
|
||||
func (server *Server) InitModels() (err error) {
|
||||
// 获取服务器中的模型列表
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 解码JSON (数组)
|
||||
var data []map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 从数据库检查此模型hash是否存在
|
||||
for _, item := range data {
|
||||
var model Model
|
||||
if err := configs.ORMDB().Where("hash = ?", item["sha256"].(string)).First(&model).Error; err != nil {
|
||||
// 不存在则添加
|
||||
model = Model{
|
||||
Name: item["model_name"].(string),
|
||||
Hash: item["sha256"].(string),
|
||||
ModelCheckpoint: item["title"].(string),
|
||||
ModelPath: item["filename"].(string), // TODO: 下载到本地
|
||||
ServerID: server.ID,
|
||||
Type: "ckp",
|
||||
}
|
||||
// TODO: 下载到本地
|
||||
// 添加到数据库
|
||||
if err := configs.ORMDB().Create(&model).Error; err != nil {
|
||||
return fmt.Errorf("添加模型到数据库失败: %v", err)
|
||||
}
|
||||
// 添加到模型列表
|
||||
server.Models = append(server.Models, model.ID)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
if err := configs.ORMDB().Save(&server).Error; err != nil {
|
||||
return fmt.Errorf("更新数据库失败: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user