Skip to content

Commit

Permalink
Expose dbsql.connOption type (#202)
Browse files Browse the repository at this point in the history
Fixes: #201

That way we can programmatically add arguments to the `NewConnector`
function instead of copy/pasting all of them across conditionals.

Signed-off-by: Miguel Palau <[email protected]>
  • Loading branch information
shelldandy authored Jul 25, 2024
1 parent a21b124 commit f0e3a08
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
},
CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs,
})

if err != nil {
return nil, dbsqlerrint.NewRequestError(ctx, fmt.Sprintf("error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath), err)
}
Expand Down Expand Up @@ -84,11 +83,11 @@ func (c *connector) Driver() driver.Driver {

var _ driver.Connector = (*connector)(nil)

type connOption func(*config.Config)
type ConnOption func(*config.Config)

// NewConnector creates a connection that can be used with `sql.OpenDB()`.
// This is an easier way to set up the DB instead of having to construct a DSN string.
func NewConnector(options ...connOption) (driver.Connector, error) {
func NewConnector(options ...ConnOption) (driver.Connector, error) {
// config with default options
cfg := config.WithDefaults()
cfg.DriverVersion = DriverVersion
Expand All @@ -102,14 +101,14 @@ func NewConnector(options ...connOption) (driver.Connector, error) {
return &connector{cfg: cfg, client: client}, nil
}

func withUserConfig(ucfg config.UserConfig) connOption {
func withUserConfig(ucfg config.UserConfig) ConnOption {
return func(c *config.Config) {
c.UserConfig = ucfg
}
}

// WithServerHostname sets up the server hostname. Mandatory.
func WithServerHostname(host string) connOption {
func WithServerHostname(host string) ConnOption {
return func(c *config.Config) {
protocol, hostname := parseHostName(host)
if protocol != "" {
Expand Down Expand Up @@ -143,7 +142,7 @@ func parseHostName(host string) (protocol, hostname string) {
}

// WithPort sets up the server port. Mandatory.
func WithPort(port int) connOption {
func WithPort(port int) ConnOption {
return func(c *config.Config) {
c.Port = port
}
Expand All @@ -153,7 +152,7 @@ func WithPort(port int) connOption {
// By default retryWaitMin = 1 * time.Second
// By default retryWaitMax = 30 * time.Second
// By default retryMax = 4
func WithRetries(retryMax int, retryWaitMin time.Duration, retryWaitMax time.Duration) connOption {
func WithRetries(retryMax int, retryWaitMin time.Duration, retryWaitMax time.Duration) ConnOption {
return func(c *config.Config) {
c.RetryWaitMax = retryWaitMax
c.RetryWaitMin = retryWaitMin
Expand All @@ -162,7 +161,7 @@ func WithRetries(retryMax int, retryWaitMin time.Duration, retryWaitMax time.Dur
}

// WithAccessToken sets up the Personal Access Token. Mandatory for now.
func WithAccessToken(token string) connOption {
func WithAccessToken(token string) ConnOption {
return func(c *config.Config) {
if token != "" {
c.AccessToken = token
Expand All @@ -175,7 +174,7 @@ func WithAccessToken(token string) connOption {
}

// WithHTTPPath sets up the endpoint to the warehouse. Mandatory.
func WithHTTPPath(path string) connOption {
func WithHTTPPath(path string) ConnOption {
return func(c *config.Config) {
if !strings.HasPrefix(path, "/") {
path = "/" + path
Expand All @@ -185,7 +184,7 @@ func WithHTTPPath(path string) connOption {
}

// WithMaxRows sets up the max rows fetched per request. Default is 10000
func WithMaxRows(n int) connOption {
func WithMaxRows(n int) ConnOption {
return func(c *config.Config) {
if n != 0 {
c.MaxRows = n
Expand All @@ -194,31 +193,31 @@ func WithMaxRows(n int) connOption {
}

// WithTimeout adds timeout for the server query execution. Default is no timeout.
func WithTimeout(n time.Duration) connOption {
func WithTimeout(n time.Duration) ConnOption {
return func(c *config.Config) {
c.QueryTimeout = n
}
}

// Sets the initial catalog name and schema name in the session.
// Use <select * from foo> instead of <select * from catalog.schema.foo>
func WithInitialNamespace(catalog, schema string) connOption {
func WithInitialNamespace(catalog, schema string) ConnOption {
return func(c *config.Config) {
c.Catalog = catalog
c.Schema = schema
}
}

// Used to identify partners. Set as a string with format <isv-name+product-name>.
func WithUserAgentEntry(entry string) connOption {
func WithUserAgentEntry(entry string) ConnOption {
return func(c *config.Config) {
c.UserAgentEntry = entry
}
}

// Sessions params will be set upon opening the session by calling SET function.
// If using connection pool, session params can avoid successive calls of "SET ..."
func WithSessionParams(params map[string]string) connOption {
func WithSessionParams(params map[string]string) ConnOption {
return func(c *config.Config) {
for k, v := range params {
if strings.ToLower(k) == "timezone" {
Expand All @@ -227,7 +226,6 @@ func WithSessionParams(params map[string]string) connOption {
} else {
c.Location = loc
}

}
}
c.SessionParams = params
Expand All @@ -249,35 +247,35 @@ func WithSkipTLSHostVerify() connOption {
}

// WithAuthenticator sets up the Authentication. Mandatory if access token is not provided.
func WithAuthenticator(authr auth.Authenticator) connOption {
func WithAuthenticator(authr auth.Authenticator) ConnOption {
return func(c *config.Config) {
c.Authenticator = authr
}
}

// WithTransport sets up the transport configuration to be used by the httpclient.
func WithTransport(t http.RoundTripper) connOption {
func WithTransport(t http.RoundTripper) ConnOption {
return func(c *config.Config) {
c.Transport = t
}
}

// WithCloudFetch sets up the use of cloud fetch for query execution. Default is false.
func WithCloudFetch(useCloudFetch bool) connOption {
func WithCloudFetch(useCloudFetch bool) ConnOption {
return func(c *config.Config) {
c.UseCloudFetch = useCloudFetch
}
}

// WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10.
func WithMaxDownloadThreads(numThreads int) connOption {
func WithMaxDownloadThreads(numThreads int) ConnOption {
return func(c *config.Config) {
c.MaxDownloadThreads = numThreads
}
}

// Setup of Oauth M2m authentication
func WithClientCredentials(clientID, clientSecret string) connOption {
func WithClientCredentials(clientID, clientSecret string) ConnOption {
return func(c *config.Config) {
if clientID != "" && clientSecret != "" {
authr := m2m.NewAuthenticator(clientID, clientSecret, c.Host)
Expand Down

0 comments on commit f0e3a08

Please sign in to comment.