转移
This commit is contained in:
184
models/task.py
Normal file
184
models/task.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import datetime
|
||||
from pydantic import BaseModel
|
||||
from configs.config import SQLITE3_PATH
|
||||
import uuid
|
||||
import time
|
||||
import requests
|
||||
|
||||
|
||||
__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
|
||||
__taskdb.row_factory = sqlite3.Row
|
||||
|
||||
# 创建任务表
|
||||
__taskdb.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
# 任务表单模型
|
||||
class TaskForm(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
user_id: int
|
||||
description: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
status: str = "pending"
|
||||
progress: int = 0
|
||||
websockets: list = []
|
||||
|
||||
|
||||
class TaskManager:
|
||||
__tasks = {}
|
||||
|
||||
# 使用pydantic模型的写法
|
||||
def add_task(self, name: str, user_id: int, description: str):
|
||||
id = str(uuid.uuid4())
|
||||
date = str(datetime.datetime.now())
|
||||
task = Task(id=id, name=name, user_id=user_id, description=description, created_at=date, updated_at=date)
|
||||
self.__tasks[id] = task
|
||||
return task
|
||||
|
||||
def get_task(self, task_id: str):
|
||||
task = self.__tasks.get(task_id)
|
||||
if task:
|
||||
return task.dict()
|
||||
else:
|
||||
return None
|
||||
|
||||
# 直接返回对象的写法
|
||||
def get_tasks(self, user_id: int=None):
|
||||
if user_id:
|
||||
return [task for task in self.__tasks.values() if task.user_id == user_id]
|
||||
else:
|
||||
return [task for task in self.__tasks.values()]
|
||||
|
||||
|
||||
def remove_task(self, task_id: str):
|
||||
if self.__tasks.get(task_id):
|
||||
self.__tasks.pop(task_id)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# 正确的写法(不等待async函数直接返回结果)
|
||||
async def add_websocket(self, task_id: str, websocket):
|
||||
print(self.__tasks)
|
||||
# 每次调用连接数都是1, 这是错误的
|
||||
self.__tasks[task_id].websockets.append(websocket)
|
||||
print(f"当前连接数: {len(self.__tasks[task_id].websockets)}")
|
||||
|
||||
|
||||
#if len(task.websockets) > 0 and task.progress == 0 and task.status == "pending":
|
||||
# print("任务未被监听, 开启监听")
|
||||
# await self.start_listen_task(task)
|
||||
|
||||
# 移除 websocket
|
||||
def remove_websocket(self, task_id: str, websocket):
|
||||
task = self.__tasks.get(task_id)
|
||||
if task:
|
||||
print("连接已断开, 移除 websocket")
|
||||
task.websockets.remove(websocket)
|
||||
print(f"当前剩余连接数: {len(task.websockets)}")
|
||||
|
||||
# 正确的写法
|
||||
async def start_listen_task(self, task: Task):
|
||||
while True:
|
||||
print(f"start listen task: {len(task.websockets)}")
|
||||
time.sleep(2.5)
|
||||
if len(task.websockets) == 0:
|
||||
break
|
||||
print(f"get task status:{task.id}")
|
||||
break
|
||||
#data = requests.get(f"http://localhost:5001/api/task/{task.id}").json()
|
||||
#print(data)
|
||||
## 通过 http 请求获取任务的状态
|
||||
## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket
|
||||
#if task.status != data.status:
|
||||
# task.status = data.status
|
||||
# print("send task status to websocket")
|
||||
# for websocket in task.websockets:
|
||||
# await websocket.send_json(task)
|
||||
|
||||
|
||||
|
||||
## 任务实体模型
|
||||
#class TaskModel(BaseModel):
|
||||
# id: str
|
||||
# name: str
|
||||
# user_id: int
|
||||
# description: str
|
||||
# created_at: str
|
||||
# updated_at: str
|
||||
# websockets: list=[]
|
||||
#
|
||||
## 向数据库中添加任务
|
||||
#def add_task(name, user_id, description):
|
||||
# print(name, user_id, description)
|
||||
# date = datetime.datetime.now()
|
||||
# cursor = __taskdb.cursor()
|
||||
# cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date))
|
||||
# id = cursor.lastrowid
|
||||
# cursor.close()
|
||||
# __taskdb.commit()
|
||||
# return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date}
|
||||
#
|
||||
#
|
||||
## 从数据库中获取任务列表
|
||||
#def get_tasks(user_id: int=None):
|
||||
# cursor = __taskdb.cursor()
|
||||
# if user_id:
|
||||
# cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,))
|
||||
# else:
|
||||
# cursor.execute("SELECT * FROM tasks")
|
||||
# tasks = cursor.fetchall()
|
||||
# cursor.close()
|
||||
# return [dict(task) for task in tasks]
|
||||
#
|
||||
#
|
||||
## 从数据库中获取指定任务
|
||||
#def get_task(task_id: int):
|
||||
# cursor = __taskdb.cursor()
|
||||
# cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
|
||||
# task = cursor.fetchone()
|
||||
# cursor.close()
|
||||
# return task
|
||||
#
|
||||
#
|
||||
## 更新数据库中的任务
|
||||
#def update_task(task: TaskModel):
|
||||
# cursor = __taskdb.cursor()
|
||||
# cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", (
|
||||
# task.name,
|
||||
# task.description,
|
||||
# datetime.datetime.now(),
|
||||
# task.id
|
||||
# ))
|
||||
# cursor.close()
|
||||
# __taskdb.commit()
|
||||
# return task
|
||||
#
|
||||
#
|
||||
## 从数据库中删除任务
|
||||
#def delete_task(task_id: int):
|
||||
# cursor = __taskdb.cursor()
|
||||
# cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
|
||||
# cursor.close()
|
||||
# __taskdb.commit()
|
||||
# return task_id
|
||||
|
||||
|
Reference in New Issue
Block a user