import os import random import string import sqlite3 import numpy as np import time import io from PIL import Image from towhee import pipe, ops from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header from models.milvus import get_collection, collection_name from models.mysql import get_cursor, conn from configs.config import UPLOAD_PATH from utilities.download import download_image router = APIRouter() RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec')) # 获取状态统计 @router.get('', summary='状态统计', description='通过表名获取状态统计') def count_images(): collection = get_collection(collection_name) return {'status': True, 'count': collection.num_entities} # 手动重建索引 @router.get('/create_index', summary='重建索引', description='手动重建索引', include_in_schema=False) async def create_index(): collection = get_collection(collection_name) default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}} collection.create_index(field_name="embedding", index_params=default_index) collection.load() return {'status': True, 'count': collection.num_entities} # 重建指定图像的向量 @router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量') async def rewrite_image(image_id: int): print('START', image_id, '重建向量') cursor = get_cursor() cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") img = cursor.fetchone() cursor.close() if img is None: print('mysql中原始图片不存在:', image_id) return Response('图片不存在', status_code=404) img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) print('img_path', img_path) image = download_image(img['content']) if image is None: print('图片下载失败:', img['content']) return Response('图片下载失败', status_code=404) image.save(img_path, 'png', save_all=True) with Image.open(img_path) as imgx: feat = RESNET50(imgx).get()[0] collection = get_collection(collection_name) collection.delete(expr=f'id in [{image_id}]') rest = collection.insert([[image_id], [feat], [img['article_id']]]) os.remove(img_path) print('END', image_id, '重建向量', rest.primary_keys) return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()} # 获取相似(废弃) @router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片') async def similar_images(image_id: int, page: int = 1, pageSize: int = 20): collection = get_collection(collection_name) result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1) # 如果没有结果, 则重新生成记录 if len(result) == 0: cursor = get_cursor() cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") img = cursor.fetchone() cursor.close() if img is None: print('mysql 中图片不存在:', image_id) return Response('图片不存在', status_code=404) img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) image = download_image(img['content']) if image is None: print('图片下载失败:', img['content']) return Response('图片下载失败', status_code=404) image.save(img_path, 'png', save_all=True) with Image.open(img_path) as imgx: feat = RESNET50(imgx).get()[0] # 移除可能存在的旧记录, 换上新的 collection.delete(expr=f'id in [{image_id}]') collection.insert([[image_id], [feat], [img['article_id']]]) os.remove(img_path) print('生成') else: print('通过') feat = result[0]['embedding'] res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0] # 翻页(截取有效范围, page * pageize) ope = page*pageSize-pageSize end = page*pageSize next = False # 为数据附加信息 ids = [i.id for i in res] # 获取所有ID if len(res) <= end: ids = ids[ope:] next = False else: ids = ids[ope:end] next = True str_ids = str(ids).replace('[', '').replace(']', '') if str_ids == '': print('没有更多数据了') return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []} cursor = get_cursor() cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})") imgs = cursor.fetchall() if len(imgs) == 0: return imgs # 获取用户ID和文章ID uids = list(set([x['user_id'] for x in imgs])) tids = list(set([x['article_id'] for x in imgs])) # 获取用户信息 cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") users = cursor.fetchall() # 获取文章信息 cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") articles = cursor.fetchall() cursor.close() # 合并信息 user, article = {}, {} for x in users: user[x['id']] = x for x in articles: article[x['id']] = x for x in imgs: x['article'] = article[x['article_id']] x['user'] = user[x['user_id']] x['distance'] = [i.distance for i in res if i.id == x['id']][0] if x['praise_count'] == None: x['praise_count'] = 0 if x['collect_count'] == None: x['collect_count'] = 0 # 将字段名转换为驼峰 x['createTime'] = x.pop('create_time') x['updateTime'] = x.pop('update_time') # 对 imgs 重新排序(按照 distance 字段) imgs = sorted(imgs, key=lambda x: x['distance']) return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs} @router.post(path='', summary='以图搜图', description='上传图片进行搜索') async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: int = 20): content = await image.read() img = Image.open(image.file) embeddig = RESNET50(img).get()[0] collection = get_collection('default') res = collection.search([embeddig],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=500)[0] ope, end = (page - 1) * pageSize, page * pageSize ids, nextx = [x.id for x in res][ope:end], len(res) > end if not ids: return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []} cursor = get_cursor() cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})") imgs = cursor.fetchall() if not imgs: return imgs uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs)) cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") users = cursor.fetchall() cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") articles = cursor.fetchall() cursor.close() user, article = {x['id']: x for x in users}, {x['id']: x for x in articles} for x in imgs: x.update({ 'article': article[x['article_id']], 'user': user[x['user_id']], #'distance': next(i.distance for i in res if i.id == x['id'], 0), 'distance': [i.distance for i in res if i.id == x['id']][0], 'praise_count': x.get('praise_count', 0), 'collect_count': x.get('collect_count', 0) }) imgs.sort(key=lambda x: x['distance']) return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs} @router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像") async def delete_images(thread_id: str): collection = get_collection(collection_name) collection.delete(expr="article_id in ["+thread_id+"]") collection.load() default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}} collection.create_index(field_name="embedding", index_params=default_index) return {"status": True, 'msg': '删除完毕'}