Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting mongo+srv support from new, offical mongo driver #394

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,20 @@ We really appreciate contributions, but they must meet the following requirement
We merge PRs into `development`, which is then tested in a sharded, replicated environment in our datacenter for regressions. Once everyone is happy, we merge to master - this is to maintain a bit of quality control past the usual PR process.

**Thanks** for helping!

# How to test the code
In order to run the tests, you need the following installed (assuming Ubuntu)

* daemontools (for svstat)
* mongo (for mongo client)
* mongodb (for mongo server)

Before running the tests, you need to start the test mongo server with `make startdb`. After the tests are done, you
can tear it down with `make stopdb`.

The tests to run are defined in `.travis.yml` under the *script* section.

## Note about DNS Lookup of SRV records
If you are testing on Linux, you may run into an error that net.LookupSRV cannot parse the DNS response.
In this case, you are likely using a distribution that uses systemd for DNS and go does not handle it yet.
The easiest work around is to change your /etc/resolv.conf file and set a remote nameserver; 8.8.8.8 or 9.9.9.9.
238 changes: 220 additions & 18 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"net"
"net/url"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -82,6 +83,9 @@ const (
DefaultConnectionPoolLimit = 4096

zeroDuration = time.Duration(0)

// SchemeMongoDBSRV denotes that the user wants to use a DNS service record to determine the cluster
SchemeMongoDBSRV = "mongodb+srv://"
)

// When changing the Session type, check if newSession and copySession
Expand Down Expand Up @@ -795,9 +799,43 @@ type urlInfoOption struct {
}

func extractURL(s string) (*urlInfo, error) {
s = strings.TrimPrefix(s, "mongodb://")
info := &urlInfo{options: []urlInfoOption{}}

if strings.HasPrefix(s, SchemeMongoDBSRV) {
hosts, err := parseHosts(s)
if err != nil {
return nil, err
}

connectionArgs, err := getConnectionArgsFromTXT(s)
if err != nil {
return nil, err
}

temp := strings.TrimPrefix(s, SchemeMongoDBSRV)
_, err = parseUsernamePassword(temp, info)
if err != nil {
return nil, err
}

// using what we looked up, construct a new mongo URL that the rest of this function can use
// the construction has a few different conditionals, so it was easier to build it up
// manually instead of using fmt.Sprintf.
url, err := url.Parse(s)
s = "mongodb://"
if len(info.user) > 0 {
s += info.user + ":" + info.pass + "@"
}
s += hosts + url.Path
s += "?" + connectionArgs

if len(url.RawQuery) > 0 {
s += "&" + url.RawQuery
}
}

s = strings.TrimPrefix(s, "mongodb://")

if c := strings.Index(s, "?"); c != -1 {
for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) {
l := strings.SplitN(pair, "=", 2)
Expand All @@ -808,23 +846,9 @@ func extractURL(s string) (*urlInfo, error) {
}
s = s[:c]
}
if c := strings.Index(s, "@"); c != -1 {
pair := strings.SplitN(s[:c], ":", 2)
if len(pair) > 2 || pair[0] == "" {
return nil, errors.New("credentials must be provided as user:pass@host")
}
var err error
info.user, err = url.QueryUnescape(pair[0])
if err != nil {
return nil, fmt.Errorf("cannot unescape username in URL: %q", pair[0])
}
if len(pair) > 1 {
info.pass, err = url.QueryUnescape(pair[1])
if err != nil {
return nil, fmt.Errorf("cannot unescape password in URL")
}
}
s = s[c+1:]
s, err := parseUsernamePassword(s, info)
if err != nil {
return nil, err
}
if c := strings.LastIndex(s, "/"); c != -1 && !strings.HasSuffix(s[c+1:], ".sock") {
info.db = s[c+1:]
Expand All @@ -849,6 +873,184 @@ func extractURL(s string) (*urlInfo, error) {
return info, nil
}

func parseUsernamePassword(host string, info *urlInfo) (string, error) {
if c := strings.Index(host, "@"); c != -1 {
pair := strings.SplitN(host[:c], ":", 2)
if len(pair) > 2 || pair[0] == "" {
return "", errors.New("credentials must be provided as user:pass@host")
}
var err error
info.user, err = url.QueryUnescape(pair[0])
if err != nil {
return "", fmt.Errorf("cannot unescape username in URL: %q", pair[0])
}
if len(pair) > 1 {
info.pass, err = url.QueryUnescape(pair[1])
if err != nil {
return "", fmt.Errorf("cannot unescape password in URL")
}
}
host = host[c+1:]
}

return host, nil
}

// parseHosts is largely adopted from the new, official mongo driver at
// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/dns/dns.go#L28
//
// Instead of returning an array of hosts, return all of the hosts in
// a single string using comma separators. This will make consuming it easier
// for the mgo driver.
func parseHosts(host string) (string, error) {
// The 'host' field starts as the entire string. We need to pull out just the host portion
url, err := url.Parse(host)

if err != nil {
return "", err
}

parsedHosts := strings.Split(url.Host, ",")

if len(parsedHosts) != 1 {
return "", fmt.Errorf("URI with SRV must include one and only one hostname")
}

if len(url.Port()) > 0 {
// we were able to successfully extract a port from the host,
// but should not be able to when using SRV
return "", fmt.Errorf("URI with srv must not include a port number")
}

_, addresses, err := lookupSRV("mongodb", "tcp", url.Host)
if err != nil {
fmt.Printf("Failed lookupSRV: addresses=%v, err=%v\n", addresses, err)
return "", err
}

trimmedHost := strings.TrimSuffix(url.Host, ".")
parsedHosts = make([]string, 5) // reset so we don't keep the original hostname

for _, address := range addresses {
trimmedAddressTarget := strings.TrimSuffix(address.Target, ".")
err := validateSRVResult(trimmedAddressTarget, trimmedHost)
if err != nil {
continue
}
parsedHosts = append(parsedHosts, fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port))
}

hosts := ""

for _, host := range parsedHosts {
if len(hosts) > 0 {
hosts += ","
}

hosts += host
}

return hosts, nil
}

// LookupSRV was largely pulled from the new, offical mongo driver at
// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/topology/polling_srv_records_test.go#L44
// As the globalsign driver has no good place to keep the state, all of the extra state
// related code was removed. This means a DNS lookup is going to happen every time
// even if it failed the time before.
func lookupSRV(service, proto, name string) (string, []*net.SRV, error) {
str, addresses, err := net.LookupSRV("mongodb", "tcp", name)
if err != nil {
return str, addresses, err
}

return str, addresses, err
}

// validateSRVResult was pulled from the new, official mongo driver at
// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/dns/dns.go#L100
func validateSRVResult(recordFromSRV, inputHostName string) error {
separatedInputDomain := strings.Split(inputHostName, ".")
separatedRecord := strings.Split(recordFromSRV, ".")
if len(separatedRecord) < 2 {
return errors.New("DNS name must contain at least 2 labels")
}
if len(separatedRecord) < len(separatedInputDomain) {
return errors.New("Domain suffix from SRV record not matched input domain")
}

inputDomainSuffix := separatedInputDomain[1:]
domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)

recordDomainSuffix := separatedRecord[domainSuffixOffset:]
for ix, label := range inputDomainSuffix {
if label != recordDomainSuffix[ix] {
return errors.New("Domain suffix from SRV record not matched input domain")
}
}
return nil
}

// getConnectionArgsFromTXT is largely taken from the new, official mongo driver at
// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/dns/dns.go#L38
func getConnectionArgsFromTXT(host string) (string, error) {
var connectionArgsFromTXT []string

// error ignored because not finding a TXT record should not be
// considered an error.
url, err := url.Parse(host)

if err != nil {
return "", err
}

recordsFromTXT, _ := net.LookupTXT(url.Host)

// This is a temporary fix to get around bug https://github.com/golang/go/issues/21472.
// It will currently incorrectly concatenate multiple TXT records to one
// on windows.
if runtime.GOOS == "windows" {
recordsFromTXT = []string{strings.Join(recordsFromTXT, "")}
}

if len(recordsFromTXT) > 1 {
return "", errors.New("multiple records from TXT not supported")
}
if len(recordsFromTXT) == 1 {
connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' })

err := validateTXTResult(connectionArgsFromTXT)
if err != nil {
return "", err
}

return recordsFromTXT[0], nil
}

return "", errors.New("unexpected case, should never fall through")
}

var allowedTXTOptions = map[string]struct{}{
"authsource": {},
"replicaset": {},
}

// validateTXTResult is largely taken from the new, offical mongo driver at
// https://github.com/mongodb/mongo-go-driver/blob/f1f16a1f4d769d844812278841a184ae7f301732/x/mongo/driver/dns/dns.go#L127
func validateTXTResult(paramsFromTXT []string) error {
for _, param := range paramsFromTXT {
kv := strings.SplitN(param, "=", 2)
if len(kv) != 2 {
return errors.New("Invalid TXT record")
}
key := strings.ToLower(kv[0])
if _, ok := allowedTXTOptions[key]; !ok {
return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0])
}
}
return nil
}

func newSession(consistency Mode, cluster *mongoCluster, info *DialInfo) (session *Session) {
cluster.Acquire()
session = &Session{
Expand Down
29 changes: 29 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,35 @@ func (s *S) TestURLInvalidSafe(c *C) {
}
}

// it is difficult to set up a service record on demand as it requires new DNS entries.
// This cannot just be done using an /etc/hosts file modification.
// This test will be disabled by default. If you want to test the code
// that performs a mongo connection using a DNS SVR (service) record,
// you will need to set that up yourself. I would recommend using a free tier
// of Mongo Atlas for the test.
//
// For more information, please see
// https://www.mongodb.com/blog/post/mongodb-3-6-here-to-SRV-you-with-easier-replica-set-connections.
func (s *S) TestURLMongoServiceRecord(c *C) {
c.Skip("Requires DNS configuration / Atlas")
url := "mongodb+srv://<username>:<password>@<host>.mongodb.net/<db>?ssl=true"
dialInfo, err := mgo.ParseURL(url)
c.Assert(err, IsNil)

session, err := mgo.DialWithInfo(dialInfo)
c.Assert(err, IsNil)

defer session.Close()

db := session.DB("id")

var result map[string]interface{}

err = db.C("<collection>").Find(nil).Sort("-_id").One(&result)
c.Assert(err, IsNil)
c.Assert(len(result), Not(Equals), 0)
}

func (s *S) TestURLUnixSocket(c *C) {
type test struct {
url string
Expand Down