Files
ocr/pp.py
2024-11-19 04:26:17 +08:00

197 lines
8.1 KiB
Python
Executable File

#!/usr/bin/env python3
import os
import io
import oss2
import time
import json
import base64
import dotenv
import pymysql
import requests
import numpy as np
import warnings
import logging
from PIL import Image, ImageFile
from paddleocr import PaddleOCR
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
oss2.defaults.connection_pool_size = 100
config = dotenv.dotenv_values(".env")
user = config['ZINCSEARCH_USER']
password = config['ZINCSEARCH_PASSWORD']
zinc_host = config['ZINCSEARCH_HOST']
index = config['ZINCSEARCH_INDEX']
bas64encoded_creds = base64.b64encode(bytes(f"{user}:{password}", "utf-8")).decode("utf-8")
headers = {"Content-type": "application/json", "Authorization": f"Basic {bas64encoded_creds}"}
zinc_url = f"{zinc_host}/api/{index}/_doc"
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.float32):
return int(obj)
if isinstance(obj, np.ndarray):
return obj.astype(int).tolist()
return super(MyEncoder, self).default(obj)
def download_image(url: str, max_size=32767) -> Image.Image:
if url.endswith('.gif') or url.endswith('.GIF'):
print(f'跳过GIF {url}')
return None
try:
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
if os.path.exists(url):
img = Image.open(url)
else:
print(f'从OSS下载图片 {url}')
img = Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']).get_object(url).read()))
else:
response = requests.get(url)
img = Image.open(io.BytesIO(response.content))
if img.mode != 'RGB':
img = img.convert('RGB')
if img.size[0] > max_size or img.size[1] > max_size:
print(f'跳过尺寸过大的图像 {url}')
return None
return img
except Exception as e:
print(f'图片从{url}下载失败,错误信息为:{e}')
return None
def connect_to_mysql():
return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
# 中英日韩俄
EN = PaddleOCR(use_angle_cls=True, lang="en")
CH = PaddleOCR(use_angle_cls=True, lang="ch")
JP = PaddleOCR(use_angle_cls=True, lang="japan")
KR = PaddleOCR(use_angle_cls=True, lang="korean")
RU = PaddleOCR(use_angle_cls=True, lang="ru")
def process_images(conn):
with conn.cursor(pymysql.cursors.SSCursor) as cursor:
cursor.execute("SELECT id, content FROM web_images WHERE text='' AND text!='[]' AND article_category_top_id=22 LIMIT 10")
for id, content in cursor.fetchall():
image = download_image(content)
if image is None:
continue
if isinstance(image, Image.Image):
image = np.array(image)
print('---------------------', id, content)
en = EN.ocr(image, cls=True)[0]
ch = CH.ocr(image, cls=True)[0]
jp = JP.ocr(image, cls=True)[0]
kr = KR.ocr(image, cls=True)[0]
ru = RU.ocr(image, cls=True)[0]
en = en if en is not None else []
ch = ch if ch is not None else []
jp = jp if jp is not None else []
kr = kr if kr is not None else []
ru = ru if ru is not None else []
# 排除字符长度小于2的行
jp = [x for x in jp if len(x[1][0]) > 1]
kr = [x for x in kr if len(x[1][0]) > 1]
ch = [x for x in ch if len(x[1][0]) > 1]
en = [x for x in en if len(x[1][0]) > 1]
ru = [x for x in ru if len(x[1][0]) > 1]
# 排除纯数字的行
jp = [x for x in jp if not x[1][0].isdigit()]
kr = [x for x in kr if not x[1][0].isdigit()]
ch = [x for x in ch if not x[1][0].isdigit()]
en = [x for x in en if not x[1][0].isdigit()]
ru = [x for x in ru if not x[1][0].isdigit()]
# 排除置信度小于 0.8 的行
jp = [x for x in jp if x[1][1] > 0.8]
kr = [x for x in kr if x[1][1] > 0.8]
ch = [x for x in ch if x[1][1] > 0.8]
en = [x for x in en if x[1][1] > 0.8]
ru = [x for x in ru if x[1][1] > 0.8]
print(f'置信度大于 0.8 的行: jp {len(jp)} kr {len(kr)} ch {len(ch)} en {len(en)} ru {len(ru)}')
# 去除字符串中包含的数字和标点(不作计数)
jp_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in jp]
kr_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in kr]
ch_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ch]
en_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in en]
ru_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ru]
# 计算置信度平均数
jpx = np.mean([x[1][1] for x in jp_ex]) if len(jp_ex) > 0 else 0
krx = np.mean([x[1][1] for x in kr_ex]) if len(kr_ex) > 0 else 0
chx = np.mean([x[1][1] for x in ch_ex]) if len(ch_ex) > 0 else 0
enx = np.mean([x[1][1] for x in en_ex]) if len(en_ex) > 0 else 0
rux = np.mean([x[1][1] for x in ru_ex]) if len(ru_ex) > 0 else 0
# 计算总字数
jpt = len(''.join([x[1][0] for x in jp_ex]))
krt = len(''.join([x[1][0] for x in kr_ex]))
cht = len(''.join([x[1][0] for x in ch_ex]))
ent = len(''.join([x[1][0] for x in en_ex]))
rut = len(''.join([x[1][0] for x in ru_ex]))
# 计算总字数 x 置信度平均数
jpx = jpx * jpt
krx = krx * krt
chx = chx * cht
enx = enx * ent
rux = rux * rut
print('jp', jpx)
print('kr', krx)
print('ch', chx)
print('en', enx)
print('ru', rux)
# 创建一个新的字典,其中键是浮点数(置信度),值是语言
confidence_dict = {jpx: 'jp', krx: 'kr', chx: 'ch', enx: 'en', rux: 'ru'}
# 找出置信度最高的语言
max = np.max([jpx, krx, chx, enx, rux])
max_confidence_language = confidence_dict[max]
# 结构化存储
data = []
# 使用置信度最高的语言作为键来访问字典
all = {'en': en, 'ch': ch, 'jp': jp, 'kr': kr, 'ru': ru}
for 坐标, 文本 in all[max_confidence_language]:
print(max_confidence_language, 坐标, 文本)
data.append({'text': 文本[0], 'confidence': 文本[1], 'coordinate': 坐标 })
# 转换为字符串存储到索引库
obj = { "_id": str(id), "text": ' '.join([x['text'] for x in data]) }
res = requests.put(zinc_url, headers=headers, data=json.dumps(obj), proxies={'http': '', 'https': ''})
print("\033[1;32m{}\033[0m".format(id) if json.loads(res.text)['message'] == 'ok' else id, text)
# 转换为 JSON 存储到数据库
with conn.cursor() as cursor:
data = json.dumps(data, ensure_ascii=False, cls=MyEncoder)
cursor.execute("UPDATE web_images SET text = %s WHERE id = %s", (data, id))
conn.commit()
def main():
conn = connect_to_mysql()
while True:
process_images(conn)
time.sleep(10)
if __name__ == "__main__":
main()