This commit is contained in:
2023-12-05 03:10:46 +08:00
parent 3786304926
commit f3a5d44c57
4 changed files with 12 additions and 76 deletions

View File

@@ -1,3 +1,9 @@
# ocr # OCR
基于深度学习的文字识别提取标记
- 由于当前没有较优的语言分类识别方案, 使用四倍算力换精度
- 当前支持 英文 中文 日文 韩文 俄文 的识别
- 去除纯数字和单字符以及置信度低于80的文字
- 数据转json存储于mysql web_images 每张图像对应的 text 字段
- 文字以空格分隔合并为字符串加入 Elasticsearch 索引
基于深度学习的文字识别提取标记

71
main.py
View File

@@ -1,71 +0,0 @@
#!/usr/bin/env python3
import io
import requests
import oss2
import pymysql
import cnocr
import json
import numpy as np
import warnings
from PIL import Image, ImageFile
from dotenv import dotenv_values
from elasticsearch import Elasticsearch
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
config = dotenv_values(".env")
oss2.defaults.connection_pool_size = 100
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.float32):
return int(obj)
if isinstance(obj, np.ndarray):
return obj.astype(int).tolist()
return super(MyEncoder, self).default(obj)
def download_image(url:str) -> Image.Image:
try:
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
return Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']).get_object(url).read()))
else:
response = requests.get(url)
return Image.open(io.BytesIO(response.content))
except Exception:
print(f'图片从{url}下载失败')
return None
def connect_to_mysql():
return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
def save_text(conn, id:int, text:str):
with conn.cursor() as cursor:
cursor.execute("UPDATE web_images SET text = %s WHERE id = %s", (text, id))
def process_images(conn, ocr, es):
with conn.cursor(pymysql.cursors.SSCursor) as cursor:
cursor.execute("SELECT id, content FROM web_images WHERE text='' LIMIT 10000")
for id, content in cursor.fetchall():
image = download_image(content)
if image is None:
continue
item = [x for x in ocr.ocr(image) if x['text'] and not x['text'].isdigit() and len(x['text']) > 1]
text = ' '.join([x['text'] for x in item])
print(id, text)
#save_text(conn, id, json.dumps(item, ensure_ascii=False, cls=MyEncoder))
#es.index(index='web_images', id=id, body={'content': text})
#conn.commit()
def main():
es = Elasticsearch(config['ELASTICSEARCH_HOST'], basic_auth=(config['ELASTICSEARCH_USERNAME'], config['ELASTICSEARCH_PASSWORD']), verify_certs=False)
if not es.indices.exists(index='web_images'):
es.indices.create(index='web_images')
ocr = cnocr.CnOcr(rec_model_name='ch_PP-OCRv3')
conn = connect_to_mysql()
process_images(conn, ocr, es)
if __name__ == "__main__":
main()

5
pp.py
View File

@@ -30,7 +30,7 @@ class MyEncoder(json.JSONEncoder):
return super(MyEncoder, self).default(obj) return super(MyEncoder, self).default(obj)
def download_image(url: str) -> Image.Image: def download_image(url: str, max_size=32767) -> Image.Image:
if url.endswith('.gif') or url.endswith('.GIF'): if url.endswith('.gif') or url.endswith('.GIF'):
print(f'跳过GIF {url}') print(f'跳过GIF {url}')
return None return None
@@ -44,6 +44,9 @@ def download_image(url: str) -> Image.Image:
img = Image.open(io.BytesIO(response.content)) img = Image.open(io.BytesIO(response.content))
if img.mode != 'RGB': if img.mode != 'RGB':
img = img.convert('RGB') img = img.convert('RGB')
if img.size[0] > max_size or img.size[1] > max_size:
print(f'跳过尺寸过大的图像 {url}')
return None
return img return img
except Exception as e: except Exception as e:
print(f'图片从{url}下载失败,错误信息为:{e}') print(f'图片从{url}下载失败,错误信息为:{e}')

View File

@@ -1,12 +1,10 @@
whell==0.42.0 whell==0.42.0
cnocr==2.2.4.2
elasticsearch==8.11.0 elasticsearch==8.11.0
numpy==1.26.2 numpy==1.26.2
oss2==2.18.3 oss2==2.18.3
paddleocr==2.7.0.3 paddleocr==2.7.0.3
paddleocr.egg==info paddleocr.egg==info
Pillow==10.1.0 Pillow==10.1.0
Pillow==10.1.0
PyMySQL==1.1.0 PyMySQL==1.1.0
python-dotenv==1.0.0 python-dotenv==1.0.0
Requests==2.31.0 Requests==2.31.0