#!/usr/bin/env python3.10 import os import io import oss2 import time import json import base64 import dotenv import pymysql import requests import numpy as np import warnings import logging from PIL import Image, ImageFile from paddleocr import PaddleOCR logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印 logging.disable(logging.WARNING) # 关闭WARNING日志的打印 warnings.filterwarnings("ignore") ImageFile.LOAD_TRUNCATED_IMAGES = True oss2.defaults.connection_pool_size = 100 config = dotenv.dotenv_values(".env") user = config['ZINCSEARCH_USER'] password = config['ZINCSEARCH_PASSWORD'] zinc_host = config['ZINCSEARCH_HOST'] index = config['ZINCSEARCH_INDEX'] bas64encoded_creds = base64.b64encode(bytes(f"{user}:{password}", "utf-8")).decode("utf-8") headers = {"Content-type": "application/json", "Authorization": f"Basic {bas64encoded_creds}"} zinc_url = f"{zinc_host}/api/{index}/_doc" class MyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.float32): return int(obj) if isinstance(obj, np.ndarray): return obj.astype(int).tolist() return super(MyEncoder, self).default(obj) def download_image(url: str, max_size=32767) -> Image.Image: if url.endswith('.gif') or url.endswith('.GIF'): print(f'跳过GIF {url}') return None try: if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'): url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '') if os.path.exists(url): print(f'从本地读取图片 {url}') img = Image.open(url) else: print(f'从OSS下载图片 {url}') oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET']) img = Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']).get_object(url).read())) else: response = requests.get(url) img = Image.open(io.BytesIO(response.content)) if img.mode != 'RGB': img = img.convert('RGB') if img.size[0] > max_size or img.size[1] > max_size: print(f'跳过尺寸过大的图像 {url}') return None return img except Exception as e: print(f'图片从{url}下载失败,错误信息为:{e}') return None def connect_to_mysql(): return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor) # 中英日韩俄 EN = PaddleOCR(use_angle_cls=True, lang="en") CH = PaddleOCR(use_angle_cls=True, lang="ch") JP = PaddleOCR(use_angle_cls=True, lang="japan") KR = PaddleOCR(use_angle_cls=True, lang="korean") RU = PaddleOCR(use_angle_cls=True, lang="ru") def process_images(conn, offset=0) -> int: with conn.cursor(pymysql.cursors.SSCursor) as cursor: cursor.execute("SELECT id, content FROM web_images WHERE text='' AND article_category_top_id=22 LIMIT 100 OFFSET %s", (offset,)) for id, content in cursor.fetchall(): image = download_image(content) if image is None: continue if isinstance(image, Image.Image): image = np.array(image) print('---------------------', id, content) # 執行提取文字 en = EN.ocr(image, cls=True)[0] or [] ch = CH.ocr(image, cls=True)[0] or [] jp = JP.ocr(image, cls=True)[0] or [] kr = KR.ocr(image, cls=True)[0] or [] ru = RU.ocr(image, cls=True)[0] or [] # 排除字符长度小于2的行, 排除纯数字的行, 排除置信度小于 0.8 的行 jp = [x for x in jp if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] kr = [x for x in kr if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] ch = [x for x in ch if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] en = [x for x in en if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] ru = [x for x in ru if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] print(f'置信度大于 0.8 的行: jp {len(jp)} kr {len(kr)} ch {len(ch)} en {len(en)} ru {len(ru)}') # 去除字符串中包含的数字和标点(不作计数) jp_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in jp] kr_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in kr] ch_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ch] en_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in en] ru_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ru] # 计算置信度平均值 x 计算总字数 jpx = (np.mean([x[1][1] for x in jp_ex]) if jp_ex else 0) * len(''.join([x[1][0] for x in jp_ex])) krx = (np.mean([x[1][1] for x in kr_ex]) if kr_ex else 0) * len(''.join([x[1][0] for x in kr_ex])) chx = (np.mean([x[1][1] for x in ch_ex]) if ch_ex else 0) * len(''.join([x[1][0] for x in ch_ex])) enx = (np.mean([x[1][1] for x in en_ex]) if en_ex else 0) * len(''.join([x[1][0] for x in en_ex])) rux = (np.mean([x[1][1] for x in ru_ex]) if ru_ex else 0) * len(''.join([x[1][0] for x in ru_ex])) # 找出置信度最高的语言, 结构化存储 confidences = {'jp': jpx, 'kr': krx, 'ch': chx, 'en': enx, 'ru': rux} max_confidence_language = max(confidences, key=confidences.get) languages = {'en': en, 'ch': ch, 'jp': jp, 'kr': kr, 'ru': ru} data = [{'text': text[0], 'confidence': text[1], 'coordinate': coord} for coord, text in languages[max_confidence_language]] print("data:", data) # 转换为字符串存储到索引库 obj = { "_id": str(id), "text": ' '.join([x['text'] for x in data]) } print("转换为字符串存储到索引库:", obj) res = requests.put(zinc_url, headers=headers, data=json.dumps(obj), proxies={'http': '', 'https': ''}) print("\033[1;32m{}\033[0m".format(id) if json.loads(res.text)['message'] == 'ok' else obj["id"], obj["text"]) # 转换为 JSON 存储到数据库 with conn.cursor() as cursor: data = json.dumps(data, ensure_ascii=False, cls=MyEncoder) cursor.execute("UPDATE web_images SET text = %s WHERE id = %s", (data, id)) conn.commit() return offset+100 def main(): conn = connect_to_mysql() offset = 1500 while True: offset = process_images(conn, offset) time.sleep(0) if __name__ == "__main__": main()