diff --git a/go.mod b/go.mod index a6ecfb8..201bf4c 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,20 @@ require ( ) require ( + github.com/PuerkitoBio/goquery v1.6.1 // indirect + github.com/andybalholm/cascadia v1.2.0 // indirect + github.com/gobwas/glob v0.2.3 // indirect github.com/mattn/go-sqlite3 v1.14.16 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect + github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca // indirect github.com/sizeofint/webp-animation v0.0.0-20190207194838-b631dc900de9 // indirect + github.com/tidwall/gjson v1.8.0 // indirect + github.com/tidwall/match v1.0.3 // indirect + github.com/tidwall/pretty v1.2.0 // indirect golang.org/x/image v0.7.0 // indirect + golang.org/x/net v0.6.0 // indirect + golang.org/x/text v0.9.0 // indirect + gopkg.in/xmlpath.v2 v2.0.0-20150820204837-860cbeca3ebc // indirect ) require ( @@ -24,6 +35,7 @@ require ( github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.669 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm v1.0.669 + github.com/zhshch2002/goreq v0.0.0-20210608055943-7028cfd48a0d gopkg.in/yaml.v2 v2.4.0 gorm.io/driver/sqlite v1.5.0 ) diff --git a/go.sum b/go.sum index 7e5cbc1..38fb0b0 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,24 @@ +github.com/PuerkitoBio/goquery v1.6.1 h1:FgjbQZKl5HTmcn4sKBgvx8vv63nhyhIpv7lJpFGCWpk= +github.com/PuerkitoBio/goquery v1.6.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= +github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= +github.com/andybalholm/cascadia v1.2.0 h1:vuRCkM5Ozh/BfmsaTm26kbjm0mIOM3yS5Ek/F5h18aE= +github.com/andybalholm/cascadia v1.2.0/go.mod h1:YCyR8vOZT9aZ1CHEd8ap0gMVm2aFgxBp0T0eFw1RUQY= github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk= github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c= github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= +github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= +github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= @@ -12,20 +29,45 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww= github.com/russross/blackfriday v1.6.0/go.mod h1:ti0ldHuxg49ri4ksnFxlkCfN+hvslNlmVHqNRXXJNAY= +github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca h1:NugYot0LIVPxTvN8n+Kvkn6TrbMyxQiuvKdEwFdR9vI= +github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca/go.mod h1:uugorj2VCxiV1x+LzaIdVa9b4S4qGAcH6cbhh4qVxOU= github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591 h1:dCWBD4Xchp/XFIR/x6D2l74DtQHvIpHsmpPRHgH9oUo= github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591/go.mod h1:IXC7KN2FEuTEISdePm37qcFyXInAh6pfW35yDjbdfOM= github.com/sizeofint/webp-animation v0.0.0-20190207194838-b631dc900de9 h1:i3LYMwQ0zkh/BJ47vIZN+jBYqV4/f6DFoAsW8rwV490= github.com/sizeofint/webp-animation v0.0.0-20190207194838-b631dc900de9/go.mod h1:/NQ8ciRuH+vxYhrFlnX70gvXBugMYQbBygCRocFgSZ4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.669 h1:5KKJBcemqKONBFxMdMyLMvk+TrqXaEPhqe9TrZqB3r0= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.669/go.mod h1:7sCQWVkxcsR38nffDW057DRGk8mUjK1Ing/EFOK8s8Y= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm v1.0.669 h1:gc1bPO/YVfuXEIs+HbQ/gFlFjdkJjOsjm8xWqF7hPww= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm v1.0.669/go.mod h1:hhy13j6NKKxt/g62JZEDekJNQx3EAevnHopmwlt2tRc= +github.com/tidwall/gjson v1.8.0 h1:Qt+orfosKn0rbNTZqHYDqBrmm3UDA4KRkv70fDzG+PQ= +github.com/tidwall/gjson v1.8.0/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk= +github.com/tidwall/match v1.0.3 h1:FQUVvBImDutD8wJLN6c5eMzWtjgONK9MwIBCOrUJKeE= +github.com/tidwall/match v1.0.3/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.1.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zhshch2002/goreq v0.0.0-20210608055943-7028cfd48a0d h1:a7RuxYLiIzfCqYaISNnuUPahCFIPPT1ERWxJxbnFJeA= +github.com/zhshch2002/goreq v0.0.0-20210608055943-7028cfd48a0d/go.mod h1:f+jNcJUd3buNPA42ai935kaWFai/hxOMkzvgMfbtHhs= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -33,15 +75,21 @@ golang.org/x/image v0.7.0 h1:gzS29xtG1J5ybQlv0PuyfE3nmc6R4qB73m6LUUmvFuw= golang.org/x/image v0.7.0/go.mod h1:nd/q4ef1AKKYl/4kft7g+6UyGbdiqWqTP1ZAbRoV7Rg= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0 h1:L4ZwwTvKW9gr0ZMS1yrHD9GZhIuVjOBBnaKH+SPQK0Q= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -50,9 +98,12 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -61,8 +112,13 @@ golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/xmlpath.v2 v2.0.0-20150820204837-860cbeca3ebc h1:LMEBgNcZUqXaP7evD1PZcL6EcDVa2QOFuI+cqM3+AJM= +gopkg.in/xmlpath.v2 v2.0.0-20150820204837-860cbeca3ebc/go.mod h1:N8UOSI6/c2yOpa/XDz3KVUiegocTziPiqNkeNTMiG1k= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/sqlite v1.5.0 h1:zKYbzRCpBrT1bNijRnxLDJWPjVfImGEn0lSnUY5gZ+c= gorm.io/driver/sqlite v1.5.0/go.mod h1:kDMDfntV9u/vuMmz8APHtHF0b4nyBB7sfCieC6G8k8I= gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= diff --git a/models/Model.go b/models/Model.go index 51ba085..eca8001 100644 --- a/models/Model.go +++ b/models/Model.go @@ -1,15 +1,23 @@ package models import ( + "bytes" "crypto/md5" "fmt" "io/ioutil" + "log" "main/configs" "net/http" "net/url" "os" "path/filepath" "time" + + "encoding/base64" + "image/png" + + "github.com/chai2010/webp" + "github.com/zhshch2002/goreq" ) type Model struct { @@ -35,8 +43,184 @@ type Model struct { UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } +// 创建一个带缓冲的通道,缓冲区大小为 10 +// var ch = make(chan int, 10) + func init() { configs.ORMDB().AutoMigrate(&Model{}) + + // 处理推理任务 + //go func() { + // for { + // // 从通道中取出一个数据 + // model := <-ch + // // 模型状态变化时, 向监听此模型的所有连接发送消息 + // } + //}() +} + +func (model *Model) Inference(image_list []Image, callback func()) { + log.Println(image_list) + callback() + + // 模型未部署到推理機 + if model.ServerID == "" { + log.Println("模型未部署到推理機, 开始部署模型") + + var server Server + if err := configs.ORMDB().Where("models LIKE ?", "%"+model.Name+"%").Take(&server).Error; err != nil { + log.Println(err) + // 如果没有则寻找空闲服务器 + // 如果没有空闲则创建新服务器 + // 取一台空闲的推理机上传并切换到此模型 + // 新建一台推理机上传并切换到此模型 + } + + // 执行生成任务 + if model.Image == "" { + var data = struct { + EnableHr bool `json:"enable_hr"` + DenoisingStrength int `json:"denoising_strength"` + FirstphaseWidth int `json:"firstphase_width"` + FirstphaseHeight int `json:"firstphase_height"` + HrScale int `json:"hr_scale"` + HrUpscaler string `json:"hr_upscaler"` + HrSecondPassSteps int `json:"hr_second_pass_steps"` + HrResizeX int `json:"hr_resize_x"` + HrResizeY int `json:"hr_resize_y"` + HrSamplerName string `json:"hr_sampler_name"` + HrPrompt string `json:"hr_prompt"` + HrNegativePrompt string `json:"hr_negative_prompt"` + Prompt string `json:"prompt"` + Styles []string `json:"styles"` + Seed int `json:"seed"` + Subseed int `json:"subseed"` + SubseedStrength int `json:"subseed_strength"` + SeedResizeFromH int `json:"seed_resize_from_h"` + SeedResizeFromW int `json:"seed_resize_from_w"` + SamplerName string `json:"sampler_name"` + BatchSize int `json:"batch_size"` + NIter int `json:"n_iter"` + Steps int `json:"steps"` + CfgScale int `json:"cfg_scale"` + Width int `json:"width"` + Height int `json:"height"` + RestoreFaces bool `json:"restore_faces"` + Tiling bool `json:"tiling"` + DoNotSaveSamples bool `json:"do_not_save_samples"` + DoNotSaveGrid bool `json:"do_not_save_grid"` + NegativePrompt string `json:"negative_prompt"` + Eta int `json:"eta"` + SMinUncond int `json:"s_min_uncond"` + SChurn int `json:"s_churn"` + STmax int `json:"s_tmax"` + STmin int `json:"s_tmin"` + SNoise int `json:"s_noise"` + OverrideSettings map[string]string `json:"override_settings"` + OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"` + ScriptArgs []interface{} `json:"script_args"` + SamplerIndex string `json:"sampler_index"` + ScriptName string `json:"script_name"` + SendImages bool `json:"send_images"` + SaveImages bool `json:"save_images"` + AlwaysonScripts map[string]string `json:"alwayson_scripts"` + }{ + EnableHr: false, + DenoisingStrength: 0, + FirstphaseWidth: 0, + FirstphaseHeight: 0, + HrScale: 2, + HrUpscaler: "nearest", + HrSecondPassSteps: 0, + HrResizeX: 0, + HrResizeY: 0, + HrSamplerName: "", + HrPrompt: "", + HrNegativePrompt: "", + Prompt: "miao~", + Styles: []string{}, + Seed: -1, + Subseed: -1, + SubseedStrength: 0, + SeedResizeFromH: -1, + SeedResizeFromW: -1, + SamplerName: "beamsearch", + BatchSize: 1, + NIter: 1, + Steps: 50, + CfgScale: 7, + Width: 512, + Height: 512, + RestoreFaces: false, + Tiling: false, + DoNotSaveSamples: false, + DoNotSaveGrid: false, + NegativePrompt: "", + Eta: 0, + SMinUncond: 0, + SChurn: 0, + STmax: 0, + STmin: 0, + SNoise: 1, + OverrideSettings: map[string]string{}, + OverrideSettingsRestoreAfterwards: false, + ScriptArgs: []interface{}{}, + SamplerIndex: "Euler", + ScriptName: "generate", + SendImages: true, + SaveImages: false, + AlwaysonScripts: map[string]string{}, + } + // 接收到的图片列表 + var rest = struct { + Images []string `json:"images"` + }{ + Images: []string{}, + } + var url = fmt.Sprintf("http://%s:%d/sdapi/v1/txt2img", server.IP, server.Port) + if err := goreq.Post(url).SetJsonBody(data).Do().BindJSON(&rest); err != nil { + log.Println("API 查询失败:", err) + } + log.Println("API 查询成功:", rest) + for _, img := range rest.Images { + log.Println("保存图片:", img) + // 将base64编码的图片保存到本地webp + if err := SaveBase64Image(img, "data/images/"+img+".webp"); err != nil { + log.Println(err) + } + } + } + + } +} + +// 将base64编码的图片保存到本地webp +func SaveBase64Image(base64Str string, filename string) error { + // 解码base64图片 + data, err := base64.StdEncoding.DecodeString(base64Str) + if err != nil { + return err + } + + // 将png图片解码为image.Image + img, err := png.Decode(bytes.NewReader(data)) + if err != nil { + return err + } + + // 创建webp编码器 + webpWriter, err := os.Create(filename) + if err != nil { + return err + } + defer webpWriter.Close() + + // 将image.Image编码为webp格式并保存到本地 + if err := webp.Encode(webpWriter, img, &webp.Options{Lossless: true}); err != nil { + return err + } + + return nil } func (model *Model) Train() (err error) { diff --git a/models/WebSocketMnager.go b/models/WebSocketMnager.go index 2152047..d574570 100644 --- a/models/WebSocketMnager.go +++ b/models/WebSocketMnager.go @@ -3,59 +3,44 @@ package models import ( "sync" - "github.com/google/uuid" "github.com/gorilla/websocket" ) type WebSocketManager struct { - connections map[string]*websocket.Conn - listeners map[string]map[chan struct{}]struct{} + connections map[*websocket.Conn]string // 连接指针:任务ID mutex sync.RWMutex } +// 创建一个新的连接池 func NewWebSocketManager() *WebSocketManager { return &WebSocketManager{ - connections: make(map[string]*websocket.Conn), + connections: make(map[*websocket.Conn]string), mutex: sync.RWMutex{}, } } -func (mgr *WebSocketManager) AddConnection(conn *websocket.Conn) string { +// 向连接池加入一个新连接 +func (mgr *WebSocketManager) AddConnection(conn *websocket.Conn, task string) { mgr.mutex.Lock() defer mgr.mutex.Unlock() - - id := uuid.New().String() // 为每个连接生成一个唯一的 ID - mgr.connections[id] = conn - - return id + mgr.connections[conn] = task } -func (mgr *WebSocketManager) RemoveConnection(id string) { +// 从连接池中移除一个连接 +func (mgr *WebSocketManager) RemoveConnection(conn *websocket.Conn) { mgr.mutex.Lock() defer mgr.mutex.Unlock() - delete(mgr.connections, id) + delete(mgr.connections, conn) } -func (mgr *WebSocketManager) ListenForChanges(target string, callback func()) { - notifications := make(chan struct{}) +// 任务状态变化时, 向监听此任务的所有连接发送消息 +func (mgr *WebSocketManager) NotifyTaskChange(task string, data interface{}) { mgr.mutex.Lock() defer mgr.mutex.Unlock() - if _, ok := mgr.listeners[target]; !ok { - mgr.listeners[target] = make(map[chan struct{}]struct{}) - } - mgr.listeners[target][notifications] = struct{}{} - - go func() { - for { - callback() - for listener := range mgr.listeners[target] { - select { - case listener <- struct{}{}: - default: - delete(mgr.listeners[target], listener) - } - } + for conn, value := range mgr.connections { + if value == task { + conn.WriteJSON(data) } - }() + } } diff --git a/models/server.go b/models/server.go index c54354b..2e79dc5 100644 --- a/models/server.go +++ b/models/server.go @@ -1,6 +1,7 @@ package models import ( + "database/sql/driver" "encoding/json" "fmt" "main/configs" @@ -8,23 +9,40 @@ import ( "time" ) -type Server struct { - ID string `json:"id" gorm:"primary_key"` - Name string `json:"name"` - Type string `json:"type"` // (訓練|推理) - IP string `json:"ip"` - Port int `json:"port"` - Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中) - UserName string `json:"username"` - Password string `json:"password"` - Models []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存 - CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` - UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +type ModelList []string + +func (list *ModelList) Scan(value interface{}) error { + return json.Unmarshal(value.([]byte), list) } +func (list ModelList) Value() (driver.Value, error) { + return json.Marshal(list) +} + +type Server struct { + ID string `json:"id" gorm:"primary_key"` + Name string `json:"name"` + Type string `json:"type"` // (训练|推理) + IP string `json:"ip"` + Port int `json:"port"` + Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中) + UserName string `json:"username"` + Password string `json:"password"` + Models ModelList `json:"models"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +// 获取所有服务器 +func GetServers() (servers []Server, err error) { + err = configs.ORMDB().Find(&servers).Error + return +} + +// 檢查服務器是否正常 func (server *Server) CheckStatus() error { switch server.Type { - case "訓練": + case "训练": resp, err := http.Get(fmt.Sprintf("http://%s:%d/dreambooth/status", server.IP, server.Port)) if err != nil { server.Status = "異常" diff --git a/routers/images.go b/routers/images.go index 1cb8ce5..f710989 100644 --- a/routers/images.go +++ b/routers/images.go @@ -62,8 +62,8 @@ func ImagesGet(w http.ResponseWriter, r *http.Request) { log.Println("任务编号:", task, "任务数量:", len(image_list)) // 加入连接池 - wsid := images_websocket_manager.AddConnection(conn) - defer images_websocket_manager.RemoveConnection(wsid) + images_websocket_manager.AddConnection(conn, task) + defer images_websocket_manager.RemoveConnection(conn) for { _, msg, err := conn.ReadMessage() @@ -111,6 +111,7 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) { Scheduler string `json:"scheduler"` // 调度器 Seed string `json:"seed"` // 随机种子(单张图生成时使用) Number int `json:"number"` // 生成数量 + ModelID int `json:"model_id"` // 模型ID }{} body, err := ioutil.ReadAll(r.Body) if err != nil { @@ -136,12 +137,26 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) { if template.GuidanceScale > 20 { template.GuidanceScale = 20 } + if template.Scheduler == "" { + template.Scheduler = "DDIM" + } + if template.ModelID <= 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("model_id 参数不能为空")) + return + } - // TODO: 创建任务获得任务编号, 多张图时期望可以流式推理 - task := uuid.New().String() + // 从数据库中读取模型信息 + var model models.Model = models.Model{ID: template.ModelID} + if err := configs.ORMDB().First(&model).Error; err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("模型不存在")) + return + } // 直接创建一组图片 var image_list []models.Image + var task string = uuid.New().String() for i := 0; i < template.Number; i++ { var image models.Image image.UserID = account.ID @@ -157,6 +172,10 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) { image_list = append(image_list, image) } + go model.Inference(image_list, func() { + images_websocket_manager.NotifyTaskChange(task, image_list) + }) + // 存储图片信息到数据库 if err := configs.ORMDB().Create(&image_list).Error; err != nil { log.Println(err) diff --git a/routers/models.go b/routers/models.go index 26efe78..8e8a338 100644 --- a/routers/models.go +++ b/routers/models.go @@ -15,11 +15,8 @@ import ( "strconv" "github.com/gorilla/mux" - "github.com/gorilla/websocket" ) -var manager = models.NewWebSocketManager() - func init() { // 初始化模型路由: 检查本地模型目录是否存在, 不存在则创建 if _, err := os.Stat("data/models"); err != nil { @@ -190,39 +187,6 @@ func ModelsPost(w http.ResponseWriter, r *http.Request) { // 獲取模型詳情 func ModelItemGet(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Upgrade") == "websocket" { - vars := mux.Vars(r) - id, _ := strconv.Atoi(vars["id"]) - - var model = models.Model{ID: id} - if err := configs.ORMDB().Take(&model, id).Error; err != nil { - w.WriteHeader(http.StatusNotFound) - return - } - - upgrader := websocket.Upgrader{} - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Println(err) - return - } - defer conn.Close() - wsid := manager.AddConnection(conn) - defer manager.RemoveConnection(wsid) - for { - _, msg, err := conn.ReadMessage() - if err != nil { - log.Println(err) - return - } - log.Println(string(msg)) - if string(msg) == "close" { - break - } - } - return - } - var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)} if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil { w.WriteHeader(http.StatusNotFound) diff --git a/routers/servers.go b/routers/servers.go index c684e89..fc641a8 100644 --- a/routers/servers.go +++ b/routers/servers.go @@ -65,10 +65,10 @@ func ServersPost(w http.ResponseWriter, r *http.Request) { return } - // 如果不指定類型,禁止創建服務器, 必須指定類型:訓練|推理 - if server.Type != "訓練" && server.Type != "推理" { + // 如果不指定類型,禁止創建服務器, 必須指定類型:训练|推理 + if server.Type != "训练" && server.Type != "推理" { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("必須指定類型:訓練|推理")) + w.Write([]byte("必須指定類型:训练|推理")) return }