转移
This commit is contained in:
53
demo.py
Executable file
53
demo.py
Executable file
@@ -0,0 +1,53 @@
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn.functional as functional
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
# 加载预训练模型
|
||||
model = timm.create_model('resnet50', pretrained=True)
|
||||
model = model.eval()
|
||||
|
||||
# 定义图片处理流程
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
# 读取图片
|
||||
print("loading image")
|
||||
img = Image.open("demo.png").convert("RGB")
|
||||
tensor = preprocess(img).unsqueeze(0)
|
||||
|
||||
# 检查是否有可用的GPU
|
||||
if torch.cuda.is_available():
|
||||
input_batch = tensor.to('cuda')
|
||||
model.to('cuda')
|
||||
|
||||
print("start run model")
|
||||
with torch.no_grad():
|
||||
output = model(input_batch)
|
||||
|
||||
for x in output:
|
||||
print(x.shape)
|
||||
|
||||
# 输出2048维向量
|
||||
# print(output[0])
|
||||
|
||||
'''
|
||||
from towhee import pipe, ops, DataCollection
|
||||
|
||||
p = (
|
||||
pipe.input('path')
|
||||
.map('path', 'img', ops.image_decode())
|
||||
.map('img', 'vec', ops.image_embedding.timm(model_name='resnet50'))
|
||||
.output('img', 'vec')
|
||||
)
|
||||
|
||||
ea = DataCollection(p('demo.png')).to_list()
|
||||
print(ea)
|
||||
|
||||
'''
|
||||
|
Reference in New Issue
Block a user