初始化默认服务器

This commit is contained in:
2023-07-02 04:22:24 +08:00
parent 6c4cd0cb92
commit 4afb719d13
3 changed files with 165 additions and 98 deletions

View File

@@ -18,7 +18,7 @@ import (
cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312"
)
type ModelList []string
type ModelList []int
func (list *ModelList) Scan(value interface{}) error {
return json.Unmarshal(value.([]byte), list)
@@ -75,9 +75,62 @@ func init() {
// 检查默认服务器是否存在, 不存在则添加
func InitDefaultServer() (err error) {
if err = configs.ORMDB().Where("id = ?", "default").First(&Server{}).Error; err != nil {
server := Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
err = configs.ORMDB().Create(&server).Error
var server Server
if err = configs.ORMDB().Where("id = ?", "default").First(&server).Error; err != nil {
server = Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
if err = configs.ORMDB().Create(&server).Error; err != nil {
return fmt.Errorf("创建默认服务器失败: %v", err)
}
}
// 初始化服务器中的模型列表
if err = server.InitModels(); err != nil {
return fmt.Errorf("初始化服务器中的模型列表失败: %v", err)
}
return
}
// 初始化服务器中的模型列表
func (server *Server) InitModels() (err error) {
// 获取服务器中的模型列表
resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port))
if err != nil {
return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
}
defer resp.Body.Close()
// 解码JSON (数组)
var data []map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return err
}
// 从数据库检查此模型hash是否存在
for _, item := range data {
var model Model
if err := configs.ORMDB().Where("hash = ?", item["sha256"].(string)).First(&model).Error; err != nil {
// 不存在则添加
model = Model{
Name: item["model_name"].(string),
Hash: item["sha256"].(string),
ModelCheckpoint: item["title"].(string),
ModelPath: item["filename"].(string), // TODO: 下载到本地
ServerID: server.ID,
Type: "ckp",
}
// TODO: 下载到本地
// 添加到数据库
if err := configs.ORMDB().Create(&model).Error; err != nil {
return fmt.Errorf("添加模型到数据库失败: %v", err)
}
// 添加到模型列表
server.Models = append(server.Models, model.ID)
}
}
// 更新数据库
if err := configs.ORMDB().Save(&server).Error; err != nil {
return fmt.Errorf("更新数据库失败: %v", err)
}
return
}