diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 795399550..44ee9cea1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,3 +12,20 @@ We really appreciate contributions, but they must meet the following requirement We merge PRs into `development`, which is then tested in a sharded, replicated environment in our datacenter for regressions. Once everyone is happy, we merge to master - this is to maintain a bit of quality control past the usual PR process. **Thanks** for helping! + +# How to test the code +In order to run the tests, you need the following installed (assuming Ubuntu) + +* daemontools (for svstat) +* mongo (for mongo client) +* mongodb (for mongo server) + +Before running the tests, you need to start the test mongo server with `make startdb`. After the tests are done, you +can tear it down with `make stopdb`. + +The tests to run are defined in `.travis.yml` under the *script* section. + +## Note about DNS Lookup of SRV records +If you are testing on Linux, you may run into an error that net.LookupSRV cannot parse the DNS response. +In this case, you are likely using a distribution that uses systemd for DNS and go does not handle it yet. +The easiest work around is to change your /etc/resolv.conf file and set a remote nameserver; 8.8.8.8 or 9.9.9.9. \ No newline at end of file diff --git a/session.go b/session.go index abc8e22a7..e58d30f33 100644 --- a/session.go +++ b/session.go @@ -39,6 +39,7 @@ import ( "net" "net/url" "reflect" + "runtime" "sort" "strconv" "strings" @@ -82,6 +83,9 @@ const ( DefaultConnectionPoolLimit = 4096 zeroDuration = time.Duration(0) + + // SchemeMongoDBSRV denotes that the user wants to use a DNS service record to determine the cluster + SchemeMongoDBSRV = "mongodb+srv://" ) // When changing the Session type, check if newSession and copySession @@ -795,9 +799,43 @@ type urlInfoOption struct { } func extractURL(s string) (*urlInfo, error) { - s = strings.TrimPrefix(s, "mongodb://") info := &urlInfo{options: []urlInfoOption{}} + if strings.HasPrefix(s, SchemeMongoDBSRV) { + hosts, err := parseHosts(s) + if err != nil { + return nil, err + } + + connectionArgs, err := getConnectionArgsFromTXT(s) + if err != nil { + return nil, err + } + + temp := strings.TrimPrefix(s, SchemeMongoDBSRV) + _, err = parseUsernamePassword(temp, info) + if err != nil { + return nil, err + } + + // using what we looked up, construct a new mongo URL that the rest of this function can use + // the construction has a few different conditionals, so it was easier to build it up + // manually instead of using fmt.Sprintf. + url, err := url.Parse(s) + s = "mongodb://" + if len(info.user) > 0 { + s += info.user + ":" + info.pass + "@" + } + s += hosts + url.Path + s += "?" + connectionArgs + + if len(url.RawQuery) > 0 { + s += "&" + url.RawQuery + } + } + + s = strings.TrimPrefix(s, "mongodb://") + if c := strings.Index(s, "?"); c != -1 { for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) { l := strings.SplitN(pair, "=", 2) @@ -808,23 +846,9 @@ func extractURL(s string) (*urlInfo, error) { } s = s[:c] } - if c := strings.Index(s, "@"); c != -1 { - pair := strings.SplitN(s[:c], ":", 2) - if len(pair) > 2 || pair[0] == "" { - return nil, errors.New("credentials must be provided as user:pass@host") - } - var err error - info.user, err = url.QueryUnescape(pair[0]) - if err != nil { - return nil, fmt.Errorf("cannot unescape username in URL: %q", pair[0]) - } - if len(pair) > 1 { - info.pass, err = url.QueryUnescape(pair[1]) - if err != nil { - return nil, fmt.Errorf("cannot unescape password in URL") - } - } - s = s[c+1:] + s, err := parseUsernamePassword(s, info) + if err != nil { + return nil, err } if c := strings.LastIndex(s, "/"); c != -1 && !strings.HasSuffix(s[c+1:], ".sock") { info.db = s[c+1:] @@ -849,6 +873,184 @@ func extractURL(s string) (*urlInfo, error) { return info, nil } +func parseUsernamePassword(host string, info *urlInfo) (string, error) { + if c := strings.Index(host, "@"); c != -1 { + pair := strings.SplitN(host[:c], ":", 2) + if len(pair) > 2 || pair[0] == "" { + return "", errors.New("credentials must be provided as user:pass@host") + } + var err error + info.user, err = url.QueryUnescape(pair[0]) + if err != nil { + return "", fmt.Errorf("cannot unescape username in URL: %q", pair[0]) + } + if len(pair) > 1 { + info.pass, err = url.QueryUnescape(pair[1]) + if err != nil { + return "", fmt.Errorf("cannot unescape password in URL") + } + } + host = host[c+1:] + } + + return host, nil +} + +// parseHosts is largely adopted from the new, official mongo driver at +// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/dns/dns.go#L28 +// +// Instead of returning an array of hosts, return all of the hosts in +// a single string using comma separators. This will make consuming it easier +// for the mgo driver. +func parseHosts(host string) (string, error) { + // The 'host' field starts as the entire string. We need to pull out just the host portion + url, err := url.Parse(host) + + if err != nil { + return "", err + } + + parsedHosts := strings.Split(url.Host, ",") + + if len(parsedHosts) != 1 { + return "", fmt.Errorf("URI with SRV must include one and only one hostname") + } + + if len(url.Port()) > 0 { + // we were able to successfully extract a port from the host, + // but should not be able to when using SRV + return "", fmt.Errorf("URI with srv must not include a port number") + } + + _, addresses, err := lookupSRV("mongodb", "tcp", url.Host) + if err != nil { + fmt.Printf("Failed lookupSRV: addresses=%v, err=%v\n", addresses, err) + return "", err + } + + trimmedHost := strings.TrimSuffix(url.Host, ".") + parsedHosts = make([]string, 5) // reset so we don't keep the original hostname + + for _, address := range addresses { + trimmedAddressTarget := strings.TrimSuffix(address.Target, ".") + err := validateSRVResult(trimmedAddressTarget, trimmedHost) + if err != nil { + continue + } + parsedHosts = append(parsedHosts, fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port)) + } + + hosts := "" + + for _, host := range parsedHosts { + if len(hosts) > 0 { + hosts += "," + } + + hosts += host + } + + return hosts, nil +} + +// LookupSRV was largely pulled from the new, offical mongo driver at +// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/topology/polling_srv_records_test.go#L44 +// As the globalsign driver has no good place to keep the state, all of the extra state +// related code was removed. This means a DNS lookup is going to happen every time +// even if it failed the time before. +func lookupSRV(service, proto, name string) (string, []*net.SRV, error) { + str, addresses, err := net.LookupSRV("mongodb", "tcp", name) + if err != nil { + return str, addresses, err + } + + return str, addresses, err +} + +// validateSRVResult was pulled from the new, official mongo driver at +// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/dns/dns.go#L100 +func validateSRVResult(recordFromSRV, inputHostName string) error { + separatedInputDomain := strings.Split(inputHostName, ".") + separatedRecord := strings.Split(recordFromSRV, ".") + if len(separatedRecord) < 2 { + return errors.New("DNS name must contain at least 2 labels") + } + if len(separatedRecord) < len(separatedInputDomain) { + return errors.New("Domain suffix from SRV record not matched input domain") + } + + inputDomainSuffix := separatedInputDomain[1:] + domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1) + + recordDomainSuffix := separatedRecord[domainSuffixOffset:] + for ix, label := range inputDomainSuffix { + if label != recordDomainSuffix[ix] { + return errors.New("Domain suffix from SRV record not matched input domain") + } + } + return nil +} + +// getConnectionArgsFromTXT is largely taken from the new, official mongo driver at +// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/dns/dns.go#L38 +func getConnectionArgsFromTXT(host string) (string, error) { + var connectionArgsFromTXT []string + + // error ignored because not finding a TXT record should not be + // considered an error. + url, err := url.Parse(host) + + if err != nil { + return "", err + } + + recordsFromTXT, _ := net.LookupTXT(url.Host) + + // This is a temporary fix to get around bug https://github.com/golang/go/issues/21472. + // It will currently incorrectly concatenate multiple TXT records to one + // on windows. + if runtime.GOOS == "windows" { + recordsFromTXT = []string{strings.Join(recordsFromTXT, "")} + } + + if len(recordsFromTXT) > 1 { + return "", errors.New("multiple records from TXT not supported") + } + if len(recordsFromTXT) == 1 { + connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' }) + + err := validateTXTResult(connectionArgsFromTXT) + if err != nil { + return "", err + } + + return recordsFromTXT[0], nil + } + + return "", errors.New("unexpected case, should never fall through") +} + +var allowedTXTOptions = map[string]struct{}{ + "authsource": {}, + "replicaset": {}, +} + +// validateTXTResult is largely taken from the new, offical mongo driver at +// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/dns/dns.go#L127 +func validateTXTResult(paramsFromTXT []string) error { + for _, param := range paramsFromTXT { + kv := strings.SplitN(param, "=", 2) + if len(kv) != 2 { + return errors.New("Invalid TXT record") + } + key := strings.ToLower(kv[0]) + if _, ok := allowedTXTOptions[key]; !ok { + return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0]) + } + } + return nil +} + func newSession(consistency Mode, cluster *mongoCluster, info *DialInfo) (session *Session) { cluster.Acquire() session = &Session{ diff --git a/session_test.go b/session_test.go index 864d6593a..620f7241d 100644 --- a/session_test.go +++ b/session_test.go @@ -233,6 +233,35 @@ func (s *S) TestURLInvalidSafe(c *C) { } } +// it is difficult to set up a service record on demand as it requires new DNS entries. +// This cannot just be done using an /etc/hosts file modification. +// This test will be disabled by default. If you want to test the code +// that performs a mongo connection using a DNS SVR (service) record, +// you will need to set that up yourself. I would recommend using a free tier +// of Mongo Atlas for the test. +// +// For more information, please see +// https://www.mongodb.com/blog/post/mongodb-3-6-here-to-SRV-you-with-easier-replica-set-connections. +func (s *S) TestURLMongoServiceRecord(c *C) { + c.Skip("Requires DNS configuration / Atlas") + url := "mongodb+srv://:@.mongodb.net/?ssl=true" + dialInfo, err := mgo.ParseURL(url) + c.Assert(err, IsNil) + + session, err := mgo.DialWithInfo(dialInfo) + c.Assert(err, IsNil) + + defer session.Close() + + db := session.DB("id") + + var result map[string]interface{} + + err = db.C("").Find(nil).Sort("-_id").One(&result) + c.Assert(err, IsNil) + c.Assert(len(result), Not(Equals), 0) +} + func (s *S) TestURLUnixSocket(c *C) { type test struct { url string