移除配置
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import pymilvus
|
||||
|
||||
from configs.config import MILVUS_HOST
|
||||
from configs.config import MILVUS_HOST, MILVUS_PORT
|
||||
|
||||
# 连接 Milvus (开启 Milvus 服务)
|
||||
collection_name = 'default'
|
||||
@@ -8,7 +8,7 @@ collection_name = 'default'
|
||||
|
||||
# 获取 Milvus 连接
|
||||
def get_collection(collection_name):
|
||||
pymilvus.connections.connect(host=MILVUS_HOST, port='19530')
|
||||
pymilvus.connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
if not pymilvus.utility.has_collection(collection_name):
|
||||
field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True)
|
||||
field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048)
|
||||
|
@@ -1,24 +1,27 @@
|
||||
import pymysql
|
||||
from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS
|
||||
|
||||
|
||||
# 连接 MySQL (开启 MySQL 服务)
|
||||
conn = pymysql.connect(
|
||||
# 创建 MySQL 连接
|
||||
def create_connection():
|
||||
return pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
port=MYSQL_HOST,
|
||||
port=MYSQL_PORT, # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST
|
||||
password=MYSQL_PASS,
|
||||
database=MYSQL_NAME,
|
||||
local_infile=True,
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
)
|
||||
|
||||
# 连接 MySQL (开启 MySQL 服务)
|
||||
conn = create_connection()
|
||||
|
||||
# 获取 MySQL 连接
|
||||
def get_cursor():
|
||||
global conn
|
||||
try:
|
||||
dx = conn.ping()
|
||||
conn.ping()
|
||||
return conn.cursor()
|
||||
except Exception:
|
||||
conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor)
|
||||
conn = create_connection()
|
||||
return conn.cursor()
|
||||
|
@@ -5,8 +5,6 @@ from pydantic import BaseModel
|
||||
from configs.config import SQLITE3_PATH
|
||||
import uuid
|
||||
import time
|
||||
import requests
|
||||
|
||||
|
||||
__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
|
||||
__taskdb.row_factory = sqlite3.Row
|
||||
@@ -104,81 +102,4 @@ class TaskManager:
|
||||
break
|
||||
print(f"get task status:{task.id}")
|
||||
break
|
||||
#data = requests.get(f"http://localhost:5001/api/task/{task.id}").json()
|
||||
#print(data)
|
||||
## 通过 http 请求获取任务的状态
|
||||
## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket
|
||||
#if task.status != data.status:
|
||||
# task.status = data.status
|
||||
# print("send task status to websocket")
|
||||
# for websocket in task.websockets:
|
||||
# await websocket.send_json(task)
|
||||
|
||||
|
||||
|
||||
## 任务实体模型
|
||||
#class TaskModel(BaseModel):
|
||||
# id: str
|
||||
# name: str
|
||||
# user_id: int
|
||||
# description: str
|
||||
# created_at: str
|
||||
# updated_at: str
|
||||
# websockets: list=[]
|
||||
#
|
||||
## 向数据库中添加任务
|
||||
#def add_task(name, user_id, description):
|
||||
# print(name, user_id, description)
|
||||
# date = datetime.datetime.now()
|
||||
# cursor = __taskdb.cursor()
|
||||
# cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date))
|
||||
# id = cursor.lastrowid
|
||||
# cursor.close()
|
||||
# __taskdb.commit()
|
||||
# return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date}
|
||||
#
|
||||
#
|
||||
## 从数据库中获取任务列表
|
||||
#def get_tasks(user_id: int=None):
|
||||
# cursor = __taskdb.cursor()
|
||||
# if user_id:
|
||||
# cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,))
|
||||
# else:
|
||||
# cursor.execute("SELECT * FROM tasks")
|
||||
# tasks = cursor.fetchall()
|
||||
# cursor.close()
|
||||
# return [dict(task) for task in tasks]
|
||||
#
|
||||
#
|
||||
## 从数据库中获取指定任务
|
||||
#def get_task(task_id: int):
|
||||
# cursor = __taskdb.cursor()
|
||||
# cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
|
||||
# task = cursor.fetchone()
|
||||
# cursor.close()
|
||||
# return task
|
||||
#
|
||||
#
|
||||
## 更新数据库中的任务
|
||||
#def update_task(task: TaskModel):
|
||||
# cursor = __taskdb.cursor()
|
||||
# cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", (
|
||||
# task.name,
|
||||
# task.description,
|
||||
# datetime.datetime.now(),
|
||||
# task.id
|
||||
# ))
|
||||
# cursor.close()
|
||||
# __taskdb.commit()
|
||||
# return task
|
||||
#
|
||||
#
|
||||
## 从数据库中删除任务
|
||||
#def delete_task(task_id: int):
|
||||
# cursor = __taskdb.cursor()
|
||||
# cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
|
||||
# cursor.close()
|
||||
# __taskdb.commit()
|
||||
# return task_id
|
||||
|
||||
|
||||
|
@@ -11,13 +11,11 @@ from towhee import pipe, ops
|
||||
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
|
||||
from models.milvus import get_collection, collection_name
|
||||
from models.mysql import get_cursor, conn
|
||||
#from models.resnet import Resnet50
|
||||
|
||||
from configs.config import UPLOAD_PATH
|
||||
from utilities.download import download_image
|
||||
|
||||
router = APIRouter()
|
||||
#MODEL = Resnet50()
|
||||
|
||||
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
|
||||
|
||||
@@ -37,40 +35,8 @@ async def create_index():
|
||||
collection.load()
|
||||
return {'status': True, 'count': collection.num_entities}
|
||||
|
||||
'''
|
||||
# 批量生成向量
|
||||
@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False)
|
||||
async def create_vector(count: int = 10):
|
||||
cursor = get_cursor()
|
||||
cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}")
|
||||
images = cursor.fetchall()
|
||||
cursor.close()
|
||||
for item in images:
|
||||
print(item)
|
||||
# 先查询 milvus 中是否存在
|
||||
collection = get_collection(collection_name)
|
||||
data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit
|
||||
if len(data) > 0:
|
||||
cursor = get_cursor()
|
||||
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
continue
|
||||
try:
|
||||
img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image']))
|
||||
download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True)
|
||||
feat = MODEL.resnet50_extract_feat(img_path)
|
||||
collection.insert([[item['id']], [feat], [item['article_id']]])
|
||||
cursor = get_cursor()
|
||||
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print('END')
|
||||
return images
|
||||
'''
|
||||
|
||||
# 重建指定图像的向量
|
||||
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
|
||||
async def rewrite_image(image_id: int):
|
||||
print('START', image_id, '重建向量')
|
||||
|
Reference in New Issue
Block a user