56 lines
1.3 KiB
Go
56 lines
1.3 KiB
Go
package models
|
|
|
|
import (
|
|
//"gocv.io/x/gocv"
|
|
//"github.com/xuyu/gotool/torch"
|
|
"fmt"
|
|
|
|
"github.com/wangkuiyi/gotorch"
|
|
)
|
|
|
|
func init() {
|
|
// 模型文件地址: https://download.pytorch.org/models/resnet50-19c8e357.pth
|
|
// 模型地址: "/home/satori/webp/data/resnet-50.t7"
|
|
|
|
tensor := gotorch.Load("/home/satori/webp/data/resnet-50.t7")
|
|
fmt.Println(tensor)
|
|
|
|
//model := torch.NewModel()
|
|
//err := model.ReadFromFile("/home/satori/webp/data/resnet-50.t7")
|
|
//if err != nil {
|
|
// panic(err)
|
|
//}
|
|
|
|
/**
|
|
t7 := "/home/satori/webp/data/resnet-50.t7"
|
|
// 加载t7格式的模型
|
|
model := gocv.ReadNetFromTorch(t7)
|
|
|
|
if model.Empty() {
|
|
panic("Failed to load model")
|
|
}
|
|
|
|
fmt.Println("==============================")
|
|
img := gocv.IMRead("data/test.jpeg", gocv.IMReadColor)
|
|
if img.Empty() {
|
|
panic("Failed to read image")
|
|
}
|
|
|
|
fmt.Println("==============================")
|
|
inputBlob := gocv.BlobFromImage(img, 1.0, image.Pt(224, 224), gocv.NewScalar(0, 0, 0, 0), true, false)
|
|
defer inputBlob.Close()
|
|
|
|
fmt.Println("==============================")
|
|
model.SetInput(inputBlob, "input")
|
|
outputBlob := model.Forward("output")
|
|
defer outputBlob.Close()
|
|
|
|
fmt.Println("==============================")
|
|
features := outputBlob.Reshape(1, 1)
|
|
defer features.Close()
|
|
|
|
fmt.Println("==============================")
|
|
fmt.Println(features.ToBytes())
|
|
**/
|
|
}
|