Compare commits
26 Commits
627300d8fd
...
main
Author | SHA1 | Date | |
---|---|---|---|
|
887e0b97bf | ||
d0b9bdb1cb | |||
|
6e34525536 | ||
|
6500fd5b92 | ||
16163e42c0 | |||
4465f1f9f0 | |||
ef0ef48b87 | |||
|
d7161c7df1 | ||
0d43a639da | |||
77589044c9 | |||
c1257c6d29 | |||
f3a5d44c57 | |||
3786304926 | |||
cbcbc64899 | |||
1e8be5dd82 | |||
137e8b556e | |||
d78b9f63b2 | |||
0dbd957454 | |||
655fc8c1c0 | |||
990f702c9f | |||
92921f99eb | |||
dfb7041746 | |||
57b63ff305 | |||
3f5a196da7 | |||
61aea5c20e | |||
65689e3df5 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -121,6 +121,7 @@ celerybeat.pid
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
data
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
@@ -129,6 +130,8 @@ ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
database
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"editor.inlineSuggest.showToolbar": "onHover"
|
||||
}
|
10
README.md
10
README.md
@@ -1,3 +1,11 @@
|
||||
# ocr
|
||||
# OCR
|
||||
|
||||
基于深度学习的文字识别提取标记
|
||||
- 由于当前没有较优的语言分类识别方案, 使用四倍算力换精度
|
||||
- 当前支持 英文 中文 日文 韩文 俄文 的识别
|
||||
- 去除纯数字和单字符以及置信度低于80的文字
|
||||
- 数据转json存储于mysql web_images 每张图像对应的 text 字段
|
||||
- 文字以空格分隔合并为字符串加入 Elasticsearch 索引
|
||||
|
||||
勿使用 paddleocr==2.9.1 存在顯存泄漏問題, 應使用 paddleocr==2.7.3
|
||||
|
||||
|
116
main.py
116
main.py
@@ -1,69 +1,59 @@
|
||||
import io
|
||||
import oss2
|
||||
import requests
|
||||
|
||||
from PIL import Image, ImageFile
|
||||
|
||||
# 读取 .env
|
||||
from dotenv import dotenv_values
|
||||
config = dotenv_values(".env")
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
# 下载图片(使用OSS下载)
|
||||
def download_image(url:str) -> Image:
|
||||
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
|
||||
try:
|
||||
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
|
||||
oss2.defaults.connection_pool_size = 100
|
||||
oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
|
||||
return Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']).get_object(url).read()))
|
||||
except Exception:
|
||||
print('图片下载失败:', url)
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
return Image.open(io.BytesIO(response.content))
|
||||
except Exception:
|
||||
print('图片下载失败:', url)
|
||||
return None
|
||||
|
||||
|
||||
import json
|
||||
import base64
|
||||
import pymysql
|
||||
import pymysql.cursors
|
||||
import cnocr
|
||||
import requests
|
||||
import dotenv
|
||||
|
||||
ocr = cnocr.CnOcr(rec_model_name='ch_PP-OCRv3')
|
||||
conn = pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.DictCursor)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT id, content FROM web_images LIMIT 5")
|
||||
CONFIG = dotenv.dotenv_values(".env")
|
||||
user = "admin"
|
||||
password = "Complexpass#123"
|
||||
bas64encoded_creds = base64.b64encode(bytes(f"{user}:{password}", "utf-8")).decode("utf-8")
|
||||
headers = {"Content-type": "application/json", "Authorization": f"Basic {bas64encoded_creds}"}
|
||||
index = "images"
|
||||
zinc_host = "https://zincsearch.gameui.net"
|
||||
zinc_url = f"{zinc_host}/api/{index}/_doc"
|
||||
|
||||
# 获取查询结果
|
||||
rows = cursor.fetchall()
|
||||
for row in rows:
|
||||
#print(row)
|
||||
image = download_image(row['content'])
|
||||
if image is None:
|
||||
print('图片下载失败,跳过')
|
||||
continue
|
||||
out = ocr.ocr(image)
|
||||
# 这段代码将只包含那些非空、不是纯数字且长度大于1的'text'值
|
||||
texts = [item['text'] for item in out if item['text'] and not item['text'].isdigit() and len(item['text']) > 1]
|
||||
print(texts)
|
||||
# 将数据刷入zinc, 并保持同步更新
|
||||
# 如果SQL中某一条数据被删除, 那么zinc中也要删除
|
||||
|
||||
def connect_to_mysql():
|
||||
return pymysql.connect(host=CONFIG['MYSQL_HOST'], user=CONFIG['MYSQL_USER'], password=CONFIG['MYSQL_PASSWORD'], database=CONFIG['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
|
||||
|
||||
# 查询已存在的数据写入zinc LIMIT 0,10
|
||||
def query_data(conn):
|
||||
with conn.cursor(pymysql.cursors.SSCursor) as cursor:
|
||||
cursor.execute("SELECT id, text FROM web_images WHERE text!='' AND text!='[]'")
|
||||
for id, text in cursor:
|
||||
data = { "_id": str(id), "text": " ".join([item['text'] for item in json.loads(text)]) }
|
||||
res = requests.put(zinc_url, headers=headers, data=json.dumps(data), proxies={'http': '', 'https': ''})
|
||||
print("\033[1;32m{}\033[0m".format(id) if json.loads(res.text)['message'] == 'ok' else id, data['text'])
|
||||
|
||||
query_data(connect_to_mysql())
|
||||
|
||||
|
||||
# 关闭游标和连接
|
||||
cursor.close()
|
||||
conn.close()
|
||||
# TODO 数据被删除时, zinc中也要删除
|
||||
# TODO 可以监听SQL日志, 一旦有数据变动, 就更新zinc
|
||||
# TODO 为数据之间建立事件关联, 当删除一条图像数据时, 也要删除对应的图像
|
||||
|
||||
|
||||
'''
|
||||
from cnocr import CnOcr
|
||||
|
||||
img_fp = './x.jpg'
|
||||
ocr = CnOcr(rec_model_name='ch_PP-OCRv3') # 所有参数都使用默认值
|
||||
out = ocr.ocr(img_fp)
|
||||
|
||||
print(out)
|
||||
'''
|
||||
## 查询数据
|
||||
#query = {
|
||||
# "query": {
|
||||
# "bool": {
|
||||
# "must": [
|
||||
# {
|
||||
# "query_string": {
|
||||
# "query": "City:是否"
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# },
|
||||
# "sort": [
|
||||
# "-@timestamp"
|
||||
# ],
|
||||
# "from": 0,
|
||||
# "size": 100
|
||||
#}
|
||||
#zinc_url = zinc_host + "/es/" + index + "/_search"
|
||||
#res = requests.post(zinc_url, headers=headers, data=json.dumps(query), proxies={'http': '', 'https': ''})
|
||||
#print(json.dumps(json.loads(res.text), indent=4, ensure_ascii=False))
|
||||
|
177
pp.py
Executable file
177
pp.py
Executable file
@@ -0,0 +1,177 @@
|
||||
#!/usr/bin/env python3.10
|
||||
|
||||
import gc
|
||||
import os
|
||||
import io
|
||||
import oss2
|
||||
import time
|
||||
import json
|
||||
import base64
|
||||
import dotenv
|
||||
import pymysql
|
||||
import requests
|
||||
import numpy as np
|
||||
import warnings
|
||||
import logging
|
||||
import paddle
|
||||
|
||||
from PIL import Image, ImageFile
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
paddle.set_flags({'FLAGS_fraction_of_gpu_memory_to_use': 0.4}) # 限制显存占用为GPU的80%
|
||||
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
|
||||
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
|
||||
warnings.filterwarnings("ignore")
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
oss2.defaults.connection_pool_size = 100
|
||||
|
||||
config = dotenv.dotenv_values(".env")
|
||||
user = config['ZINCSEARCH_USER']
|
||||
password = config['ZINCSEARCH_PASSWORD']
|
||||
zinc_host = config['ZINCSEARCH_HOST']
|
||||
index = config['ZINCSEARCH_INDEX']
|
||||
bas64encoded_creds = base64.b64encode(bytes(f"{user}:{password}", "utf-8")).decode("utf-8")
|
||||
headers = {"Content-type": "application/json", "Authorization": f"Basic {bas64encoded_creds}"}
|
||||
zinc_url = f"{zinc_host}/api/{index}/_doc"
|
||||
|
||||
|
||||
class MyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.float32):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.astype(int).tolist()
|
||||
return super(MyEncoder, self).default(obj)
|
||||
|
||||
|
||||
def download_image(url: str, max_size=32767) -> Image.Image:
|
||||
if url.endswith('.gif') or url.endswith('.GIF'):
|
||||
print(f'跳过GIF {url}')
|
||||
return None
|
||||
try:
|
||||
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
|
||||
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
|
||||
if os.path.exists(url):
|
||||
print(f'从本地读取图片 {url}')
|
||||
img = Image.open(url)
|
||||
else:
|
||||
print(f'从OSS下载图片 {url}')
|
||||
oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
|
||||
bucket = oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME'])
|
||||
img = Image.open(io.BytesIO(bucket.get_object(url).read()))
|
||||
else:
|
||||
print(f'从网络下载图片 {url}')
|
||||
response = requests.get(url)
|
||||
img = Image.open(io.BytesIO(response.content))
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
if max(img.size) > max_size:
|
||||
print(f'跳过尺寸过大的图像 {url}')
|
||||
return None
|
||||
return img
|
||||
except Exception as e:
|
||||
print(f'图片从{url}下载失败,错误信息为:{e}')
|
||||
return None
|
||||
|
||||
|
||||
def connect_to_mysql():
|
||||
return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
|
||||
|
||||
|
||||
|
||||
# 中英日韩俄
|
||||
EN = PaddleOCR(use_angle_cls=True, lang="en")
|
||||
CH = PaddleOCR(use_angle_cls=True, lang="ch")
|
||||
JP = PaddleOCR(use_angle_cls=True, lang="japan")
|
||||
KR = PaddleOCR(use_angle_cls=True, lang="korean")
|
||||
RU = PaddleOCR(use_angle_cls=True, lang="ru")
|
||||
|
||||
# 运行OCR并清理内存
|
||||
def process_ocr(model, image):
|
||||
result = model.ocr(image, cls=True)[0] or []
|
||||
paddle.device.cuda.empty_cache() # 清理缓存
|
||||
gc.collect() # 强制垃圾回收
|
||||
return result
|
||||
|
||||
def process_images(conn, offset=0) -> int:
|
||||
global EN, CH, JP, KR, RU
|
||||
with conn.cursor(pymysql.cursors.SSCursor) as cursor:
|
||||
cursor.execute("SELECT id, content FROM web_images WHERE text='' AND article_category_top_id=22 LIMIT 100 OFFSET %s", (offset,))
|
||||
for id, content in cursor.fetchall():
|
||||
image = download_image(content)
|
||||
if image is None:
|
||||
continue
|
||||
if isinstance(image, Image.Image):
|
||||
image = np.array(image)
|
||||
print(id, content)
|
||||
|
||||
# 執行提取文字
|
||||
#en = EN.ocr(image, cls=True)[0] or []
|
||||
#ch = CH.ocr(image, cls=True)[0] or []
|
||||
#jp = JP.ocr(image, cls=True)[0] or []
|
||||
#kr = KR.ocr(image, cls=True)[0] or []
|
||||
#ru = RU.ocr(image, cls=True)[0] or []
|
||||
# 处理每个模型
|
||||
ru = process_ocr(RU, image)
|
||||
en = process_ocr(EN, image)
|
||||
ch = process_ocr(CH, image)
|
||||
jp = process_ocr(JP, image)
|
||||
kr = process_ocr(KR, image)
|
||||
|
||||
# 排除字符长度小于2的行, 排除纯数字的行, 排除置信度小于 0.8 的行
|
||||
jp = [x for x in jp if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
||||
kr = [x for x in kr if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
||||
ch = [x for x in ch if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
||||
en = [x for x in en if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
||||
ru = [x for x in ru if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
||||
print(f'置信度大于 0.8 的行: jp {len(jp)} kr {len(kr)} ch {len(ch)} en {len(en)} ru {len(ru)}')
|
||||
|
||||
# 去除字符串中包含的数字和标点(不作计数)
|
||||
jp_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in jp]
|
||||
kr_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in kr]
|
||||
ch_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ch]
|
||||
en_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in en]
|
||||
ru_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ru]
|
||||
|
||||
# 计算置信度平均值 x 计算总字数
|
||||
jpx = (np.mean([x[1][1] for x in jp_ex]) if jp_ex else 0) * len(''.join([x[1][0] for x in jp_ex]))
|
||||
krx = (np.mean([x[1][1] for x in kr_ex]) if kr_ex else 0) * len(''.join([x[1][0] for x in kr_ex]))
|
||||
chx = (np.mean([x[1][1] for x in ch_ex]) if ch_ex else 0) * len(''.join([x[1][0] for x in ch_ex]))
|
||||
enx = (np.mean([x[1][1] for x in en_ex]) if en_ex else 0) * len(''.join([x[1][0] for x in en_ex]))
|
||||
rux = (np.mean([x[1][1] for x in ru_ex]) if ru_ex else 0) * len(''.join([x[1][0] for x in ru_ex]))
|
||||
|
||||
# 找出置信度最高的语言, 结构化存储
|
||||
confidences = {'jp': jpx, 'kr': krx, 'ch': chx, 'en': enx, 'ru': rux}
|
||||
max_confidence_language = max(confidences, key=confidences.get)
|
||||
languages = {'en': en, 'ch': ch, 'jp': jp, 'kr': kr, 'ru': ru}
|
||||
data = [{'text': text[0], 'confidence': text[1], 'coordinate': coord} for coord, text in languages[max_confidence_language]]
|
||||
#print("data:", data)
|
||||
|
||||
# 转换为字符串存储到索引库
|
||||
obj = { "_id": str(id), "text": ' '.join([x['text'] for x in data]) }
|
||||
print("转换为字符串存储到索引库:", obj)
|
||||
res = requests.put(zinc_url, headers=headers, data=json.dumps(obj), proxies={'http': '', 'https': ''})
|
||||
print("\033[1;32m{}\033[0m".format(id) if json.loads(res.text)['message'] == 'ok' else obj["id"], obj["text"])
|
||||
|
||||
# 转换为 JSON 存储到数据库
|
||||
with conn.cursor() as c:
|
||||
data = json.dumps(data, ensure_ascii=False, cls=MyEncoder)
|
||||
c.execute("UPDATE web_images SET text = %s WHERE id = %s", (data, id))
|
||||
conn.commit()
|
||||
paddle.device.cuda.empty_cache() # 清理缓存
|
||||
gc.collect() # 强制垃圾回收
|
||||
paddle.device.cuda.empty_cache() # 清理缓存
|
||||
gc.collect() # 强制垃圾回收
|
||||
return offset+100
|
||||
|
||||
def main():
|
||||
conn = connect_to_mysql()
|
||||
offset = 2000
|
||||
while True:
|
||||
print("LOOP:", offset)
|
||||
offset = process_images(conn, offset)
|
||||
time.sleep(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,82 +1,9 @@
|
||||
aiohttp==3.8.6
|
||||
aiosignal==1.3.1
|
||||
appdirs==1.4.4
|
||||
async-timeout==4.0.3
|
||||
attrs==23.1.0
|
||||
certifi==2023.7.22
|
||||
charset-normalizer==3.3.2
|
||||
click==8.1.7
|
||||
cnocr==2.2.4.2
|
||||
cnstd==1.2.3.5
|
||||
coloredlogs==15.0.1
|
||||
contourpy==1.2.0
|
||||
cycler==0.12.1
|
||||
docker-pycreds==0.4.0
|
||||
filelock==3.13.1
|
||||
flatbuffers==23.5.26
|
||||
fonttools==4.44.0
|
||||
frozenlist==1.4.0
|
||||
fsspec==2023.10.0
|
||||
gitdb==4.0.11
|
||||
GitPython==3.1.40
|
||||
huggingface-hub==0.19.0
|
||||
humanfriendly==10.0
|
||||
idna==3.4
|
||||
Jinja2==3.1.2
|
||||
kiwisolver==1.4.5
|
||||
lightning-utilities==0.9.0
|
||||
MarkupSafe==2.1.3
|
||||
matplotlib==3.8.1
|
||||
mpmath==1.3.0
|
||||
multidict==6.0.4
|
||||
networkx==3.2.1
|
||||
numpy==1.26.1
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
nvidia-nccl-cu12==2.18.1
|
||||
nvidia-nvjitlink-cu12==12.3.52
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
onnx==1.15.0
|
||||
onnxruntime==1.16.2
|
||||
opencv-python==4.8.1.78
|
||||
packaging==23.2
|
||||
pandas==2.1.3
|
||||
wheel==0.45.0
|
||||
numpy==1.26.2
|
||||
oss2==2.18.3
|
||||
paddleocr==2.7.3
|
||||
paddlepaddle-gpu=2.6.2
|
||||
Pillow==10.1.0
|
||||
Polygon3==3.0.9.1
|
||||
protobuf==4.25.0
|
||||
psutil==5.9.6
|
||||
pyclipper==1.3.0.post5
|
||||
PyMySQL==1.1.0
|
||||
pyparsing==3.1.1
|
||||
python-dateutil==2.8.2
|
||||
pytorch-lightning==2.1.1
|
||||
pytz==2023.3.post1
|
||||
PyYAML==6.0.1
|
||||
requests==2.31.0
|
||||
scipy==1.11.3
|
||||
seaborn==0.13.0
|
||||
sentry-sdk==1.34.0
|
||||
setproctitle==1.3.3
|
||||
shapely==2.0.2
|
||||
six==1.16.0
|
||||
smmap==5.0.1
|
||||
sympy==1.12
|
||||
torch==2.1.0+cpu
|
||||
torchaudio==2.1.0
|
||||
torchmetrics==1.2.0
|
||||
torchvision==0.16.0+cpu
|
||||
tqdm==4.66.1
|
||||
triton==2.1.0
|
||||
typing_extensions==4.8.0
|
||||
tzdata==2023.3
|
||||
Unidecode==1.3.7
|
||||
urllib3==2.0.7
|
||||
wandb==0.16.0
|
||||
yarl==1.9.2
|
||||
python-dotenv==1.0.0
|
||||
Requests==2.31.0
|
||||
|
Reference in New Issue
Block a user