移除配置
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import pymilvus
|
import pymilvus
|
||||||
|
|
||||||
from configs.config import MILVUS_HOST
|
from configs.config import MILVUS_HOST, MILVUS_PORT
|
||||||
|
|
||||||
# 连接 Milvus (开启 Milvus 服务)
|
# 连接 Milvus (开启 Milvus 服务)
|
||||||
collection_name = 'default'
|
collection_name = 'default'
|
||||||
@@ -8,7 +8,7 @@ collection_name = 'default'
|
|||||||
|
|
||||||
# 获取 Milvus 连接
|
# 获取 Milvus 连接
|
||||||
def get_collection(collection_name):
|
def get_collection(collection_name):
|
||||||
pymilvus.connections.connect(host=MILVUS_HOST, port='19530')
|
pymilvus.connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||||
if not pymilvus.utility.has_collection(collection_name):
|
if not pymilvus.utility.has_collection(collection_name):
|
||||||
field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True)
|
field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True)
|
||||||
field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048)
|
field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048)
|
||||||
|
@@ -1,24 +1,27 @@
|
|||||||
import pymysql
|
import pymysql
|
||||||
from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS
|
from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS
|
||||||
|
|
||||||
|
# 创建 MySQL 连接
|
||||||
|
def create_connection():
|
||||||
|
return pymysql.connect(
|
||||||
|
host=MYSQL_HOST,
|
||||||
|
user=MYSQL_USER,
|
||||||
|
port=MYSQL_PORT, # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST
|
||||||
|
password=MYSQL_PASS,
|
||||||
|
database=MYSQL_NAME,
|
||||||
|
local_infile=True,
|
||||||
|
cursorclass=pymysql.cursors.DictCursor
|
||||||
|
)
|
||||||
|
|
||||||
# 连接 MySQL (开启 MySQL 服务)
|
# 连接 MySQL (开启 MySQL 服务)
|
||||||
conn = pymysql.connect(
|
conn = create_connection()
|
||||||
host=MYSQL_HOST,
|
|
||||||
user=MYSQL_USER,
|
|
||||||
port=MYSQL_HOST,
|
|
||||||
password=MYSQL_PASS,
|
|
||||||
database=MYSQL_NAME,
|
|
||||||
local_infile=True,
|
|
||||||
cursorclass=pymysql.cursors.DictCursor
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# 获取 MySQL 连接
|
# 获取 MySQL 连接
|
||||||
def get_cursor():
|
def get_cursor():
|
||||||
|
global conn
|
||||||
try:
|
try:
|
||||||
dx = conn.ping()
|
conn.ping()
|
||||||
return conn.cursor()
|
return conn.cursor()
|
||||||
except Exception:
|
except Exception:
|
||||||
conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor)
|
conn = create_connection()
|
||||||
return conn.cursor()
|
return conn.cursor()
|
||||||
|
@@ -5,8 +5,6 @@ from pydantic import BaseModel
|
|||||||
from configs.config import SQLITE3_PATH
|
from configs.config import SQLITE3_PATH
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
|
__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
|
||||||
__taskdb.row_factory = sqlite3.Row
|
__taskdb.row_factory = sqlite3.Row
|
||||||
@@ -104,81 +102,4 @@ class TaskManager:
|
|||||||
break
|
break
|
||||||
print(f"get task status:{task.id}")
|
print(f"get task status:{task.id}")
|
||||||
break
|
break
|
||||||
#data = requests.get(f"http://localhost:5001/api/task/{task.id}").json()
|
|
||||||
#print(data)
|
|
||||||
## 通过 http 请求获取任务的状态
|
|
||||||
## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket
|
|
||||||
#if task.status != data.status:
|
|
||||||
# task.status = data.status
|
|
||||||
# print("send task status to websocket")
|
|
||||||
# for websocket in task.websockets:
|
|
||||||
# await websocket.send_json(task)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## 任务实体模型
|
|
||||||
#class TaskModel(BaseModel):
|
|
||||||
# id: str
|
|
||||||
# name: str
|
|
||||||
# user_id: int
|
|
||||||
# description: str
|
|
||||||
# created_at: str
|
|
||||||
# updated_at: str
|
|
||||||
# websockets: list=[]
|
|
||||||
#
|
|
||||||
## 向数据库中添加任务
|
|
||||||
#def add_task(name, user_id, description):
|
|
||||||
# print(name, user_id, description)
|
|
||||||
# date = datetime.datetime.now()
|
|
||||||
# cursor = __taskdb.cursor()
|
|
||||||
# cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date))
|
|
||||||
# id = cursor.lastrowid
|
|
||||||
# cursor.close()
|
|
||||||
# __taskdb.commit()
|
|
||||||
# return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date}
|
|
||||||
#
|
|
||||||
#
|
|
||||||
## 从数据库中获取任务列表
|
|
||||||
#def get_tasks(user_id: int=None):
|
|
||||||
# cursor = __taskdb.cursor()
|
|
||||||
# if user_id:
|
|
||||||
# cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,))
|
|
||||||
# else:
|
|
||||||
# cursor.execute("SELECT * FROM tasks")
|
|
||||||
# tasks = cursor.fetchall()
|
|
||||||
# cursor.close()
|
|
||||||
# return [dict(task) for task in tasks]
|
|
||||||
#
|
|
||||||
#
|
|
||||||
## 从数据库中获取指定任务
|
|
||||||
#def get_task(task_id: int):
|
|
||||||
# cursor = __taskdb.cursor()
|
|
||||||
# cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
|
|
||||||
# task = cursor.fetchone()
|
|
||||||
# cursor.close()
|
|
||||||
# return task
|
|
||||||
#
|
|
||||||
#
|
|
||||||
## 更新数据库中的任务
|
|
||||||
#def update_task(task: TaskModel):
|
|
||||||
# cursor = __taskdb.cursor()
|
|
||||||
# cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", (
|
|
||||||
# task.name,
|
|
||||||
# task.description,
|
|
||||||
# datetime.datetime.now(),
|
|
||||||
# task.id
|
|
||||||
# ))
|
|
||||||
# cursor.close()
|
|
||||||
# __taskdb.commit()
|
|
||||||
# return task
|
|
||||||
#
|
|
||||||
#
|
|
||||||
## 从数据库中删除任务
|
|
||||||
#def delete_task(task_id: int):
|
|
||||||
# cursor = __taskdb.cursor()
|
|
||||||
# cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
|
|
||||||
# cursor.close()
|
|
||||||
# __taskdb.commit()
|
|
||||||
# return task_id
|
|
||||||
|
|
||||||
|
|
||||||
|
@@ -11,13 +11,11 @@ from towhee import pipe, ops
|
|||||||
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
|
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
|
||||||
from models.milvus import get_collection, collection_name
|
from models.milvus import get_collection, collection_name
|
||||||
from models.mysql import get_cursor, conn
|
from models.mysql import get_cursor, conn
|
||||||
#from models.resnet import Resnet50
|
|
||||||
|
|
||||||
from configs.config import UPLOAD_PATH
|
from configs.config import UPLOAD_PATH
|
||||||
from utilities.download import download_image
|
from utilities.download import download_image
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
#MODEL = Resnet50()
|
|
||||||
|
|
||||||
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
|
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
|
||||||
|
|
||||||
@@ -37,40 +35,8 @@ async def create_index():
|
|||||||
collection.load()
|
collection.load()
|
||||||
return {'status': True, 'count': collection.num_entities}
|
return {'status': True, 'count': collection.num_entities}
|
||||||
|
|
||||||
'''
|
|
||||||
# 批量生成向量
|
|
||||||
@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False)
|
|
||||||
async def create_vector(count: int = 10):
|
|
||||||
cursor = get_cursor()
|
|
||||||
cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}")
|
|
||||||
images = cursor.fetchall()
|
|
||||||
cursor.close()
|
|
||||||
for item in images:
|
|
||||||
print(item)
|
|
||||||
# 先查询 milvus 中是否存在
|
|
||||||
collection = get_collection(collection_name)
|
|
||||||
data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit
|
|
||||||
if len(data) > 0:
|
|
||||||
cursor = get_cursor()
|
|
||||||
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
|
|
||||||
conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image']))
|
|
||||||
download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True)
|
|
||||||
feat = MODEL.resnet50_extract_feat(img_path)
|
|
||||||
collection.insert([[item['id']], [feat], [item['article_id']]])
|
|
||||||
cursor = get_cursor()
|
|
||||||
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
|
|
||||||
conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
print('END')
|
|
||||||
return images
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
# 重建指定图像的向量
|
||||||
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
|
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
|
||||||
async def rewrite_image(image_id: int):
|
async def rewrite_image(image_id: int):
|
||||||
print('START', image_id, '重建向量')
|
print('START', image_id, '重建向量')
|
||||||
|
Reference in New Issue
Block a user