diff --git a/pp.py b/pp.py index 961cbbe..d52c46a 100755 --- a/pp.py +++ b/pp.py @@ -20,6 +20,7 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True config = dotenv_values(".env") oss2.defaults.connection_pool_size = 100 + class MyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.float32): @@ -28,11 +29,14 @@ class MyEncoder(json.JSONEncoder): return obj.astype(int).tolist() return super(MyEncoder, self).default(obj) -def download_image(url:str) -> Image.Image: + +def download_image(url: str) -> Image.Image: try: if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'): - url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '') - oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET']) + url = url.replace('http://image.gameuiux.cn/', + '').replace('https://image.gameuiux.cn/', '') + oss_auth = oss2.Auth( + config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET']) return Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']).get_object(url).read())) else: response = requests.get(url) @@ -41,21 +45,27 @@ def download_image(url:str) -> Image.Image: print(f'图片从{url}下载失败,错误信息为:{e}') return None + def connect_to_mysql(): return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor) -def save_text(conn, id:int, text:str): - with conn.cursor() as cursor: - cursor.execute("UPDATE web_images SET text = %s WHERE id = %s", (text, id)) +def save_text(conn, id: int, text: str): + with conn.cursor() as cursor: + cursor.execute( + "UPDATE web_images SET text = %s WHERE id = %s", (text, id)) + + +# 中英日韩俄 EN = PaddleOCR(use_angle_cls=True, lang="en") CH = PaddleOCR(use_angle_cls=True, lang="ch") JP = PaddleOCR(use_angle_cls=True, lang="japan") KR = PaddleOCR(use_angle_cls=True, lang="korean") +RU = PaddleOCR(use_angle_cls=True, lang="ru") def process_images(conn, es): with conn.cursor(pymysql.cursors.SSCursor) as cursor: - cursor.execute("SELECT id, content FROM web_images WHERE text!='' LIMIT 10") + cursor.execute("SELECT id, content FROM web_images LIMIT 0,10") # WHERE text!='' for id, content in cursor.fetchall(): image = download_image(content) if image is None: @@ -63,53 +73,106 @@ def process_images(conn, es): if isinstance(image, Image.Image): image = np.array(image) print('---------------------', id, content) - for line in EN.ocr(image, cls=True)[0]: - print('EN', line) - for line in CH.ocr(image, cls=True)[0]: - print('CH', line) - for line in JP.ocr(image, cls=True)[0]: - print('JP', line) - for line in KR.ocr(image, cls=True)[0]: - print('KR', line) - #print(EN.ocr(image, cls=True)) - #print(CH.ocr(image, cls=True)) - #print(JP.ocr(image, cls=True)) - #print(KR.ocr(image, cls=True)) - # item = [x for x in ocr.ocr(image) if x['text'] and not x['text'].isdigit() and len(x['text']) > 1] - # text = ' '.join([x['text'] for x in item]) - # print(id, text) - # save_text(conn, id, json.dumps(item, ensure_ascii=False, cls=MyEncoder)) - # es.index(index='web_images', id=id, body={'content': text}) - #conn.commit() + en = EN.ocr(image, cls=True)[0] + ch = CH.ocr(image, cls=True)[0] + jp = JP.ocr(image, cls=True)[0] + kr = KR.ocr(image, cls=True)[0] + ru = RU.ocr(image, cls=True)[0] + + # 排除字符长度小于2的行 + jp = [x for x in jp if len(x[1][0]) > 1] + kr = [x for x in kr if len(x[1][0]) > 1] + ch = [x for x in ch if len(x[1][0]) > 1] + en = [x for x in en if len(x[1][0]) > 1] + ru = [x for x in ru if len(x[1][0]) > 1] + + # 排除纯数字的行 + jp = [x for x in jp if not x[1][0].isdigit()] + kr = [x for x in kr if not x[1][0].isdigit()] + ch = [x for x in ch if not x[1][0].isdigit()] + en = [x for x in en if not x[1][0].isdigit()] + ru = [x for x in ru if not x[1][0].isdigit()] + + # 排除置信度小于 0.8 的行 + jp = [x for x in jp if x[1][1] > 0.8] + kr = [x for x in kr if x[1][1] > 0.8] + ch = [x for x in ch if x[1][1] > 0.8] + en = [x for x in en if x[1][1] > 0.8] + ru = [x for x in ru if x[1][1] > 0.8] + print(f'置信度大于 0.8 的行: jp {len(jp)} kr {len(kr)} ch {len(ch)} en {len(en)} ru {len(ru)}') + + + # 去除字符串中包含的数字和标点(不作计数) + jp_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in jp] + kr_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in kr] + ch_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ch] + en_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in en] + ru_ex = [[x[0], (x[1][0].translate(str.maketrans('', '', '0123456789.,,。!?:;“”‘’\'\"')), x[1][1])] for x in ru] + + # 计算置信度平均数 + jpx = np.mean([x[1][1] for x in jp_ex]) if len(jp_ex) > 0 else 0 + krx = np.mean([x[1][1] for x in kr_ex]) if len(kr_ex) > 0 else 0 + chx = np.mean([x[1][1] for x in ch_ex]) if len(ch_ex) > 0 else 0 + enx = np.mean([x[1][1] for x in en_ex]) if len(en_ex) > 0 else 0 + rux = np.mean([x[1][1] for x in ru_ex]) if len(ru_ex) > 0 else 0 + + # 计算总字数 + jpt = len(''.join([x[1][0] for x in jp_ex])) + krt = len(''.join([x[1][0] for x in kr_ex])) + cht = len(''.join([x[1][0] for x in ch_ex])) + ent = len(''.join([x[1][0] for x in en_ex])) + rut = len(''.join([x[1][0] for x in ru_ex])) + + # 计算总字数 x 置信度平均数 + jpx = jpx * jpt + krx = krx * krt + chx = chx * cht + enx = enx * ent + rux = rux * rut + + print('jp', jpx) + print('kr', krx) + print('ch', chx) + print('en', enx) + print('ru', rux) + + # 创建一个新的字典,其中键是浮点数(置信度),值是语言 + confidence_dict = {jpx: 'jp', krx: 'kr', chx: 'ch', enx: 'en', rux: 'ru'} + + # 找出置信度最高的语言 + max = np.max([jpx, krx, chx, enx, rux]) + max_confidence_language = confidence_dict[max] + + # 结构化存储 + data = [] + + # 使用置信度最高的语言作为键来访问字典 + all = {'en': en, 'ch': ch, 'jp': jp, 'kr': kr, 'ru': ru} + for 坐标, 文本 in all[max_confidence_language]: + print(max_confidence_language, 坐标, 文本) + data.append({'text': 文本[0], 'confidence': 文本[1], 'coordinate': 坐标 }) + + # 转换为字符串存储到索引库 + text = ' '.join([x['text'] for x in data]) + + # 转换为 JSON 存储到数据库 + data = json.dumps(data, ensure_ascii=False, cls=MyEncoder) + + print(id, text) + + save_text(conn, id, data) + es.index(index='web_images', id=id, body={'content': text}) + conn.commit() + def main(): - es = Elasticsearch(config['ELASTICSEARCH_HOST'], basic_auth=(config['ELASTICSEARCH_USERNAME'], config['ELASTICSEARCH_PASSWORD']), verify_certs=False) + es = Elasticsearch(config['ELASTICSEARCH_HOST'], basic_auth=( + config['ELASTICSEARCH_USERNAME'], config['ELASTICSEARCH_PASSWORD']), verify_certs=False) if not es.indices.exists(index='web_images'): es.indices.create(index='web_images') conn = connect_to_mysql() process_images(conn, es) + if __name__ == "__main__": main() - - -# Paddleocr目前支持的多语言语种可以通过修改lang参数进行切换 -# 例如`ch`, `en`, `fr`, `german`, `korean`, `japan` -#ocr = PaddleOCR(use_angle_cls=True, lang="ch") -#img_path = './imgs/14.jpg' -#result = ocr.ocr(img_path, cls=True) -#for idx in range(len(result)): -# res = result[idx] -# for line in res: -# print(line) - -# 显示结果 -#from PIL import Image -#result = result[0] -#image = Image.open(img_path).convert('RGB') -#boxes = [line[0] for line in result] -#txts = [line[1][0] for line in result] -#scores = [line[1][1] for line in result] -#im_show = draw_ocr(image, boxes, txts, scores, font_path='./fonts/simfang.ttf') -#im_show = Image.fromarray(im_show) -#im_show.save('result.jpg') diff --git a/requirements.txt b/requirements.txt index 3b90f65..3e4b7d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +whell==0.42.0 cnocr==2.2.4.2 elasticsearch==8.11.0 numpy==1.26.2