Skip to content

Commit

Permalink
Merge pull request #317 from wneessen/more-auth-test-coverage
Browse files Browse the repository at this point in the history
More test coverage for smtp/auth
  • Loading branch information
wneessen authored Oct 3, 2024
2 parents a41639e + 4c8c0d8 commit ff5454a
Show file tree
Hide file tree
Showing 5 changed files with 384 additions and 14 deletions.
13 changes: 13 additions & 0 deletions smtp/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@

package smtp

import "errors"

var (
// ErrUnencrypted is an error indicating that the connection is not encrypted.
ErrUnencrypted = errors.New("unencrypted connection")
// ErrUnexpectedServerChallange is an error indicating that the server issued an unexpected challenge.
ErrUnexpectedServerChallange = errors.New("unexpected server challenge")
// ErrUnexpectedServerResponse is an error indicating that the server issued an unexpected response.
ErrUnexpectedServerResponse = errors.New("unexpected server response")
// ErrWrongHostname is an error indicating that the provided hostname does not match the expected value.
ErrWrongHostname = errors.New("wrong host name")
)

// Auth is implemented by an SMTP authentication mechanism.
type Auth interface {
// Start begins an authentication with a server.
Expand Down
8 changes: 2 additions & 6 deletions smtp/auth_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,9 @@
package smtp

import (
"errors"
"fmt"
)

// ErrUnencrypted is an error indicating that the connection is not encrypted.
var ErrUnencrypted = errors.New("unencrypted connection")

// loginAuth is the type that satisfies the Auth interface for the "SMTP LOGIN" auth
type loginAuth struct {
username, password string
Expand Down Expand Up @@ -55,7 +51,7 @@ func (a *loginAuth) Start(server *ServerInfo) (string, []byte, error) {
return "", nil, ErrUnencrypted
}
if server.Name != a.host {
return "", nil, errors.New("wrong host name")
return "", nil, ErrWrongHostname
}
a.respStep = 0
return "LOGIN", nil, nil
Expand All @@ -73,7 +69,7 @@ func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
a.respStep++
return []byte(a.password), nil
default:
return nil, fmt.Errorf("unexpected server response: %s", string(fromServer))
return nil, fmt.Errorf("%w: %s", ErrUnexpectedServerResponse, string(fromServer))
}
}
return nil, nil
Expand Down
10 changes: 3 additions & 7 deletions smtp/auth_plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@

package smtp

import (
"errors"
)

// plainAuth is the type that satisfies the Auth interface for the "SMTP PLAIN" auth
type plainAuth struct {
identity, username, password string
Expand All @@ -42,10 +38,10 @@ func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) {
// That might just be the attacker saying
// "it's ok, you can trust me with your password."
if !server.TLS && !isLocalhost(server.Name) {
return "", nil, errors.New("unencrypted connection")
return "", nil, ErrUnencrypted
}
if server.Name != a.host {
return "", nil, errors.New("wrong host name")
return "", nil, ErrWrongHostname
}
resp := []byte(a.identity + "\x00" + a.username + "\x00" + a.password)
return "PLAIN", resp, nil
Expand All @@ -54,7 +50,7 @@ func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) {
func (a *plainAuth) Next(_ []byte, more bool) ([]byte, error) {
if more {
// We've already sent everything.
return nil, errors.New("unexpected server challenge")
return nil, ErrUnexpectedServerChallange
}
return nil, nil
}
5 changes: 4 additions & 1 deletion smtp/auth_scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (a *scramAuth) Next(fromServer []byte, more bool) ([]byte, error) {
return resp, nil
default:
a.reset()
return nil, errors.New("unexpected server response")
return nil, fmt.Errorf("%w: %s", ErrUnexpectedServerResponse, string(fromServer))
}
}
return nil, nil
Expand Down Expand Up @@ -147,6 +147,9 @@ func (a *scramAuth) initialClientMessage() ([]byte, error) {

// SCRAM-SHA-X-PLUS auth requires channel binding
if a.isPlus {
if a.tlsConnState == nil {
return nil, errors.New("tls connection state is required for SCRAM-SHA-X-PLUS")
}
bindType := "tls-unique"
connState := a.tlsConnState
bindData := connState.TLSUnique
Expand Down
Loading

0 comments on commit ff5454a

Please sign in to comment.