模型安全加载
This commit is contained in:
		@@ -94,7 +94,7 @@ func InitDefaultServer() (err error) {
 | 
			
		||||
 | 
			
		||||
// 初始化服务器中的模型列表
 | 
			
		||||
func (server *Server) InitModels() (err error) {
 | 
			
		||||
	// 先让服务器更新模型列表 /sdapi/v1/reload-checkpoint
 | 
			
		||||
	// 刷新检查点 /sdapi/v1/reload-checkpoint
 | 
			
		||||
	resp_update, err := http.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port), "application/json", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("更新服务器中的模型列表失败: %v", err)
 | 
			
		||||
@@ -117,10 +117,6 @@ func (server *Server) InitModels() (err error) {
 | 
			
		||||
 | 
			
		||||
	// 从数据库检查此模型hash是否存在
 | 
			
		||||
	for _, item := range data_lora {
 | 
			
		||||
		// 打印为格式化的JSON
 | 
			
		||||
		//b, _ := json.MarshalIndent(item, "", "  ")
 | 
			
		||||
		//fmt.Println(string(b))
 | 
			
		||||
 | 
			
		||||
		var model Model
 | 
			
		||||
		if err := configs.ORMDB().Where("hash = ?", item["path"].(string)).First(&model).Error; err != nil {
 | 
			
		||||
			// 不存在则添加
 | 
			
		||||
@@ -146,19 +142,27 @@ func (server *Server) InitModels() (err error) {
 | 
			
		||||
		return fmt.Errorf("更新数据库失败: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 模型HASH生成
 | 
			
		||||
	// 刷新检查点
 | 
			
		||||
	if err = goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port)).Do().Err; err != nil {
 | 
			
		||||
		return fmt.Errorf("刷新检查点失败: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取服务器中的模型列表(ckp)
 | 
			
		||||
	dexs := []struct {
 | 
			
		||||
		Name            string `json:"model_name"`
 | 
			
		||||
		Hash            string `json:"sha256"`
 | 
			
		||||
		ModelCheckpoint string `json:"title"`
 | 
			
		||||
		ModelPath       string `json:"filename"`
 | 
			
		||||
	}{}
 | 
			
		||||
	if err = goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/refresh-checkpoints", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil {
 | 
			
		||||
 | 
			
		||||
	// 获取服务器中的模型列表(ckp)
 | 
			
		||||
	if err = goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil {
 | 
			
		||||
		return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	for _, item := range dexs {
 | 
			
		||||
		// 如果hash为空, 则逐一加载这些模型使其生成hash
 | 
			
		||||
		if item.Hash == "" {
 | 
			
		||||
			fmt.Println("加载模型:", item.ModelCheckpoint)
 | 
			
		||||
			if err := goreq.Post(fmt.Sprintf("http://%s:%d/sdapi/v1/options", server.IP, server.Port)).SetJsonBody(map[string]interface{}{
 | 
			
		||||
				"sd_model_checkpoint":      item.ModelCheckpoint,
 | 
			
		||||
				"CLIP_stop_at_last_layers": 2,
 | 
			
		||||
@@ -168,40 +172,27 @@ func (server *Server) InitModels() (err error) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取服务器中的模型列表(ckp)
 | 
			
		||||
	resp, err := http.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
	// 重新获取服务器中的模型列表(ckp), 忽略hash为空的模型
 | 
			
		||||
	if err = goreq.Get(fmt.Sprintf("http://%s:%d/sdapi/v1/sd-models", server.IP, server.Port)).Do().BindJSON(&dexs); err != nil {
 | 
			
		||||
		return fmt.Errorf("获取服务器中的模型列表失败: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	// 解码JSON (数组)
 | 
			
		||||
	var data []map[string]interface{}
 | 
			
		||||
	if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果hash为空, 则逐一加载这些模型使其生成hash
 | 
			
		||||
	// 然后再次更新服务器中的模型列表
 | 
			
		||||
 | 
			
		||||
	// 从数据库检查此模型hash是否存在
 | 
			
		||||
	for _, item := range data {
 | 
			
		||||
		// 打印为格式化的JSON
 | 
			
		||||
		b, _ := json.MarshalIndent(item, "", "  ")
 | 
			
		||||
		fmt.Println(string(b))
 | 
			
		||||
	for _, item := range dexs {
 | 
			
		||||
		if item.Hash == "" {
 | 
			
		||||
			fmt.Println("忽略模型:", item.ModelCheckpoint)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var model Model
 | 
			
		||||
		if err := configs.ORMDB().Where("hash = ?", item["sha256"].(string)).First(&model).Error; err != nil {
 | 
			
		||||
		if err := configs.ORMDB().Where("hash = ?", item.Hash).First(&model).Error; err != nil {
 | 
			
		||||
			// 不存在则添加
 | 
			
		||||
			model = Model{
 | 
			
		||||
				Name:            item["model_name"].(string),
 | 
			
		||||
				Hash:            item["sha256"].(string),
 | 
			
		||||
				ModelCheckpoint: item["title"].(string),
 | 
			
		||||
				ModelPath:       item["filename"].(string), // TODO: 下载到本地
 | 
			
		||||
				Name:            item.Name,
 | 
			
		||||
				Hash:            item.Hash,
 | 
			
		||||
				ModelCheckpoint: item.ModelCheckpoint,
 | 
			
		||||
				ModelPath:       item.ModelPath,
 | 
			
		||||
				ServerID:        server.ID,
 | 
			
		||||
				Type:            "ckp",
 | 
			
		||||
			}
 | 
			
		||||
			// TODO: 下载到本地
 | 
			
		||||
			// 添加到数据库
 | 
			
		||||
			if err := configs.ORMDB().Create(&model).Error; err != nil {
 | 
			
		||||
				return fmt.Errorf("添加模型到数据库失败: %v", err)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user