Files
webp/models/resnet.go
2023-04-08 05:21:24 +08:00

56 lines
1.3 KiB
Go

package models
import (
//"gocv.io/x/gocv"
//"github.com/xuyu/gotool/torch"
"fmt"
"github.com/wangkuiyi/gotorch"
)
func init() {
// 模型文件地址: https://download.pytorch.org/models/resnet50-19c8e357.pth
// 模型地址: "/home/satori/webp/data/resnet-50.t7"
tensor := gotorch.Load("/home/satori/webp/data/resnet-50.t7")
fmt.Println(tensor)
//model := torch.NewModel()
//err := model.ReadFromFile("/home/satori/webp/data/resnet-50.t7")
//if err != nil {
// panic(err)
//}
/**
t7 := "/home/satori/webp/data/resnet-50.t7"
// 加载t7格式的模型
model := gocv.ReadNetFromTorch(t7)
if model.Empty() {
panic("Failed to load model")
}
fmt.Println("==============================")
img := gocv.IMRead("data/test.jpeg", gocv.IMReadColor)
if img.Empty() {
panic("Failed to read image")
}
fmt.Println("==============================")
inputBlob := gocv.BlobFromImage(img, 1.0, image.Pt(224, 224), gocv.NewScalar(0, 0, 0, 0), true, false)
defer inputBlob.Close()
fmt.Println("==============================")
model.SetInput(inputBlob, "input")
outputBlob := model.Forward("output")
defer outputBlob.Close()
fmt.Println("==============================")
features := outputBlob.Reshape(1, 1)
defer features.Close()
fmt.Println("==============================")
fmt.Println(features.ToBytes())
**/
}