198 lines
8.7 KiB
Python
198 lines
8.7 KiB
Python
import os
|
|
import random
|
|
import string
|
|
import sqlite3
|
|
import numpy as np
|
|
import time
|
|
import io
|
|
from PIL import Image
|
|
from towhee import pipe, ops
|
|
|
|
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
|
|
from models.milvus import get_collection, collection_name
|
|
from models.mysql import get_cursor, conn
|
|
|
|
from configs.config import UPLOAD_PATH
|
|
from utilities.download import download_image
|
|
|
|
router = APIRouter()
|
|
|
|
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
|
|
|
|
# 获取状态统计
|
|
@router.get('', summary='状态统计', description='通过表名获取状态统计')
|
|
def count_images():
|
|
collection = get_collection(collection_name)
|
|
return {'status': True, 'count': collection.num_entities}
|
|
|
|
|
|
# 手动重建索引
|
|
@router.get('/create_index', summary='重建索引', description='手动重建索引', include_in_schema=False)
|
|
async def create_index():
|
|
collection = get_collection(collection_name)
|
|
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
|
|
collection.create_index(field_name="embedding", index_params=default_index)
|
|
collection.load()
|
|
return {'status': True, 'count': collection.num_entities}
|
|
|
|
|
|
# 重建指定图像的向量
|
|
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
|
|
async def rewrite_image(image_id: int):
|
|
print('START', image_id, '重建向量')
|
|
cursor = get_cursor()
|
|
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
|
|
img = cursor.fetchone()
|
|
cursor.close()
|
|
if img is None:
|
|
print('mysql中原始图片不存在:', image_id)
|
|
return Response('图片不存在', status_code=404)
|
|
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
|
|
print('img_path', img_path)
|
|
image = download_image(img['content'])
|
|
if image is None:
|
|
print('图片下载失败:', img['content'])
|
|
return Response('图片下载失败', status_code=404)
|
|
image.save(img_path, 'png', save_all=True)
|
|
|
|
with Image.open(img_path) as imgx:
|
|
feat = RESNET50(imgx).get()[0]
|
|
|
|
collection = get_collection(collection_name)
|
|
collection.delete(expr=f'id in [{image_id}]')
|
|
rest = collection.insert([[image_id], [feat], [img['article_id']]])
|
|
os.remove(img_path)
|
|
print('END', image_id, '重建向量', rest.primary_keys)
|
|
return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()}
|
|
|
|
|
|
# 获取相似(废弃)
|
|
@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片')
|
|
async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
|
|
collection = get_collection(collection_name)
|
|
result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1)
|
|
# 如果没有结果, 则重新生成记录
|
|
if len(result) == 0:
|
|
cursor = get_cursor()
|
|
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
|
|
img = cursor.fetchone()
|
|
cursor.close()
|
|
if img is None:
|
|
print('mysql 中图片不存在:', image_id)
|
|
return Response('图片不存在', status_code=404)
|
|
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
|
|
image = download_image(img['content'])
|
|
if image is None:
|
|
print('图片下载失败:', img['content'])
|
|
return Response('图片下载失败', status_code=404)
|
|
image.save(img_path, 'png', save_all=True)
|
|
with Image.open(img_path) as imgx:
|
|
feat = RESNET50(imgx).get()[0]
|
|
# 移除可能存在的旧记录, 换上新的
|
|
collection.delete(expr=f'id in [{image_id}]')
|
|
collection.insert([[image_id], [feat], [img['article_id']]])
|
|
os.remove(img_path)
|
|
print('生成')
|
|
else:
|
|
print('通过')
|
|
feat = result[0]['embedding']
|
|
res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0]
|
|
# 翻页(截取有效范围, page * pageize)
|
|
ope = page*pageSize-pageSize
|
|
end = page*pageSize
|
|
next = False
|
|
# 为数据附加信息
|
|
ids = [i.id for i in res] # 获取所有ID
|
|
if len(res) <= end:
|
|
ids = ids[ope:]
|
|
next = False
|
|
else:
|
|
ids = ids[ope:end]
|
|
next = True
|
|
str_ids = str(ids).replace('[', '').replace(']', '')
|
|
if str_ids == '':
|
|
print('没有更多数据了')
|
|
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
|
|
cursor = get_cursor()
|
|
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})")
|
|
imgs = cursor.fetchall()
|
|
if len(imgs) == 0:
|
|
return imgs
|
|
# 获取用户ID和文章ID
|
|
uids = list(set([x['user_id'] for x in imgs]))
|
|
tids = list(set([x['article_id'] for x in imgs]))
|
|
# 获取用户信息
|
|
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
|
|
users = cursor.fetchall()
|
|
# 获取文章信息
|
|
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
|
|
articles = cursor.fetchall()
|
|
cursor.close()
|
|
# 合并信息
|
|
user, article = {}, {}
|
|
for x in users: user[x['id']] = x
|
|
for x in articles: article[x['id']] = x
|
|
for x in imgs:
|
|
x['article'] = article[x['article_id']]
|
|
x['user'] = user[x['user_id']]
|
|
x['distance'] = [i.distance for i in res if i.id == x['id']][0]
|
|
if x['praise_count'] == None: x['praise_count'] = 0
|
|
if x['collect_count'] == None: x['collect_count'] = 0
|
|
# 将字段名转换为驼峰
|
|
x['createTime'] = x.pop('create_time')
|
|
x['updateTime'] = x.pop('update_time')
|
|
# 对 imgs 重新排序(按照 distance 字段)
|
|
imgs = sorted(imgs, key=lambda x: x['distance'])
|
|
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs}
|
|
|
|
|
|
@router.post(path='', summary='以图搜图', description='上传图片进行搜索')
|
|
async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: int = 20):
|
|
content = await image.read()
|
|
img = Image.open(image.file)
|
|
embeddig = RESNET50(img).get()[0]
|
|
collection = get_collection('default')
|
|
res = collection.search([embeddig],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=500)[0]
|
|
ope, end = (page - 1) * pageSize, page * pageSize
|
|
ids, nextx = [x.id for x in res][ope:end], len(res) > end
|
|
|
|
if not ids:
|
|
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
|
|
|
|
cursor = get_cursor()
|
|
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})")
|
|
imgs = cursor.fetchall()
|
|
|
|
if not imgs:
|
|
return imgs
|
|
|
|
uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs))
|
|
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
|
|
users = cursor.fetchall()
|
|
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
|
|
articles = cursor.fetchall()
|
|
cursor.close()
|
|
|
|
user, article = {x['id']: x for x in users}, {x['id']: x for x in articles}
|
|
for x in imgs:
|
|
x.update({
|
|
'article': article[x['article_id']],
|
|
'user': user[x['user_id']],
|
|
#'distance': next(i.distance for i in res if i.id == x['id'], 0),
|
|
'distance': [i.distance for i in res if i.id == x['id']][0],
|
|
'praise_count': x.get('praise_count', 0),
|
|
'collect_count': x.get('collect_count', 0)
|
|
})
|
|
|
|
imgs.sort(key=lambda x: x['distance'])
|
|
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
|
|
|
|
@router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像")
|
|
async def delete_images(thread_id: str):
|
|
collection = get_collection(collection_name)
|
|
collection.delete(expr="article_id in ["+thread_id+"]")
|
|
collection.load()
|
|
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
|
|
collection.create_index(field_name="embedding", index_params=default_index)
|
|
return {"status": True, 'msg': '删除完毕'}
|