diff --git a/.gitignore b/.gitignore index c3340af..ea6b25d 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ go.work data/ +venv/ .env .env.example main diff --git a/bin/main.go b/bin/main.go index 92c26a8..acb536a 100644 --- a/bin/main.go +++ b/bin/main.go @@ -90,24 +90,58 @@ type ListView struct { var mysqlConnection models.MysqlConnection var milvusConnection models.MilvusConnection +func GetNetWorkEmbedding(id int) (embedding []float32) { + host := viper.GetString("embedding.host") + port := viper.GetInt("embedding.port") + httpClient := &http.Client{} + request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/reverse/%d", host, port, id), nil) + if err != nil { + log.Println("请求失败:", err) + return + } + response, err := httpClient.Do(request) + if err != nil { + log.Println("请求失败:", err) + return + } + defer response.Body.Close() + var result map[string]interface{} + err = json.NewDecoder(response.Body).Decode(&result) + if err != nil { + log.Println("解析失败:", err) + return + } + if result["code"] != 0 { + log.Println("请求失败:", result["message"]) + return + } + embedding = result["feature"].([]float32) + return embedding +} + func (image *Image) GetSimilarImagesIdList(collection_name string) (ids []int64) { ctx := context.Background() // 先从milvus中查询图片的向量 + var embedding []float32 result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", image.Id), []string{"embedding"}) if err != nil { - log.Println("Milvus query failed:", err) - return - } - var embedding []float32 - for _, item := range result { - if item.Name() == "embedding" { - embedding = item.FieldData().GetVectors().GetFloatVector().Data - continue + log.Println("查詢向量失敗:", err) + embedding = GetNetWorkEmbedding(image.Id) + } else { + for _, item := range result { + if item.Name() == "embedding" { + embedding = item.FieldData().GetVectors().GetFloatVector().Data + continue + } } } - // TODO: 处理向量不存在的情况(生成) - // TODO: 处理图片不存在的情况(404) + + // 处理向量不存在的情况 + if len(embedding) == 0 { + log.Println("向量不存在, 也未能重新生成") + return ids + } // 用向量查询相似图片 topk := 1000 @@ -115,7 +149,7 @@ func (image *Image) GetSimilarImagesIdList(collection_name string) (ids []int64) vectors := []entity.Vector{entity.FloatVector(embedding)} resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp) if err != nil { - log.Println("Milvus search failed:", err) + log.Println("搜索相似失敗:", err) return } diff --git a/bin/resnet.py b/bin/resnet.py new file mode 100644 index 0000000..29a0030 --- /dev/null +++ b/bin/resnet.py @@ -0,0 +1,73 @@ +import json +import uuid +import towhee + +from http.server import BaseHTTPRequestHandler, HTTPServer + +img_path = './data/test.jpeg' +feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list() +print(feat[0]) + +# 定義一個服務器類 +class ResNetServer(BaseHTTPRequestHandler): + + # 定義一個初始化方法 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # 定義一個處理 GET 請求的方法 + def do_GET(self): + self.send_response(200) # 設置服務器響應的狀態碼 + self.send_header('Content-type', 'text/html') # 設置服務器響應的標頭 + self.end_headers() # 完成服務器響應的標頭 + # 設置服務器響應的內容 + self.wfile.write(bytes(''' + +
+ +Upload an image to get the prediction result.
+ // 發送JSON格式的表單 + + + ''', 'utf-8')) + + ## 定義一個處理 POST 請求的方法 + #def do_POST(self): + # self.send_response(200) # 設置服務器響應的狀態碼 + # self.send_header('Content-type', 'application/json') # 設置服務器響應的標頭 + # self.end_headers() # 完成服務器響應的標頭 + # content_length = int(self.headers['Content-Length']) # 獲取請求的內容長度 + # body = self.rfile.read(content_length) # 獲取請求的內容 + # params = json.loads(body) # 將請求的內容解析為字典 + # img_path = params['img_path'] # 獲取圖像的路徑 + # feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list() + # results = json.dumps(feat[0].tolist()) # 將結果轉換為JSON格式 + # self.wfile.write(bytes(results, 'utf-8')) + + # 定義一個處理 POST 請求的方法, 接收二進制圖像文件 + def do_POST(self): + self.send_response(200) # 設置服務器響應的狀態碼 + self.send_header('Content-type', 'application/json') # 設置服務器響應的標頭 + self.end_headers() # 完成服務器響應的標頭 + content_length = int(self.headers['Content-Length']) # 獲取請求的內容長度 + body = self.rfile.read(content_length) # 獲取請求的內容 + #img_path = './data/' + str(uuid.uuid4()) # 生成一個隨機的圖像文件名 + #with open(img_path, 'wb') as f: # 將二進制圖像文件解碼為圖像數據保存在本地 + # f.write(body) + feat = towhee.blob(body).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list() + results = json.dumps(feat[0].tolist()) # 將結果轉換為JSON格式 + self.wfile.write(bytes(results, 'utf-8')) # 將結果發送給客戶端 + + +# 判斷是否是直接執行該文件 +if __name__ == '__main__': + HTTPServer(('0.0.0.0', 8000), ResNetServer).serve_forever() diff --git a/go.mod b/go.mod index 02c8404..4b4ebb7 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/aliyun/aliyun-oss-go-sdk v2.2.7+incompatible github.com/chai2010/webp v1.1.1 github.com/go-sql-driver/mysql v1.7.0 + github.com/sbinet/go-python v0.1.0 github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591 golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect ) diff --git a/go.sum b/go.sum index 47077f1..21bce40 100644 --- a/go.sum +++ b/go.sum @@ -171,6 +171,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= +github.com/sbinet/go-python v0.1.0 h1:WlS8dGoxKMt9/c54U4XQuVhQt79p0uJUdzopuDR4QaI= +github.com/sbinet/go-python v0.1.0/go.mod h1:Pq31TCdgxj39xSYY/VAfsWWrFphYSmx3jmPHtotzQNY= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591 h1:dCWBD4Xchp/XFIR/x6D2l74DtQHvIpHsmpPRHgH9oUo= github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591/go.mod h1:IXC7KN2FEuTEISdePm37qcFyXInAh6pfW35yDjbdfOM=