Compare commits

..

28 Commits

Author SHA1 Message Date
散仙
887e0b97bf 版本導致內存泄漏 2024-11-21 20:31:03 +08:00
d0b9bdb1cb 全局變量 2024-11-21 20:02:22 +08:00
散仙
6e34525536 DEBUG 2024-11-21 19:56:16 +08:00
散仙
6500fd5b92 debug 2024-11-21 18:52:59 +08:00
16163e42c0 簡化合併 2024-11-19 15:36:01 +08:00
4465f1f9f0 簡化跳行邏輯 2024-11-19 15:27:40 +08:00
ef0ef48b87 簡化合併 2024-11-19 15:18:41 +08:00
散仙
d7161c7df1 合併寫入 2024-11-19 14:03:17 +08:00
0d43a639da 持续运行 2024-11-19 04:26:17 +08:00
77589044c9 存储到zinc 2023-12-09 00:21:09 +08:00
c1257c6d29 将数据刷入zinc 2023-12-08 19:23:56 +08:00
f3a5d44c57 归并 2023-12-05 03:10:46 +08:00
3786304926 将不是三通道的图像转换为三通道 2023-12-03 17:41:25 +08:00
cbcbc64899 将不是三通道的图像转换为三通道 2023-12-03 17:40:29 +08:00
1e8be5dd82 同步更改 2023-12-03 17:34:25 +08:00
137e8b556e 多语言支持 2023-12-02 20:38:56 +08:00
d78b9f63b2 列印 2023-12-02 15:18:15 +08:00
0dbd957454 同步 2023-12-02 03:50:05 +08:00
655fc8c1c0 多语言调试 2023-12-02 02:41:36 +08:00
990f702c9f 重新生成依赖列表 2023-12-02 02:14:17 +08:00
92921f99eb 移除 text 预览 2023-12-01 02:39:13 +08:00
dfb7041746 标准 2023-12-01 02:33:36 +08:00
57b63ff305 test Elasticsearch 2023-11-20 02:19:39 +08:00
3f5a196da7 Add database functionality and update image
processing
2023-11-17 22:07:11 +08:00
61aea5c20e Remove unused OCR code 2023-11-17 04:13:54 +08:00
65689e3df5 Refactor download_image function and add print
statement for image OCR results
2023-11-16 07:31:19 +08:00
627300d8fd 筛选 2023-11-16 06:25:12 +08:00
50e45944d9 env 2023-11-16 05:37:16 +08:00
7 changed files with 252 additions and 130 deletions

3
.gitignore vendored
View File

@@ -121,6 +121,7 @@ celerybeat.pid
*.sage.py
# Environments
data
.env
.venv
env/
@@ -129,6 +130,8 @@ ENV/
env.bak/
venv.bak/
database
# Spyder project settings
.spyderproject
.spyproject

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"editor.inlineSuggest.showToolbar": "onHover"
}

View File

@@ -1,3 +1,11 @@
# ocr
# OCR
基于深度学习的文字识别提取标记
- 由于当前没有较优的语言分类识别方案, 使用四倍算力换精度
- 当前支持 英文 中文 日文 韩文 俄文 的识别
- 去除纯数字和单字符以及置信度低于80的文字
- 数据转json存储于mysql web_images 每张图像对应的 text 字段
- 文字以空格分隔合并为字符串加入 Elasticsearch 索引
勿使用 paddleocr==2.9.1 存在顯存泄漏問題, 應使用 paddleocr==2.7.3

100
main.py
View File

@@ -1,55 +1,59 @@
import io
import oss2
import requests
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
# 下载图片(使用OSS下载)
def download_image(url:str) -> Image:
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
try:
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
oss2.defaults.connection_pool_size = 100
oss_host = 'oss-cn-shanghai-internal.aliyuncs.com'
oss_auth = oss2.Auth('LTAI4GH3qP6VA3QpmTYCgXEW', 'r2wz4bJty8iYfGIcFmEqlY1yon2Ruy')
return Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{oss_host}', 'gameui-image2').get_object(url).read()))
except Exception:
return None
else:
try:
response = requests.get(url)
return Image.open(io.BytesIO(response.content))
except Exception:
print('图片下载失败:', url)
return None
import json
import base64
import pymysql
import pymysql.cursors
import requests
import dotenv
conn = pymysql.connect(host='172.21.216.35', user='gameui', password='gameui@2022', database='gameui', cursorclass=pymysql.cursors.DictCursor)
cursor = conn.cursor()
cursor.execute("SELECT id, content FROM web_images LIMIT 10")
CONFIG = dotenv.dotenv_values(".env")
user = "admin"
password = "Complexpass#123"
bas64encoded_creds = base64.b64encode(bytes(f"{user}:{password}", "utf-8")).decode("utf-8")
headers = {"Content-type": "application/json", "Authorization": f"Basic {bas64encoded_creds}"}
index = "images"
zinc_host = "https://zincsearch.gameui.net"
zinc_url = f"{zinc_host}/api/{index}/_doc"
# 获取查询结果
rows = cursor.fetchall()
for row in rows:
print(row)
image = download_image(row.content)
# 将数据刷入zinc, 并保持同步更新
# 如果SQL中某一条数据被删除, 那么zinc中也要删除
# 关闭游标和连接
cursor.close()
conn.close()
def connect_to_mysql():
return pymysql.connect(host=CONFIG['MYSQL_HOST'], user=CONFIG['MYSQL_USER'], password=CONFIG['MYSQL_PASSWORD'], database=CONFIG['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
# 查询已存在的数据写入zinc LIMIT 0,10
def query_data(conn):
with conn.cursor(pymysql.cursors.SSCursor) as cursor:
cursor.execute("SELECT id, text FROM web_images WHERE text!='' AND text!='[]'")
for id, text in cursor:
data = { "_id": str(id), "text": " ".join([item['text'] for item in json.loads(text)]) }
res = requests.put(zinc_url, headers=headers, data=json.dumps(data), proxies={'http': '', 'https': ''})
print("\033[1;32m{}\033[0m".format(id) if json.loads(res.text)['message'] == 'ok' else id, data['text'])
query_data(connect_to_mysql())
'''
from cnocr import CnOcr
# TODO 数据被删除时, zinc中也要删除
# TODO 可以监听SQL日志, 一旦有数据变动, 就更新zinc
# TODO 为数据之间建立事件关联, 当删除一条图像数据时, 也要删除对应的图像
img_fp = './x.jpg'
ocr = CnOcr(rec_model_name='ch_PP-OCRv3') # 所有参数都使用默认值
out = ocr.ocr(img_fp)
print(out)
'''
## 查询数据
#query = {
# "query": {
# "bool": {
# "must": [
# {
# "query_string": {
# "query": "City:是否"
# }
# }
# ]
# }
# },
# "sort": [
# "-@timestamp"
# ],
# "from": 0,
# "size": 100
#}
#zinc_url = zinc_host + "/es/" + index + "/_search"
#res = requests.post(zinc_url, headers=headers, data=json.dumps(query), proxies={'http': '', 'https': ''})
#print(json.dumps(json.loads(res.text), indent=4, ensure_ascii=False))

177
pp.py Executable file
View File

@@ -0,0 +1,177 @@
#!/usr/bin/env python3.10
import gc
import os
import io
import oss2
import time
import json
import base64
import dotenv
import pymysql
import requests
import numpy as np
import warnings
import logging
import paddle
from PIL import Image, ImageFile
from paddleocr import PaddleOCR
paddle.set_flags({'FLAGS_fraction_of_gpu_memory_to_use': 0.4}) # 限制显存占用为GPU的80%
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
oss2.defaults.connection_pool_size = 100
config = dotenv.dotenv_values(".env")
user = config['ZINCSEARCH_USER']
password = config['ZINCSEARCH_PASSWORD']
zinc_host = config['ZINCSEARCH_HOST']
index = config['ZINCSEARCH_INDEX']
bas64encoded_creds = base64.b64encode(bytes(f"{user}:{password}", "utf-8")).decode("utf-8")
headers = {"Content-type": "application/json", "Authorization": f"Basic {bas64encoded_creds}"}
zinc_url = f"{zinc_host}/api/{index}/_doc"
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.float32):
return int(obj)
if isinstance(obj, np.ndarray):
return obj.astype(int).tolist()
return super(MyEncoder, self).default(obj)
def download_image(url: str, max_size=32767) -> Image.Image:
if url.endswith('.gif') or url.endswith('.GIF'):
print(f'跳过GIF {url}')
return None
try:
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
if os.path.exists(url):
print(f'从本地读取图片 {url}')
img = Image.open(url)
else:
print(f'从OSS下载图片 {url}')
oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
bucket = oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME'])
img = Image.open(io.BytesIO(bucket.get_object(url).read()))
else:
print(f'从网络下载图片 {url}')
response = requests.get(url)
img = Image.open(io.BytesIO(response.content))
if img.mode != 'RGB':
img = img.convert('RGB')
if max(img.size) > max_size:
print(f'跳过尺寸过大的图像 {url}')
return None
return img
except Exception as e:
print(f'图片从{url}下载失败,错误信息为:{e}')
return None
def connect_to_mysql():
return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
# 中英日韩俄
EN = PaddleOCR(use_angle_cls=True, lang="en")
CH = PaddleOCR(use_angle_cls=True, lang="ch")
JP = PaddleOCR(use_angle_cls=True, lang="japan")
KR = PaddleOCR(use_angle_cls=True, lang="korean")
RU = PaddleOCR(use_angle_cls=True, lang="ru")
# 运行OCR并清理内存
def process_ocr(model, image):
result = model.ocr(image, cls=True)[0] or []
paddle.device.cuda.empty_cache() # 清理缓存
gc.collect() # 强制垃圾回收
return result
def process_images(conn, offset=0) -> int:
global EN, CH, JP, KR, RU
with conn.cursor(pymysql.cursors.SSCursor) as cursor:
cursor.execute("SELECT id, content FROM web_images WHERE text='' AND article_category_top_id=22 LIMIT 100 OFFSET %s", (offset,))
for id, content in cursor.fetchall():
image = download_image(content)
if image is None:
continue
if isinstance(image, Image.Image):
image = np.array(image)
print(id, content)
# 執行提取文字
#en = EN.ocr(image, cls=True)[0] or []
#ch = CH.ocr(image, cls=True)[0] or []
#jp = JP.ocr(image, cls=True)[0] or []
#kr = KR.ocr(image, cls=True)[0] or []
#ru = RU.ocr(image, cls=True)[0] or []
# 处理每个模型
ru = process_ocr(RU, image)
en = process_ocr(EN, image)
ch = process_ocr(CH, image)
jp = process_ocr(JP, image)
kr = process_ocr(KR, image)
# 排除字符长度小于2的行, 排除纯数字的行, 排除置信度小于 0.8 的行
jp = [x for x in jp if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
kr = [x for x in kr if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
ch = [x for x in ch if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
en = [x for x in en if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
ru = [x for x in ru if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
print(f'置信度大于 0.8 的行: jp {len(jp)} kr {len(kr)} ch {len(ch)} en {len(en)} ru {len(ru)}')
# 去除字符串中包含的数字和标点(不作计数)
jp_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in jp]
kr_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in kr]
ch_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ch]
en_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in en]
ru_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ru]
# 计算置信度平均值 x 计算总字数
jpx = (np.mean([x[1][1] for x in jp_ex]) if jp_ex else 0) * len(''.join([x[1][0] for x in jp_ex]))
krx = (np.mean([x[1][1] for x in kr_ex]) if kr_ex else 0) * len(''.join([x[1][0] for x in kr_ex]))
chx = (np.mean([x[1][1] for x in ch_ex]) if ch_ex else 0) * len(''.join([x[1][0] for x in ch_ex]))
enx = (np.mean([x[1][1] for x in en_ex]) if en_ex else 0) * len(''.join([x[1][0] for x in en_ex]))
rux = (np.mean([x[1][1] for x in ru_ex]) if ru_ex else 0) * len(''.join([x[1][0] for x in ru_ex]))
# 找出置信度最高的语言, 结构化存储
confidences = {'jp': jpx, 'kr': krx, 'ch': chx, 'en': enx, 'ru': rux}
max_confidence_language = max(confidences, key=confidences.get)
languages = {'en': en, 'ch': ch, 'jp': jp, 'kr': kr, 'ru': ru}
data = [{'text': text[0], 'confidence': text[1], 'coordinate': coord} for coord, text in languages[max_confidence_language]]
#print("data:", data)
# 转换为字符串存储到索引库
obj = { "_id": str(id), "text": ' '.join([x['text'] for x in data]) }
print("转换为字符串存储到索引库:", obj)
res = requests.put(zinc_url, headers=headers, data=json.dumps(obj), proxies={'http': '', 'https': ''})
print("\033[1;32m{}\033[0m".format(id) if json.loads(res.text)['message'] == 'ok' else obj["id"], obj["text"])
# 转换为 JSON 存储到数据库
with conn.cursor() as c:
data = json.dumps(data, ensure_ascii=False, cls=MyEncoder)
c.execute("UPDATE web_images SET text = %s WHERE id = %s", (data, id))
conn.commit()
paddle.device.cuda.empty_cache() # 清理缓存
gc.collect() # 强制垃圾回收
paddle.device.cuda.empty_cache() # 清理缓存
gc.collect() # 强制垃圾回收
return offset+100
def main():
conn = connect_to_mysql()
offset = 2000
while True:
print("LOOP:", offset)
offset = process_images(conn, offset)
time.sleep(0)
if __name__ == "__main__":
main()

View File

@@ -1,82 +1,9 @@
aiohttp==3.8.6
aiosignal==1.3.1
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.1.0
certifi==2023.7.22
charset-normalizer==3.3.2
click==8.1.7
cnocr==2.2.4.2
cnstd==1.2.3.5
coloredlogs==15.0.1
contourpy==1.2.0
cycler==0.12.1
docker-pycreds==0.4.0
filelock==3.13.1
flatbuffers==23.5.26
fonttools==4.44.0
frozenlist==1.4.0
fsspec==2023.10.0
gitdb==4.0.11
GitPython==3.1.40
huggingface-hub==0.19.0
humanfriendly==10.0
idna==3.4
Jinja2==3.1.2
kiwisolver==1.4.5
lightning-utilities==0.9.0
MarkupSafe==2.1.3
matplotlib==3.8.1
mpmath==1.3.0
multidict==6.0.4
networkx==3.2.1
numpy==1.26.1
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.52
nvidia-nvtx-cu12==12.1.105
onnx==1.15.0
onnxruntime==1.16.2
opencv-python==4.8.1.78
packaging==23.2
pandas==2.1.3
wheel==0.45.0
numpy==1.26.2
oss2==2.18.3
paddleocr==2.7.3
paddlepaddle-gpu=2.6.2
Pillow==10.1.0
Polygon3==3.0.9.1
protobuf==4.25.0
psutil==5.9.6
pyclipper==1.3.0.post5
PyMySQL==1.1.0
pyparsing==3.1.1
python-dateutil==2.8.2
pytorch-lightning==2.1.1
pytz==2023.3.post1
PyYAML==6.0.1
requests==2.31.0
scipy==1.11.3
seaborn==0.13.0
sentry-sdk==1.34.0
setproctitle==1.3.3
shapely==2.0.2
six==1.16.0
smmap==5.0.1
sympy==1.12
torch==2.1.0+cpu
torchaudio==2.1.0
torchmetrics==1.2.0
torchvision==0.16.0+cpu
tqdm==4.66.1
triton==2.1.0
typing_extensions==4.8.0
tzdata==2023.3
Unidecode==1.3.7
urllib3==2.0.7
wandb==0.16.0
yarl==1.9.2
python-dotenv==1.0.0
Requests==2.31.0

BIN
x.jpg

Binary file not shown.

Before

Width:  |  Height:  |  Size: 61 KiB