Files
webp/models/elasticsearch.go
2023-12-06 00:09:34 +08:00

152 lines
3.1 KiB
Go

package models
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"log"
"net/http"
"github.com/elastic/go-elasticsearch/v8"
)
func elasticsearch_init() (es *elasticsearch.Client) {
es, err := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{config.GetString("elasticsearch.host")},
Username: config.GetString("elasticsearch.user"),
Password: config.GetString("elasticsearch.password"),
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
})
if err != nil {
log.Printf("Error creating the client: %s", err)
return nil
}
return es
}
type SearchData struct {
_shards struct {
failed int
skipped int
successful int
total int
}
Hits struct {
Hits []struct {
ID string `json:"_id"`
Index string `json:"_index"`
Score float64 `json:"_score"`
Source struct {
Content string `json:"content"`
} `json:"_source"`
Type string `json:"_type"`
} `json:"hits"`
max_score float64
total struct {
relation string
value int
}
} `json:"hits"`
timed_out bool
took int
}
// 获取搜索结果的 ID 列表
func (sd SearchData) GetIDList(first, last, after, before int) (id_list []string) {
for _, hit := range sd.Hits.Hits {
id_list = append(id_list, hit.ID)
}
// 如果 after 不为 0, 从这个ID开始向后取切片
if after != 0 {
after_str := fmt.Sprint(after)
for i, id := range id_list {
if id == after_str {
id_list = id_list[i+1:]
break
}
}
}
// 如果 before 不为 0, 从这个ID开始向前取切片
if before != 0 {
before_str := fmt.Sprint(before)
for i, id := range id_list {
if id == before_str {
id_list = id_list[:i]
break
}
}
}
// 如果 first 不为 0, 取切片的前 first 个元素
if first != 0 {
if first > len(id_list) {
first = len(id_list)
}
id_list = id_list[:first]
}
// 如果 last 不为 0, 取切片的后 last 个元素
if last != 0 {
if last > len(id_list) {
last = len(id_list)
}
id_list = id_list[len(id_list)-last:]
}
return id_list
}
// 获取搜索结果的内容列表
func ElasticsearchSearch(text string) (r SearchData) {
// 通过字符串构建查询
var buf bytes.Buffer
query := map[string]interface{}{
"query": map[string]interface{}{
"match": map[string]interface{}{
"content": text,
},
},
}
if err := json.NewEncoder(&buf).Encode(query); err != nil {
log.Printf("Error encoding query: %s", err)
return
}
es := elasticsearch_init()
// 执行查询(最大返回200条)
res, err := es.Search(
es.Search.WithContext(context.Background()),
es.Search.WithIndex("web_images"),
es.Search.WithBody(&buf),
es.Search.WithTrackTotalHits(true),
es.Search.WithPretty(),
es.Search.WithSize(200),
)
if err != nil {
log.Printf("Error getting response: %s", err)
return
}
defer res.Body.Close()
// 处理错误
if res.IsError() {
log.Printf("Error: %s", res.String())
return
}
// 转换返回结果
if err := json.NewDecoder(res.Body).Decode(&r); err != nil {
log.Printf("Error parsing the response body: %s", err)
return
}
return r
}