Files
2024-11-11 18:04:50 +08:00

190 lines
9.4 KiB
Python

import os
import random
import string
import sqlite3
import numpy as np
import time
import io
from PIL import Image
from towhee import pipe, ops
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
from models.milvus import get_collection, collection_name
from models.mysql import pool
from configs.config import UPLOAD_PATH
from utilities.download import download_image
router = APIRouter()
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
# 获取状态统计
@router.get('', summary='状态统计', description='通过表名获取状态统计')
def count_images():
collection = get_collection(collection_name)
return {'status': True, 'count': collection.num_entities}
# 手动重建索引
@router.get('/create_index', summary='重建索引', description='手动重建索引', include_in_schema=False)
async def create_index():
collection = get_collection(collection_name)
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
collection.create_index(field_name="embedding", index_params=default_index)
collection.load()
return {'status': True, 'count': collection.num_entities}
# 重建指定图像的向量
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
async def rewrite_image(image_id: int):
with pool.connection() as conn:
with conn.cursor() as cursor:
print('START', image_id, '重建向量')
cursor.execute(f"SELECT content,article_id FROM `web_images` WHERE id={image_id}")
img = cursor.fetchone()
if img is None:
print('mysql中原始图片不存在:', image_id)
return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
print('img_path', img_path)
image = download_image(img['content'])
if image is None:
print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx:
feat = RESNET50(imgx).get()[0]
collection = get_collection(collection_name)
collection.delete(expr=f'id in [{image_id}]')
rest = collection.insert([[image_id], [feat], [img['article_id']]])
os.remove(img_path)
print('END', image_id, '重建向量', rest.primary_keys)
return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()}
# 获取相似(废弃)
@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片')
async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
with pool.connection() as conn:
collection = get_collection(collection_name)
result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1)
# 如果没有结果, 则重新生成记录
if len(result) == 0:
with conn.cursor() as cursor:
cursor.execute(f"SELECT content FROM `web_images` WHERE id={image_id}")
img = cursor.fetchone()
if img is None:
print('mysql 中图片不存在:', image_id)
return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
image = download_image(img['content'])
if image is None:
print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx:
feat = RESNET50(imgx).get()[0]
# 移除可能存在的旧记录, 换上新的
collection.delete(expr=f'id in [{image_id}]')
collection.insert([[image_id], [feat], [img['article_id']]])
os.remove(img_path)
print('生成')
else:
print('通过')
feat = result[0]['embedding']
res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0]
# 翻页(截取有效范围, page * pageize)
ope = page*pageSize-pageSize
end = page*pageSize
next = False
# 为数据附加信息
ids = [i.id for i in res] # 获取所有ID
if len(res) <= end:
ids = ids[ope:]
next = False
else:
ids = ids[ope:end]
next = True
str_ids = str(ids).replace('[', '').replace(']', '')
if str_ids == '':
print('没有更多数据了')
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
with conn.cursor() as cursor:
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})")
imgs = cursor.fetchall()
if len(imgs) == 0:
return imgs
# 获取用户ID和文章ID
uids = list(set([x['user_id'] for x in imgs]))
tids = list(set([x['article_id'] for x in imgs]))
# 获取用户信息
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
users = cursor.fetchall()
# 获取文章信息
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
articles = cursor.fetchall()
# 合并信息
user, article = {}, {}
for x in users: user[x['id']] = x
for x in articles: article[x['id']] = x
for x in imgs:
x['article'] = article[x['article_id']]
x['user'] = user[x['user_id']]
x['distance'] = [i.distance for i in res if i.id == x['id']][0]
if x['praise_count'] == None: x['praise_count'] = 0
if x['collect_count'] == None: x['collect_count'] = 0
# 将字段名转换为驼峰
x['createTime'] = x.pop('create_time')
x['updateTime'] = x.pop('update_time')
# 对 imgs 重新排序(按照 distance 字段)
imgs = sorted(imgs, key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs}
@router.post(path='', summary='以图搜图', description='上传图片进行搜索')
async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: int = 20):
content = await image.read()
img = Image.open(image.file)
embeddig = RESNET50(img).get()[0]
collection = get_collection('default')
res = collection.search([embeddig],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=500)[0]
ope, end = (page - 1) * pageSize, page * pageSize
ids, nextx = [x.id for x in res][ope:end], len(res) > end
if not ids:
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
with pool.connection() as conn:
with conn.cursor() as cursor:
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})")
imgs = cursor.fetchall()
if not imgs:
return imgs
uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs))
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
users = cursor.fetchall()
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
articles = cursor.fetchall()
user, article = {x['id']: x for x in users}, {x['id']: x for x in articles}
for x in imgs:
x.update({
'article': article[x['article_id']],
'user': user[x['user_id']],
'distance': [i.distance for i in res if i.id == x['id']][0],
'praise_count': x.get('praise_count', 0),
'collect_count': x.get('collect_count', 0)
})
imgs.sort(key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
@router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像")
async def delete_images(thread_id: str):
collection = get_collection(collection_name)
collection.delete(expr="article_id in ["+thread_id+"]")
collection.load()
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
collection.create_index(field_name="embedding", index_params=default_index)
return {"status": True, 'msg': '删除完毕'}