154 lines
7.1 KiB
Python
Executable File
154 lines
7.1 KiB
Python
Executable File
#!/usr/bin/env python3.10
|
|
|
|
import os
|
|
import io
|
|
import oss2
|
|
import time
|
|
import json
|
|
import base64
|
|
import dotenv
|
|
import pymysql
|
|
import requests
|
|
import numpy as np
|
|
import warnings
|
|
import logging
|
|
|
|
from PIL import Image, ImageFile
|
|
from paddleocr import PaddleOCR
|
|
|
|
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
|
|
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
|
|
warnings.filterwarnings("ignore")
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
oss2.defaults.connection_pool_size = 100
|
|
|
|
config = dotenv.dotenv_values(".env")
|
|
user = config['ZINCSEARCH_USER']
|
|
password = config['ZINCSEARCH_PASSWORD']
|
|
zinc_host = config['ZINCSEARCH_HOST']
|
|
index = config['ZINCSEARCH_INDEX']
|
|
bas64encoded_creds = base64.b64encode(bytes(f"{user}:{password}", "utf-8")).decode("utf-8")
|
|
headers = {"Content-type": "application/json", "Authorization": f"Basic {bas64encoded_creds}"}
|
|
zinc_url = f"{zinc_host}/api/{index}/_doc"
|
|
|
|
|
|
class MyEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, np.float32):
|
|
return int(obj)
|
|
if isinstance(obj, np.ndarray):
|
|
return obj.astype(int).tolist()
|
|
return super(MyEncoder, self).default(obj)
|
|
|
|
|
|
def download_image(url: str, max_size=32767) -> Image.Image:
|
|
if url.endswith('.gif') or url.endswith('.GIF'):
|
|
print(f'跳过GIF {url}')
|
|
return None
|
|
try:
|
|
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
|
|
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
|
|
if os.path.exists(url):
|
|
print(f'从本地读取图片 {url}')
|
|
img = Image.open(url)
|
|
else:
|
|
print(f'从OSS下载图片 {url}')
|
|
oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
|
|
img = Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']).get_object(url).read()))
|
|
else:
|
|
response = requests.get(url)
|
|
img = Image.open(io.BytesIO(response.content))
|
|
if img.mode != 'RGB':
|
|
img = img.convert('RGB')
|
|
if img.size[0] > max_size or img.size[1] > max_size:
|
|
print(f'跳过尺寸过大的图像 {url}')
|
|
return None
|
|
return img
|
|
except Exception as e:
|
|
print(f'图片从{url}下载失败,错误信息为:{e}')
|
|
return None
|
|
|
|
|
|
def connect_to_mysql():
|
|
return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
|
|
|
|
|
|
|
|
# 中英日韩俄
|
|
EN = PaddleOCR(use_angle_cls=True, lang="en")
|
|
CH = PaddleOCR(use_angle_cls=True, lang="ch")
|
|
JP = PaddleOCR(use_angle_cls=True, lang="japan")
|
|
KR = PaddleOCR(use_angle_cls=True, lang="korean")
|
|
RU = PaddleOCR(use_angle_cls=True, lang="ru")
|
|
|
|
def process_images(conn, offset=0) -> int:
|
|
with conn.cursor(pymysql.cursors.SSCursor) as cursor:
|
|
cursor.execute("SELECT id, content FROM web_images WHERE text='' AND article_category_top_id=22 LIMIT 100 OFFSET %s", (offset,))
|
|
for id, content in cursor.fetchall():
|
|
image = download_image(content)
|
|
if image is None:
|
|
continue
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
print('---------------------', id, content)
|
|
|
|
# 執行提取文字
|
|
en = EN.ocr(image, cls=True)[0] or []
|
|
ch = CH.ocr(image, cls=True)[0] or []
|
|
jp = JP.ocr(image, cls=True)[0] or []
|
|
kr = KR.ocr(image, cls=True)[0] or []
|
|
ru = RU.ocr(image, cls=True)[0] or []
|
|
|
|
# 排除字符长度小于2的行, 排除纯数字的行, 排除置信度小于 0.8 的行
|
|
jp = [x for x in jp if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
|
kr = [x for x in kr if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
|
ch = [x for x in ch if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
|
en = [x for x in en if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
|
ru = [x for x in ru if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8]
|
|
print(f'置信度大于 0.8 的行: jp {len(jp)} kr {len(kr)} ch {len(ch)} en {len(en)} ru {len(ru)}')
|
|
|
|
# 去除字符串中包含的数字和标点(不作计数)
|
|
jp_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in jp]
|
|
kr_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in kr]
|
|
ch_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ch]
|
|
en_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in en]
|
|
ru_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ru]
|
|
|
|
# 计算置信度平均值 x 计算总字数
|
|
jpx = (np.mean([x[1][1] for x in jp_ex]) if jp_ex else 0) * len(''.join([x[1][0] for x in jp_ex]))
|
|
krx = (np.mean([x[1][1] for x in kr_ex]) if kr_ex else 0) * len(''.join([x[1][0] for x in kr_ex]))
|
|
chx = (np.mean([x[1][1] for x in ch_ex]) if ch_ex else 0) * len(''.join([x[1][0] for x in ch_ex]))
|
|
enx = (np.mean([x[1][1] for x in en_ex]) if en_ex else 0) * len(''.join([x[1][0] for x in en_ex]))
|
|
rux = (np.mean([x[1][1] for x in ru_ex]) if ru_ex else 0) * len(''.join([x[1][0] for x in ru_ex]))
|
|
|
|
# 找出置信度最高的语言, 结构化存储
|
|
confidences = {'jp': jpx, 'kr': krx, 'ch': chx, 'en': enx, 'ru': rux}
|
|
max_confidence_language = max(confidences, key=confidences.get)
|
|
languages = {'en': en, 'ch': ch, 'jp': jp, 'kr': kr, 'ru': ru}
|
|
data = [{'text': text[0], 'confidence': text[1], 'coordinate': coord} for coord, text in languages[max_confidence_language]]
|
|
print("data:", data)
|
|
|
|
# 转换为字符串存储到索引库
|
|
obj = { "_id": str(id), "text": ' '.join([x['text'] for x in data]) }
|
|
print("转换为字符串存储到索引库:", obj)
|
|
res = requests.put(zinc_url, headers=headers, data=json.dumps(obj), proxies={'http': '', 'https': ''})
|
|
print("\033[1;32m{}\033[0m".format(id) if json.loads(res.text)['message'] == 'ok' else obj["id"], obj["text"])
|
|
|
|
# 转换为 JSON 存储到数据库
|
|
with conn.cursor() as cursor:
|
|
data = json.dumps(data, ensure_ascii=False, cls=MyEncoder)
|
|
cursor.execute("UPDATE web_images SET text = %s WHERE id = %s", (data, id))
|
|
conn.commit()
|
|
return offset+100
|
|
|
|
def main():
|
|
conn = connect_to_mysql()
|
|
offset = 1500
|
|
while True:
|
|
offset = process_images(conn, offset)
|
|
time.sleep(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|