Files
reverse_image_search_gpu/routers/reverse.py
2024-11-04 05:20:42 +08:00

232 lines
10 KiB
Python

import os
import random
import string
import sqlite3
import numpy as np
import time
import io
from PIL import Image
from towhee import pipe, ops
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
from models.milvus import get_collection, collection_name
from models.mysql import get_cursor, conn
#from models.resnet import Resnet50
from configs.config import UPLOAD_PATH
from utilities.download import download_image
router = APIRouter()
#MODEL = Resnet50()
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
# 获取状态统计
@router.get('', summary='状态统计', description='通过表名获取状态统计')
def count_images():
collection = get_collection(collection_name)
return {'status': True, 'count': collection.num_entities}
# 手动重建索引
@router.get('/create_index', summary='重建索引', description='手动重建索引', include_in_schema=False)
async def create_index():
collection = get_collection(collection_name)
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
collection.create_index(field_name="embedding", index_params=default_index)
collection.load()
return {'status': True, 'count': collection.num_entities}
'''
# 批量生成向量
@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False)
async def create_vector(count: int = 10):
cursor = get_cursor()
cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}")
images = cursor.fetchall()
cursor.close()
for item in images:
print(item)
# 先查询 milvus 中是否存在
collection = get_collection(collection_name)
data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit
if len(data) > 0:
cursor = get_cursor()
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
conn.commit()
cursor.close()
continue
try:
img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image']))
download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True)
feat = MODEL.resnet50_extract_feat(img_path)
collection.insert([[item['id']], [feat], [item['article_id']]])
cursor = get_cursor()
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
conn.commit()
cursor.close()
except Exception as e:
print(e)
print('END')
return images
'''
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
async def rewrite_image(image_id: int):
print('START', image_id, '重建向量')
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('mysql中原始图片不存在:', image_id)
return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
print('img_path', img_path)
image = download_image(img['content'])
if image is None:
print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx:
feat = RESNET50(imgx).get()[0]
collection = get_collection(collection_name)
collection.delete(expr=f'id in [{image_id}]')
rest = collection.insert([[image_id], [feat], [img['article_id']]])
os.remove(img_path)
print('END', image_id, '重建向量', rest.primary_keys)
return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()}
# 获取相似(废弃)
@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片')
async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
collection = get_collection(collection_name)
result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1)
# 如果没有结果, 则重新生成记录
if len(result) == 0:
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('mysql 中图片不存在:', image_id)
return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
image = download_image(img['content'])
if image is None:
print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx:
feat = RESNET50(imgx).get()[0]
# 移除可能存在的旧记录, 换上新的
collection.delete(expr=f'id in [{image_id}]')
collection.insert([[image_id], [feat], [img['article_id']]])
os.remove(img_path)
print('生成')
else:
print('通过')
feat = result[0]['embedding']
res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0]
# 翻页(截取有效范围, page * pageize)
ope = page*pageSize-pageSize
end = page*pageSize
next = False
# 为数据附加信息
ids = [i.id for i in res] # 获取所有ID
if len(res) <= end:
ids = ids[ope:]
next = False
else:
ids = ids[ope:end]
next = True
str_ids = str(ids).replace('[', '').replace(']', '')
if str_ids == '':
print('没有更多数据了')
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
cursor = get_cursor()
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})")
imgs = cursor.fetchall()
if len(imgs) == 0:
return imgs
# 获取用户ID和文章ID
uids = list(set([x['user_id'] for x in imgs]))
tids = list(set([x['article_id'] for x in imgs]))
# 获取用户信息
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
users = cursor.fetchall()
# 获取文章信息
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
articles = cursor.fetchall()
cursor.close()
# 合并信息
user, article = {}, {}
for x in users: user[x['id']] = x
for x in articles: article[x['id']] = x
for x in imgs:
x['article'] = article[x['article_id']]
x['user'] = user[x['user_id']]
x['distance'] = [i.distance for i in res if i.id == x['id']][0]
if x['praise_count'] == None: x['praise_count'] = 0
if x['collect_count'] == None: x['collect_count'] = 0
# 将字段名转换为驼峰
x['createTime'] = x.pop('create_time')
x['updateTime'] = x.pop('update_time')
# 对 imgs 重新排序(按照 distance 字段)
imgs = sorted(imgs, key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs}
@router.post(path='', summary='以图搜图', description='上传图片进行搜索')
async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: int = 20):
content = await image.read()
img = Image.open(image.file)
embeddig = RESNET50(img).get()[0]
collection = get_collection('default')
res = collection.search([embeddig],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=500)[0]
ope, end = (page - 1) * pageSize, page * pageSize
ids, nextx = [x.id for x in res][ope:end], len(res) > end
if not ids:
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
cursor = get_cursor()
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})")
imgs = cursor.fetchall()
if not imgs:
return imgs
uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs))
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
users = cursor.fetchall()
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
articles = cursor.fetchall()
cursor.close()
user, article = {x['id']: x for x in users}, {x['id']: x for x in articles}
for x in imgs:
x.update({
'article': article[x['article_id']],
'user': user[x['user_id']],
#'distance': next(i.distance for i in res if i.id == x['id'], 0),
'distance': [i.distance for i in res if i.id == x['id']][0],
'praise_count': x.get('praise_count', 0),
'collect_count': x.get('collect_count', 0)
})
imgs.sort(key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
@router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像")
async def delete_images(thread_id: str):
collection = get_collection(collection_name)
collection.delete(expr="article_id in ["+thread_id+"]")
collection.load()
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
collection.create_index(field_name="embedding", index_params=default_index)
return {"status": True, 'msg': '删除完毕'}