Files
ocr/pp.py
散仙 6e34525536 DEBUG
2024-11-21 19:56:16 +08:00

177 lines
7.9 KiB
Python
Executable File

#!/usr/bin/env python3.10
import gc
import os
import io
import oss2
import time
import json
import base64
import dotenv
import pymysql
import requests
import numpy as np
import warnings
import logging
import paddle
from PIL import Image, ImageFile
from paddleocr import PaddleOCR
paddle.set_flags({'FLAGS_fraction_of_gpu_memory_to_use': 0.6}) # 限制显存占用为GPU的80%
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
oss2.defaults.connection_pool_size = 100
config = dotenv.dotenv_values(".env")
user = config['ZINCSEARCH_USER']
password = config['ZINCSEARCH_PASSWORD']
zinc_host = config['ZINCSEARCH_HOST']
index = config['ZINCSEARCH_INDEX']
bas64encoded_creds = base64.b64encode(bytes(f"{user}:{password}", "utf-8")).decode("utf-8")
headers = {"Content-type": "application/json", "Authorization": f"Basic {bas64encoded_creds}"}
zinc_url = f"{zinc_host}/api/{index}/_doc"
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.float32):
return int(obj)
if isinstance(obj, np.ndarray):
return obj.astype(int).tolist()
return super(MyEncoder, self).default(obj)
def download_image(url: str, max_size=32767) -> Image.Image:
if url.endswith('.gif') or url.endswith('.GIF'):
print(f'跳过GIF {url}')
return None
try:
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
if os.path.exists(url):
print(f'从本地读取图片 {url}')
img = Image.open(url)
else:
print(f'从OSS下载图片 {url}')
oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
bucket = oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME'])
img = Image.open(io.BytesIO(bucket.get_object(url).read()))
else:
print(f'从网络下载图片 {url}')
response = requests.get(url)
img = Image.open(io.BytesIO(response.content))
if img.mode != 'RGB':
img = img.convert('RGB')
if max(img.size) > max_size:
print(f'跳过尺寸过大的图像 {url}')
return None
return img
except Exception as e:
print(f'图片从{url}下载失败,错误信息为:{e}')
return None
def connect_to_mysql():
return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
# 中英日韩俄
EN = PaddleOCR(use_angle_cls=True, lang="en")
CH = PaddleOCR(use_angle_cls=True, lang="ch")
JP = PaddleOCR(use_angle_cls=True, lang="japan")
KR = PaddleOCR(use_angle_cls=True, lang="korean")
RU = PaddleOCR(use_angle_cls=True, lang="ru")
# 运行OCR并清理内存
def process_ocr(model, image):
result = model.ocr(image, cls=True)[0] or []
paddle.device.cuda.empty_cache() # 清理缓存
gc.collect() # 强制垃圾回收
return result
def process_images(conn, offset=0) -> int:
with conn.cursor(pymysql.cursors.SSCursor) as cursor:
cursor.execute("SELECT id, content FROM web_images WHERE text='' AND article_category_top_id=22 LIMIT 100 OFFSET %s", (offset,))
for id, content in cursor.fetchall():
image = download_image(content)
if image is None:
continue
if isinstance(image, Image.Image):
image = np.array(image)
print(id, content)
# 執行提取文字
#en = EN.ocr(image, cls=True)[0] or []
#ch = CH.ocr(image, cls=True)[0] or []
#jp = JP.ocr(image, cls=True)[0] or []
#kr = KR.ocr(image, cls=True)[0] or []
#ru = RU.ocr(image, cls=True)[0] or []
# 处理每个模型
ru = process_ocr(RU, image)
en = process_ocr(EN, image)
ch = process_ocr(CH, image)
jp = process_ocr(JP, image)
kr = process_ocr(KR, image)
# 排除字符长度小于2的行, 排除纯数字的行, 排除置信度小于 0.8 的行
jp = [x for x in jp if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
kr = [x for x in kr if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
ch = [x for x in ch if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
en = [x for x in en if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
ru = [x for x in ru if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
print(f'置信度大于 0.8 的行: jp {len(jp)} kr {len(kr)} ch {len(ch)} en {len(en)} ru {len(ru)}')
# 去除字符串中包含的数字和标点(不作计数)
jp_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in jp]
kr_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in kr]
ch_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ch]
en_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in en]
ru_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ru]
# 计算置信度平均值 x 计算总字数
jpx = (np.mean([x[1][1] for x in jp_ex]) if jp_ex else 0) * len(''.join([x[1][0] for x in jp_ex]))
krx = (np.mean([x[1][1] for x in kr_ex]) if kr_ex else 0) * len(''.join([x[1][0] for x in kr_ex]))
chx = (np.mean([x[1][1] for x in ch_ex]) if ch_ex else 0) * len(''.join([x[1][0] for x in ch_ex]))
enx = (np.mean([x[1][1] for x in en_ex]) if en_ex else 0) * len(''.join([x[1][0] for x in en_ex]))
rux = (np.mean([x[1][1] for x in ru_ex]) if ru_ex else 0) * len(''.join([x[1][0] for x in ru_ex]))
# 找出置信度最高的语言, 结构化存储
confidences = {'jp': jpx, 'kr': krx, 'ch': chx, 'en': enx, 'ru': rux}
max_confidence_language = max(confidences, key=confidences.get)
languages = {'en': en, 'ch': ch, 'jp': jp, 'kr': kr, 'ru': ru}
data = [{'text': text[0], 'confidence': text[1], 'coordinate': coord} for coord, text in languages[max_confidence_language]]
#print("data:", data)
# 转换为字符串存储到索引库
obj = { "_id": str(id), "text": ' '.join([x['text'] for x in data]) }
print("转换为字符串存储到索引库:", obj)
res = requests.put(zinc_url, headers=headers, data=json.dumps(obj), proxies={'http': '', 'https': ''})
print("\033[1;32m{}\033[0m".format(id) if json.loads(res.text)['message'] == 'ok' else obj["id"], obj["text"])
# 转换为 JSON 存储到数据库
with conn.cursor() as c:
data = json.dumps(data, ensure_ascii=False, cls=MyEncoder)
c.execute("UPDATE web_images SET text = %s WHERE id = %s", (data, id))
conn.commit()
paddle.device.cuda.empty_cache() # 清理缓存
gc.collect() # 强制垃圾回收
paddle.device.cuda.empty_cache() # 清理缓存
gc.collect() # 强制垃圾回收
return offset+100
def main():
conn = connect_to_mysql()
offset = 2000
while True:
print("LOOP:", offset)
offset = process_images(conn, offset)
time.sleep(0)
if __name__ == "__main__":
main()