标准
This commit is contained in:
		
							
								
								
									
										154
									
								
								demo.ipynb
									
									
									
									
									
								
							
							
						
						
									
										154
									
								
								demo.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										21
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								demo.py
									
									
									
									
									
								
							@@ -1,21 +0,0 @@
 | 
			
		||||
import pymysql
 | 
			
		||||
import pymysql.cursors
 | 
			
		||||
from dotenv import dotenv_values
 | 
			
		||||
from pprint import pprint
 | 
			
		||||
 | 
			
		||||
config = dotenv_values(".env")
 | 
			
		||||
conn = pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.DictCursor)
 | 
			
		||||
cursor = conn.cursor()
 | 
			
		||||
cursor.execute("SELECT * FROM web_images WHERE id=1436682 LIMIT 5")
 | 
			
		||||
 | 
			
		||||
# 获取查询结果
 | 
			
		||||
rows = cursor.fetchall()
 | 
			
		||||
for row in rows:
 | 
			
		||||
    # 格式化打印
 | 
			
		||||
    pprint(row)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 关闭游标和连接
 | 
			
		||||
cursor.close()
 | 
			
		||||
conn.close()
 | 
			
		||||
							
								
								
									
										124
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										124
									
								
								main.py
									
									
									
									
									
								
							@@ -3,65 +3,17 @@
 | 
			
		||||
import io
 | 
			
		||||
import requests
 | 
			
		||||
import oss2
 | 
			
		||||
import plyvel
 | 
			
		||||
 | 
			
		||||
from PIL import Image, ImageFile
 | 
			
		||||
from pprint import pprint
 | 
			
		||||
from dotenv import dotenv_values
 | 
			
		||||
 | 
			
		||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
 | 
			
		||||
 | 
			
		||||
config = dotenv_values(".env")
 | 
			
		||||
db = plyvel.DB('database', create_if_missing=True)
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
# 写入一个键值对
 | 
			
		||||
db.put(b'key', b'value')
 | 
			
		||||
# 获取一个键的值
 | 
			
		||||
value = db.get(b'key')
 | 
			
		||||
# 删除一个键值对
 | 
			
		||||
db.delete(b'key')
 | 
			
		||||
# 批量写入
 | 
			
		||||
with db.write_batch() as wb:
 | 
			
		||||
    for i in range(10000):
 | 
			
		||||
        wb.put(b'key' + str(i).encode(), b'value' + str(i).encode())
 | 
			
		||||
# 迭代数据库中的所有键值对
 | 
			
		||||
for key, value in db:
 | 
			
		||||
    print(key, value)
 | 
			
		||||
# 关闭数据库
 | 
			
		||||
db.close()
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
# 下载图片(使用OSS下载)
 | 
			
		||||
def download_image(url:str) -> Image.Image:
 | 
			
		||||
    if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
 | 
			
		||||
        try:
 | 
			
		||||
            url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
 | 
			
		||||
            oss2.defaults.connection_pool_size = 100
 | 
			
		||||
            oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
 | 
			
		||||
            return Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']).get_object(url).read()))
 | 
			
		||||
        except Exception:
 | 
			
		||||
            print('图片从OSS下载失败:', url)
 | 
			
		||||
            return None
 | 
			
		||||
    else:
 | 
			
		||||
        try:
 | 
			
		||||
            response = requests.get(url)
 | 
			
		||||
            return Image.open(io.BytesIO(response.content))
 | 
			
		||||
        except Exception:
 | 
			
		||||
            print('图片从URL下载失败:', url)
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
import pymysql
 | 
			
		||||
import pymysql.cursors
 | 
			
		||||
import cnocr
 | 
			
		||||
import json
 | 
			
		||||
import numpy as np
 | 
			
		||||
from PIL import Image, ImageFile
 | 
			
		||||
from dotenv import dotenv_values
 | 
			
		||||
from elasticsearch import Elasticsearch
 | 
			
		||||
 | 
			
		||||
# 打开 mysql
 | 
			
		||||
ocr = cnocr.CnOcr(rec_model_name='ch_PP-OCRv3')
 | 
			
		||||
conn = pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.DictCursor)
 | 
			
		||||
cursor = conn.cursor()
 | 
			
		||||
cursor.execute("SELECT id, content FROM web_images LIMIT 5")
 | 
			
		||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
 | 
			
		||||
config = dotenv_values(".env")
 | 
			
		||||
oss2.defaults.connection_pool_size = 100
 | 
			
		||||
 | 
			
		||||
class MyEncoder(json.JSONEncoder):
 | 
			
		||||
    def default(self, obj):
 | 
			
		||||
@@ -71,35 +23,45 @@ class MyEncoder(json.JSONEncoder):
 | 
			
		||||
            return obj.astype(int).tolist()
 | 
			
		||||
        return super(MyEncoder, self).default(obj)
 | 
			
		||||
 | 
			
		||||
dataset = []
 | 
			
		||||
def download_image(url:str) -> Image.Image:
 | 
			
		||||
    try:
 | 
			
		||||
        if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
 | 
			
		||||
            url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
 | 
			
		||||
            oss_auth = oss2.Auth(config['OSS_ACCESS_KEY_ID'], config['OSS_ACCESS_KEY_SECRET'])
 | 
			
		||||
            return Image.open(io.BytesIO(oss2.Bucket(oss_auth, f'http://{config["OSS_HOST"]}', config['OSS_BUCKET_NAME']).get_object(url).read()))
 | 
			
		||||
        else:
 | 
			
		||||
            response = requests.get(url)
 | 
			
		||||
            return Image.open(io.BytesIO(response.content))
 | 
			
		||||
    except Exception:
 | 
			
		||||
        print(f'图片从{url}下载失败')
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
# 获取查询结果(跳过下载失败的)
 | 
			
		||||
for item in cursor.fetchall():
 | 
			
		||||
    image = download_image(item['content'])
 | 
			
		||||
    if image is None:
 | 
			
		||||
        continue
 | 
			
		||||
    # 将只包含那些非空非纯数字且长度大于1的'text'值
 | 
			
		||||
    out = ocr.ocr(image)
 | 
			
		||||
    out = [x for x in out if x['text'] and not x['text'].isdigit() and len(x['text']) > 1]
 | 
			
		||||
    #print(item['id'], json.dumps(out, ensure_ascii=False, cls=MyEncoder))
 | 
			
		||||
    dataset.append({
 | 
			
		||||
        'id': item['id'],
 | 
			
		||||
        'content': json.dumps(out, ensure_ascii=False, cls=MyEncoder)
 | 
			
		||||
    })
 | 
			
		||||
    
 | 
			
		||||
    #texts = [x['text'] for x in out if x['text'] and not x['text'].isdigit() and len(x['text']) > 1]
 | 
			
		||||
    #print(item['id'], texts)
 | 
			
		||||
def connect_to_mysql():
 | 
			
		||||
    return pymysql.connect(host=config['MYSQL_HOST'], user=config['MYSQL_USER'], password=config['MYSQL_PASSWORD'], database=config['MYSQL_NAME'], cursorclass=pymysql.cursors.SSDictCursor)
 | 
			
		||||
 | 
			
		||||
    # 将结果存入 leveldb
 | 
			
		||||
    # db.put(str(row['id']).encode(), ','.join(texts).encode())
 | 
			
		||||
def save_text(conn, id:int, text:str):
 | 
			
		||||
    with conn.cursor() as cursor:
 | 
			
		||||
        cursor.execute("UPDATE web_images SET text = %s WHERE id = %s", (text, id))
 | 
			
		||||
 | 
			
		||||
print(dataset)
 | 
			
		||||
def process_images(conn, ocr, es):
 | 
			
		||||
    with conn.cursor(pymysql.cursors.SSCursor) as cursor:
 | 
			
		||||
        cursor.execute("SELECT id, content, text FROM web_images WHERE text!='' LIMIT 10")
 | 
			
		||||
        for id, content, text in cursor.fetchall():
 | 
			
		||||
            image = download_image(content)
 | 
			
		||||
            if image is None:
 | 
			
		||||
                continue
 | 
			
		||||
            item = [x for x in ocr.ocr(image) if x['text'] and not x['text'].isdigit() and len(x['text']) > 1]
 | 
			
		||||
            save_text(conn, id, json.dumps(item, ensure_ascii=False, cls=MyEncoder))
 | 
			
		||||
            texts = ' '.join([x['text'] for x in item])
 | 
			
		||||
            es.index(index='web_images', id=id, body={'content': texts})
 | 
			
		||||
 | 
			
		||||
# 关闭游标和连接
 | 
			
		||||
cursor.close()
 | 
			
		||||
conn.close()
 | 
			
		||||
def main():
 | 
			
		||||
    es = Elasticsearch(config['ELASTICSEARCH_HOST'], basic_auth=(config['ELASTICSEARCH_USERNAME'], config['ELASTICSEARCH_PASSWORD']), verify_certs=False)
 | 
			
		||||
    if not es.indices.exists(index='web_images'):
 | 
			
		||||
        es.indices.create(index='web_images')
 | 
			
		||||
    ocr = cnocr.CnOcr(rec_model_name='ch_PP-OCRv3')
 | 
			
		||||
    conn = connect_to_mysql()
 | 
			
		||||
    process_images(conn, ocr, es)
 | 
			
		||||
 | 
			
		||||
# 关闭数据库
 | 
			
		||||
db.close()
 | 
			
		||||
 | 
			
		||||
print('Done')
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user