更改推理结构
This commit is contained in:
18
README.md
18
README.md
@@ -170,24 +170,6 @@ message "$response" "上傳圖片" true
|
|||||||
// @fit: 裁切方式 cover contain fill auto
|
// @fit: 裁切方式 cover contain fill auto
|
||||||
```
|
```
|
||||||
|
|
||||||
任務:
|
|
||||||
|
|
||||||
```go
|
|
||||||
type Task struct {
|
|
||||||
ID int `json:"id" gorm:"primary_key"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type string `json:"type"` // 任務類型(訓練|推理)
|
|
||||||
Status string `json:"status"` // (initial|ready|waiting|running|success|error)
|
|
||||||
Progress int `json:"progress"` // (0-100)
|
|
||||||
UserID int `json:"user_id"`
|
|
||||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
- [x] GET [/api/tasks](/api/tasks) 任務列表(全部任務)
|
|
||||||
- [x] POST [/api/tasks](/api/tasks) 創建任務
|
|
||||||
|
|
||||||
|
|
||||||
參數:
|
參數:
|
||||||
|
|
||||||
|
6
main.go
6
main.go
@@ -68,12 +68,6 @@ func main() {
|
|||||||
r.HandleFunc("/api/images/{id}/like", routers.ImagesItemLike).Methods("POST") // 添加一条喜欢
|
r.HandleFunc("/api/images/{id}/like", routers.ImagesItemLike).Methods("POST") // 添加一条喜欢
|
||||||
r.HandleFunc("/api/images/{id}/like", routers.ImagesItemUnlike).Methods("DELETE") // 移除一条喜欢
|
r.HandleFunc("/api/images/{id}/like", routers.ImagesItemUnlike).Methods("DELETE") // 移除一条喜欢
|
||||||
|
|
||||||
r.HandleFunc("/api/tasks", routers.TasksGet).Methods("GET")
|
|
||||||
r.HandleFunc("/api/tasks", routers.TasksPost).Methods("POST")
|
|
||||||
r.HandleFunc("/api/tasks/{id}", routers.TasksItemGet).Methods("GET")
|
|
||||||
r.HandleFunc("/api/tasks/{id}", routers.TasksItemPatch).Methods("PATCH")
|
|
||||||
r.HandleFunc("/api/tasks/{id}", routers.TasksItemDelete).Methods("DELETE")
|
|
||||||
|
|
||||||
r.HandleFunc("/api/tags", routers.TagsGet).Methods("GET")
|
r.HandleFunc("/api/tags", routers.TagsGet).Methods("GET")
|
||||||
r.HandleFunc("/api/tags", routers.TagsPost).Methods("POST")
|
r.HandleFunc("/api/tags", routers.TagsPost).Methods("POST")
|
||||||
r.HandleFunc("/api/tags/{id}", routers.TagsItemGet).Methods("GET")
|
r.HandleFunc("/api/tags/{id}", routers.TagsItemGet).Methods("GET")
|
||||||
|
@@ -37,6 +37,7 @@ type Image struct {
|
|||||||
Progress int `json:"progress"` // 任务进度(0-100)
|
Progress int `json:"progress"` // 任务进度(0-100)
|
||||||
Public bool `json:"public"` // 是否公开
|
Public bool `json:"public"` // 是否公开
|
||||||
UserID int `json:"user_id"` // 用户ID
|
UserID int `json:"user_id"` // 用户ID
|
||||||
|
ModelID int `json:"model_id"` // 模型ID
|
||||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
}
|
}
|
||||||
|
282
models/Model.go
282
models/Model.go
@@ -11,6 +11,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
@@ -43,51 +44,77 @@ type Model struct {
|
|||||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建一个带缓冲的通道,缓冲区大小为 10
|
|
||||||
// var ch = make(chan int, 10)
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
configs.ORMDB().AutoMigrate(&Model{})
|
configs.ORMDB().AutoMigrate(&Model{})
|
||||||
|
|
||||||
// 检查 images 目录是否存在, 不存在则创建
|
|
||||||
if _, err := os.Stat("data/images"); err != nil {
|
if _, err := os.Stat("data/images"); err != nil {
|
||||||
if err := os.MkdirAll("data/images", 0777); err != nil {
|
if err := os.MkdirAll("data/images", 0777); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理推理任务
|
|
||||||
//go func() {
|
|
||||||
// for {
|
|
||||||
// // 从通道中取出一个数据
|
|
||||||
// model := <-ch
|
|
||||||
// // 模型状态变化时, 向监听此模型的所有连接发送消息
|
|
||||||
// }
|
|
||||||
//}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 从数据库加载
|
||||||
|
func (model *Model) Load() {
|
||||||
|
configs.ORMDB().First(&model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推理模型
|
||||||
func (model *Model) Inference(image_list []Image, callback func(Image)) {
|
func (model *Model) Inference(image_list []Image, callback func(Image)) {
|
||||||
|
var server Server
|
||||||
|
|
||||||
// 模型未部署到推理機
|
// 模型未部署到推理機
|
||||||
if model.ServerID == "" {
|
if model.ServerID == "" {
|
||||||
//log.Println("模型未部署到推理機, 开始部署模型")
|
log.Println("模型未部署到推理機, 开始部署模型")
|
||||||
log.Println("模型已部署到推理機, 开始推理模型")
|
|
||||||
|
|
||||||
var server Server
|
// 寻找一台就绪的推理机, 且已部署模型目标模型
|
||||||
server.IP = "106.15.192.42"
|
if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "就绪").Where("models LIKE ?", "%"+strconv.Itoa(model.ID)+"%").First(&server).Error; err != nil {
|
||||||
server.Port = 7860
|
// 寻找一台就绪的推理机, 且模型位置仍有空余
|
||||||
//if err := configs.ORMDB().Where("models LIKE ?", "%"+model.Name+"%").Take(&server).Error; err != nil {
|
if err := configs.ORMDB().Where("type = ?", "推理").Where("status = ?", "空闲").Where("length(models) < ?", 5).First(&server).Error; err != nil {
|
||||||
// log.Println(err)
|
log.Println("创建一台新的推理机: 当前禁止创建新服务器")
|
||||||
// // 如果没有则寻找空闲服务器
|
return
|
||||||
// // 如果没有空闲则创建新服务器
|
}
|
||||||
// // 取一台空闲的推理机上传并切换到此模型
|
// 上传目标模型到推理机
|
||||||
// // 新建一台推理机上传并切换到此模型
|
log.Println("上传模型到推理机: 当前禁止上传模型")
|
||||||
//}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 执行生成任务
|
var form = struct {
|
||||||
if model.Image == "" {
|
Components []struct {
|
||||||
img := image_list[0]
|
ID int `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Props struct {
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
} `json:"components"`
|
||||||
|
}{}
|
||||||
|
// 检查当前是否为目标模型, 不是则执行切换模型 http://106.15.192.42:7860/config
|
||||||
|
if err := goreq.Get(fmt.Sprintf("http://%s:%d/config", server.IP, server.Port)).Do().BindJSON(&form); err != nil {
|
||||||
|
log.Println("获取推理机配置失败:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var isSet = false
|
||||||
|
for _, component := range form.Components {
|
||||||
|
if component.Type == "dropdown" && component.ID == 1514 && component.Props.Value == model.Name {
|
||||||
|
log.Println("当前推理机已经部署了目标模型")
|
||||||
|
isSet = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSet {
|
||||||
|
log.Println("当前推理机未部署目标模型, 开始部署目标模型")
|
||||||
|
// 没有切换模型接口
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录到模型
|
||||||
|
model.ServerID = server.ID
|
||||||
|
configs.ORMDB().Save(&model)
|
||||||
|
}
|
||||||
|
|
||||||
// 发送的参数
|
// 发送的参数
|
||||||
|
var img = image_list[0]
|
||||||
var datx map[string]interface{} = make(map[string]interface{})
|
var datx map[string]interface{} = make(map[string]interface{})
|
||||||
datx["prompt"] = img.Prompt // 提示词
|
datx["prompt"] = img.Prompt // 提示词
|
||||||
datx["seed"] = img.Seed // 随机数种子
|
datx["seed"] = img.Seed // 随机数种子
|
||||||
@@ -99,101 +126,6 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
|
|||||||
}
|
}
|
||||||
fmt.Println("image_list:", datx)
|
fmt.Println("image_list:", datx)
|
||||||
|
|
||||||
var data = struct {
|
|
||||||
//EnableHr bool `json:"enable_hr"`
|
|
||||||
//DenoisingStrength int `json:"denoising_strength"`
|
|
||||||
//FirstphaseWidth int `json:"firstphase_width"`
|
|
||||||
//FirstphaseHeight int `json:"firstphase_height"`
|
|
||||||
//HrScale int `json:"hr_scale"`
|
|
||||||
//HrUpscaler string `json:"hr_upscaler"`
|
|
||||||
//HrSecondPassSteps int `json:"hr_second_pass_steps"`
|
|
||||||
//HrResizeX int `json:"hr_resize_x"`
|
|
||||||
//HrResizeY int `json:"hr_resize_y"`
|
|
||||||
//HrSamplerName string `json:"hr_sampler_name"`
|
|
||||||
//HrPrompt string `json:"hr_prompt"`
|
|
||||||
//HrNegativePrompt string `json:"hr_negative_prompt"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
//Styles []string `json:"styles"`
|
|
||||||
Seed int `json:"seed"`
|
|
||||||
//Subseed int `json:"subseed"`
|
|
||||||
//SubseedStrength int `json:"subseed_strength"`
|
|
||||||
//SeedResizeFromH int `json:"seed_resize_from_h"`
|
|
||||||
//SeedResizeFromW int `json:"seed_resize_from_w"`
|
|
||||||
SamplerName string `json:"sampler_name"`
|
|
||||||
//BatchSize int `json:"batch_size"`
|
|
||||||
NIter int `json:"n_iter"`
|
|
||||||
Steps int `json:"steps"`
|
|
||||||
CfgScale int `json:"cfg_scale"`
|
|
||||||
//Width int `json:"width"`
|
|
||||||
//Height int `json:"height"`
|
|
||||||
//RestoreFaces bool `json:"restore_faces"`
|
|
||||||
//Tiling bool `json:"tiling"`
|
|
||||||
//DoNotSaveSamples bool `json:"do_not_save_samples"`
|
|
||||||
//DoNotSaveGrid bool `json:"do_not_save_grid"`
|
|
||||||
//NegativePrompt string `json:"negative_prompt"`
|
|
||||||
//Eta int `json:"eta"`
|
|
||||||
//SMinUncond int `json:"s_min_uncond"`
|
|
||||||
//SChurn int `json:"s_churn"`
|
|
||||||
//STmax int `json:"s_tmax"`
|
|
||||||
//STmin int `json:"s_tmin"`
|
|
||||||
//SNoise int `json:"s_noise"`
|
|
||||||
//OverrideSettings map[string]string `json:"override_settings"`
|
|
||||||
//OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"`
|
|
||||||
//ScriptArgs []interface{} `json:"script_args"`
|
|
||||||
//SamplerIndex string `json:"sampler_index"`
|
|
||||||
//ScriptName string `json:"script_name"`
|
|
||||||
//SendImages bool `json:"send_images"`
|
|
||||||
//SaveImages bool `json:"save_images"`
|
|
||||||
//AlwaysonScripts map[string]string `json:"alwayson_scripts"`
|
|
||||||
}{
|
|
||||||
//EnableHr: false,
|
|
||||||
//DenoisingStrength: 0,
|
|
||||||
//FirstphaseWidth: 0,
|
|
||||||
//FirstphaseHeight: 0,
|
|
||||||
//HrScale: 2,
|
|
||||||
//HrUpscaler: "nearest",
|
|
||||||
//HrSecondPassSteps: 0,
|
|
||||||
//HrResizeX: 0,
|
|
||||||
//HrResizeY: 0,
|
|
||||||
//HrSamplerName: "",
|
|
||||||
//HrPrompt: "",
|
|
||||||
//HrNegativePrompt: "",
|
|
||||||
Prompt: image_list[0].Prompt,
|
|
||||||
//Styles: []string{},
|
|
||||||
Seed: image_list[0].Seed,
|
|
||||||
//Subseed: -1,
|
|
||||||
//SubseedStrength: 0,
|
|
||||||
//SeedResizeFromH: -1,
|
|
||||||
//SeedResizeFromW: -1,
|
|
||||||
SamplerName: image_list[0].SamplerName, // 采样器名称
|
|
||||||
//BatchSize: 1,
|
|
||||||
NIter: len(image_list), // 1~100
|
|
||||||
Steps: 50, // 1~150
|
|
||||||
CfgScale: image_list[0].CfgScale,
|
|
||||||
//Width: 512,
|
|
||||||
//Height: 512,
|
|
||||||
//RestoreFaces: false,
|
|
||||||
//Tiling: false,
|
|
||||||
//DoNotSaveSamples: false,
|
|
||||||
//DoNotSaveGrid: false,
|
|
||||||
//NegativePrompt: "",
|
|
||||||
//Eta: 0,
|
|
||||||
//SMinUncond: 0,
|
|
||||||
//SChurn: 0,
|
|
||||||
//STmax: 0,
|
|
||||||
//STmin: 0,
|
|
||||||
//SNoise: 1,
|
|
||||||
//OverrideSettings: map[string]string{},
|
|
||||||
//OverrideSettingsRestoreAfterwards: false,
|
|
||||||
//ScriptArgs: []interface{}{},
|
|
||||||
//SamplerIndex: "Euler",
|
|
||||||
//ScriptName: "generate",
|
|
||||||
//SendImages: true,
|
|
||||||
//SaveImages: false,
|
|
||||||
//AlwaysonScripts: map[string]string{},
|
|
||||||
}
|
|
||||||
fmt.Println("data:", data)
|
|
||||||
|
|
||||||
// 接收到的图片列表
|
// 接收到的图片列表
|
||||||
var rest = struct {
|
var rest = struct {
|
||||||
Images []string `json:"images"`
|
Images []string `json:"images"`
|
||||||
@@ -219,9 +151,6 @@ func (model *Model) Inference(image_list []Image, callback func(Image)) {
|
|||||||
image_list[index].Progress = 100
|
image_list[index].Progress = 100
|
||||||
callback(image_list[index])
|
callback(image_list[index])
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Println("模型未部署到推理機, 取消推理模型")
|
log.Println("模型未部署到推理機, 取消推理模型")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,6 +183,7 @@ func SaveBase64Image(base64Str string, filename string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 训练模型
|
||||||
func (model *Model) Train() (err error) {
|
func (model *Model) Train() (err error) {
|
||||||
|
|
||||||
// 獲取一臺空閒的訓練機
|
// 獲取一臺空閒的訓練機
|
||||||
@@ -420,3 +350,101 @@ func (model *Model) Train() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
|
||||||
|
var data = struct {
|
||||||
|
//EnableHr bool `json:"enable_hr"`
|
||||||
|
//DenoisingStrength int `json:"denoising_strength"`
|
||||||
|
//FirstphaseWidth int `json:"firstphase_width"`
|
||||||
|
//FirstphaseHeight int `json:"firstphase_height"`
|
||||||
|
//HrScale int `json:"hr_scale"`
|
||||||
|
//HrUpscaler string `json:"hr_upscaler"`
|
||||||
|
//HrSecondPassSteps int `json:"hr_second_pass_steps"`
|
||||||
|
//HrResizeX int `json:"hr_resize_x"`
|
||||||
|
//HrResizeY int `json:"hr_resize_y"`
|
||||||
|
//HrSamplerName string `json:"hr_sampler_name"`
|
||||||
|
//HrPrompt string `json:"hr_prompt"`
|
||||||
|
//HrNegativePrompt string `json:"hr_negative_prompt"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
//Styles []string `json:"styles"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
//Subseed int `json:"subseed"`
|
||||||
|
//SubseedStrength int `json:"subseed_strength"`
|
||||||
|
//SeedResizeFromH int `json:"seed_resize_from_h"`
|
||||||
|
//SeedResizeFromW int `json:"seed_resize_from_w"`
|
||||||
|
SamplerName string `json:"sampler_name"`
|
||||||
|
//BatchSize int `json:"batch_size"`
|
||||||
|
NIter int `json:"n_iter"`
|
||||||
|
Steps int `json:"steps"`
|
||||||
|
CfgScale int `json:"cfg_scale"`
|
||||||
|
//Width int `json:"width"`
|
||||||
|
//Height int `json:"height"`
|
||||||
|
//RestoreFaces bool `json:"restore_faces"`
|
||||||
|
//Tiling bool `json:"tiling"`
|
||||||
|
//DoNotSaveSamples bool `json:"do_not_save_samples"`
|
||||||
|
//DoNotSaveGrid bool `json:"do_not_save_grid"`
|
||||||
|
//NegativePrompt string `json:"negative_prompt"`
|
||||||
|
//Eta int `json:"eta"`
|
||||||
|
//SMinUncond int `json:"s_min_uncond"`
|
||||||
|
//SChurn int `json:"s_churn"`
|
||||||
|
//STmax int `json:"s_tmax"`
|
||||||
|
//STmin int `json:"s_tmin"`
|
||||||
|
//SNoise int `json:"s_noise"`
|
||||||
|
//OverrideSettings map[string]string `json:"override_settings"`
|
||||||
|
//OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"`
|
||||||
|
//ScriptArgs []interface{} `json:"script_args"`
|
||||||
|
//SamplerIndex string `json:"sampler_index"`
|
||||||
|
//ScriptName string `json:"script_name"`
|
||||||
|
//SendImages bool `json:"send_images"`
|
||||||
|
//SaveImages bool `json:"save_images"`
|
||||||
|
//AlwaysonScripts map[string]string `json:"alwayson_scripts"`
|
||||||
|
}{
|
||||||
|
//EnableHr: false,
|
||||||
|
//DenoisingStrength: 0,
|
||||||
|
//FirstphaseWidth: 0,
|
||||||
|
//FirstphaseHeight: 0,
|
||||||
|
//HrScale: 2,
|
||||||
|
//HrUpscaler: "nearest",
|
||||||
|
//HrSecondPassSteps: 0,
|
||||||
|
//HrResizeX: 0,
|
||||||
|
//HrResizeY: 0,
|
||||||
|
//HrSamplerName: "",
|
||||||
|
//HrPrompt: "",
|
||||||
|
//HrNegativePrompt: "",
|
||||||
|
Prompt: image_list[0].Prompt,
|
||||||
|
//Styles: []string{},
|
||||||
|
Seed: image_list[0].Seed,
|
||||||
|
//Subseed: -1,
|
||||||
|
//SubseedStrength: 0,
|
||||||
|
//SeedResizeFromH: -1,
|
||||||
|
//SeedResizeFromW: -1,
|
||||||
|
SamplerName: image_list[0].SamplerName, // 采样器名称
|
||||||
|
//BatchSize: 1,
|
||||||
|
NIter: len(image_list), // 1~100
|
||||||
|
Steps: 50, // 1~150
|
||||||
|
CfgScale: image_list[0].CfgScale,
|
||||||
|
//Width: 512,
|
||||||
|
//Height: 512,
|
||||||
|
//RestoreFaces: false,
|
||||||
|
//Tiling: false,
|
||||||
|
//DoNotSaveSamples: false,
|
||||||
|
//DoNotSaveGrid: false,
|
||||||
|
//NegativePrompt: "",
|
||||||
|
//Eta: 0,
|
||||||
|
//SMinUncond: 0,
|
||||||
|
//SChurn: 0,
|
||||||
|
//STmax: 0,
|
||||||
|
//STmin: 0,
|
||||||
|
//SNoise: 1,
|
||||||
|
//OverrideSettings: map[string]string{},
|
||||||
|
//OverrideSettingsRestoreAfterwards: false,
|
||||||
|
//ScriptArgs: []interface{}{},
|
||||||
|
//SamplerIndex: "Euler",
|
||||||
|
//ScriptName: "generate",
|
||||||
|
//SendImages: true,
|
||||||
|
//SaveImages: false,
|
||||||
|
//AlwaysonScripts: map[string]string{},
|
||||||
|
}
|
||||||
|
fmt.Println("data:", data)
|
||||||
|
**/
|
||||||
|
@@ -1,82 +0,0 @@
|
|||||||
package models
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
type Task struct {
|
|
||||||
ID int `json:"id" gorm:"primary_key"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type string `json:"type"` // 任務類型(訓練|推理)
|
|
||||||
Status string `json:"status"` // (initial|ready|waiting|running|success|error)
|
|
||||||
Progress int `json:"progress"` // (0-100)
|
|
||||||
UserID int `json:"user_id"`
|
|
||||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// 异步任务调度管理
|
|
||||||
type AsynchronousTaskSchedulingManagement struct {
|
|
||||||
tasks map[int]Task // 任务队列
|
|
||||||
servers map[string]Server // 服务器队列
|
|
||||||
}
|
|
||||||
|
|
||||||
// 向任务队列添加任务
|
|
||||||
func (m *AsynchronousTaskSchedulingManagement) AddTask(task Task) {
|
|
||||||
// 1. 任务加入队列, 任务状态为 waiting, 每次从最后一个任务开始执行, 并向上全取同模型的任务, 任务状态更新为 waiting(排队), 并更新此模型的排队状态到所有关注此模型的用户(管理员每个连接都关注所有模型)
|
|
||||||
// 2. 任务从队列中取出, 任务队列长度超过12则增加机器, 模型类型持续占用机器则增加机器
|
|
||||||
|
|
||||||
// 加入任务队列
|
|
||||||
m.tasks[task.ID] = task
|
|
||||||
|
|
||||||
// 检查任务队列长度, 持续增长超过12则增加机器
|
|
||||||
if len(m.tasks) > 12 {
|
|
||||||
var server Server
|
|
||||||
m.servers[server.ID] = server
|
|
||||||
}
|
|
||||||
|
|
||||||
// 向目标机器发送模型
|
|
||||||
// 向目标机器切换模型
|
|
||||||
// 向目标机器发送任务
|
|
||||||
}
|
|
||||||
|
|
||||||
//// 推理任務
|
|
||||||
//func startInferenceTask(task *Task) {
|
|
||||||
//
|
|
||||||
// // 獲取一臺可用的 GPU 資源
|
|
||||||
// // ...
|
|
||||||
//
|
|
||||||
// // 執行推理任務
|
|
||||||
// // ...
|
|
||||||
//
|
|
||||||
// // 更新任務狀態
|
|
||||||
// task.Status = "running"
|
|
||||||
// task.Progress = 0
|
|
||||||
// task.Update()
|
|
||||||
//
|
|
||||||
// // 監聽任務狀態
|
|
||||||
// for {
|
|
||||||
// // 延遲 1 秒
|
|
||||||
// time.Sleep(1 * time.Second)
|
|
||||||
//
|
|
||||||
// // 查詢任務狀態
|
|
||||||
// resp, err := http.Get("http://localhost:5000/api/v1/tasks/" + strconv.Itoa(task.ID))
|
|
||||||
// if err != nil {
|
|
||||||
// log.Println(err)
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// defer resp.Body.Close()
|
|
||||||
//
|
|
||||||
// // 解析任務狀態
|
|
||||||
// // ...
|
|
||||||
//
|
|
||||||
// // 更新任務狀態
|
|
||||||
// task.Progress = 100
|
|
||||||
// task.Status = "success"
|
|
||||||
// task.Update()
|
|
||||||
//
|
|
||||||
// // 任務結束判定
|
|
||||||
// if task.Progress == 100 {
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
//}
|
|
@@ -51,6 +51,13 @@ var config = struct {
|
|||||||
}{}
|
}{}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
configs.ORMDB().AutoMigrate(&Server{})
|
||||||
|
// 檢查所有服務器的狀態, 無效的服務器設置為異常
|
||||||
|
var servers []Server
|
||||||
|
configs.ORMDB().Find(&servers)
|
||||||
|
for _, server := range servers {
|
||||||
|
server.CheckStatus()
|
||||||
|
}
|
||||||
// 讀取配置文件
|
// 讀取配置文件
|
||||||
absPath, _ := filepath.Abs("./data/config.yaml")
|
absPath, _ := filepath.Abs("./data/config.yaml")
|
||||||
configFile, err := ioutil.ReadFile(absPath)
|
configFile, err := ioutil.ReadFile(absPath)
|
||||||
@@ -60,10 +67,23 @@ func init() {
|
|||||||
if err := yaml.Unmarshal(configFile, &config); err != nil {
|
if err := yaml.Unmarshal(configFile, &config); err != nil {
|
||||||
panic(fmt.Errorf("格式化配置文件失敗: %v", err))
|
panic(fmt.Errorf("格式化配置文件失敗: %v", err))
|
||||||
}
|
}
|
||||||
|
// 初始化检查默认服务器
|
||||||
|
if err := InitDefaultServer(); err != nil {
|
||||||
|
panic(fmt.Errorf("初始化默认服务器失败: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查默认服务器是否存在, 不存在则添加
|
||||||
|
func InitDefaultServer() (err error) {
|
||||||
|
if err = configs.ORMDB().Where("id = ?", "default").First(&Server{}).Error; err != nil {
|
||||||
|
server := Server{ID: "default", IP: "106.15.192.42", Port: 7860, Type: "推理", Status: "就绪"}
|
||||||
|
err = configs.ORMDB().Create(&server).Error
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建一台新服务器
|
// 创建一台新服务器
|
||||||
func NewServer(server_type string) (server *Server, err error) {
|
func NewServer(server_type string) (server Server, err error) {
|
||||||
// 调用 API 创建一台新服务器(通過腾讯云API創建服務器)
|
// 调用 API 创建一台新服务器(通過腾讯云API創建服務器)
|
||||||
client, err := cvm.NewClient(common.NewCredential(config.TencentCloud.SecretId, config.TencentCloud.SecretKey), config.TencentCloud.Region, profile.NewClientProfile())
|
client, err := cvm.NewClient(common.NewCredential(config.TencentCloud.SecretId, config.TencentCloud.SecretKey), config.TencentCloud.Region, profile.NewClientProfile())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -84,7 +104,7 @@ func NewServer(server_type string) (server *Server, err error) {
|
|||||||
fmt.Println("創建服務器成功:", response.Response.InstanceIdSet[0])
|
fmt.Println("創建服務器成功:", response.Response.InstanceIdSet[0])
|
||||||
|
|
||||||
// 获取服务器信息
|
// 获取服务器信息
|
||||||
var get_server_info = func(InstanceIdSet *string) (server *Server, err error) {
|
var get_server_info = func(InstanceIdSet *string) (server Server, err error) {
|
||||||
response2, err := client.DescribeInstances(cvm.NewDescribeInstancesRequest())
|
response2, err := client.DescribeInstances(cvm.NewDescribeInstancesRequest())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return server, fmt.Errorf("獲取實例詳情失敗: %v", err)
|
return server, fmt.Errorf("獲取實例詳情失敗: %v", err)
|
||||||
@@ -169,13 +189,3 @@ func (server *Server) CheckStatus() error {
|
|||||||
// 檢查服務器是否正常
|
// 檢查服務器是否正常
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
|
||||||
configs.ORMDB().AutoMigrate(&Server{})
|
|
||||||
// 檢查所有服務器的狀態, 無效的服務器設置為異常
|
|
||||||
var servers []Server
|
|
||||||
configs.ORMDB().Find(&servers)
|
|
||||||
for _, server := range servers {
|
|
||||||
server.CheckStatus()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
106
routers/tasks.go
106
routers/tasks.go
@@ -1,106 +0,0 @@
|
|||||||
package routers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"main/configs"
|
|
||||||
"main/models"
|
|
||||||
"main/utils"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TasksGet(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var listview models.ListView
|
|
||||||
listview.Page = utils.ParamInt(r.URL.Query().Get("page"), 1)
|
|
||||||
listview.PageSize = utils.ParamInt(r.URL.Query().Get("pageSize"), 10)
|
|
||||||
var task_list []models.Task
|
|
||||||
db := configs.ORMDB()
|
|
||||||
db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&task_list)
|
|
||||||
for _, task := range task_list {
|
|
||||||
listview.List = append(listview.List, task)
|
|
||||||
}
|
|
||||||
db.Model(&models.Task{}).Count(&listview.Total)
|
|
||||||
listview.Next = listview.Page*listview.PageSize < int(listview.Total)
|
|
||||||
listview.WriteJSON(w)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TasksPost(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var task models.Task
|
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
if err = json.Unmarshal(body, &task); err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
configs.ORMDB().Create(&task)
|
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
||||||
w.Write(utils.ToJSON(task))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TasksItemGet(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Header.Get("Upgrade") == "websocket" {
|
|
||||||
vars := mux.Vars(r)
|
|
||||||
id, _ := strconv.Atoi(vars["id"])
|
|
||||||
|
|
||||||
var task models.Task = models.Task{ID: id}
|
|
||||||
if err := configs.ORMDB().First(&task, id).Error; err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
upgrader := websocket.Upgrader{}
|
|
||||||
ws, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer ws.Close()
|
|
||||||
for {
|
|
||||||
_, message, err := ws.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
task.Status = string(message)
|
|
||||||
configs.ORMDB().Model(&task).Update("status", task.Status)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
||||||
configs.ORMDB().First(&task)
|
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
||||||
w.Write(utils.ToJSON(task))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TasksItemPatch(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var task models.Task
|
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
if err = json.Unmarshal(body, &task); err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
task.ID = utils.ParamInt(mux.Vars(r)["id"], 0)
|
|
||||||
configs.ORMDB().Model(&task).Updates(task)
|
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
||||||
w.Write(utils.ToJSON(task))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TasksItemDelete(w http.ResponseWriter, r *http.Request) {
|
|
||||||
task := models.Task{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
|
|
||||||
configs.ORMDB().Delete(&task)
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
}
|
|
Reference in New Issue
Block a user