From 723eca3353fc25759abc4761e4f17c2421158246 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A1=9C=E8=8F=AF?= Date: Wed, 28 Jun 2023 14:29:44 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E6=8E=A8=E7=90=86=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 18 --- main.go | 6 - models/Image.go | 1 + models/Model.go | 348 +++++++++++++++++++++++++---------------------- models/Task.go | 82 ----------- models/server.go | 34 +++-- routers/tasks.go | 106 --------------- 7 files changed, 211 insertions(+), 384 deletions(-) delete mode 100644 models/Task.go delete mode 100644 routers/tasks.go diff --git a/README.md b/README.md index d9c3453..d70a1df 100644 --- a/README.md +++ b/README.md @@ -170,24 +170,6 @@ message "$response" "上傳圖片" true // @fit: 裁切方式 cover contain fill auto ``` -任務: - -```go -type Task struct { - ID int `json:"id" gorm:"primary_key"` - Name string `json:"name"` - Type string `json:"type"` // 任務類型(訓練|推理) - Status string `json:"status"` // (initial|ready|waiting|running|success|error) - Progress int `json:"progress"` // (0-100) - UserID int `json:"user_id"` - CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` - UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` -} -``` - -- [x] GET [/api/tasks](/api/tasks) 任務列表(全部任務) -- [x] POST [/api/tasks](/api/tasks) 創建任務 - 參數: diff --git a/main.go b/main.go index 16f305c..06f7891 100644 --- a/main.go +++ b/main.go @@ -68,12 +68,6 @@ func main() { r.HandleFunc("/api/images/{id}/like", routers.ImagesItemLike).Methods("POST") // 添加一条喜欢 r.HandleFunc("/api/images/{id}/like", routers.ImagesItemUnlike).Methods("DELETE") // 移除一条喜欢 - r.HandleFunc("/api/tasks", routers.TasksGet).Methods("GET") - r.HandleFunc("/api/tasks", routers.TasksPost).Methods("POST") - r.HandleFunc("/api/tasks/{id}", routers.TasksItemGet).Methods("GET") - r.HandleFunc("/api/tasks/{id}", routers.TasksItemPatch).Methods("PATCH") - r.HandleFunc("/api/tasks/{id}", routers.TasksItemDelete).Methods("DELETE") - r.HandleFunc("/api/tags", routers.TagsGet).Methods("GET") r.HandleFunc("/api/tags", routers.TagsPost).Methods("POST") r.HandleFunc("/api/tags/{id}", routers.TagsItemGet).Methods("GET") diff --git a/models/Image.go b/models/Image.go index 6a9998c..9f54d26 100644 --- a/models/Image.go +++ b/models/Image.go @@ -37,6 +37,7 @@ type Image struct { Progress int `json:"progress"` // 任务进度(0-100) Public bool `json:"public"` // 是否公开 UserID int `json:"user_id"` // 用户ID + ModelID int `json:"model_id"` // 模型ID CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } diff --git a/models/Model.go b/models/Model.go index c416883..ec810c9 100644 --- a/models/Model.go +++ b/models/Model.go @@ -11,6 +11,7 @@ import ( "net/url" "os" "path/filepath" + "strconv" "time" "encoding/base64" @@ -43,184 +44,112 @@ type Model struct { UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } -// 创建一个带缓冲的通道,缓冲区大小为 10 -// var ch = make(chan int, 10) - func init() { configs.ORMDB().AutoMigrate(&Model{}) - - // 检查 images 目录是否存在, 不存在则创建 if _, err := os.Stat("data/images"); err != nil { if err := os.MkdirAll("data/images", 0777); err != nil { log.Println(err) } } - - // 处理推理任务 - //go func() { - // for { - // // 从通道中取出一个数据 - // model := <-ch - // // 模型状态变化时, 向监听此模型的所有连接发送消息 - // } - //}() } +// 从数据库加载 +func (model *Model) Load() { + configs.ORMDB().First(&model) +} + +// 推理模型 func (model *Model) Inference(image_list []Image, callback func(Image)) { + var server Server + // 模型未部署到推理機 if model.ServerID == "" { - //log.Println("模型未部署到推理機, 开始部署模型") - log.Println("模型已部署到推理機, 开始推理模型") + log.Println("模型未部署到推理機, 开始部署模型") - var server Server - server.IP = "106.15.192.42" - server.Port = 7860 - //if err := configs.ORMDB().Where("models LIKE ?", "%"+model.Name+"%").Take(&server).Error; err != nil { - // log.Println(err) - // // 如果没有则寻找空闲服务器 - // // 如果没有空闲则创建新服务器 - // // 取一台空闲的推理机上传并切换到此模型 - // // 新建一台推理机上传并切换到此模型 - //} - - // 执行生成任务 - if model.Image == "" { - img := image_list[0] - - // 发送的参数 - var datx map[string]interface{} = make(map[string]interface{}) - datx["prompt"] = img.Prompt // 提示词 - datx["seed"] = img.Seed // 随机数种子 - datx["n_iter"] = len(image_list) // 生成图像数量 - datx["steps"] = img.Steps // 迭代步数 - datx["cfg_scale"] = img.CfgScale // 提示词引导系数 (CFG Scale) - if img.SamplerName == "" { - datx["sampler_name"] = img.SamplerName // 采样器名称 + // 寻找一台就绪的推理机, 且已部署模型目标模型 + if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("models LIKE ?", "%"+strconv.Itoa(model.ID)+"%").First(&server).Error; err != nil { + // 寻找一台就绪的推理机, 且模型位置仍有空余 + if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "空闲").Where("length(models) < ?", 5).First(&server).Error; err != nil { + log.Println("创建一台新的推理机: 当前禁止创建新服务器") + return } - fmt.Println("image_list:", datx) + // 上传目标模型到推理机 + log.Println("上传模型到推理机: 当前禁止上传模型") + return + } - var data = struct { - //EnableHr bool `json:"enable_hr"` - //DenoisingStrength int `json:"denoising_strength"` - //FirstphaseWidth int `json:"firstphase_width"` - //FirstphaseHeight int `json:"firstphase_height"` - //HrScale int `json:"hr_scale"` - //HrUpscaler string `json:"hr_upscaler"` - //HrSecondPassSteps int `json:"hr_second_pass_steps"` - //HrResizeX int `json:"hr_resize_x"` - //HrResizeY int `json:"hr_resize_y"` - //HrSamplerName string `json:"hr_sampler_name"` - //HrPrompt string `json:"hr_prompt"` - //HrNegativePrompt string `json:"hr_negative_prompt"` - Prompt string `json:"prompt"` - //Styles []string `json:"styles"` - Seed int `json:"seed"` - //Subseed int `json:"subseed"` - //SubseedStrength int `json:"subseed_strength"` - //SeedResizeFromH int `json:"seed_resize_from_h"` - //SeedResizeFromW int `json:"seed_resize_from_w"` - SamplerName string `json:"sampler_name"` - //BatchSize int `json:"batch_size"` - NIter int `json:"n_iter"` - Steps int `json:"steps"` - CfgScale int `json:"cfg_scale"` - //Width int `json:"width"` - //Height int `json:"height"` - //RestoreFaces bool `json:"restore_faces"` - //Tiling bool `json:"tiling"` - //DoNotSaveSamples bool `json:"do_not_save_samples"` - //DoNotSaveGrid bool `json:"do_not_save_grid"` - //NegativePrompt string `json:"negative_prompt"` - //Eta int `json:"eta"` - //SMinUncond int `json:"s_min_uncond"` - //SChurn int `json:"s_churn"` - //STmax int `json:"s_tmax"` - //STmin int `json:"s_tmin"` - //SNoise int `json:"s_noise"` - //OverrideSettings map[string]string `json:"override_settings"` - //OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"` - //ScriptArgs []interface{} `json:"script_args"` - //SamplerIndex string `json:"sampler_index"` - //ScriptName string `json:"script_name"` - //SendImages bool `json:"send_images"` - //SaveImages bool `json:"save_images"` - //AlwaysonScripts map[string]string `json:"alwayson_scripts"` - }{ - //EnableHr: false, - //DenoisingStrength: 0, - //FirstphaseWidth: 0, - //FirstphaseHeight: 0, - //HrScale: 2, - //HrUpscaler: "nearest", - //HrSecondPassSteps: 0, - //HrResizeX: 0, - //HrResizeY: 0, - //HrSamplerName: "", - //HrPrompt: "", - //HrNegativePrompt: "", - Prompt: image_list[0].Prompt, - //Styles: []string{}, - Seed: image_list[0].Seed, - //Subseed: -1, - //SubseedStrength: 0, - //SeedResizeFromH: -1, - //SeedResizeFromW: -1, - SamplerName: image_list[0].SamplerName, // 采样器名称 - //BatchSize: 1, - NIter: len(image_list), // 1~100 - Steps: 50, // 1~150 - CfgScale: image_list[0].CfgScale, - //Width: 512, - //Height: 512, - //RestoreFaces: false, - //Tiling: false, - //DoNotSaveSamples: false, - //DoNotSaveGrid: false, - //NegativePrompt: "", - //Eta: 0, - //SMinUncond: 0, - //SChurn: 0, - //STmax: 0, - //STmin: 0, - //SNoise: 1, - //OverrideSettings: map[string]string{}, - //OverrideSettingsRestoreAfterwards: false, - //ScriptArgs: []interface{}{}, - //SamplerIndex: "Euler", - //ScriptName: "generate", - //SendImages: true, - //SaveImages: false, - //AlwaysonScripts: map[string]string{}, - } - fmt.Println("data:", data) - - // 接收到的图片列表 - var rest = struct { - Images []string `json:"images"` - }{} - var url = fmt.Sprintf("http://%s:%d/sdapi/v1/txt2img", server.IP, server.Port) - if err := goreq.Post(url).SetJsonBody(datx).Do().BindJSON(&rest); err != nil { - log.Println("API 查询失败:", err) - } - for index, img := range rest.Images { - var filename = fmt.Sprintf("%x", md5.Sum([]byte(img))) - log.Println("保存图片:", filename) - if err := SaveBase64Image(img, "data/images/"+filename+".webp"); err != nil { - log.Println(err) + var form = struct { + Components []struct { + ID int `json:"id"` + Type string `json:"type"` + Props struct { + Value string `json:"value"` } - image_list[index].Name = filename - image_list[index].Path = "data/images/" + filename + ".webp" - image_list[index].Hash = filename - image_list[index].Type = "image/webp" - image_list[index].Width = 512 - image_list[index].Height = 512 - image_list[index].Format = "webp" - image_list[index].Status = "success" - image_list[index].Progress = 100 - callback(image_list[index]) + } `json:"components"` + }{} + // 检查当前是否为目标模型, 不是则执行切换模型 http://106.15.192.42:7860/config + if err := goreq.Get(fmt.Sprintf("http://%s:%d/config", server.IP, server.Port)).Do().BindJSON(&form); err != nil { + log.Println("获取推理机配置失败:", err) + return + } + + var isSet = false + for _, component := range form.Components { + if component.Type == "dropdown" && component.ID == 1514 && component.Props.Value == model.Name { + log.Println("当前推理机已经部署了目标模型") + isSet = true + break } } - return + + if !isSet { + log.Println("当前推理机未部署目标模型, 开始部署目标模型") + // 没有切换模型接口 + return + } + + // 记录到模型 + model.ServerID = server.ID + configs.ORMDB().Save(&model) + } + + // 发送的参数 + var img = image_list[0] + var datx map[string]interface{} = make(map[string]interface{}) + datx["prompt"] = img.Prompt // 提示词 + datx["seed"] = img.Seed // 随机数种子 + datx["n_iter"] = len(image_list) // 生成图像数量 + datx["steps"] = img.Steps // 迭代步数 + datx["cfg_scale"] = img.CfgScale // 提示词引导系数 (CFG Scale) + if img.SamplerName == "" { + datx["sampler_name"] = img.SamplerName // 采样器名称 + } + fmt.Println("image_list:", datx) + + // 接收到的图片列表 + var rest = struct { + Images []string `json:"images"` + }{} + var url = fmt.Sprintf("http://%s:%d/sdapi/v1/txt2img", server.IP, server.Port) + if err := goreq.Post(url).SetJsonBody(datx).Do().BindJSON(&rest); err != nil { + log.Println("API 查询失败:", err) + } + for index, img := range rest.Images { + var filename = fmt.Sprintf("%x", md5.Sum([]byte(img))) + log.Println("保存图片:", filename) + if err := SaveBase64Image(img, "data/images/"+filename+".webp"); err != nil { + log.Println(err) + } + image_list[index].Name = filename + image_list[index].Path = "data/images/" + filename + ".webp" + image_list[index].Hash = filename + image_list[index].Type = "image/webp" + image_list[index].Width = 512 + image_list[index].Height = 512 + image_list[index].Format = "webp" + image_list[index].Status = "success" + image_list[index].Progress = 100 + callback(image_list[index]) } log.Println("模型未部署到推理機, 取消推理模型") } @@ -254,6 +183,7 @@ func SaveBase64Image(base64Str string, filename string) error { return nil } +// 训练模型 func (model *Model) Train() (err error) { // 獲取一臺空閒的訓練機 @@ -420,3 +350,101 @@ func (model *Model) Train() (err error) { return nil } + +/** + + var data = struct { + //EnableHr bool `json:"enable_hr"` + //DenoisingStrength int `json:"denoising_strength"` + //FirstphaseWidth int `json:"firstphase_width"` + //FirstphaseHeight int `json:"firstphase_height"` + //HrScale int `json:"hr_scale"` + //HrUpscaler string `json:"hr_upscaler"` + //HrSecondPassSteps int `json:"hr_second_pass_steps"` + //HrResizeX int `json:"hr_resize_x"` + //HrResizeY int `json:"hr_resize_y"` + //HrSamplerName string `json:"hr_sampler_name"` + //HrPrompt string `json:"hr_prompt"` + //HrNegativePrompt string `json:"hr_negative_prompt"` + Prompt string `json:"prompt"` + //Styles []string `json:"styles"` + Seed int `json:"seed"` + //Subseed int `json:"subseed"` + //SubseedStrength int `json:"subseed_strength"` + //SeedResizeFromH int `json:"seed_resize_from_h"` + //SeedResizeFromW int `json:"seed_resize_from_w"` + SamplerName string `json:"sampler_name"` + //BatchSize int `json:"batch_size"` + NIter int `json:"n_iter"` + Steps int `json:"steps"` + CfgScale int `json:"cfg_scale"` + //Width int `json:"width"` + //Height int `json:"height"` + //RestoreFaces bool `json:"restore_faces"` + //Tiling bool `json:"tiling"` + //DoNotSaveSamples bool `json:"do_not_save_samples"` + //DoNotSaveGrid bool `json:"do_not_save_grid"` + //NegativePrompt string `json:"negative_prompt"` + //Eta int `json:"eta"` + //SMinUncond int `json:"s_min_uncond"` + //SChurn int `json:"s_churn"` + //STmax int `json:"s_tmax"` + //STmin int `json:"s_tmin"` + //SNoise int `json:"s_noise"` + //OverrideSettings map[string]string `json:"override_settings"` + //OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"` + //ScriptArgs []interface{} `json:"script_args"` + //SamplerIndex string `json:"sampler_index"` + //ScriptName string `json:"script_name"` + //SendImages bool `json:"send_images"` + //SaveImages bool `json:"save_images"` + //AlwaysonScripts map[string]string `json:"alwayson_scripts"` + }{ + //EnableHr: false, + //DenoisingStrength: 0, + //FirstphaseWidth: 0, + //FirstphaseHeight: 0, + //HrScale: 2, + //HrUpscaler: "nearest", + //HrSecondPassSteps: 0, + //HrResizeX: 0, + //HrResizeY: 0, + //HrSamplerName: "", + //HrPrompt: "", + //HrNegativePrompt: "", + Prompt: image_list[0].Prompt, + //Styles: []string{}, + Seed: image_list[0].Seed, + //Subseed: -1, + //SubseedStrength: 0, + //SeedResizeFromH: -1, + //SeedResizeFromW: -1, + SamplerName: image_list[0].SamplerName, // 采样器名称 + //BatchSize: 1, + NIter: len(image_list), // 1~100 + Steps: 50, // 1~150 + CfgScale: image_list[0].CfgScale, + //Width: 512, + //Height: 512, + //RestoreFaces: false, + //Tiling: false, + //DoNotSaveSamples: false, + //DoNotSaveGrid: false, + //NegativePrompt: "", + //Eta: 0, + //SMinUncond: 0, + //SChurn: 0, + //STmax: 0, + //STmin: 0, + //SNoise: 1, + //OverrideSettings: map[string]string{}, + //OverrideSettingsRestoreAfterwards: false, + //ScriptArgs: []interface{}{}, + //SamplerIndex: "Euler", + //ScriptName: "generate", + //SendImages: true, + //SaveImages: false, + //AlwaysonScripts: map[string]string{}, + } + fmt.Println("data:", data) +**/ diff --git a/models/Task.go b/models/Task.go deleted file mode 100644 index 76e340b..0000000 --- a/models/Task.go +++ /dev/null @@ -1,82 +0,0 @@ -package models - -import "time" - -type Task struct { - ID int `json:"id" gorm:"primary_key"` - Name string `json:"name"` - Type string `json:"type"` // 任務類型(訓練|推理) - Status string `json:"status"` // (initial|ready|waiting|running|success|error) - Progress int `json:"progress"` // (0-100) - UserID int `json:"user_id"` - CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` - UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` -} - -// 异步任务调度管理 -type AsynchronousTaskSchedulingManagement struct { - tasks map[int]Task // 任务队列 - servers map[string]Server // 服务器队列 -} - -// 向任务队列添加任务 -func (m *AsynchronousTaskSchedulingManagement) AddTask(task Task) { - // 1. 任务加入队列, 任务状态为 waiting, 每次从最后一个任务开始执行, 并向上全取同模型的任务, 任务状态更新为 waiting(排队), 并更新此模型的排队状态到所有关注此模型的用户(管理员每个连接都关注所有模型) - // 2. 任务从队列中取出, 任务队列长度超过12则增加机器, 模型类型持续占用机器则增加机器 - - // 加入任务队列 - m.tasks[task.ID] = task - - // 检查任务队列长度, 持续增长超过12则增加机器 - if len(m.tasks) > 12 { - var server Server - m.servers[server.ID] = server - } - - // 向目标机器发送模型 - // 向目标机器切换模型 - // 向目标机器发送任务 -} - -//// 推理任務 -//func startInferenceTask(task *Task) { -// -// // 獲取一臺可用的 GPU 資源 -// // ... -// -// // 執行推理任務 -// // ... -// -// // 更新任務狀態 -// task.Status = "running" -// task.Progress = 0 -// task.Update() -// -// // 監聽任務狀態 -// for { -// // 延遲 1 秒 -// time.Sleep(1 * time.Second) -// -// // 查詢任務狀態 -// resp, err := http.Get("http://localhost:5000/api/v1/tasks/" + strconv.Itoa(task.ID)) -// if err != nil { -// log.Println(err) -// continue -// } -// defer resp.Body.Close() -// -// // 解析任務狀態 -// // ... -// -// // 更新任務狀態 -// task.Progress = 100 -// task.Status = "success" -// task.Update() -// -// // 任務結束判定 -// if task.Progress == 100 { -// break -// } -// } -// -//} diff --git a/models/server.go b/models/server.go index 52cc1b1..1e9d994 100644 --- a/models/server.go +++ b/models/server.go @@ -51,6 +51,13 @@ var config = struct { }{} func init() { + configs.ORMDB().AutoMigrate(&Server{}) + // 檢查所有服務器的狀態, 無效的服務器設置為異常 + var servers []Server + configs.ORMDB().Find(&servers) + for _, server := range servers { + server.CheckStatus() + } // 讀取配置文件 absPath, _ := filepath.Abs("./data/config.yaml") configFile, err := ioutil.ReadFile(absPath) @@ -60,10 +67,23 @@ func init() { if err := yaml.Unmarshal(configFile, &config); err != nil { panic(fmt.Errorf("格式化配置文件失敗: %v", err)) } + // 初始化检查默认服务器 + if err := InitDefaultServer(); err != nil { + panic(fmt.Errorf("初始化默认服务器失败: %v", err)) + } +} + +// 检查默认服务器是否存在, 不存在则添加 +func InitDefaultServer() (err error) { + if err = configs.ORMDB().Where("id = ?", "default").First(&Server{}).Error; err != nil { + server := Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"} + err = configs.ORMDB().Create(&server).Error + } + return } // 创建一台新服务器 -func NewServer(server_type string) (server *Server, err error) { +func NewServer(server_type string) (server Server, err error) { // 调用 API 创建一台新服务器(通過腾讯云API創建服務器) client, err := cvm.NewClient(common.NewCredential(config.TencentCloud.SecretId, config.TencentCloud.SecretKey), config.TencentCloud.Region, profile.NewClientProfile()) if err != nil { @@ -84,7 +104,7 @@ func NewServer(server_type string) (server *Server, err error) { fmt.Println("創建服務器成功:", response.Response.InstanceIdSet[0]) // 获取服务器信息 - var get_server_info = func(InstanceIdSet *string) (server *Server, err error) { + var get_server_info = func(InstanceIdSet *string) (server Server, err error) { response2, err := client.DescribeInstances(cvm.NewDescribeInstancesRequest()) if err != nil { return server, fmt.Errorf("獲取實例詳情失敗: %v", err) @@ -169,13 +189,3 @@ func (server *Server) CheckStatus() error { // 檢查服務器是否正常 return nil } - -func init() { - configs.ORMDB().AutoMigrate(&Server{}) - // 檢查所有服務器的狀態, 無效的服務器設置為異常 - var servers []Server - configs.ORMDB().Find(&servers) - for _, server := range servers { - server.CheckStatus() - } -} diff --git a/routers/tasks.go b/routers/tasks.go deleted file mode 100644 index bd847a3..0000000 --- a/routers/tasks.go +++ /dev/null @@ -1,106 +0,0 @@ -package routers - -import ( - "encoding/json" - "io/ioutil" - "log" - "main/configs" - "main/models" - "main/utils" - "net/http" - "strconv" - - "github.com/gorilla/mux" - "github.com/gorilla/websocket" -) - -func TasksGet(w http.ResponseWriter, r *http.Request) { - var listview models.ListView - listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1) - listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10) - var task_list []models.Task - db := configs.ORMDB() - db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&task_list) - for _, task := range task_list { - listview.List = append(listview.List, task) - } - db.Model(&models.Task{}).Count(&listview.Total) - listview.Next = listview.Page*listview.PageSize < int(listview.Total) - listview.WriteJSON(w) -} - -func TasksPost(w http.ResponseWriter, r *http.Request) { - var task models.Task - body, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println(err) - return - } - defer r.Body.Close() - if err = json.Unmarshal(body, &task); err != nil { - log.Println(err) - return - } - configs.ORMDB().Create(&task) - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(utils.ToJSON(task)) -} - -func TasksItemGet(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Upgrade") == "websocket" { - vars := mux.Vars(r) - id, _ := strconv.Atoi(vars["id"]) - - var task models.Task = models.Task{ID: id} - if err := configs.ORMDB().First(&task, id).Error; err != nil { - log.Println(err) - w.WriteHeader(http.StatusNotFound) - return - } - upgrader := websocket.Upgrader{} - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Println(err) - return - } - defer ws.Close() - for { - _, message, err := ws.ReadMessage() - if err != nil { - log.Println(err) - break - } - task.Status = string(message) - configs.ORMDB().Model(&task).Update("status", task.Status) - } - return - } - task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - configs.ORMDB().First(&task) - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(utils.ToJSON(task)) -} - -func TasksItemPatch(w http.ResponseWriter, r *http.Request) { - var task models.Task - body, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println(err) - return - } - defer r.Body.Close() - if err = json.Unmarshal(body, &task); err != nil { - log.Println(err) - return - } - task.ID = utils.ParamInt(mux.Vars(r)["id"], 0) - configs.ORMDB().Model(&task).Updates(task) - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(utils.ToJSON(task)) -} - -func TasksItemDelete(w http.ResponseWriter, r *http.Request) { - task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} - configs.ORMDB().Delete(&task) - w.WriteHeader(http.StatusNoContent) -}