diff --git a/models/Image.go b/models/Image.go index 3b4c075..fca58ce 100644 --- a/models/Image.go +++ b/models/Image.go @@ -16,30 +16,30 @@ import ( ) type Image struct { - ID int `json:"id" gorm:"primary_key"` // ID - Name string `json:"name"` // 名称 - Hash string `json:"hash"` // 哈希值 - Path string `json:"path"` // 路径 - Type string `json:"type"` // 类型 - Size int `json:"size"` // 大小 - Width int `json:"width"` // 宽度 - Height int `json:"height"` // 高度 - Format string `json:"format"` // 格式 - Prompt string `json:"prompt"` // 提示词 - NegativePrompt string `json:"negative_prompt"` // 负向提示 - Steps int `json:"steps"` // 迭代步数 (Steps 1~150) - CfgScale int `json:"cfg_scale"` // 引导比例(minimum: 1; maximum: 20) - SamplerName string `json:"sampler_name"` // 采样器名称 - Seed int `json:"seed"` // 随机种子(minimum: 0; maximum: 2147483647) - FromImage int `json:"from_image"` // 来源图片(如果是从图片生成的, 则记录来源图片的ID) - Task string `json:"task"` // 任务编号(uuid) - Status string `json:"status"` // 任务状态(queued|running|finished|failed) - Progress int `json:"progress"` // 任务进度(0-100) - Public bool `json:"public"` // 是否公开 - UserID int `json:"user_id"` // 用户ID - ModelID int `json:"model_id"` // 模型ID - Preview string `json:"preview" gorm:"-"` // 实时预览 base64 - User User `json:"user" gorm:"-"` // 用户 + ID int `json:"id" gorm:"primary_key"` // ID + Name string `json:"name"` // 名称 + Hash string `json:"hash"` // 哈希值 + Path string `json:"path"` // 路径 + Type string `json:"type"` // 类型 + Size int `json:"size"` // 大小 + Width int `json:"width"` // 宽度 + Height int `json:"height"` // 高度 + Format string `json:"format"` // 格式 + Prompt string `json:"prompt"` // 提示词 + NegativePrompt string `json:"negative_prompt"` // 负向提示 + Steps int `json:"steps"` // 迭代步数 (Steps 1~150) + CfgScale int `json:"cfg_scale"` // 引导比例(minimum: 1; maximum: 20) + SamplerName string `json:"sampler_name"` // 采样器名称 + Seed int `json:"seed"` // 随机种子(minimum: 0; maximum: 2147483647) + FromImage int `json:"from_image"` // 来源图片(如果是从图片生成的, 则记录来源图片的ID) + Task string `json:"task"` // 任务编号(uuid) + Status string `json:"status"` // 任务状态(queued|running|finished|failed) + Progress int `json:"progress"` // 任务进度(0-100) + Public bool `json:"public"` // 是否公开 + UserID int `json:"user_id"` // 用户ID + ModelID int `json:"model_id"` // 模型ID + Preview string `json:"preview" gorm:"-"` // 实时预览 base64 或 url + User *User `json:"user" gorm:"foreignKey:UserID;"` // 用户 CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } diff --git a/models/Model.go b/models/Model.go index e7d0e84..a77edda 100644 --- a/models/Model.go +++ b/models/Model.go @@ -22,26 +22,26 @@ import ( ) type Model struct { - ID int `json:"id" gorm:"primary_key"` // 模型ID - Name string `json:"name"` // 模型名稱 - ModelCheckpoint string `json:"model_checkpoint"` // 模型檢查點 - Info string `json:"info"` // 模型描述 - Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti) - TriggerWords string `json:"trigger_words"` // 觸發詞 - BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2) - ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑) - Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error|public) - Progress int `json:"progress"` // (0-100) - Preview string `json:"preview"` // 模型預覽圖片 - Hash string `json:"hash"` // 模型哈希值(sha256) - Epochs int `json:"epochs"` // 訓練步數 - LearningRate float32 `json:"learning_rate"` // 學習率(0.000005) - Tags TagList `json:"tags"` // 模型標籤(標籤名數組) - UserID int `json:"user_id"` // 模型的所有者 - DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID - ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機) - Stars StarList `json:"stars"` // 模型的收藏者 - User User `json:"user" gorm:"-"` // 模型的所有者 + ID int `json:"id" gorm:"primary_key"` // 模型ID + Name string `json:"name"` // 模型名稱 + ModelCheckpoint string `json:"model_checkpoint"` // 模型檢查點 + Info string `json:"info"` // 模型描述 + Type string `json:"type"` // 模型類型(lora|ckp|hyper|ti) + TriggerWords string `json:"trigger_words"` // 觸發詞 + BaseModel string `json:"base_model"` // 基礎模型(SD1.5|SD2) + ModelPath string `json:"model_path"` // 模型路徑(實際存放在服務器上的路徑) + Status string `json:"status" default:"initial"` // (initial|ready|waiting|running|success|error|public) + Progress int `json:"progress"` // (0-100) + Preview string `json:"preview"` // 模型預覽圖片 + Hash string `json:"hash"` // 模型哈希值(sha256) + Epochs int `json:"epochs"` // 訓練步數 + LearningRate float32 `json:"learning_rate"` // 學習率(0.000005) + Tags TagList `json:"tags"` // 模型標籤(標籤名數組) + UserID int `json:"user_id"` // 模型的所有者 + DatasetID int `json:"dataset_id"` // 模型所使用的數據集ID + ServerID string `json:"server_id"` // 模型所在服務器(訓練機或推理機) + Stars StarList `json:"stars"` // 模型的收藏者 + User *User `json:"user" gorm:"foreignKey:UserID;"` // 模型的所有者 CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } diff --git a/routers/images.go b/routers/images.go index 183458b..d467ce2 100644 --- a/routers/images.go +++ b/routers/images.go @@ -116,10 +116,8 @@ func ImagesGet(w http.ResponseWriter, r *http.Request) { db = db.Where("id IN (?)", list) } - db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&image_list).Count(&listview.Total) + db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Preload("User").Find(&image_list).Count(&listview.Total) for _, image := range image_list { - image.User = models.User{ID: image.UserID} - db.First(&image.User) if image.Preview == "" { image.Preview = "https://image.gameuiux.cn/2023/06/27/1687851028u=3116699095,2862677591&fm=253&fmt=auto&app=120&f=JPEG.webp" } diff --git a/routers/models.go b/routers/models.go index f2c137e..ee7f677 100644 --- a/routers/models.go +++ b/routers/models.go @@ -142,11 +142,8 @@ func ModelsGet(w http.ResponseWriter, r *http.Request) { db = db.Where("id IN (?)", list) } - db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Find(&model_list).Count(&listview.Total) - for _, model := range model_list { - model.User = models.User{ID: model.UserID} - db.Take(&model.User) - } + db.Offset((listview.Page - 1) * listview.PageSize).Limit(listview.PageSize).Preload("User").Find(&model_list).Count(&listview.Total) + listview.List = model_list listview.Next = listview.Page*listview.PageSize < int(listview.Total) listview.WriteJSON(w) @@ -376,6 +373,11 @@ func ModelItemPatch(w http.ResponseWriter, r *http.Request) { model.Tags = model_new.Tags } + // TODO: 只允许管理员更新模型 + if model_new.UserID != 0 && model_new.UserID != model.UserID { + model.UserID = model_new.UserID + } + // 執行更新 if err := configs.ORMDB().Save(&model).Error; err != nil { log.Println(err)