生成图像(默认参数)
This commit is contained in:
		@@ -62,8 +62,8 @@ func ImagesGet(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		log.Println("任务编号:", task, "任务数量:", len(image_list))
 | 
			
		||||
 | 
			
		||||
		// 加入连接池
 | 
			
		||||
		wsid := images_websocket_manager.AddConnection(conn)
 | 
			
		||||
		defer images_websocket_manager.RemoveConnection(wsid)
 | 
			
		||||
		images_websocket_manager.AddConnection(conn, task)
 | 
			
		||||
		defer images_websocket_manager.RemoveConnection(conn)
 | 
			
		||||
 | 
			
		||||
		for {
 | 
			
		||||
			_, msg, err := conn.ReadMessage()
 | 
			
		||||
@@ -111,6 +111,7 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
				Scheduler         string  `json:"scheduler"`           // 调度器
 | 
			
		||||
				Seed              string  `json:"seed"`                // 随机种子(单张图生成时使用)
 | 
			
		||||
				Number            int     `json:"number"`              // 生成数量
 | 
			
		||||
				ModelID           int     `json:"model_id"`            // 模型ID
 | 
			
		||||
			}{}
 | 
			
		||||
			body, err := ioutil.ReadAll(r.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -136,12 +137,26 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			if template.GuidanceScale > 20 {
 | 
			
		||||
				template.GuidanceScale = 20
 | 
			
		||||
			}
 | 
			
		||||
			if template.Scheduler == "" {
 | 
			
		||||
				template.Scheduler = "DDIM"
 | 
			
		||||
			}
 | 
			
		||||
			if template.ModelID <= 0 {
 | 
			
		||||
				w.WriteHeader(http.StatusBadRequest)
 | 
			
		||||
				w.Write([]byte("model_id 参数不能为空"))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// TODO: 创建任务获得任务编号, 多张图时期望可以流式推理
 | 
			
		||||
			task := uuid.New().String()
 | 
			
		||||
			// 从数据库中读取模型信息
 | 
			
		||||
			var model models.Model = models.Model{ID: template.ModelID}
 | 
			
		||||
			if err := configs.ORMDB().First(&model).Error; err != nil {
 | 
			
		||||
				w.WriteHeader(http.StatusBadRequest)
 | 
			
		||||
				w.Write([]byte("模型不存在"))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 直接创建一组图片
 | 
			
		||||
			var image_list []models.Image
 | 
			
		||||
			var task string = uuid.New().String()
 | 
			
		||||
			for i := 0; i < template.Number; i++ {
 | 
			
		||||
				var image models.Image
 | 
			
		||||
				image.UserID = account.ID
 | 
			
		||||
@@ -157,6 +172,10 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
				image_list = append(image_list, image)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			go model.Inference(image_list, func() {
 | 
			
		||||
				images_websocket_manager.NotifyTaskChange(task, image_list)
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			// 存储图片信息到数据库
 | 
			
		||||
			if err := configs.ORMDB().Create(&image_list).Error; err != nil {
 | 
			
		||||
				log.Println(err)
 | 
			
		||||
 
 | 
			
		||||
@@ -15,11 +15,8 @@ import (
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"github.com/gorilla/mux"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var manager = models.NewWebSocketManager()
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	// 初始化模型路由: 检查本地模型目录是否存在, 不存在则创建
 | 
			
		||||
	if _, err := os.Stat("data/models"); err != nil {
 | 
			
		||||
@@ -190,39 +187,6 @@ func ModelsPost(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
 | 
			
		||||
// 獲取模型詳情
 | 
			
		||||
func ModelItemGet(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	if r.Header.Get("Upgrade") == "websocket" {
 | 
			
		||||
		vars := mux.Vars(r)
 | 
			
		||||
		id, _ := strconv.Atoi(vars["id"])
 | 
			
		||||
 | 
			
		||||
		var model = models.Model{ID: id}
 | 
			
		||||
		if err := configs.ORMDB().Take(&model, id).Error; err != nil {
 | 
			
		||||
			w.WriteHeader(http.StatusNotFound)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		upgrader := websocket.Upgrader{}
 | 
			
		||||
		conn, err := upgrader.Upgrade(w, r, nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		defer conn.Close()
 | 
			
		||||
		wsid := manager.AddConnection(conn)
 | 
			
		||||
		defer manager.RemoveConnection(wsid)
 | 
			
		||||
		for {
 | 
			
		||||
			_, msg, err := conn.ReadMessage()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Println(err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			log.Println(string(msg))
 | 
			
		||||
			if string(msg) == "close" {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
 | 
			
		||||
	if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
 | 
			
		||||
		w.WriteHeader(http.StatusNotFound)
 | 
			
		||||
 
 | 
			
		||||
@@ -65,10 +65,10 @@ func ServersPost(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果不指定類型,禁止創建服務器, 必須指定類型:訓練|推理
 | 
			
		||||
	if server.Type != "訓練" && server.Type != "推理" {
 | 
			
		||||
	// 如果不指定類型,禁止創建服務器, 必須指定類型:训练|推理
 | 
			
		||||
	if server.Type != "训练" && server.Type != "推理" {
 | 
			
		||||
		w.WriteHeader(http.StatusBadRequest)
 | 
			
		||||
		w.Write([]byte("必須指定類型:訓練|推理"))
 | 
			
		||||
		w.Write([]byte("必須指定類型:训练|推理"))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user