連接py
This commit is contained in:
		
							
								
								
									
										56
									
								
								bin/main.go
									
									
									
									
									
								
							
							
						
						
									
										56
									
								
								bin/main.go
									
									
									
									
									
								
							@@ -90,24 +90,58 @@ type ListView struct {
 | 
			
		||||
var mysqlConnection models.MysqlConnection
 | 
			
		||||
var milvusConnection models.MilvusConnection
 | 
			
		||||
 | 
			
		||||
func GetNetWorkEmbedding(id int) (embedding []float32) {
 | 
			
		||||
	host := viper.GetString("embedding.host")
 | 
			
		||||
	port := viper.GetInt("embedding.port")
 | 
			
		||||
	httpClient := &http.Client{}
 | 
			
		||||
	request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/reverse/%d", host, port, id), nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("请求失败:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	response, err := httpClient.Do(request)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("请求失败:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer response.Body.Close()
 | 
			
		||||
	var result map[string]interface{}
 | 
			
		||||
	err = json.NewDecoder(response.Body).Decode(&result)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("解析失败:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if result["code"] != 0 {
 | 
			
		||||
		log.Println("请求失败:", result["message"])
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	embedding = result["feature"].([]float32)
 | 
			
		||||
	return embedding
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (image *Image) GetSimilarImagesIdList(collection_name string) (ids []int64) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
 | 
			
		||||
	// 先从milvus中查询图片的向量
 | 
			
		||||
	var embedding []float32
 | 
			
		||||
	result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", image.Id), []string{"embedding"})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("Milvus query failed:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var embedding []float32
 | 
			
		||||
	for _, item := range result {
 | 
			
		||||
		if item.Name() == "embedding" {
 | 
			
		||||
			embedding = item.FieldData().GetVectors().GetFloatVector().Data
 | 
			
		||||
			continue
 | 
			
		||||
		log.Println("查詢向量失敗:", err)
 | 
			
		||||
		embedding = GetNetWorkEmbedding(image.Id)
 | 
			
		||||
	} else {
 | 
			
		||||
		for _, item := range result {
 | 
			
		||||
			if item.Name() == "embedding" {
 | 
			
		||||
				embedding = item.FieldData().GetVectors().GetFloatVector().Data
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// TODO: 处理向量不存在的情况(生成)
 | 
			
		||||
	// TODO: 处理图片不存在的情况(404)
 | 
			
		||||
 | 
			
		||||
	// 处理向量不存在的情况
 | 
			
		||||
	if len(embedding) == 0 {
 | 
			
		||||
		log.Println("向量不存在, 也未能重新生成")
 | 
			
		||||
		return ids
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 用向量查询相似图片
 | 
			
		||||
	topk := 1000
 | 
			
		||||
@@ -115,7 +149,7 @@ func (image *Image) GetSimilarImagesIdList(collection_name string) (ids []int64)
 | 
			
		||||
	vectors := []entity.Vector{entity.FloatVector(embedding)}
 | 
			
		||||
	resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("Milvus search failed:", err)
 | 
			
		||||
		log.Println("搜索相似失敗:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										73
									
								
								bin/resnet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								bin/resnet.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,73 @@
 | 
			
		||||
import json
 | 
			
		||||
import uuid
 | 
			
		||||
import towhee
 | 
			
		||||
 | 
			
		||||
from http.server import BaseHTTPRequestHandler, HTTPServer
 | 
			
		||||
 | 
			
		||||
img_path = './data/test.jpeg'
 | 
			
		||||
feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list()
 | 
			
		||||
print(feat[0])
 | 
			
		||||
 | 
			
		||||
# 定義一個服務器類
 | 
			
		||||
class ResNetServer(BaseHTTPRequestHandler):
 | 
			
		||||
    
 | 
			
		||||
        # 定義一個初始化方法
 | 
			
		||||
        def __init__(self, *args, **kwargs):
 | 
			
		||||
            super().__init__(*args, **kwargs)
 | 
			
		||||
    
 | 
			
		||||
        # 定義一個處理 GET 請求的方法
 | 
			
		||||
        def do_GET(self):
 | 
			
		||||
            self.send_response(200)                       # 設置服務器響應的狀態碼
 | 
			
		||||
            self.send_header('Content-type', 'text/html') # 設置服務器響應的標頭
 | 
			
		||||
            self.end_headers()                            # 完成服務器響應的標頭
 | 
			
		||||
            # 設置服務器響應的內容
 | 
			
		||||
            self.wfile.write(bytes('''
 | 
			
		||||
                <html>
 | 
			
		||||
                <head>
 | 
			
		||||
                <meta charset="utf-8">
 | 
			
		||||
                <title>ResNet Server</title>
 | 
			
		||||
                </head>
 | 
			
		||||
                <body>
 | 
			
		||||
                <h1>ResNet Server</h1>
 | 
			
		||||
                // 這裡是一個簡單的服務器,可以接收圖像文件,並將圖像文件轉換為ResNet50的特徵向量
 | 
			
		||||
                <form action="/" method="post" enctype="multipart/form-data">
 | 
			
		||||
                <input type="file" name="file" />
 | 
			
		||||
                <input type="submit" value="Upload" />
 | 
			
		||||
                </form>
 | 
			
		||||
                <p>Upload an image to get the prediction result.</p>
 | 
			
		||||
                // 發送JSON格式的表單
 | 
			
		||||
                </body>
 | 
			
		||||
                </html>
 | 
			
		||||
            ''', 'utf-8'))
 | 
			
		||||
    
 | 
			
		||||
        ## 定義一個處理 POST 請求的方法
 | 
			
		||||
        #def do_POST(self):
 | 
			
		||||
        #    self.send_response(200)                              # 設置服務器響應的狀態碼
 | 
			
		||||
        #    self.send_header('Content-type', 'application/json') # 設置服務器響應的標頭
 | 
			
		||||
        #    self.end_headers()                                   # 完成服務器響應的標頭
 | 
			
		||||
        #    content_length = int(self.headers['Content-Length']) # 獲取請求的內容長度
 | 
			
		||||
        #    body = self.rfile.read(content_length)               # 獲取請求的內容
 | 
			
		||||
        #    params = json.loads(body)                            # 將請求的內容解析為字典
 | 
			
		||||
        #    img_path = params['img_path']                        # 獲取圖像的路徑
 | 
			
		||||
        #    feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list()
 | 
			
		||||
        #    results = json.dumps(feat[0].tolist())               # 將結果轉換為JSON格式
 | 
			
		||||
        #    self.wfile.write(bytes(results, 'utf-8'))
 | 
			
		||||
 | 
			
		||||
        # 定義一個處理 POST 請求的方法, 接收二進制圖像文件
 | 
			
		||||
        def do_POST(self):
 | 
			
		||||
            self.send_response(200)                              # 設置服務器響應的狀態碼
 | 
			
		||||
            self.send_header('Content-type', 'application/json') # 設置服務器響應的標頭
 | 
			
		||||
            self.end_headers()                                   # 完成服務器響應的標頭
 | 
			
		||||
            content_length = int(self.headers['Content-Length']) # 獲取請求的內容長度
 | 
			
		||||
            body = self.rfile.read(content_length)               # 獲取請求的內容
 | 
			
		||||
            #img_path = './data/' + str(uuid.uuid4())             # 生成一個隨機的圖像文件名
 | 
			
		||||
            #with open(img_path, 'wb') as f:                      # 將二進制圖像文件解碼為圖像數據保存在本地
 | 
			
		||||
            #    f.write(body)
 | 
			
		||||
            feat = towhee.blob(body).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list()
 | 
			
		||||
            results = json.dumps(feat[0].tolist())               # 將結果轉換為JSON格式
 | 
			
		||||
            self.wfile.write(bytes(results, 'utf-8'))            # 將結果發送給客戶端
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 判斷是否是直接執行該文件
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    HTTPServer(('0.0.0.0', 8000), ResNetServer).serve_forever()
 | 
			
		||||
		Reference in New Issue
	
	Block a user