Files
2024-11-11 18:01:52 +08:00

51 lines
1.1 KiB
Python
Executable File

import timm
import torch
import torch.nn.functional as functional
from PIL import Image
from torchvision import transforms
# 加载预训练模型
model = timm.create_model('resnet50', pretrained=True)
model = model.eval()
# 定义图片处理流程
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 读取图片
print("loading image")
img = Image.open("demo.png").convert("RGB")
tensor = preprocess(img).unsqueeze(0)
# 检查是否有可用的GPU
if torch.cuda.is_available():
input_batch = tensor.to('cuda')
model.to('cuda')
print("start run model")
with torch.no_grad():
output = model(input_batch)
for x in output:
print(x.shape)
'''
from towhee import pipe, ops, DataCollection
p = (
pipe.input('path')
.map('path', 'img', ops.image_decode())
.map('img', 'vec', ops.image_embedding.timm(model_name='resnet50'))
.output('img', 'vec')
)
ea = DataCollection(p('demo.png')).to_list()
print(ea)
'''