package models import ( //"gocv.io/x/gocv" //"github.com/xuyu/gotool/torch" "fmt" "github.com/wangkuiyi/gotorch" ) func init() { // 模型文件地址: https://download.pytorch.org/models/resnet50-19c8e357.pth // 模型地址: "/home/satori/webp/data/resnet-50.t7" tensor := gotorch.Load("/home/satori/webp/data/resnet-50.t7") fmt.Println(tensor) //model := torch.NewModel() //err := model.ReadFromFile("/home/satori/webp/data/resnet-50.t7") //if err != nil { // panic(err) //} /** t7 := "/home/satori/webp/data/resnet-50.t7" // 加载t7格式的模型 model := gocv.ReadNetFromTorch(t7) if model.Empty() { panic("Failed to load model") } fmt.Println("==============================") img := gocv.IMRead("data/test.jpeg", gocv.IMReadColor) if img.Empty() { panic("Failed to read image") } fmt.Println("==============================") inputBlob := gocv.BlobFromImage(img, 1.0, image.Pt(224, 224), gocv.NewScalar(0, 0, 0, 0), true, false) defer inputBlob.Close() fmt.Println("==============================") model.SetInput(inputBlob, "input") outputBlob := model.Forward("output") defer outputBlob.Close() fmt.Println("==============================") features := outputBlob.Reshape(1, 1) defer features.Close() fmt.Println("==============================") fmt.Println(features.ToBytes()) **/ }