diff --git a/config/configuration.go b/config/configuration.go index 506641cf9..881ea8794 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -742,6 +742,47 @@ const ( // - A filepath to a file with read access. SocketCAFile string = "SocketCAFile" + // SocketPrivateKeyBytes is an optional value containing raw bytes of a PEM + // encoded private key to use for secure TLS communications. + // Must be used with SocketCertificateBytes. + // Must contain PEM encoded data. + // + // Required: No + // + // Default: N/A + // + // Valid Values: + // - Raw bytes containing a valid PEM encoded private key. + SocketPrivateKeyBytes string = "SocketPrivateKeyBytes" + + // SocketCertificateBytes is an optional value containing raw bytes of a PEM + // encoded certificate to use for secure TLS communications. + // Must be used with SocketPrivateKeyBytes. + // Must contain PEM encoded data. + // + // Required: No + // + // Default: N/A + // + // Valid Values: + // - Raw bytes containing a valid PEM encoded certificate. + SocketCertificateBytes string = "SocketCertificateBytes" + + // SocketCABytes is an optional value containing raw bytes of a PEM encoded + // root CA to use for secure TLS communications. For acceptors, client + // certificates will be verified against this CA. For initiators, clients + // will use the CA to verify the server certificate. If not configured, + // initiators will verify the server certificates using the host's root CA + // set. + // + // Required: No + // + // Default: N/A + // + // Valid Values: + // - Raw bytes containing a valid PEM encoded CA. + SocketCABytes string = "SocketCABytes" + // SocketInsecureSkipVerify controls whether a client verifies the server's certificate chain and host name. // If SocketInsecureSkipVerify is set to Y, crypto/tls accepts any certificate presented by the server and any host name in that certificate. // In this mode, TLS is susceptible to machine-in-the-middle attacks unless custom verification is used. diff --git a/dialer.go b/dialer.go index a8e4c1893..de4419ff8 100644 --- a/dialer.go +++ b/dialer.go @@ -25,7 +25,7 @@ import ( "github.com/quickfixgo/quickfix/config" ) -func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error) { +func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, err error) { stdDialer := &net.Dialer{} if settings.HasSetting(config.SocketTimeout) { timeout, err := settings.DurationSetting(config.SocketTimeout) @@ -73,9 +73,23 @@ func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error } } - dialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, dialer) + var proxyDialer proxy.Dialer + + proxyDialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, stdDialer) + if err != nil { + return + } + + if contextDialer, ok := proxyDialer.(proxy.ContextDialer); ok { + dialer = contextDialer + } else { + err = fmt.Errorf("proxy does not support context dialer") + return + } + default: err = fmt.Errorf("unsupported proxy type %s", proxyType) } + return } diff --git a/field_map.go b/field_map.go index 18216f81c..4aac64b1d 100644 --- a/field_map.go +++ b/field_map.go @@ -115,6 +115,20 @@ func (m FieldMap) GetField(tag Tag, parser FieldValueReader) MessageRejectError return nil } +// GetField parses of a field with Tag tag. Returned reject may indicate the field is not present, or the field value is invalid. +func (m FieldMap) getFieldNoLock(tag Tag, parser FieldValueReader) MessageRejectError { + f, ok := m.tagLookup[tag] + if !ok { + return ConditionallyRequiredFieldMissing(tag) + } + + if err := parser.Read(f[0].value); err != nil { + return IncorrectDataFormatForValue(tag) + } + + return nil +} + // GetBytes is a zero-copy GetField wrapper for []bytes fields. func (m FieldMap) GetBytes(tag Tag) ([]byte, MessageRejectError) { m.rwLock.RLock() @@ -128,6 +142,16 @@ func (m FieldMap) GetBytes(tag Tag) ([]byte, MessageRejectError) { return f[0].value, nil } +// getBytesNoLock is a lock free zero-copy GetField wrapper for []bytes fields. +func (m FieldMap) getBytesNoLock(tag Tag) ([]byte, MessageRejectError) { + f, ok := m.tagLookup[tag] + if !ok { + return nil, ConditionallyRequiredFieldMissing(tag) + } + + return f[0].value, nil +} + // GetBool is a GetField wrapper for bool fields. func (m FieldMap) GetBool(tag Tag) (bool, MessageRejectError) { var val FIXBoolean @@ -152,6 +176,21 @@ func (m FieldMap) GetInt(tag Tag) (int, MessageRejectError) { return int(val), err } +// GetInt is a lock free GetField wrapper for int fields. +func (m FieldMap) getIntNoLock(tag Tag) (int, MessageRejectError) { + bytes, err := m.getBytesNoLock(tag) + if err != nil { + return 0, err + } + + var val FIXInt + if val.Read(bytes) != nil { + err = IncorrectDataFormatForValue(tag) + } + + return int(val), err +} + // GetTime is a GetField wrapper for utc timestamp fields. func (m FieldMap) GetTime(tag Tag) (t time.Time, err MessageRejectError) { m.rwLock.RLock() @@ -179,6 +218,15 @@ func (m FieldMap) GetString(tag Tag) (string, MessageRejectError) { return string(val), nil } +// GetString is a GetField wrapper for string fields. +func (m FieldMap) getStringNoLock(tag Tag) (string, MessageRejectError) { + var val FIXString + if err := m.getFieldNoLock(tag, &val); err != nil { + return "", err + } + return string(val), nil +} + // GetGroup is a Get function specific to Group Fields. func (m FieldMap) GetGroup(parser FieldGroupReader) MessageRejectError { m.rwLock.RLock() @@ -246,6 +294,13 @@ func (m *FieldMap) Clear() { } } +func (m *FieldMap) clearNoLock() { + m.tags = m.tags[0:0] + for k := range m.tagLookup { + delete(m.tagLookup, k) + } +} + // CopyInto overwrites the given FieldMap with this one. func (m *FieldMap) CopyInto(to *FieldMap) { m.rwLock.RLock() @@ -263,9 +318,6 @@ func (m *FieldMap) CopyInto(to *FieldMap) { } func (m *FieldMap) add(f field) { - m.rwLock.Lock() - defer m.rwLock.Unlock() - t := fieldTag(f) if _, ok := m.tagLookup[t]; !ok { m.tags = append(m.tags, t) diff --git a/initiator.go b/initiator.go index 8f7a76200..18451477e 100644 --- a/initiator.go +++ b/initiator.go @@ -17,6 +17,7 @@ package quickfix import ( "bufio" + "context" "crypto/tls" "strings" "sync" @@ -50,7 +51,7 @@ func (i *Initiator) Start() (err error) { return } - var dialer proxy.Dialer + var dialer proxy.ContextDialer if dialer, err = loadDialerConfig(settings); err != nil { return } @@ -142,7 +143,7 @@ func (i *Initiator) waitForReconnectInterval(reconnectInterval time.Duration) bo return true } -func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.Dialer) { +func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.ContextDialer) { var wg sync.WaitGroup wg.Add(1) go func() { @@ -162,6 +163,19 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di return } + ctx, cancel := context.WithCancel(context.Background()) + + // We start a goroutine in order to be able to cancel the dialer mid-connection + // on receiving a stop signal to stop the initiator. + go func() { + select { + case <-i.stopChan: + cancel() + case <-ctx.Done(): + return + } + }() + var disconnected chan interface{} var msgIn chan fixIn var msgOut chan []byte @@ -169,7 +183,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)] session.log.OnEventf("Connecting to: %v", address) - netConn, err := dialer.Dial("tcp", address) + netConn, err := dialer.DialContext(ctx, "tcp", address) if err != nil { session.log.OnEventf("Failed to connect: %v", err) goto reconnect @@ -208,6 +222,10 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di close(disconnected) }() + // This ensures we properly cleanup the goroutine and context used for + // dial cancelation after successful connection. + cancel() + select { case <-disconnected: case <-i.stopChan: @@ -215,6 +233,8 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di } reconnect: + cancel() + connectionAttempt++ session.log.OnEventf("Reconnecting in %v", session.ReconnectInterval) if !i.waitForReconnectInterval(session.ReconnectInterval) { diff --git a/message.go b/message.go index 11afd7669..3e1e7202b 100644 --- a/message.go +++ b/message.go @@ -181,18 +181,20 @@ func ParseMessageWithDataDictionary( // doParsing executes the message parsing process. func doParsing(mp *msgParser) (err error) { + mp.msg.Header.rwLock.Lock() + defer mp.msg.Header.rwLock.Unlock() + mp.msg.Body.rwLock.Lock() + defer mp.msg.Body.rwLock.Unlock() + mp.msg.Trailer.rwLock.Lock() + defer mp.msg.Trailer.rwLock.Unlock() + // Initialize for parsing. - mp.msg.Header.Clear() - mp.msg.Body.Clear() - mp.msg.Trailer.Clear() + mp.msg.Header.clearNoLock() + mp.msg.Body.clearNoLock() + mp.msg.Trailer.clearNoLock() // Allocate expected message fields in one chunk. - fieldCount := 0 - for _, b := range mp.rawBytes { - if b == '\001' { - fieldCount++ - } - } + fieldCount := bytes.Count(mp.rawBytes, []byte{'\001'}) if fieldCount == 0 { return parseError{OrigError: fmt.Sprintf("No Fields detected in %s", string(mp.rawBytes))} } @@ -267,7 +269,7 @@ func doParsing(mp *msgParser) (err error) { } if mp.parsedFieldBytes.tag == tagXMLDataLen { - xmlDataLen, _ = mp.msg.Header.GetInt(tagXMLDataLen) + xmlDataLen, _ = mp.msg.Header.getIntNoLock(tagXMLDataLen) } mp.fieldIndex++ } @@ -292,7 +294,7 @@ func doParsing(mp *msgParser) (err error) { } } - bodyLength, err := mp.msg.Header.GetInt(tagBodyLength) + bodyLength, err := mp.msg.Header.getIntNoLock(tagBodyLength) if err != nil { err = parseError{OrigError: err.Error()} } else if length != bodyLength && !xmlDataMsg { @@ -373,7 +375,7 @@ func parseGroup(mp *msgParser, tags []Tag) { // tags slice will contain multiple tags if the tag in question is found while processing a group already. func isNumInGroupField(msg *Message, tags []Tag, appDataDictionary *datadictionary.DataDictionary) bool { if appDataDictionary != nil { - msgt, err := msg.MsgType() + msgt, err := msg.msgTypeNoLock() if err != nil { return false } @@ -406,7 +408,7 @@ func isNumInGroupField(msg *Message, tags []Tag, appDataDictionary *datadictiona // tags slice will contain multiple tags if the tag in question is found while processing a group already. func getGroupFields(msg *Message, tags []Tag, appDataDictionary *datadictionary.DataDictionary) (fields []*datadictionary.FieldDef) { if appDataDictionary != nil { - msgt, err := msg.MsgType() + msgt, err := msg.msgTypeNoLock() if err != nil { return } @@ -476,6 +478,10 @@ func (m *Message) MsgType() (string, MessageRejectError) { return m.Header.GetString(tagMsgType) } +func (m *Message) msgTypeNoLock() (string, MessageRejectError) { + return m.Header.getStringNoLock(tagMsgType) +} + // IsMsgTypeOf returns true if the Header contains MsgType (tag 35) field and its value is the specified one. func (m *Message) IsMsgTypeOf(msgType string) bool { if v, err := m.MsgType(); err == nil { diff --git a/session_factory.go b/session_factory.go index 37818deb9..d875762de 100644 --- a/session_factory.go +++ b/session_factory.go @@ -284,7 +284,7 @@ func (f sessionFactory) newSession( for _, dayStr := range dayStrs { day, ok := dayLookup[dayStr] if !ok { - err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: weekdaysStr} + err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: []byte(weekdaysStr)} return } weekdays = append(weekdays, day) @@ -315,7 +315,7 @@ func (f sessionFactory) newSession( parseDay := func(setting, dayStr string) (day time.Weekday, err error) { day, ok := dayLookup[dayStr] if !ok { - return day, IncorrectFormatForSetting{Setting: setting, Value: dayStr} + return day, IncorrectFormatForSetting{Setting: setting, Value: []byte(dayStr)} } return } @@ -355,7 +355,7 @@ func (f sessionFactory) newSession( s.timestampPrecision = Nanos default: - err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: precisionStr} + err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: []byte(precisionStr)} return } } diff --git a/session_settings.go b/session_settings.go index 323ac6234..74d6cdd27 100644 --- a/session_settings.go +++ b/session_settings.go @@ -23,7 +23,7 @@ import ( // SessionSettings maps session settings to values with typed accessors. type SessionSettings struct { - settings map[string]string + settings map[string][]byte } // ConditionallyRequiredSetting indicates a missing setting. @@ -37,8 +37,9 @@ func (e ConditionallyRequiredSetting) Error() string { // IncorrectFormatForSetting indicates a setting that is incorrectly formatted. type IncorrectFormatForSetting struct { - Setting, Value string - Err error + Setting string + Value []byte + Err error } func (e IncorrectFormatForSetting) Error() string { @@ -47,7 +48,7 @@ func (e IncorrectFormatForSetting) Error() string { // Init initializes or resets SessionSettings. func (s *SessionSettings) Init() { - s.settings = make(map[string]string) + s.settings = make(map[string][]byte) } // NewSessionSettings returns a newly initialized SessionSettings instance. @@ -58,8 +59,8 @@ func NewSessionSettings() *SessionSettings { return s } -// Set assigns a value to a setting on SessionSettings. -func (s *SessionSettings) Set(setting string, val string) { +// SetRaw assigns a value to a setting on SessionSettings. +func (s *SessionSettings) SetRaw(setting string, val []byte) { // Lazy init. if s.settings == nil { s.Init() @@ -68,69 +69,87 @@ func (s *SessionSettings) Set(setting string, val string) { s.settings[setting] = val } +// Set assigns a string value to a setting on SessionSettings. +func (s *SessionSettings) Set(setting string, val string) { + // Lazy init + if s.settings == nil { + s.Init() + } + + s.settings[setting] = []byte(val) +} + // HasSetting returns true if a setting is set, false if not. func (s *SessionSettings) HasSetting(setting string) bool { _, ok := s.settings[setting] return ok } -// Setting is a settings string accessor. Returns an error if the setting is missing. -func (s *SessionSettings) Setting(setting string) (string, error) { +// RawSetting is a settings accessor that returns the raw byte slice value of +// the setting. Returns an error if the setting is missing. +func (s *SessionSettings) RawSetting(setting string) ([]byte, error) { val, ok := s.settings[setting] if !ok { - return val, ConditionallyRequiredSetting{setting} + return nil, ConditionallyRequiredSetting{Setting: setting} } return val, nil } -// IntSetting returns the requested setting parsed as an int. Returns an errror if the setting is not set or cannot be parsed as an int. -func (s *SessionSettings) IntSetting(setting string) (val int, err error) { - stringVal, err := s.Setting(setting) +// Setting is a settings string accessor. Returns an error if the setting is missing. +func (s *SessionSettings) Setting(setting string) (string, error) { + val, err := s.RawSetting(setting) + if err != nil { + return "", err + } + return string(val), nil +} + +// IntSetting returns the requested setting parsed as an int. Returns an errror if the setting is not set or cannot be parsed as an int. +func (s *SessionSettings) IntSetting(setting string) (int, error) { + rawVal, err := s.RawSetting(setting) if err != nil { - return + return 0, err } - if val, err = strconv.Atoi(stringVal); err != nil { - return val, IncorrectFormatForSetting{Setting: setting, Value: stringVal, Err: err} + if val, err := strconv.Atoi(string(rawVal)); err == nil { + return val, nil } - return + return 0, IncorrectFormatForSetting{Setting: setting, Value: rawVal, Err: err} } // DurationSetting returns the requested setting parsed as a time.Duration. // Returns an error if the setting is not set or cannot be parsed as a time.Duration. -func (s *SessionSettings) DurationSetting(setting string) (val time.Duration, err error) { - stringVal, err := s.Setting(setting) - +func (s *SessionSettings) DurationSetting(setting string) (time.Duration, error) { + rawVal, err := s.RawSetting(setting) if err != nil { - return + return 0, err } - if val, err = time.ParseDuration(stringVal); err != nil { - return val, IncorrectFormatForSetting{Setting: setting, Value: stringVal, Err: err} + if val, err := time.ParseDuration(string(rawVal)); err == nil { + return val, nil } - return + return 0, IncorrectFormatForSetting{Setting: setting, Value: rawVal, Err: err} } // BoolSetting returns the requested setting parsed as a boolean. Returns an error if the setting is not set or cannot be parsed as a bool. func (s SessionSettings) BoolSetting(setting string) (bool, error) { - stringVal, err := s.Setting(setting) - + rawVal, err := s.RawSetting(setting) if err != nil { return false, err } - switch stringVal { + switch string(rawVal) { case "Y", "y": return true, nil case "N", "n": return false, nil } - return false, IncorrectFormatForSetting{Setting: setting, Value: stringVal} + return false, IncorrectFormatForSetting{Setting: setting, Value: rawVal} } func (s *SessionSettings) overlay(overlay *SessionSettings) { diff --git a/session_settings_test.go b/session_settings_test.go index 7fa9a94ad..d413f6866 100644 --- a/session_settings_test.go +++ b/session_settings_test.go @@ -16,7 +16,9 @@ package quickfix import ( + "bytes" "testing" + "time" "github.com/quickfixgo/quickfix/config" ) @@ -55,10 +57,15 @@ func TestSessionSettings_IntSettings(t *testing.T) { } s.Set(config.SocketAcceptPort, "notanint") - if _, err := s.IntSetting(config.SocketAcceptPort); err == nil { + _, err := s.IntSetting(config.SocketAcceptPort) + if err == nil { t.Error("Expected error for unparsable value") } + if err.Error() != `"notanint" is invalid for SocketAcceptPort` { + t.Errorf("Expected %s, got %s", `"notanint" is invalid for SocketAcceptPort`, err) + } + s.Set(config.SocketAcceptPort, "1005") val, err := s.IntSetting(config.SocketAcceptPort) if err != nil { @@ -77,10 +84,15 @@ func TestSessionSettings_BoolSettings(t *testing.T) { } s.Set(config.ResetOnLogon, "notabool") - if _, err := s.BoolSetting(config.ResetOnLogon); err == nil { + _, err := s.BoolSetting(config.ResetOnLogon) + if err == nil { t.Error("Expected error for unparsable value") } + if err.Error() != `"notabool" is invalid for ResetOnLogon` { + t.Errorf("Expected %s, got %s", `"notabool" is invalid for ResetOnLogon`, err) + } + var boolTests = []struct { input string expected bool @@ -105,6 +117,55 @@ func TestSessionSettings_BoolSettings(t *testing.T) { } } +func TestSessionSettings_DurationSettings(t *testing.T) { + s := NewSessionSettings() + if _, err := s.BoolSetting(config.ReconnectInterval); err == nil { + t.Error("Expected error for unknown setting") + } + + s.Set(config.ReconnectInterval, "not duration") + + _, err := s.DurationSetting(config.ReconnectInterval) + if err == nil { + t.Error("Expected error for unparsable value") + } + + if err.Error() != `"not duration" is invalid for ReconnectInterval` { + t.Errorf("Expected %s, got %s", `"not duration" is invalid for ReconnectInterval`, err) + } + + s.Set(config.ReconnectInterval, "10s") + + got, err := s.DurationSetting(config.ReconnectInterval) + if err != nil { + t.Error("Unexpected err", err) + } + + expected, _ := time.ParseDuration("10s") + + if got != expected { + t.Errorf("Expected %v, got %v", expected, got) + } +} + +func TestSessionSettings_ByteSettings(t *testing.T) { + s := NewSessionSettings() + if _, err := s.RawSetting(config.SocketPrivateKeyBytes); err == nil { + t.Error("Expected error for unknown setting") + } + + s.SetRaw(config.SocketPrivateKeyBytes, []byte("pembytes")) + + got, err := s.RawSetting(config.SocketPrivateKeyBytes) + if err != nil { + t.Error("Unexpected err", err) + } + + if !bytes.Equal([]byte("pembytes"), got) { + t.Errorf("Expected %v, got %v", []byte("pembytes"), got) + } +} + func TestSessionSettings_Clone(t *testing.T) { s := NewSessionSettings() diff --git a/tag_value.go b/tag_value.go index 6e721862f..28bba7a7a 100644 --- a/tag_value.go +++ b/tag_value.go @@ -39,8 +39,28 @@ func (tv *TagValue) init(tag Tag, value []byte) { } func (tv *TagValue) parse(rawFieldBytes []byte) error { - sepIndex := bytes.IndexByte(rawFieldBytes, '=') + var sepIndex int + // Most of the Fix tags are 4 or less characters long, so we can optimize + // for that by checking the 5 first characters without looping over the + // whole byte slice. + if len(rawFieldBytes) >= 5 { + if rawFieldBytes[1] == '=' { + sepIndex = 1 + goto PARSE + } else if rawFieldBytes[2] == '=' { + sepIndex = 2 + goto PARSE + } else if rawFieldBytes[3] == '=' { + sepIndex = 3 + goto PARSE + } else if rawFieldBytes[4] == '=' { + sepIndex = 4 + goto PARSE + } + } + + sepIndex = bytes.IndexByte(rawFieldBytes, '=') switch sepIndex { case -1: return fmt.Errorf("tagValue.Parse: No '=' in '%s'", rawFieldBytes) @@ -48,6 +68,7 @@ func (tv *TagValue) parse(rawFieldBytes []byte) error { return fmt.Errorf("tagValue.Parse: No tag in '%s'", rawFieldBytes) } +PARSE: parsedTag, err := atoi(rawFieldBytes[:sepIndex]) if err != nil { return fmt.Errorf("tagValue.Parse: %s", err.Error()) diff --git a/tls.go b/tls.go index b9978a727..335d8e5d7 100644 --- a/tls.go +++ b/tls.go @@ -18,18 +18,20 @@ package quickfix import ( "crypto/tls" "crypto/x509" + "errors" "fmt" "os" "github.com/quickfixgo/quickfix/config" ) -func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) { +func loadTLSConfig(settings *SessionSettings) (*tls.Config, error) { + var err error allowSkipClientCerts := false if settings.HasSetting(config.SocketUseSSL) { allowSkipClientCerts, err = settings.BoolSetting(config.SocketUseSSL) if err != nil { - return + return nil, err } } @@ -37,7 +39,7 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) if settings.HasSetting(config.SocketServerName) { serverName, err = settings.Setting(config.SocketServerName) if err != nil { - return + return nil, err } } @@ -45,17 +47,20 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) if settings.HasSetting(config.SocketInsecureSkipVerify) { insecureSkipVerify, err = settings.BoolSetting(config.SocketInsecureSkipVerify) if err != nil { - return + return nil, err } } - if !settings.HasSetting(config.SocketPrivateKeyFile) && !settings.HasSetting(config.SocketCertificateFile) { + if !settings.HasSetting(config.SocketPrivateKeyFile) && + !settings.HasSetting(config.SocketCertificateFile) && + !settings.HasSetting(config.SocketPrivateKeyBytes) && + !settings.HasSetting(config.SocketCertificateBytes) { if !allowSkipClientCerts { - return + return nil, nil } } - tlsConfig = defaultTLSConfig() + tlsConfig := defaultTLSConfig() tlsConfig.ServerName = serverName tlsConfig.InsecureSkipVerify = insecureSkipVerify setMinVersionExplicit(settings, tlsConfig) @@ -67,49 +72,80 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) privateKeyFile, err = settings.Setting(config.SocketPrivateKeyFile) if err != nil { - return + return nil, err } certificateFile, err = settings.Setting(config.SocketCertificateFile) if err != nil { - return + return nil, err } tlsConfig.Certificates = make([]tls.Certificate, 1) if tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(certificateFile, privateKeyFile); err != nil { - return + return nil, fmt.Errorf("failed to load key pair: %w", err) + } + } else if settings.HasSetting(config.SocketPrivateKeyBytes) || settings.HasSetting(config.SocketCertificateBytes) { + privateKeyBytes, err := settings.RawSetting(config.SocketPrivateKeyBytes) + if err != nil { + return nil, err + } + + certificateBytes, err := settings.RawSetting(config.SocketCertificateBytes) + if err != nil { + return nil, err } + + tlsConfig.Certificates = make([]tls.Certificate, 1) + + certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse key pair: %w", err) + } + + tlsConfig.Certificates[0] = certificate } if !allowSkipClientCerts { tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert } - if !settings.HasSetting(config.SocketCAFile) { - return + if !settings.HasSetting(config.SocketCAFile) && !settings.HasSetting(config.SocketCABytes) { + return tlsConfig, nil } - caFile, err := settings.Setting(config.SocketCAFile) - if err != nil { - return - } + certPool := x509.NewCertPool() + if settings.HasSetting(config.SocketCAFile) { + caFile, err := settings.Setting(config.SocketCAFile) + if err != nil { + return nil, err + } - pem, err := os.ReadFile(caFile) - if err != nil { - return - } + pem, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA bundle: %w", err) + } - certPool := x509.NewCertPool() - if !certPool.AppendCertsFromPEM(pem) { - err = fmt.Errorf("Failed to parse %v", caFile) - return + if !certPool.AppendCertsFromPEM(pem) { + err = fmt.Errorf("failed to parse %v", caFile) + return nil, err + } + } else { + caBytes, err := settings.RawSetting(config.SocketCABytes) + if err != nil { + return nil, err + } + + if !certPool.AppendCertsFromPEM(caBytes) { + err = errors.New("failed to parse CA bundle from raw bytes") + return nil, err + } } tlsConfig.RootCAs = certPool tlsConfig.ClientCAs = certPool - return + return tlsConfig, nil } // defaultTLSConfig brought to you by https://github.com/gtank/cryptopasta/ diff --git a/tls_test.go b/tls_test.go index 0bae6cd54..ea9ea93a6 100644 --- a/tls_test.go +++ b/tls_test.go @@ -17,6 +17,7 @@ package quickfix import ( "crypto/tls" + "os" "testing" "github.com/stretchr/testify/suite" @@ -26,8 +27,9 @@ import ( type TLSTestSuite struct { suite.Suite - settings *Settings - PrivateKeyFile, CertificateFile, CAFile string + settings *Settings + PrivateKeyFile, CertificateFile, CAFile string + PrivateKeyBytes, CertificateBytes, CABytes []byte } func TestTLSTestSuite(t *testing.T) { @@ -39,6 +41,18 @@ func (s *TLSTestSuite) SetupTest() { s.PrivateKeyFile = "_test_data/localhost.key" s.CertificateFile = "_test_data/localhost.crt" s.CAFile = "_test_data/ca.crt" + + privateKeyBytes, err := os.ReadFile(s.PrivateKeyFile) + s.Require().NoError(err) + s.PrivateKeyBytes = privateKeyBytes + + certificateBytes, err := os.ReadFile(s.CertificateFile) + s.Require().NoError(err) + s.CertificateBytes = certificateBytes + + caBytes, err := os.ReadFile(s.CAFile) + s.Require().NoError(err) + s.CABytes = caBytes } func (s *TLSTestSuite) TestLoadTLSNoSettings() { @@ -51,11 +65,13 @@ func (s *TLSTestSuite) TestLoadTLSMissingKeyOrCert() { s.settings.GlobalSettings().Set(config.SocketPrivateKeyFile, s.PrivateKeyFile) _, err := loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketCertificateFile") s.SetupTest() s.settings.GlobalSettings().Set(config.SocketCertificateFile, s.CertificateFile) _, err = loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketPrivateKeyFile") } func (s *TLSTestSuite) TestLoadTLSInvalidKeyOrCert() { @@ -63,6 +79,7 @@ func (s *TLSTestSuite) TestLoadTLSInvalidKeyOrCert() { s.settings.GlobalSettings().Set(config.SocketCertificateFile, "foo") _, err := loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "failed to load key pair: open foo: no such file or directory") } func (s *TLSTestSuite) TestLoadTLSNoCA() { @@ -86,6 +103,7 @@ func (s *TLSTestSuite) TestLoadTLSWithBadCA() { _, err := loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "failed to read CA bundle: open bar: no such file or directory") } func (s *TLSTestSuite) TestLoadTLSWithCA() { @@ -223,3 +241,93 @@ func (s *TLSTestSuite) TestMinimumTLSVersion() { s.NotNil(tlsConfig) s.Equal(tlsConfig.MinVersion, uint16(tls.VersionTLS12)) } + +func (s *TLSTestSuite) TestLoadTLSBytesMissingKeyOrCert() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + _, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketCertificateBytes") + + s.SetupTest() + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + _, err = loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketPrivateKeyBytes") +} + +func (s *TLSTestSuite) TestLoadTLSBytesInvalidKeyOrCert() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, []byte("not a cert")) + _, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "failed to parse key pair: tls: failed to find any PEM data in certificate input") + + s.SetupTest() + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, []byte("not a key")) + _, err = loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "failed to parse key pair: tls: failed to find any PEM data in key input") +} + +func (s *TLSTestSuite) TestLoadTLSBytesNoCA() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NoError(err) + s.NotNil(tlsConfig) + + s.Len(tlsConfig.Certificates, 1) + s.Nil(tlsConfig.RootCAs) + s.Nil(tlsConfig.ClientCAs) + s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) +} + +func (s *TLSTestSuite) TestLoadTLSBytesWithBadCA() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCABytes, []byte("bar")) + + _, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "failed to parse CA bundle from raw bytes") +} + +func (s *TLSTestSuite) TestLoadTLSBytesWithCA() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCABytes, s.CABytes) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.Len(tlsConfig.Certificates, 1) + s.NotNil(tlsConfig.RootCAs) + s.NotNil(tlsConfig.ClientCAs) + s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) +} + +func (s *TLSTestSuite) TestLoadTLSBytesWithOnlyCA() { + s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") + s.settings.GlobalSettings().SetRaw(config.SocketCABytes, s.CABytes) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.NotNil(tlsConfig.RootCAs) + s.NotNil(tlsConfig.ClientCAs) +} + +func (s *TLSTestSuite) TestServerNameWithCertsFromBytes() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().Set(config.SocketServerName, "DummyServerNameWithCerts") + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + s.Equal("DummyServerNameWithCerts", tlsConfig.ServerName) +}