Files
ocr/pp.py
2023-12-02 20:38:56 +08:00

179 lines
7.2 KiB
Python
Executable File

#!/usr/bin/env python3
import io
import requests
import oss2
import pymysql
import json
import numpy as np
import warnings
import logging
from PIL import Image, ImageFile
from dotenv import dotenv_values
from elasticsearch import Elasticsearch
from paddleocr import PaddleOCR
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
config = dotenv_values(".env")
oss2.defaults.connection_pool_size = 100
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.float32):
return int(obj)
if isinstance(obj, np.ndarray):
return obj.astype(int).tolist()
return super(MyEncoder, self).default(obj)
def download_image(url: str) -> Image.Image:
try:
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
url = url.replace('http://image.gameuiux.cn/',
'').replace('https://image.gameuiux.cn/', '')
oss_auth = oss2.Auth(
config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
return Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']).get_object(url).read()))
else:
response = requests.get(url)
return Image.open(io.BytesIO(response.content))
except Exception as e:
print(f'图片从{url}下载失败,错误信息为:{e}')
return None
def connect_to_mysql():
return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
def save_text(conn, id: int, text: str):
with conn.cursor() as cursor:
cursor.execute(
"UPDATE web_images SET text = %s WHERE id = %s", (text, id))
# 中英日韩俄
EN = PaddleOCR(use_angle_cls=True, lang="en")
CH = PaddleOCR(use_angle_cls=True, lang="ch")
JP = PaddleOCR(use_angle_cls=True, lang="japan")
KR = PaddleOCR(use_angle_cls=True, lang="korean")
RU = PaddleOCR(use_angle_cls=True, lang="ru")
def process_images(conn, es):
with conn.cursor(pymysql.cursors.SSCursor) as cursor:
cursor.execute("SELECT id, content FROM web_images LIMIT 0,10") # WHERE text!=''
for id, content in cursor.fetchall():
image = download_image(content)
if image is None:
continue
if isinstance(image, Image.Image):
image = np.array(image)
print('---------------------', id, content)
en = EN.ocr(image, cls=True)[0]
ch = CH.ocr(image, cls=True)[0]
jp = JP.ocr(image, cls=True)[0]
kr = KR.ocr(image, cls=True)[0]
ru = RU.ocr(image, cls=True)[0]
# 排除字符长度小于2的行
jp = [x for x in jp if len(x[1][0]) > 1]
kr = [x for x in kr if len(x[1][0]) > 1]
ch = [x for x in ch if len(x[1][0]) > 1]
en = [x for x in en if len(x[1][0]) > 1]
ru = [x for x in ru if len(x[1][0]) > 1]
# 排除纯数字的行
jp = [x for x in jp if not x[1][0].isdigit()]
kr = [x for x in kr if not x[1][0].isdigit()]
ch = [x for x in ch if not x[1][0].isdigit()]
en = [x for x in en if not x[1][0].isdigit()]
ru = [x for x in ru if not x[1][0].isdigit()]
# 排除置信度小于 0.8 的行
jp = [x for x in jp if x[1][1] > 0.8]
kr = [x for x in kr if x[1][1] > 0.8]
ch = [x for x in ch if x[1][1] > 0.8]
en = [x for x in en if x[1][1] > 0.8]
ru = [x for x in ru if x[1][1] > 0.8]
print(f'置信度大于 0.8 的行: jp {len(jp)} kr {len(kr)} ch {len(ch)} en {len(en)} ru {len(ru)}')
# 去除字符串中包含的数字和标点(不作计数)
jp_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in jp]
kr_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in kr]
ch_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ch]
en_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in en]
ru_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ru]
# 计算置信度平均数
jpx = np.mean([x[1][1] for x in jp_ex]) if len(jp_ex) > 0 else 0
krx = np.mean([x[1][1] for x in kr_ex]) if len(kr_ex) > 0 else 0
chx = np.mean([x[1][1] for x in ch_ex]) if len(ch_ex) > 0 else 0
enx = np.mean([x[1][1] for x in en_ex]) if len(en_ex) > 0 else 0
rux = np.mean([x[1][1] for x in ru_ex]) if len(ru_ex) > 0 else 0
# 计算总字数
jpt = len(''.join([x[1][0] for x in jp_ex]))
krt = len(''.join([x[1][0] for x in kr_ex]))
cht = len(''.join([x[1][0] for x in ch_ex]))
ent = len(''.join([x[1][0] for x in en_ex]))
rut = len(''.join([x[1][0] for x in ru_ex]))
# 计算总字数 x 置信度平均数
jpx = jpx * jpt
krx = krx * krt
chx = chx * cht
enx = enx * ent
rux = rux * rut
print('jp', jpx)
print('kr', krx)
print('ch', chx)
print('en', enx)
print('ru', rux)
# 创建一个新的字典,其中键是浮点数(置信度),值是语言
confidence_dict = {jpx: 'jp', krx: 'kr', chx: 'ch', enx: 'en', rux: 'ru'}
# 找出置信度最高的语言
max = np.max([jpx, krx, chx, enx, rux])
max_confidence_language = confidence_dict[max]
# 结构化存储
data = []
# 使用置信度最高的语言作为键来访问字典
all = {'en': en, 'ch': ch, 'jp': jp, 'kr': kr, 'ru': ru}
for 坐标, 文本 in all[max_confidence_language]:
print(max_confidence_language, 坐标, 文本)
data.append({'text': 文本[0], 'confidence': 文本[1], 'coordinate': 坐标 })
# 转换为字符串存储到索引库
text = ' '.join([x['text'] for x in data])
# 转换为 JSON 存储到数据库
data = json.dumps(data, ensure_ascii=False, cls=MyEncoder)
print(id, text)
save_text(conn, id, data)
es.index(index='web_images', id=id, body={'content': text})
conn.commit()
def main():
es = Elasticsearch(config['ELASTICSEARCH_HOST'], basic_auth=(
config['ELASTICSEARCH_USERNAME'], config['ELASTICSEARCH_PASSWORD']), verify_certs=False)
if not es.indices.exists(index='web_images'):
es.indices.create(index='web_images')
conn = connect_to_mysql()
process_images(conn, es)
if __name__ == "__main__":
main()