初始化默认服务器
This commit is contained in:
@@ -18,7 +18,7 @@ import (
|
||||
cvm "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm/v20170312"
|
||||
)
|
||||
|
||||
type ModelList []string
|
||||
type ModelList []int
|
||||
|
||||
func (list *ModelList) Scan(value interface{}) error {
|
||||
return json.Unmarshal(value.([]byte), list)
|
||||
@@ -75,9 +75,62 @@ func init() {
|
||||
|
||||
// 检查默认服务器是否存在, 不存在则添加
|
||||
func InitDefaultServer() (err error) {
|
||||
if err = configs.ORMDB().Where("id = ?", "default").First(&Server{}).Error; err != nil {
|
||||
server := Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
|
||||
err = configs.ORMDB().Create(&server).Error
|
||||
var server Server
|
||||
if err = configs.ORMDB().Where("id = ?", "default").First(&server).Error; err != nil {
|
||||
server = Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
|
||||
if err = configs.ORMDB().Create(&server).Error; err != nil {
|
||||
return fmt.Errorf("创建默认服务器失败: %v", err)
|
||||
}
|
||||
}
|
||||
// 初始化服务器中的模型列表
|
||||
if err = server.InitModels(); err != nil {
|
||||
return fmt.Errorf("初始化服务器中的模型列表失败: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化服务器中的模型列表
|
||||
func (server *Server) InitModels() (err error) {
|
||||
// 获取服务器中的模型列表
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 解码JSON (数组)
|
||||
var data []map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 从数据库检查此模型hash是否存在
|
||||
for _, item := range data {
|
||||
var model Model
|
||||
if err := configs.ORMDB().Where("hash = ?", item["sha256"].(string)).First(&model).Error; err != nil {
|
||||
// 不存在则添加
|
||||
model = Model{
|
||||
Name: item["model_name"].(string),
|
||||
Hash: item["sha256"].(string),
|
||||
ModelCheckpoint: item["title"].(string),
|
||||
ModelPath: item["filename"].(string), // TODO: 下载到本地
|
||||
ServerID: server.ID,
|
||||
Type: "ckp",
|
||||
}
|
||||
// TODO: 下载到本地
|
||||
// 添加到数据库
|
||||
if err := configs.ORMDB().Create(&model).Error; err != nil {
|
||||
return fmt.Errorf("添加模型到数据库失败: %v", err)
|
||||
}
|
||||
// 添加到模型列表
|
||||
server.Models = append(server.Models, model.ID)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
if err := configs.ORMDB().Save(&server).Error; err != nil {
|
||||
return fmt.Errorf("更新数据库失败: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user