diff --git a/consul/registry.go b/consul/registry.go index 10d38bf..5e18b6b 100644 --- a/consul/registry.go +++ b/consul/registry.go @@ -94,12 +94,21 @@ func (c *consulRegistry) Register(info *registry.Info) error { return fmt.Errorf("getting service id failed, err: %w", err) } + tags, err := convTagMapToSlice(info.Tags) + if err == nil { + for _, tag := range c.opts.AdditionInfo.Tags { + if !inArray(tag, tags) { + tags = append(tags, tag) + } + } + } + svcInfo := &api.AgentServiceRegistration{ ID: svcID, Name: info.ServiceName, Address: host, Port: port, - Tags: c.opts.AdditionInfo.Tags, + Tags: tags, Meta: c.opts.AdditionInfo.Meta, Weights: &api.AgentWeights{ Passing: info.Weight, diff --git a/consul/resolver.go b/consul/resolver.go index 6a178e7..b9fd8b3 100644 --- a/consul/resolver.go +++ b/consul/resolver.go @@ -60,11 +60,12 @@ func (c *consulResolver) Resolve(_ context.Context, desc string) (discovery.Resu if svc == nil || svc.Address == "" { continue } + tags := splitTags(svc.Tags) eps = append(eps, discovery.NewInstance( defaultNetwork, net.JoinHostPort(svc.Address, fmt.Sprintf("%d", svc.Port)), svc.Weights.Passing, - svc.Meta, + tags, )) } diff --git a/consul/utils.go b/consul/utils.go index 9b662bf..8910554 100644 --- a/consul/utils.go +++ b/consul/utils.go @@ -15,13 +15,20 @@ package consul import ( + "errors" "fmt" "net" "strconv" + "strings" "github.com/cloudwego/hertz/pkg/app/server/registry" + "github.com/cloudwego/hertz/pkg/common/utils" ) +const kvJoinChar = ":" + +var errIllegalTagChar = errors.New("illegal tag character") + func parseAddr(addr net.Addr) (host string, port int, err error) { host, portStr, err := net.SplitHostPort(addr.String()) if err != nil { @@ -29,7 +36,16 @@ func parseAddr(addr net.Addr) (host string, port int, err error) { } if host == "" || host == "::" { - return "", 0, fmt.Errorf("empty host") + detectHost := utils.LocalIP() + if detectHost == utils.UNKNOWN_IP_ADDR { + return "", 0, fmt.Errorf("get local ip error") + } + + host, _, err = net.SplitHostPort(detectHost) + + if err != nil { + return "", 0, fmt.Errorf("empty host") + } } port, err = strconv.Atoi(portStr) @@ -50,3 +66,56 @@ func getServiceId(info *registry.Info) (string, error) { } return fmt.Sprintf("%s:%s:%d", info.ServiceName, host, port), nil } + +// convTagMapToSlice Tags map be convert to slice. +// Keys must not contain `:`. +func convTagMapToSlice(tagMap map[string]string) ([]string, error) { + svcTags := make([]string, 0, len(tagMap)) + for k, v := range tagMap { + var tag string + if strings.Contains(k, kvJoinChar) { + return svcTags, errIllegalTagChar + } + if v == "" { + tag = k + } else { + tag = fmt.Sprintf("%s%s%s", k, kvJoinChar, v) + } + svcTags = append(svcTags, tag) + } + return svcTags, nil +} + +// splitTags Tags characters be separated to map. +func splitTags(tags []string) map[string]string { + n := len(tags) + tagMap := make(map[string]string, n) + if n == 0 { + return tagMap + } + + for _, tag := range tags { + if tag == "" { + continue + } + strArr := strings.SplitN(tag, kvJoinChar, 2) + if len(strArr) == 2 { + key := strArr[0] + tagMap[key] = strArr[1] + } + if len(strArr) == 1 { + tagMap[strArr[0]] = "" + } + } + + return tagMap +} + +func inArray(needle string, haystack []string) bool { + for _, k := range haystack { + if needle == k { + return true + } + } + return false +}