diff --git a/models/config.go b/models/config.go index d0b4eac..14228b4 100644 --- a/models/config.go +++ b/models/config.go @@ -1,7 +1,6 @@ package models import ( - "fmt" "log" "path/filepath" "runtime" @@ -16,10 +15,7 @@ var ( ) func init() { - //如果命令行参数中有test,则使用测试环境的配置 config_file := filepath.Join(Root, "data", "config.yaml") - fmt.Println(config_file) - viper.SetConfigFile(config_file) if err := viper.ReadInConfig(); err != nil { log.Println("读取配置文件失败", err) diff --git a/models/elasticsearch.go b/models/elasticsearch.go index ac7fd1e..a478075 100644 --- a/models/elasticsearch.go +++ b/models/elasticsearch.go @@ -30,14 +30,33 @@ func elasticsearch_init() (es *elasticsearch.Client) { } type SearchData struct { - Total int64 `json:"total"` + _shards struct { + failed int + skipped int + successful int + total int + } + Hits struct { + Hits []struct { + ID string `json:"_id"` + Index string `json:"_index"` + Score float64 `json:"_score"` + Source struct { + Content string `json:"content"` + } `json:"_source"` + Type string `json:"_type"` + } `json:"hits"` + max_score float64 + total struct { + relation string + value int + } + } `json:"hits"` + timed_out bool + took int } -func ElasticsearchSearch(text string) map[string]interface{} { - var ( - r map[string]interface{} - ) - +func ElasticsearchSearch(text string) (r *SearchData) { // 通过字符串构建查询 var buf bytes.Buffer query := map[string]interface{}{ @@ -54,7 +73,7 @@ func ElasticsearchSearch(text string) map[string]interface{} { es := elasticsearch_init() - // Perform the search request. + // 执行查询 res, err := es.Search( es.Search.WithContext(context.Background()), es.Search.WithIndex("my_index"), @@ -68,30 +87,17 @@ func ElasticsearchSearch(text string) map[string]interface{} { } defer res.Body.Close() - // Check response status + // 处理错误 if res.IsError() { log.Printf("Error: %s", res.String()) return nil } - // Deserialize the response into a map. + // 转换返回结果 if err := json.NewDecoder(res.Body).Decode(&r); err != nil { log.Printf("Error parsing the response body: %s", err) return nil } - // Print the response status, number of results, and request duration. - log.Printf( - "[%s] %d hits; took: %dms", - res.Status(), - int(r["hits"].(map[string]interface{})["total"].(map[string]interface{})["value"].(float64)), - int(r["took"].(float64)), - ) - - // Print the ID and document source for each hit. - for _, hit := range r["hits"].(map[string]interface{})["hits"].([]interface{}) { - log.Printf(" * ID=%s, %s", hit.(map[string]interface{})["_id"], hit.(map[string]interface{})["_source"]) - } - return r } diff --git a/models/elasticsearch_test.go b/models/elasticsearch_test.go index 9c28ef8..a96e5d5 100644 --- a/models/elasticsearch_test.go +++ b/models/elasticsearch_test.go @@ -6,11 +6,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestMyFunction(t *testing.T) { +func TestElasticsearchSearch(t *testing.T) { // 创建一个测试用例 - expected := 10 + expected := "植物学家 阿尔法 可复活一次 技能:召唤豌豆射手 转到设置" actual := ElasticsearchSearch("豌豆") // 使用 assert 包中的函数来验证函数的输出 - assert.Equal(t, expected, actual) + assert.Equal(t, expected, actual.Hits.Hits[0].Source.Content) }