生成图像(默认参数)

This commit is contained in:
2023-06-21 15:40:00 +08:00
parent 6f06c701ad
commit 2a71384fad
8 changed files with 324 additions and 86 deletions

12
go.mod
View File

@@ -11,9 +11,20 @@ require (
)
require (
github.com/PuerkitoBio/goquery v1.6.1 // indirect
github.com/andybalholm/cascadia v1.2.0 // indirect
github.com/gobwas/glob v0.2.3 // indirect
github.com/mattn/go-sqlite3 v1.14.16 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca // indirect
github.com/sizeofint/webp-animation v0.0.0-20190207194838-b631dc900de9 // indirect
github.com/tidwall/gjson v1.8.0 // indirect
github.com/tidwall/match v1.0.3 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
golang.org/x/image v0.7.0 // indirect
golang.org/x/net v0.6.0 // indirect
golang.org/x/text v0.9.0 // indirect
gopkg.in/xmlpath.v2 v2.0.0-20150820204837-860cbeca3ebc // indirect
)
require (
@@ -24,6 +35,7 @@ require (
github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.669
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm v1.0.669
github.com/zhshch2002/goreq v0.0.0-20210608055943-7028cfd48a0d
gopkg.in/yaml.v2 v2.4.0
gorm.io/driver/sqlite v1.5.0
)

56
go.sum
View File

@@ -1,7 +1,24 @@
github.com/PuerkitoBio/goquery v1.6.1 h1:FgjbQZKl5HTmcn4sKBgvx8vv63nhyhIpv7lJpFGCWpk=
github.com/PuerkitoBio/goquery v1.6.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=
github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y=
github.com/andybalholm/cascadia v1.2.0 h1:vuRCkM5Ozh/BfmsaTm26kbjm0mIOM3yS5Ek/F5h18aE=
github.com/andybalholm/cascadia v1.2.0/go.mod h1:YCyR8vOZT9aZ1CHEd8ap0gMVm2aFgxBp0T0eFw1RUQY=
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
@@ -12,20 +29,45 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww=
github.com/russross/blackfriday v1.6.0/go.mod h1:ti0ldHuxg49ri4ksnFxlkCfN+hvslNlmVHqNRXXJNAY=
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca h1:NugYot0LIVPxTvN8n+Kvkn6TrbMyxQiuvKdEwFdR9vI=
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca/go.mod h1:uugorj2VCxiV1x+LzaIdVa9b4S4qGAcH6cbhh4qVxOU=
github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591 h1:dCWBD4Xchp/XFIR/x6D2l74DtQHvIpHsmpPRHgH9oUo=
github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591/go.mod h1:IXC7KN2FEuTEISdePm37qcFyXInAh6pfW35yDjbdfOM=
github.com/sizeofint/webp-animation v0.0.0-20190207194838-b631dc900de9 h1:i3LYMwQ0zkh/BJ47vIZN+jBYqV4/f6DFoAsW8rwV490=
github.com/sizeofint/webp-animation v0.0.0-20190207194838-b631dc900de9/go.mod h1:/NQ8ciRuH+vxYhrFlnX70gvXBugMYQbBygCRocFgSZ4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.669 h1:5KKJBcemqKONBFxMdMyLMvk+TrqXaEPhqe9TrZqB3r0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.669/go.mod h1:7sCQWVkxcsR38nffDW057DRGk8mUjK1Ing/EFOK8s8Y=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm v1.0.669 h1:gc1bPO/YVfuXEIs+HbQ/gFlFjdkJjOsjm8xWqF7hPww=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/cvm v1.0.669/go.mod h1:hhy13j6NKKxt/g62JZEDekJNQx3EAevnHopmwlt2tRc=
github.com/tidwall/gjson v1.8.0 h1:Qt+orfosKn0rbNTZqHYDqBrmm3UDA4KRkv70fDzG+PQ=
github.com/tidwall/gjson v1.8.0/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk=
github.com/tidwall/match v1.0.3 h1:FQUVvBImDutD8wJLN6c5eMzWtjgONK9MwIBCOrUJKeE=
github.com/tidwall/match v1.0.3/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.1.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zhshch2002/goreq v0.0.0-20210608055943-7028cfd48a0d h1:a7RuxYLiIzfCqYaISNnuUPahCFIPPT1ERWxJxbnFJeA=
github.com/zhshch2002/goreq v0.0.0-20210608055943-7028cfd48a0d/go.mod h1:f+jNcJUd3buNPA42ai935kaWFai/hxOMkzvgMfbtHhs=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -33,15 +75,21 @@ golang.org/x/image v0.7.0 h1:gzS29xtG1J5ybQlv0PuyfE3nmc6R4qB73m6LUUmvFuw=
golang.org/x/image v0.7.0/go.mod h1:nd/q4ef1AKKYl/4kft7g+6UyGbdiqWqTP1ZAbRoV7Rg=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0 h1:L4ZwwTvKW9gr0ZMS1yrHD9GZhIuVjOBBnaKH+SPQK0Q=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -50,9 +98,12 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
@@ -61,8 +112,13 @@ golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/xmlpath.v2 v2.0.0-20150820204837-860cbeca3ebc h1:LMEBgNcZUqXaP7evD1PZcL6EcDVa2QOFuI+cqM3+AJM=
gopkg.in/xmlpath.v2 v2.0.0-20150820204837-860cbeca3ebc/go.mod h1:N8UOSI6/c2yOpa/XDz3KVUiegocTziPiqNkeNTMiG1k=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/sqlite v1.5.0 h1:zKYbzRCpBrT1bNijRnxLDJWPjVfImGEn0lSnUY5gZ+c=
gorm.io/driver/sqlite v1.5.0/go.mod h1:kDMDfntV9u/vuMmz8APHtHF0b4nyBB7sfCieC6G8k8I=
gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

View File

@@ -1,15 +1,23 @@
package models
import (
"bytes"
"crypto/md5"
"fmt"
"io/ioutil"
"log"
"main/configs"
"net/http"
"net/url"
"os"
"path/filepath"
"time"
"encoding/base64"
"image/png"
"github.com/chai2010/webp"
"github.com/zhshch2002/goreq"
)
type Model struct {
@@ -35,8 +43,184 @@ type Model struct {
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// 创建一个带缓冲的通道,缓冲区大小为 10
// var ch = make(chan int, 10)
func init() {
configs.ORMDB().AutoMigrate(&Model{})
// 处理推理任务
//go func() {
// for {
// // 从通道中取出一个数据
// model := <-ch
// // 模型状态变化时, 向监听此模型的所有连接发送消息
// }
//}()
}
func (model *Model) Inference(image_list []Image, callback func()) {
log.Println(image_list)
callback()
// 模型未部署到推理機
if model.ServerID == "" {
log.Println("模型未部署到推理機, 开始部署模型")
var server Server
if err := configs.ORMDB().Where("models LIKE ?", "%"+model.Name+"%").Take(&server).Error; err != nil {
log.Println(err)
// 如果没有则寻找空闲服务器
// 如果没有空闲则创建新服务器
// 取一台空闲的推理机上传并切换到此模型
// 新建一台推理机上传并切换到此模型
}
// 执行生成任务
if model.Image == "" {
var data = struct {
EnableHr bool `json:"enable_hr"`
DenoisingStrength int `json:"denoising_strength"`
FirstphaseWidth int `json:"firstphase_width"`
FirstphaseHeight int `json:"firstphase_height"`
HrScale int `json:"hr_scale"`
HrUpscaler string `json:"hr_upscaler"`
HrSecondPassSteps int `json:"hr_second_pass_steps"`
HrResizeX int `json:"hr_resize_x"`
HrResizeY int `json:"hr_resize_y"`
HrSamplerName string `json:"hr_sampler_name"`
HrPrompt string `json:"hr_prompt"`
HrNegativePrompt string `json:"hr_negative_prompt"`
Prompt string `json:"prompt"`
Styles []string `json:"styles"`
Seed int `json:"seed"`
Subseed int `json:"subseed"`
SubseedStrength int `json:"subseed_strength"`
SeedResizeFromH int `json:"seed_resize_from_h"`
SeedResizeFromW int `json:"seed_resize_from_w"`
SamplerName string `json:"sampler_name"`
BatchSize int `json:"batch_size"`
NIter int `json:"n_iter"`
Steps int `json:"steps"`
CfgScale int `json:"cfg_scale"`
Width int `json:"width"`
Height int `json:"height"`
RestoreFaces bool `json:"restore_faces"`
Tiling bool `json:"tiling"`
DoNotSaveSamples bool `json:"do_not_save_samples"`
DoNotSaveGrid bool `json:"do_not_save_grid"`
NegativePrompt string `json:"negative_prompt"`
Eta int `json:"eta"`
SMinUncond int `json:"s_min_uncond"`
SChurn int `json:"s_churn"`
STmax int `json:"s_tmax"`
STmin int `json:"s_tmin"`
SNoise int `json:"s_noise"`
OverrideSettings map[string]string `json:"override_settings"`
OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"`
ScriptArgs []interface{} `json:"script_args"`
SamplerIndex string `json:"sampler_index"`
ScriptName string `json:"script_name"`
SendImages bool `json:"send_images"`
SaveImages bool `json:"save_images"`
AlwaysonScripts map[string]string `json:"alwayson_scripts"`
}{
EnableHr: false,
DenoisingStrength: 0,
FirstphaseWidth: 0,
FirstphaseHeight: 0,
HrScale: 2,
HrUpscaler: "nearest",
HrSecondPassSteps: 0,
HrResizeX: 0,
HrResizeY: 0,
HrSamplerName: "",
HrPrompt: "",
HrNegativePrompt: "",
Prompt: "miao~",
Styles: []string{},
Seed: -1,
Subseed: -1,
SubseedStrength: 0,
SeedResizeFromH: -1,
SeedResizeFromW: -1,
SamplerName: "beamsearch",
BatchSize: 1,
NIter: 1,
Steps: 50,
CfgScale: 7,
Width: 512,
Height: 512,
RestoreFaces: false,
Tiling: false,
DoNotSaveSamples: false,
DoNotSaveGrid: false,
NegativePrompt: "",
Eta: 0,
SMinUncond: 0,
SChurn: 0,
STmax: 0,
STmin: 0,
SNoise: 1,
OverrideSettings: map[string]string{},
OverrideSettingsRestoreAfterwards: false,
ScriptArgs: []interface{}{},
SamplerIndex: "Euler",
ScriptName: "generate",
SendImages: true,
SaveImages: false,
AlwaysonScripts: map[string]string{},
}
// 接收到的图片列表
var rest = struct {
Images []string `json:"images"`
}{
Images: []string{},
}
var url = fmt.Sprintf("http://%s:%d/sdapi/v1/txt2img", server.IP, server.Port)
if err := goreq.Post(url).SetJsonBody(data).Do().BindJSON(&rest); err != nil {
log.Println("API 查询失败:", err)
}
log.Println("API 查询成功:", rest)
for _, img := range rest.Images {
log.Println("保存图片:", img)
// 将base64编码的图片保存到本地webp
if err := SaveBase64Image(img, "data/images/"+img+".webp"); err != nil {
log.Println(err)
}
}
}
}
}
// 将base64编码的图片保存到本地webp
func SaveBase64Image(base64Str string, filename string) error {
// 解码base64图片
data, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return err
}
// 将png图片解码为image.Image
img, err := png.Decode(bytes.NewReader(data))
if err != nil {
return err
}
// 创建webp编码器
webpWriter, err := os.Create(filename)
if err != nil {
return err
}
defer webpWriter.Close()
// 将image.Image编码为webp格式并保存到本地
if err := webp.Encode(webpWriter, img, &webp.Options{Lossless: true}); err != nil {
return err
}
return nil
}
func (model *Model) Train() (err error) {

View File

@@ -3,59 +3,44 @@ package models
import (
"sync"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type WebSocketManager struct {
connections map[string]*websocket.Conn
listeners map[string]map[chan struct{}]struct{}
connections map[*websocket.Conn]string // 连接指针:任务ID
mutex sync.RWMutex
}
// 创建一个新的连接池
func NewWebSocketManager() *WebSocketManager {
return &WebSocketManager{
connections: make(map[string]*websocket.Conn),
connections: make(map[*websocket.Conn]string),
mutex: sync.RWMutex{},
}
}
func (mgr *WebSocketManager) AddConnection(conn *websocket.Conn) string {
// 向连接池加入一个新连接
func (mgr *WebSocketManager) AddConnection(conn *websocket.Conn, task string) {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()
id := uuid.New().String() // 为每个连接生成一个唯一的 ID
mgr.connections[id] = conn
return id
mgr.connections[conn] = task
}
func (mgr *WebSocketManager) RemoveConnection(id string) {
// 从连接池中移除一个连接
func (mgr *WebSocketManager) RemoveConnection(conn *websocket.Conn) {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()
delete(mgr.connections, id)
delete(mgr.connections, conn)
}
func (mgr *WebSocketManager) ListenForChanges(target string, callback func()) {
notifications := make(chan struct{})
// 任务状态变化时, 向监听此任务的所有连接发送消息
func (mgr *WebSocketManager) NotifyTaskChange(task string, data interface{}) {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()
if _, ok := mgr.listeners[target]; !ok {
mgr.listeners[target] = make(map[chan struct{}]struct{})
}
mgr.listeners[target][notifications] = struct{}{}
go func() {
for {
callback()
for listener := range mgr.listeners[target] {
select {
case listener <- struct{}{}:
default:
delete(mgr.listeners[target], listener)
}
}
for conn, value := range mgr.connections {
if value == task {
conn.WriteJSON(data)
}
}()
}
}

View File

@@ -1,6 +1,7 @@
package models
import (
"database/sql/driver"
"encoding/json"
"fmt"
"main/configs"
@@ -8,23 +9,40 @@ import (
"time"
)
type Server struct {
ID string `json:"id" gorm:"primary_key"`
Name string `json:"name"`
Type string `json:"type"` // (訓練|推理)
IP string `json:"ip"`
Port int `json:"port"`
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
UserName string `json:"username"`
Password string `json:"password"`
Models []map[string]interface{} `json:"models" gorm:"-"` // 數據庫不必保存
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
type ModelList []string
func (list *ModelList) Scan(value interface{}) error {
return json.Unmarshal(value.([]byte), list)
}
func (list ModelList) Value() (driver.Value, error) {
return json.Marshal(list)
}
type Server struct {
ID string `json:"id" gorm:"primary_key"`
Name string `json:"name"`
Type string `json:"type"` // (训练|推理)
IP string `json:"ip"`
Port int `json:"port"`
Status string `json:"status"` // (異常|初始化|閒置|就緒|工作中|關閉中)
UserName string `json:"username"`
Password string `json:"password"`
Models ModelList `json:"models"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// 获取所有服务器
func GetServers() (servers []Server, err error) {
err = configs.ORMDB().Find(&servers).Error
return
}
// 檢查服務器是否正常
func (server *Server) CheckStatus() error {
switch server.Type {
case "訓練":
case "训练":
resp, err := http.Get(fmt.Sprintf("http://%s:%d/dreambooth/status", server.IP, server.Port))
if err != nil {
server.Status = "異常"

View File

@@ -62,8 +62,8 @@ func ImagesGet(w http.ResponseWriter, r *http.Request) {
log.Println("任务编号:", task, "任务数量:", len(image_list))
// 加入连接池
wsid := images_websocket_manager.AddConnection(conn)
defer images_websocket_manager.RemoveConnection(wsid)
images_websocket_manager.AddConnection(conn, task)
defer images_websocket_manager.RemoveConnection(conn)
for {
_, msg, err := conn.ReadMessage()
@@ -111,6 +111,7 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
Scheduler string `json:"scheduler"` // 调度器
Seed string `json:"seed"` // 随机种子(单张图生成时使用)
Number int `json:"number"` // 生成数量
ModelID int `json:"model_id"` // 模型ID
}{}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
@@ -136,12 +137,26 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
if template.GuidanceScale > 20 {
template.GuidanceScale = 20
}
if template.Scheduler == "" {
template.Scheduler = "DDIM"
}
if template.ModelID <= 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("model_id 参数不能为空"))
return
}
// TODO: 创建任务获得任务编号, 多张图时期望可以流式推理
task := uuid.New().String()
// 从数据库中读取模型信息
var model models.Model = models.Model{ID: template.ModelID}
if err := configs.ORMDB().First(&model).Error; err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("模型不存在"))
return
}
// 直接创建一组图片
var image_list []models.Image
var task string = uuid.New().String()
for i := 0; i < template.Number; i++ {
var image models.Image
image.UserID = account.ID
@@ -157,6 +172,10 @@ func ImagesPost(w http.ResponseWriter, r *http.Request) {
image_list = append(image_list, image)
}
go model.Inference(image_list, func() {
images_websocket_manager.NotifyTaskChange(task, image_list)
})
// 存储图片信息到数据库
if err := configs.ORMDB().Create(&image_list).Error; err != nil {
log.Println(err)

View File

@@ -15,11 +15,8 @@ import (
"strconv"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
var manager = models.NewWebSocketManager()
func init() {
// 初始化模型路由: 检查本地模型目录是否存在, 不存在则创建
if _, err := os.Stat("data/models"); err != nil {
@@ -190,39 +187,6 @@ func ModelsPost(w http.ResponseWriter, r *http.Request) {
// 獲取模型詳情
func ModelItemGet(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
vars := mux.Vars(r)
id, _ := strconv.Atoi(vars["id"])
var model = models.Model{ID: id}
if err := configs.ORMDB().Take(&model, id).Error; err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
defer conn.Close()
wsid := manager.AddConnection(conn)
defer manager.RemoveConnection(wsid)
for {
_, msg, err := conn.ReadMessage()
if err != nil {
log.Println(err)
return
}
log.Println(string(msg))
if string(msg) == "close" {
break
}
}
return
}
var model = models.Model{ID: utils.ParamInt(mux.Vars(r)["id"], 0)}
if err := configs.ORMDB().Take(&model, utils.ParamInt(mux.Vars(r)["id"], 0)).Error; err != nil {
w.WriteHeader(http.StatusNotFound)

View File

@@ -65,10 +65,10 @@ func ServersPost(w http.ResponseWriter, r *http.Request) {
return
}
// 如果不指定類型,禁止創建服務器, 必須指定類型:訓練|推理
if server.Type != "訓練" && server.Type != "推理" {
// 如果不指定類型,禁止創建服務器, 必須指定類型:训练|推理
if server.Type != "训练" && server.Type != "推理" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("必須指定類型:訓練|推理"))
w.Write([]byte("必須指定類型:训练|推理"))
return
}