連接py
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -22,6 +22,7 @@
|
||||
go.work
|
||||
|
||||
data/
|
||||
venv/
|
||||
.env
|
||||
.env.example
|
||||
main
|
||||
|
48
bin/main.go
48
bin/main.go
@@ -90,24 +90,58 @@ type ListView struct {
|
||||
var mysqlConnection models.MysqlConnection
|
||||
var milvusConnection models.MilvusConnection
|
||||
|
||||
func GetNetWorkEmbedding(id int) (embedding []float32) {
|
||||
host := viper.GetString("embedding.host")
|
||||
port := viper.GetInt("embedding.port")
|
||||
httpClient := &http.Client{}
|
||||
request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s:%d/reverse/%d", host, port, id), nil)
|
||||
if err != nil {
|
||||
log.Println("请求失败:", err)
|
||||
return
|
||||
}
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
log.Println("请求失败:", err)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(response.Body).Decode(&result)
|
||||
if err != nil {
|
||||
log.Println("解析失败:", err)
|
||||
return
|
||||
}
|
||||
if result["code"] != 0 {
|
||||
log.Println("请求失败:", result["message"])
|
||||
return
|
||||
}
|
||||
embedding = result["feature"].([]float32)
|
||||
return embedding
|
||||
}
|
||||
|
||||
func (image *Image) GetSimilarImagesIdList(collection_name string) (ids []int64) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 先从milvus中查询图片的向量
|
||||
var embedding []float32
|
||||
result, err := milvusConnection.Client.Query(ctx, collection_name, nil, fmt.Sprintf("id in [%d]", image.Id), []string{"embedding"})
|
||||
if err != nil {
|
||||
log.Println("Milvus query failed:", err)
|
||||
return
|
||||
}
|
||||
var embedding []float32
|
||||
log.Println("查詢向量失敗:", err)
|
||||
embedding = GetNetWorkEmbedding(image.Id)
|
||||
} else {
|
||||
for _, item := range result {
|
||||
if item.Name() == "embedding" {
|
||||
embedding = item.FieldData().GetVectors().GetFloatVector().Data
|
||||
continue
|
||||
}
|
||||
}
|
||||
// TODO: 处理向量不存在的情况(生成)
|
||||
// TODO: 处理图片不存在的情况(404)
|
||||
}
|
||||
|
||||
// 处理向量不存在的情况
|
||||
if len(embedding) == 0 {
|
||||
log.Println("向量不存在, 也未能重新生成")
|
||||
return ids
|
||||
}
|
||||
|
||||
// 用向量查询相似图片
|
||||
topk := 1000
|
||||
@@ -115,7 +149,7 @@ func (image *Image) GetSimilarImagesIdList(collection_name string) (ids []int64)
|
||||
vectors := []entity.Vector{entity.FloatVector(embedding)}
|
||||
resultx, err := milvusConnection.Client.Search(ctx, collection_name, nil, "", []string{"id", "article_id"}, vectors, "embedding", entity.L2, topk, sp)
|
||||
if err != nil {
|
||||
log.Println("Milvus search failed:", err)
|
||||
log.Println("搜索相似失敗:", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
73
bin/resnet.py
Normal file
73
bin/resnet.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import json
|
||||
import uuid
|
||||
import towhee
|
||||
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
|
||||
img_path = './data/test.jpeg'
|
||||
feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list()
|
||||
print(feat[0])
|
||||
|
||||
# 定義一個服務器類
|
||||
class ResNetServer(BaseHTTPRequestHandler):
|
||||
|
||||
# 定義一個初始化方法
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# 定義一個處理 GET 請求的方法
|
||||
def do_GET(self):
|
||||
self.send_response(200) # 設置服務器響應的狀態碼
|
||||
self.send_header('Content-type', 'text/html') # 設置服務器響應的標頭
|
||||
self.end_headers() # 完成服務器響應的標頭
|
||||
# 設置服務器響應的內容
|
||||
self.wfile.write(bytes('''
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>ResNet Server</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>ResNet Server</h1>
|
||||
// 這裡是一個簡單的服務器,可以接收圖像文件,並將圖像文件轉換為ResNet50的特徵向量
|
||||
<form action="/" method="post" enctype="multipart/form-data">
|
||||
<input type="file" name="file" />
|
||||
<input type="submit" value="Upload" />
|
||||
</form>
|
||||
<p>Upload an image to get the prediction result.</p>
|
||||
// 發送JSON格式的表單
|
||||
</body>
|
||||
</html>
|
||||
''', 'utf-8'))
|
||||
|
||||
## 定義一個處理 POST 請求的方法
|
||||
#def do_POST(self):
|
||||
# self.send_response(200) # 設置服務器響應的狀態碼
|
||||
# self.send_header('Content-type', 'application/json') # 設置服務器響應的標頭
|
||||
# self.end_headers() # 完成服務器響應的標頭
|
||||
# content_length = int(self.headers['Content-Length']) # 獲取請求的內容長度
|
||||
# body = self.rfile.read(content_length) # 獲取請求的內容
|
||||
# params = json.loads(body) # 將請求的內容解析為字典
|
||||
# img_path = params['img_path'] # 獲取圖像的路徑
|
||||
# feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list()
|
||||
# results = json.dumps(feat[0].tolist()) # 將結果轉換為JSON格式
|
||||
# self.wfile.write(bytes(results, 'utf-8'))
|
||||
|
||||
# 定義一個處理 POST 請求的方法, 接收二進制圖像文件
|
||||
def do_POST(self):
|
||||
self.send_response(200) # 設置服務器響應的狀態碼
|
||||
self.send_header('Content-type', 'application/json') # 設置服務器響應的標頭
|
||||
self.end_headers() # 完成服務器響應的標頭
|
||||
content_length = int(self.headers['Content-Length']) # 獲取請求的內容長度
|
||||
body = self.rfile.read(content_length) # 獲取請求的內容
|
||||
#img_path = './data/' + str(uuid.uuid4()) # 生成一個隨機的圖像文件名
|
||||
#with open(img_path, 'wb') as f: # 將二進制圖像文件解碼為圖像數據保存在本地
|
||||
# f.write(body)
|
||||
feat = towhee.blob(body).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list()
|
||||
results = json.dumps(feat[0].tolist()) # 將結果轉換為JSON格式
|
||||
self.wfile.write(bytes(results, 'utf-8')) # 將結果發送給客戶端
|
||||
|
||||
|
||||
# 判斷是否是直接執行該文件
|
||||
if __name__ == '__main__':
|
||||
HTTPServer(('0.0.0.0', 8000), ResNetServer).serve_forever()
|
1
go.mod
1
go.mod
@@ -38,6 +38,7 @@ require (
|
||||
github.com/aliyun/aliyun-oss-go-sdk v2.2.7+incompatible
|
||||
github.com/chai2010/webp v1.1.1
|
||||
github.com/go-sql-driver/mysql v1.7.0
|
||||
github.com/sbinet/go-python v0.1.0
|
||||
github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect
|
||||
)
|
||||
|
2
go.sum
2
go.sum
@@ -171,6 +171,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
|
||||
github.com/sbinet/go-python v0.1.0 h1:WlS8dGoxKMt9/c54U4XQuVhQt79p0uJUdzopuDR4QaI=
|
||||
github.com/sbinet/go-python v0.1.0/go.mod h1:Pq31TCdgxj39xSYY/VAfsWWrFphYSmx3jmPHtotzQNY=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591 h1:dCWBD4Xchp/XFIR/x6D2l74DtQHvIpHsmpPRHgH9oUo=
|
||||
github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591/go.mod h1:IXC7KN2FEuTEISdePm37qcFyXInAh6pfW35yDjbdfOM=
|
||||
|
Reference in New Issue
Block a user