#!/usr/bin/env python3.10 import gc import os import io import oss2 import time import json import base64 import dotenv import pymysql import requests import numpy as np import warnings import logging import paddle from PIL import Image, ImageFile from paddleocr import PaddleOCR paddle.set_flags({'FLAGS_fraction_of_gpu_memory_to_use': 0.4}) # 限制显存占用为GPU的80% logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印 logging.disable(logging.WARNING) # 关闭WARNING日志的打印 warnings.filterwarnings("ignore") ImageFile.LOAD_TRUNCATED_IMAGES = True oss2.defaults.connection_pool_size = 100 config = dotenv.dotenv_values(".env") user = config['ZINCSEARCH_USER'] password = config['ZINCSEARCH_PASSWORD'] zinc_host = config['ZINCSEARCH_HOST'] index = config['ZINCSEARCH_INDEX'] bas64encoded_creds = base64.b64encode(bytes(f"{user}:{password}", "utf-8")).decode("utf-8") headers = {"Content-type": "application/json", "Authorization": f"Basic {bas64encoded_creds}"} zinc_url = f"{zinc_host}/api/{index}/_doc" class MyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.float32): return int(obj) if isinstance(obj, np.ndarray): return obj.astype(int).tolist() return super(MyEncoder, self).default(obj) def download_image(url: str, max_size=32767) -> Image.Image: if url.endswith('.gif') or url.endswith('.GIF'): print(f'跳过GIF {url}') return None try: if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'): url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '') if os.path.exists(url): print(f'从本地读取图片 {url}') img = Image.open(url) else: print(f'从OSS下载图片 {url}') oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET']) bucket = oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']) img = Image.open(io.BytesIO(bucket.get_object(url).read())) else: print(f'从网络下载图片 {url}') response = requests.get(url) img = Image.open(io.BytesIO(response.content)) if img.mode != 'RGB': img = img.convert('RGB') if max(img.size) > max_size: print(f'跳过尺寸过大的图像 {url}') return None return img except Exception as e: print(f'图片从{url}下载失败,错误信息为:{e}') return None def connect_to_mysql(): return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor) # 中英日韩俄 EN = PaddleOCR(use_angle_cls=True, lang="en") CH = PaddleOCR(use_angle_cls=True, lang="ch") JP = PaddleOCR(use_angle_cls=True, lang="japan") KR = PaddleOCR(use_angle_cls=True, lang="korean") RU = PaddleOCR(use_angle_cls=True, lang="ru") # 运行OCR并清理内存 def process_ocr(model, image): result = model.ocr(image, cls=True)[0] or [] paddle.device.cuda.empty_cache() # 清理缓存 gc.collect() # 强制垃圾回收 return result def process_images(conn, offset=0) -> int: global EN, CH, JP, KR, RU with conn.cursor(pymysql.cursors.SSCursor) as cursor: cursor.execute("SELECT id, content FROM web_images WHERE text='' AND article_category_top_id=22 LIMIT 100 OFFSET %s", (offset,)) for id, content in cursor.fetchall(): image = download_image(content) if image is None: continue if isinstance(image, Image.Image): image = np.array(image) print(id, content) # 執行提取文字 #en = EN.ocr(image, cls=True)[0] or [] #ch = CH.ocr(image, cls=True)[0] or [] #jp = JP.ocr(image, cls=True)[0] or [] #kr = KR.ocr(image, cls=True)[0] or [] #ru = RU.ocr(image, cls=True)[0] or [] # 处理每个模型 ru = process_ocr(RU, image) en = process_ocr(EN, image) ch = process_ocr(CH, image) jp = process_ocr(JP, image) kr = process_ocr(KR, image) # 排除字符长度小于2的行, 排除纯数字的行, 排除置信度小于 0.8 的行 jp = [x for x in jp if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] kr = [x for x in kr if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] ch = [x for x in ch if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] en = [x for x in en if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] ru = [x for x in ru if len(x[1][0]) > 1 and not x[1][0].isdigit() and x[1][1] > 0.8] print(f'置信度大于 0.8 的行: jp {len(jp)} kr {len(kr)} ch {len(ch)} en {len(en)} ru {len(ru)}') # 去除字符串中包含的数字和标点(不作计数) jp_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in jp] kr_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in kr] ch_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ch] en_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in en] ru_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ru] # 计算置信度平均值 x 计算总字数 jpx = (np.mean([x[1][1] for x in jp_ex]) if jp_ex else 0) * len(''.join([x[1][0] for x in jp_ex])) krx = (np.mean([x[1][1] for x in kr_ex]) if kr_ex else 0) * len(''.join([x[1][0] for x in kr_ex])) chx = (np.mean([x[1][1] for x in ch_ex]) if ch_ex else 0) * len(''.join([x[1][0] for x in ch_ex])) enx = (np.mean([x[1][1] for x in en_ex]) if en_ex else 0) * len(''.join([x[1][0] for x in en_ex])) rux = (np.mean([x[1][1] for x in ru_ex]) if ru_ex else 0) * len(''.join([x[1][0] for x in ru_ex])) # 找出置信度最高的语言, 结构化存储 confidences = {'jp': jpx, 'kr': krx, 'ch': chx, 'en': enx, 'ru': rux} max_confidence_language = max(confidences, key=confidences.get) languages = {'en': en, 'ch': ch, 'jp': jp, 'kr': kr, 'ru': ru} data = [{'text': text[0], 'confidence': text[1], 'coordinate': coord} for coord, text in languages[max_confidence_language]] #print("data:", data) # 转换为字符串存储到索引库 obj = { "_id": str(id), "text": ' '.join([x['text'] for x in data]) } print("转换为字符串存储到索引库:", obj) res = requests.put(zinc_url, headers=headers, data=json.dumps(obj), proxies={'http': '', 'https': ''}) print("\033[1;32m{}\033[0m".format(id) if json.loads(res.text)['message'] == 'ok' else obj["id"], obj["text"]) # 转换为 JSON 存储到数据库 with conn.cursor() as c: data = json.dumps(data, ensure_ascii=False, cls=MyEncoder) c.execute("UPDATE web_images SET text = %s WHERE id = %s", (data, id)) conn.commit() paddle.device.cuda.empty_cache() # 清理缓存 gc.collect() # 强制垃圾回收 paddle.device.cuda.empty_cache() # 清理缓存 gc.collect() # 强制垃圾回收 return offset+100 def main(): conn = connect_to_mysql() offset = 2000 while True: print("LOOP:", offset) offset = process_images(conn, offset) time.sleep(0) if __name__ == "__main__": main()