共用配置文件导入
This commit is contained in:
		@@ -1,44 +1,9 @@
 | 
			
		||||
package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"runtime"
 | 
			
		||||
import "github.com/spf13/viper"
 | 
			
		||||
 | 
			
		||||
	"github.com/spf13/viper"
 | 
			
		||||
)
 | 
			
		||||
var config *viper.Viper
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	Root  string
 | 
			
		||||
	Viper *viper.Viper
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	_, b, _, _ := runtime.Caller(0)
 | 
			
		||||
	Root = filepath.Join(filepath.Dir(b), "..")
 | 
			
		||||
	config_file := filepath.Join(Root, "data", "config.yaml")
 | 
			
		||||
	viper.SetConfigFile(config_file)
 | 
			
		||||
	if err := viper.ReadInConfig(); err != nil {
 | 
			
		||||
		log.Println("读取配置文件失败", err)
 | 
			
		||||
		生成配置文件()
 | 
			
		||||
	}
 | 
			
		||||
	Viper = viper.GetViper()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func 生成配置文件() {
 | 
			
		||||
	viper.Set("mysql.host", "")
 | 
			
		||||
	viper.Set("mysql.port", 3306)
 | 
			
		||||
	viper.Set("mysql.user", "")
 | 
			
		||||
	viper.Set("mysql.password", "")
 | 
			
		||||
	viper.Set("mysql.dbname", "")
 | 
			
		||||
	viper.Set("mysql.charset", "utf8mb4")
 | 
			
		||||
	viper.Set("mysql.maxOpenConns", 100)
 | 
			
		||||
 | 
			
		||||
	viper.Set("oss.endpoint", "")
 | 
			
		||||
	viper.Set("oss.accessID", "")
 | 
			
		||||
	viper.Set("oss.accessKey", "")
 | 
			
		||||
 | 
			
		||||
	viper.Set("video.endpoint", "")
 | 
			
		||||
	viper.Set("video.accessKeyID", "")
 | 
			
		||||
	viper.Set("video.accessKey", "")
 | 
			
		||||
func InitConfig(cfg *viper.Viper) {
 | 
			
		||||
	config = cfg
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -13,11 +13,9 @@ import (
 | 
			
		||||
 | 
			
		||||
func elasticsearch_init() (es *elasticsearch.Client) {
 | 
			
		||||
	es, err := elasticsearch.NewClient(elasticsearch.Config{
 | 
			
		||||
		Addresses: []string{
 | 
			
		||||
			Viper.Get("elasticsearch.host").(string),
 | 
			
		||||
		},
 | 
			
		||||
		Username: Viper.Get("elasticsearch.user").(string),
 | 
			
		||||
		Password: Viper.Get("elasticsearch.password").(string),
 | 
			
		||||
		Addresses: []string{config.GetString("elasticsearch.host")},
 | 
			
		||||
		Username:  config.GetString("elasticsearch.user"),
 | 
			
		||||
		Password:  config.GetString("elasticsearch.password"),
 | 
			
		||||
		Transport: &http.Transport{
 | 
			
		||||
			TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
 | 
			
		||||
		},
 | 
			
		||||
 
 | 
			
		||||
@@ -18,12 +18,11 @@ func (m *MilvusConnection) GetClient() client.Client {
 | 
			
		||||
 | 
			
		||||
func (m *MilvusConnection) Init() (err error) {
 | 
			
		||||
	log.Println("Milvus connection init")
 | 
			
		||||
	host := Viper.Get("milvus.host").(string)
 | 
			
		||||
	port := Viper.Get("milvus.port").(int)
 | 
			
		||||
	m.Client, err = client.NewGrpcClient(
 | 
			
		||||
		context.Background(),
 | 
			
		||||
		fmt.Sprintf("%s:%d", host, port),
 | 
			
		||||
	)
 | 
			
		||||
	m.Client, err = client.NewGrpcClient(context.Background(), fmt.Sprintf(
 | 
			
		||||
		"%s:%d",
 | 
			
		||||
		config.GetString("milvus.host"),
 | 
			
		||||
		config.GetInt("milvus.port"),
 | 
			
		||||
	))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("Milvus connection failed:", err)
 | 
			
		||||
		return
 | 
			
		||||
@@ -31,16 +30,3 @@ func (m *MilvusConnection) Init() (err error) {
 | 
			
		||||
	log.Println("Milvus connection success")
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *MilvusConnection) GetCollection(collection_name string) (collection *client.Client, err error) {
 | 
			
		||||
	if m.Client == nil {
 | 
			
		||||
		m.Init()
 | 
			
		||||
	}
 | 
			
		||||
	err = m.Client.LoadCollection(context.Background(), collection_name, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("Milvus load collection failed:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	collection = &m.Client
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,38 +10,24 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	_ "github.com/go-sql-driver/mysql"
 | 
			
		||||
	"github.com/jmoiron/sqlx"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var connection *sqlx.DB
 | 
			
		||||
var connectionx *sql.DB
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	var err error
 | 
			
		||||
	user := Viper.Get("mysql.user").(string)
 | 
			
		||||
	password := Viper.Get("mysql.password").(string)
 | 
			
		||||
	host := Viper.Get("mysql.host").(string)
 | 
			
		||||
	port := Viper.Get("mysql.port").(int)
 | 
			
		||||
	database := Viper.Get("mysql.database").(string)
 | 
			
		||||
	connection, err = sqlx.Connect("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", user, password, host, port, database))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalln("连接数据库失败", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MysqlConnection struct {
 | 
			
		||||
	Database *sql.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 初始化数据库连接
 | 
			
		||||
func (m *MysqlConnection) Init() (err error) {
 | 
			
		||||
	user := Viper.Get("mysql.user").(string)
 | 
			
		||||
	password := Viper.Get("mysql.password").(string)
 | 
			
		||||
	host := Viper.Get("mysql.host").(string)
 | 
			
		||||
	port := Viper.Get("mysql.port").(int)
 | 
			
		||||
	database := Viper.Get("mysql.database").(string)
 | 
			
		||||
	sqlconf := user + ":" + password + "@tcp(" + host + ":" + strconv.Itoa(port) + ")/" + database + "?charset=utf8mb4&parseTime=True&loc=Local"
 | 
			
		||||
	m.Database, err = sql.Open("mysql", sqlconf) // 连接数据库
 | 
			
		||||
	m.Database, err = sql.Open("mysql", fmt.Sprintf(
 | 
			
		||||
		"%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
 | 
			
		||||
		config.GetString("mysql.user"),
 | 
			
		||||
		config.GetString("mysql.password"),
 | 
			
		||||
		config.GetString("mysql.host"),
 | 
			
		||||
		config.GetInt("mysql.port"),
 | 
			
		||||
		config.GetString("mysql.database"),
 | 
			
		||||
	))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("连接数据库失败", err)
 | 
			
		||||
		return
 | 
			
		||||
 
 | 
			
		||||
@@ -8,12 +8,11 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetBucket(bucketName string) *oss.Bucket {
 | 
			
		||||
	// 从config文件中读取配置
 | 
			
		||||
	endpoint := Viper.Get("oss.endpoint").(string)
 | 
			
		||||
	accessID := Viper.Get("oss.accessID").(string)
 | 
			
		||||
	accessKey := Viper.Get("oss.accessKey").(string)
 | 
			
		||||
 | 
			
		||||
	client, err := oss.New(endpoint, accessID, accessKey)
 | 
			
		||||
	client, err := oss.New(
 | 
			
		||||
		config.GetString("oss.endpoint"),
 | 
			
		||||
		config.GetString("oss.accessID"),
 | 
			
		||||
		config.GetString("oss.accessKey"),
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		HandleError(err)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user