54 lines
1.1 KiB
Python
Executable File
54 lines
1.1 KiB
Python
Executable File
import timm
|
|
import torch
|
|
import torch.nn.functional as functional
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
|
|
# 加载预训练模型
|
|
model = timm.create_model('resnet50', pretrained=True)
|
|
model = model.eval()
|
|
|
|
# 定义图片处理流程
|
|
preprocess = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
# 读取图片
|
|
print("loading image")
|
|
img = Image.open("demo.png").convert("RGB")
|
|
tensor = preprocess(img).unsqueeze(0)
|
|
|
|
# 检查是否有可用的GPU
|
|
if torch.cuda.is_available():
|
|
input_batch = tensor.to('cuda')
|
|
model.to('cuda')
|
|
|
|
print("start run model")
|
|
with torch.no_grad():
|
|
output = model(input_batch)
|
|
|
|
for x in output:
|
|
print(x.shape)
|
|
|
|
# 输出2048维向量
|
|
# print(output[0])
|
|
|
|
'''
|
|
from towhee import pipe, ops, DataCollection
|
|
|
|
p = (
|
|
pipe.input('path')
|
|
.map('path', 'img', ops.image_decode())
|
|
.map('img', 'vec', ops.image_embedding.timm(model_name='resnet50'))
|
|
.output('img', 'vec')
|
|
)
|
|
|
|
ea = DataCollection(p('demo.png')).to_list()
|
|
print(ea)
|
|
|
|
'''
|
|
|