344 lines
10 KiB
Go
344 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"time"
|
|
|
|
"regexp"
|
|
"strconv"
|
|
|
|
"encoding/json"
|
|
|
|
"git.satori.love/gameui/webp/api"
|
|
"git.satori.love/gameui/webp/models"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
"github.com/graphql-go/graphql"
|
|
"github.com/graphql-go/handler"
|
|
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
|
"github.com/spf13/viper"
|
|
|
|
lru "github.com/hashicorp/golang-lru/v2"
|
|
)
|
|
|
|
// string 转换为 int, 如果转换失败则返回默认值
|
|
func stringToInt(str string, defaultValue int) int {
|
|
if str == "" {
|
|
return defaultValue
|
|
}
|
|
value, err := strconv.Atoi(str)
|
|
if err != nil {
|
|
return defaultValue
|
|
}
|
|
return value
|
|
}
|
|
|
|
func LogComponent(startTime int64, r *http.Request) {
|
|
ms := (time.Now().UnixNano() - startTime) / 1000000
|
|
color := "\033[1;32m%d\033[0m"
|
|
if ms > 800 {
|
|
color = "\033[1;31m%dms\033[0m" // 紅色加重
|
|
} else if ms > 500 {
|
|
color = "\033[1;33m%dms\033[0m" // 黃色加重
|
|
} else if ms > 300 {
|
|
color = "\033[1;32m%dms\033[0m" // 綠色加重
|
|
} else if ms > 200 {
|
|
color = "\033[1;34m%dms\033[0m" // 藍色加重
|
|
} else if ms > 100 {
|
|
color = "\033[1;35m%dms\033[0m" // 紫色加重
|
|
} else {
|
|
color = "\033[1;36m%dms\033[0m" // 黑色加重
|
|
}
|
|
endTime := fmt.Sprintf(color, ms)
|
|
method := fmt.Sprintf("\033[1;32m%s\033[0m", r.Method) // 綠色加重
|
|
url := fmt.Sprintf("\033[1;34m%s\033[0m", r.URL) // 藍色加重
|
|
|
|
log.Println(method, url, endTime, r.Header.Get("X-Forwarded-For"))
|
|
}
|
|
|
|
func LogRequest(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer LogComponent(time.Now().UnixNano(), r) // 最后打印日志
|
|
|
|
var user_id int
|
|
if token := r.Header.Get("token"); token != "" {
|
|
user_id = api.ParseToken(token)
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "user_id", user_id)))
|
|
})
|
|
}
|
|
|
|
type Image struct {
|
|
Id int `json:"id" db:"id"`
|
|
Width int `json:"width" db:"width"`
|
|
Height int `json:"height" db:"height"`
|
|
Content string `json:"content" db:"content"`
|
|
ArticleCategoryTopId int `json:"article_category_top_id" db:"article_category_top_id"`
|
|
PraiseCount int `json:"praise_count" db:"praise_count"`
|
|
CollectCount int `json:"collect_count" db:"collect_count"`
|
|
CreateTime time.Time `json:"createTime" db:"createTime"`
|
|
UpdateTime time.Time `json:"updateTime" db:"updateTime"`
|
|
UserID int `json:"user_id" db:"user_id"`
|
|
User models.User `json:"user" db:"user"`
|
|
Article models.Article `json:"article" db:"article"`
|
|
}
|
|
|
|
type Tag struct {
|
|
Id int `json:"id"`
|
|
Name string `json:"name"`
|
|
CreateTime time.Time `json:"create_time"`
|
|
UpdateTime time.Time `json:"update_time"`
|
|
}
|
|
|
|
type History struct {
|
|
Type string `json:"type"`
|
|
CreateTime time.Time `json:"create_time"`
|
|
Data interface{} `json:"data"`
|
|
}
|
|
|
|
type ListView struct {
|
|
Code int `json:"code"`
|
|
Page int `json:"page"`
|
|
PageSize int `json:"pageSize"`
|
|
Total int `json:"total"`
|
|
Next bool `json:"next"`
|
|
List []interface{} `json:"list"`
|
|
}
|
|
|
|
var mysqlConnection models.MysqlConnection
|
|
var milvusConnection models.MilvusConnection
|
|
|
|
func GetNetWorkEmbedding(id int) (embedding []float32) {
|
|
host := viper.GetString("embedding.host")
|
|
port := viper.GetInt("embedding.port")
|
|
httpClient := &http.Client{}
|
|
request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/api/default/%d", host, port, id), nil)
|
|
if err != nil {
|
|
log.Println("请求失败1:", err)
|
|
return
|
|
}
|
|
response, err := httpClient.Do(request)
|
|
if err != nil {
|
|
log.Println("请求失败2:", err)
|
|
return
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
var result struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Feature []float32 `json:"feature"`
|
|
}
|
|
err = json.NewDecoder(response.Body).Decode(&result)
|
|
if err != nil {
|
|
log.Println("解析失败:", err)
|
|
return
|
|
}
|
|
if result.Code != 0 {
|
|
log.Println("请求失败3:", result.Message)
|
|
return
|
|
}
|
|
return result.Feature
|
|
}
|
|
|
|
var lruCache, _ = lru.New[int, []int64](100000)
|
|
|
|
func (image *Image) GetSimilarImagesIdList(collection_name string) (ids []int64) {
|
|
ctx := context.Background()
|
|
|
|
// 先从 LRU 中查询缓存的结果, 如果缓存中有, 直接返回
|
|
if value, ok := lruCache.Get(image.Id); ok {
|
|
return value
|
|
}
|
|
|
|
// 先从milvus中查询图片的向量
|
|
var embedding []float32
|
|
result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", image.Id), []string{"embedding"})
|
|
if err != nil {
|
|
log.Println("查詢向量失敗:", err)
|
|
embedding = GetNetWorkEmbedding(image.Id)
|
|
} else {
|
|
for _, item := range result {
|
|
if item.Name() == "embedding" {
|
|
embedding = item.FieldData().GetVectors().GetFloatVector().Data
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
// 处理向量不存在的情况
|
|
if len(embedding) == 0 {
|
|
log.Println("向量不存在, 也未能重新生成")
|
|
return ids
|
|
}
|
|
|
|
// 用向量查询相似图片
|
|
topk := 200
|
|
sp, _ := entity.NewIndexIvfFlatSearchParam(64)
|
|
vectors := []entity.Vector{entity.FloatVector(embedding)}
|
|
resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp)
|
|
if err != nil {
|
|
log.Println("搜索相似失敗:", err)
|
|
return
|
|
}
|
|
|
|
// 输出结果
|
|
for _, item := range resultx {
|
|
ids = item.IDs.FieldData().GetScalars().GetLongData().GetData()
|
|
}
|
|
|
|
// 将结果缓存到 LRU 中
|
|
lruCache.Add(image.Id, ids)
|
|
|
|
return ids
|
|
}
|
|
|
|
func main() {
|
|
runtime.GOMAXPROCS(runtime.NumCPU() - 1)
|
|
|
|
configFilePath := flag.String("config", "./data/config.yaml", "配置文件路径")
|
|
flag.Parse()
|
|
|
|
viper.SetConfigFile(*configFilePath)
|
|
if err := viper.ReadInConfig(); err != nil {
|
|
log.Println("读取配置文件失败", err)
|
|
}
|
|
config := viper.GetViper()
|
|
|
|
models.InitConfig(config)
|
|
models.ZincInit()
|
|
api.InitDefault(config)
|
|
|
|
mysqlConnection.Init()
|
|
milvusConnection.Init()
|
|
err := milvusConnection.Client.LoadCollection(context.Background(), "default", false)
|
|
if err != nil {
|
|
log.Println("Milvus load collection failed:", err)
|
|
return
|
|
}
|
|
|
|
if config.GetBool("oss.local") {
|
|
fmt.Println("开启图像色调计算")
|
|
go api.CheckColorNullRows(0)
|
|
}
|
|
if config.GetBool("gorse.open") {
|
|
fmt.Println("开启用户偏好收集")
|
|
api.GorseInit(config.GetString("gorse.host"), config.GetInt("gorse.port"))
|
|
}
|
|
|
|
schema, err := graphql.NewSchema(graphql.SchemaConfig{Query: graphql.NewObject(graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{
|
|
"users": api.UserItems,
|
|
"games": api.GameItems,
|
|
"works": api.WorkItems,
|
|
"collections": api.CollectionItems,
|
|
"articles": api.ArticleItems,
|
|
"article": api.ArticleItem,
|
|
"images": api.ImageItems,
|
|
"image": api.ImageItem,
|
|
"searchs": api.SearchItems,
|
|
}})})
|
|
|
|
if err != nil {
|
|
log.Fatalf("failed to create new schema, error: %v", err)
|
|
}
|
|
|
|
http.Handle("/api", LogRequest(handler.New(&handler.Config{
|
|
Schema: &schema,
|
|
Playground: true,
|
|
Pretty: false,
|
|
})))
|
|
|
|
// URL 格式: /image/{type}-{id}-{width}x{height}-{fit}.{format}?version
|
|
http.HandleFunc("/image", func(w http.ResponseWriter, r *http.Request) {
|
|
defer LogComponent(time.Now().UnixNano(), r) // 最后打印日志
|
|
|
|
// 如果本地文件存在,直接输出
|
|
filePath := filepath.Join("data/webp", r.URL.Path)
|
|
if _, err := os.Stat(filePath); err == nil {
|
|
http.ServeFile(w, r, filePath)
|
|
return
|
|
}
|
|
|
|
reg := regexp.MustCompile(`^/image/([a-z]+)-([0-9]+)-([0-9]+)x([0-9]+)-([a-z]+).(jpg|jpeg|png|webp)$`)
|
|
matches := reg.FindStringSubmatch(r.URL.Path)
|
|
if len(matches) != 7 {
|
|
log.Println("URL 格式错误", matches)
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
group, id, width, height, fit, format := matches[1], matches[2], stringToInt(matches[3], 0), stringToInt(matches[4], 0), matches[5], matches[6]
|
|
content, err := mysqlConnection.GetImageContent(group, id)
|
|
if err != nil {
|
|
log.Println("获取图片失败", format, err)
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
var img models.Image
|
|
if err := img.Init(content); err != nil {
|
|
log.Println("初始化图片失败", format, err)
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
data, err := img.ToWebP(width, height, fit)
|
|
if err != nil {
|
|
log.Println("转换图片失败", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
err = os.MkdirAll(filepath.Dir(filePath), os.ModePerm)
|
|
if err != nil {
|
|
log.Println("创建文件目录失败:", err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
err = os.WriteFile(filePath, data, 0644)
|
|
if err != nil {
|
|
log.Println("保存文件失败:", err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "image/webp")
|
|
w.Header().Set("Cache-Control", "max-age=604800")
|
|
w.Write(data)
|
|
})
|
|
|
|
// 获取转换后的m3u8视频链接
|
|
http.HandleFunc("/video", func(w http.ResponseWriter, r *http.Request) {
|
|
defer LogComponent(time.Now().UnixNano(), r) // 最后打印日志
|
|
|
|
queryParam := r.URL.Query().Get("url")
|
|
safeParam, err := url.QueryUnescape(queryParam)
|
|
if err != nil {
|
|
log.Println("解码URL失败", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
fmt.Println("safeParam", safeParam)
|
|
urls, err := models.GetVideoM3U8(safeParam)
|
|
fmt.Println("urls", urls, err)
|
|
if err != nil {
|
|
log.Println("获取视频链接失败", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// 将对象转换为有缩进的JSON输出
|
|
json, _ := json.MarshalIndent(urls, "", " ")
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write(json)
|
|
})
|
|
|
|
// 从Viper中读取配置
|
|
port := viper.GetString("server.port")
|
|
log.Println("Server is running at http://localhost:" + port)
|
|
http.ListenAndServe(":"+port, nil)
|
|
}
|