106 lines
2.9 KiB
Python
106 lines
2.9 KiB
Python
import os
|
|
import sqlite3
|
|
import datetime
|
|
from pydantic import BaseModel
|
|
from configs.config import SQLITE3_PATH
|
|
import uuid
|
|
import time
|
|
|
|
__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
|
|
__taskdb.row_factory = sqlite3.Row
|
|
|
|
# 创建任务表
|
|
__taskdb.execute("""
|
|
CREATE TABLE IF NOT EXISTS tasks (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL,
|
|
user_id INTEGER NOT NULL,
|
|
description TEXT NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)
|
|
""")
|
|
|
|
|
|
# 任务表单模型
|
|
class TaskForm(BaseModel):
|
|
name: str
|
|
description: str
|
|
|
|
|
|
class Task(BaseModel):
|
|
id: str
|
|
name: str
|
|
user_id: int
|
|
description: str
|
|
created_at: str
|
|
updated_at: str
|
|
status: str = "pending"
|
|
progress: int = 0
|
|
websockets: list = []
|
|
|
|
|
|
class TaskManager:
|
|
__tasks = {}
|
|
|
|
# 使用pydantic模型的写法
|
|
def add_task(self, name: str, user_id: int, description: str):
|
|
id = str(uuid.uuid4())
|
|
date = str(datetime.datetime.now())
|
|
task = Task(id=id, name=name, user_id=user_id, description=description, created_at=date, updated_at=date)
|
|
self.__tasks[id] = task
|
|
return task
|
|
|
|
def get_task(self, task_id: str):
|
|
task = self.__tasks.get(task_id)
|
|
if task:
|
|
return task.dict()
|
|
else:
|
|
return None
|
|
|
|
# 直接返回对象的写法
|
|
def get_tasks(self, user_id: int=None):
|
|
if user_id:
|
|
return [task for task in self.__tasks.values() if task.user_id == user_id]
|
|
else:
|
|
return [task for task in self.__tasks.values()]
|
|
|
|
|
|
def remove_task(self, task_id: str):
|
|
if self.__tasks.get(task_id):
|
|
self.__tasks.pop(task_id)
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
# 正确的写法(不等待async函数直接返回结果)
|
|
async def add_websocket(self, task_id: str, websocket):
|
|
print(self.__tasks)
|
|
# 每次调用连接数都是1, 这是错误的
|
|
self.__tasks[task_id].websockets.append(websocket)
|
|
print(f"当前连接数: {len(self.__tasks[task_id].websockets)}")
|
|
|
|
|
|
#if len(task.websockets) > 0 and task.progress == 0 and task.status == "pending":
|
|
# print("任务未被监听, 开启监听")
|
|
# await self.start_listen_task(task)
|
|
|
|
# 移除 websocket
|
|
def remove_websocket(self, task_id: str, websocket):
|
|
task = self.__tasks.get(task_id)
|
|
if task:
|
|
print("连接已断开, 移除 websocket")
|
|
task.websockets.remove(websocket)
|
|
print(f"当前剩余连接数: {len(task.websockets)}")
|
|
|
|
# 正确的写法
|
|
async def start_listen_task(self, task: Task):
|
|
while True:
|
|
print(f"start listen task: {len(task.websockets)}")
|
|
time.sleep(2.5)
|
|
if len(task.websockets) == 0:
|
|
break
|
|
print(f"get task status:{task.id}")
|
|
break
|
|
|