diff --git a/config/configuration.go b/config/configuration.go index f524f5005..754f587b1 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -739,6 +739,47 @@ const ( // - A filepath to a file with read access. SocketCAFile string = "SocketCAFile" + // SocketPrivateKeyBytes is an optional value containing raw bytes of a PEM + // encoded private key to use for secure TLS communications. + // Must be used with SocketCertificateBytes. + // Must contain PEM encoded data. + // + // Required: No + // + // Default: N/A + // + // Valid Values: + // - Raw bytes containing a valid PEM encoded private key. + SocketPrivateKeyBytes string = "SocketPrivateKeyBytes" + + // SocketCertificateBytes is an optional value containing raw bytes of a PEM + // encoded certificate to use for secure TLS communications. + // Must be used with SocketPrivateKeyBytes. + // Must contain PEM encoded data. + // + // Required: No + // + // Default: N/A + // + // Valid Values: + // - Raw bytes containing a valid PEM encoded certificate. + SocketCertificateBytes string = "SocketCertificateBytes" + + // SocketCABytes is an optional value containing raw bytes of a PEM encoded + // root CA to use for secure TLS communications. For acceptors, client + // certificates will be verified against this CA. For initiators, clients + // will use the CA to verify the server certificate. If not configured, + // initiators will verify the server certificates using the host's root CA + // set. + // + // Required: No + // + // Default: N/A + // + // Valid Values: + // - Raw bytes containing a valid PEM encoded CA. + SocketCABytes string = "SocketCABytes" + // SocketInsecureSkipVerify controls whether a client verifies the server's certificate chain and host name. // If SocketInsecureSkipVerify is set to Y, crypto/tls accepts any certificate presented by the server and any host name in that certificate. // In this mode, TLS is susceptible to machine-in-the-middle attacks unless custom verification is used. diff --git a/session_factory.go b/session_factory.go index 37818deb9..d875762de 100644 --- a/session_factory.go +++ b/session_factory.go @@ -284,7 +284,7 @@ func (f sessionFactory) newSession( for _, dayStr := range dayStrs { day, ok := dayLookup[dayStr] if !ok { - err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: weekdaysStr} + err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: []byte(weekdaysStr)} return } weekdays = append(weekdays, day) @@ -315,7 +315,7 @@ func (f sessionFactory) newSession( parseDay := func(setting, dayStr string) (day time.Weekday, err error) { day, ok := dayLookup[dayStr] if !ok { - return day, IncorrectFormatForSetting{Setting: setting, Value: dayStr} + return day, IncorrectFormatForSetting{Setting: setting, Value: []byte(dayStr)} } return } @@ -355,7 +355,7 @@ func (f sessionFactory) newSession( s.timestampPrecision = Nanos default: - err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: precisionStr} + err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: []byte(precisionStr)} return } } diff --git a/session_settings.go b/session_settings.go index 323ac6234..74d6cdd27 100644 --- a/session_settings.go +++ b/session_settings.go @@ -23,7 +23,7 @@ import ( // SessionSettings maps session settings to values with typed accessors. type SessionSettings struct { - settings map[string]string + settings map[string][]byte } // ConditionallyRequiredSetting indicates a missing setting. @@ -37,8 +37,9 @@ func (e ConditionallyRequiredSetting) Error() string { // IncorrectFormatForSetting indicates a setting that is incorrectly formatted. type IncorrectFormatForSetting struct { - Setting, Value string - Err error + Setting string + Value []byte + Err error } func (e IncorrectFormatForSetting) Error() string { @@ -47,7 +48,7 @@ func (e IncorrectFormatForSetting) Error() string { // Init initializes or resets SessionSettings. func (s *SessionSettings) Init() { - s.settings = make(map[string]string) + s.settings = make(map[string][]byte) } // NewSessionSettings returns a newly initialized SessionSettings instance. @@ -58,8 +59,8 @@ func NewSessionSettings() *SessionSettings { return s } -// Set assigns a value to a setting on SessionSettings. -func (s *SessionSettings) Set(setting string, val string) { +// SetRaw assigns a value to a setting on SessionSettings. +func (s *SessionSettings) SetRaw(setting string, val []byte) { // Lazy init. if s.settings == nil { s.Init() @@ -68,69 +69,87 @@ func (s *SessionSettings) Set(setting string, val string) { s.settings[setting] = val } +// Set assigns a string value to a setting on SessionSettings. +func (s *SessionSettings) Set(setting string, val string) { + // Lazy init + if s.settings == nil { + s.Init() + } + + s.settings[setting] = []byte(val) +} + // HasSetting returns true if a setting is set, false if not. func (s *SessionSettings) HasSetting(setting string) bool { _, ok := s.settings[setting] return ok } -// Setting is a settings string accessor. Returns an error if the setting is missing. -func (s *SessionSettings) Setting(setting string) (string, error) { +// RawSetting is a settings accessor that returns the raw byte slice value of +// the setting. Returns an error if the setting is missing. +func (s *SessionSettings) RawSetting(setting string) ([]byte, error) { val, ok := s.settings[setting] if !ok { - return val, ConditionallyRequiredSetting{setting} + return nil, ConditionallyRequiredSetting{Setting: setting} } return val, nil } -// IntSetting returns the requested setting parsed as an int. Returns an errror if the setting is not set or cannot be parsed as an int. -func (s *SessionSettings) IntSetting(setting string) (val int, err error) { - stringVal, err := s.Setting(setting) +// Setting is a settings string accessor. Returns an error if the setting is missing. +func (s *SessionSettings) Setting(setting string) (string, error) { + val, err := s.RawSetting(setting) + if err != nil { + return "", err + } + return string(val), nil +} + +// IntSetting returns the requested setting parsed as an int. Returns an errror if the setting is not set or cannot be parsed as an int. +func (s *SessionSettings) IntSetting(setting string) (int, error) { + rawVal, err := s.RawSetting(setting) if err != nil { - return + return 0, err } - if val, err = strconv.Atoi(stringVal); err != nil { - return val, IncorrectFormatForSetting{Setting: setting, Value: stringVal, Err: err} + if val, err := strconv.Atoi(string(rawVal)); err == nil { + return val, nil } - return + return 0, IncorrectFormatForSetting{Setting: setting, Value: rawVal, Err: err} } // DurationSetting returns the requested setting parsed as a time.Duration. // Returns an error if the setting is not set or cannot be parsed as a time.Duration. -func (s *SessionSettings) DurationSetting(setting string) (val time.Duration, err error) { - stringVal, err := s.Setting(setting) - +func (s *SessionSettings) DurationSetting(setting string) (time.Duration, error) { + rawVal, err := s.RawSetting(setting) if err != nil { - return + return 0, err } - if val, err = time.ParseDuration(stringVal); err != nil { - return val, IncorrectFormatForSetting{Setting: setting, Value: stringVal, Err: err} + if val, err := time.ParseDuration(string(rawVal)); err == nil { + return val, nil } - return + return 0, IncorrectFormatForSetting{Setting: setting, Value: rawVal, Err: err} } // BoolSetting returns the requested setting parsed as a boolean. Returns an error if the setting is not set or cannot be parsed as a bool. func (s SessionSettings) BoolSetting(setting string) (bool, error) { - stringVal, err := s.Setting(setting) - + rawVal, err := s.RawSetting(setting) if err != nil { return false, err } - switch stringVal { + switch string(rawVal) { case "Y", "y": return true, nil case "N", "n": return false, nil } - return false, IncorrectFormatForSetting{Setting: setting, Value: stringVal} + return false, IncorrectFormatForSetting{Setting: setting, Value: rawVal} } func (s *SessionSettings) overlay(overlay *SessionSettings) { diff --git a/session_settings_test.go b/session_settings_test.go index 7fa9a94ad..d413f6866 100644 --- a/session_settings_test.go +++ b/session_settings_test.go @@ -16,7 +16,9 @@ package quickfix import ( + "bytes" "testing" + "time" "github.com/quickfixgo/quickfix/config" ) @@ -55,10 +57,15 @@ func TestSessionSettings_IntSettings(t *testing.T) { } s.Set(config.SocketAcceptPort, "notanint") - if _, err := s.IntSetting(config.SocketAcceptPort); err == nil { + _, err := s.IntSetting(config.SocketAcceptPort) + if err == nil { t.Error("Expected error for unparsable value") } + if err.Error() != `"notanint" is invalid for SocketAcceptPort` { + t.Errorf("Expected %s, got %s", `"notanint" is invalid for SocketAcceptPort`, err) + } + s.Set(config.SocketAcceptPort, "1005") val, err := s.IntSetting(config.SocketAcceptPort) if err != nil { @@ -77,10 +84,15 @@ func TestSessionSettings_BoolSettings(t *testing.T) { } s.Set(config.ResetOnLogon, "notabool") - if _, err := s.BoolSetting(config.ResetOnLogon); err == nil { + _, err := s.BoolSetting(config.ResetOnLogon) + if err == nil { t.Error("Expected error for unparsable value") } + if err.Error() != `"notabool" is invalid for ResetOnLogon` { + t.Errorf("Expected %s, got %s", `"notabool" is invalid for ResetOnLogon`, err) + } + var boolTests = []struct { input string expected bool @@ -105,6 +117,55 @@ func TestSessionSettings_BoolSettings(t *testing.T) { } } +func TestSessionSettings_DurationSettings(t *testing.T) { + s := NewSessionSettings() + if _, err := s.BoolSetting(config.ReconnectInterval); err == nil { + t.Error("Expected error for unknown setting") + } + + s.Set(config.ReconnectInterval, "not duration") + + _, err := s.DurationSetting(config.ReconnectInterval) + if err == nil { + t.Error("Expected error for unparsable value") + } + + if err.Error() != `"not duration" is invalid for ReconnectInterval` { + t.Errorf("Expected %s, got %s", `"not duration" is invalid for ReconnectInterval`, err) + } + + s.Set(config.ReconnectInterval, "10s") + + got, err := s.DurationSetting(config.ReconnectInterval) + if err != nil { + t.Error("Unexpected err", err) + } + + expected, _ := time.ParseDuration("10s") + + if got != expected { + t.Errorf("Expected %v, got %v", expected, got) + } +} + +func TestSessionSettings_ByteSettings(t *testing.T) { + s := NewSessionSettings() + if _, err := s.RawSetting(config.SocketPrivateKeyBytes); err == nil { + t.Error("Expected error for unknown setting") + } + + s.SetRaw(config.SocketPrivateKeyBytes, []byte("pembytes")) + + got, err := s.RawSetting(config.SocketPrivateKeyBytes) + if err != nil { + t.Error("Unexpected err", err) + } + + if !bytes.Equal([]byte("pembytes"), got) { + t.Errorf("Expected %v, got %v", []byte("pembytes"), got) + } +} + func TestSessionSettings_Clone(t *testing.T) { s := NewSessionSettings() diff --git a/tls.go b/tls.go index b9978a727..335d8e5d7 100644 --- a/tls.go +++ b/tls.go @@ -18,18 +18,20 @@ package quickfix import ( "crypto/tls" "crypto/x509" + "errors" "fmt" "os" "github.com/quickfixgo/quickfix/config" ) -func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) { +func loadTLSConfig(settings *SessionSettings) (*tls.Config, error) { + var err error allowSkipClientCerts := false if settings.HasSetting(config.SocketUseSSL) { allowSkipClientCerts, err = settings.BoolSetting(config.SocketUseSSL) if err != nil { - return + return nil, err } } @@ -37,7 +39,7 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) if settings.HasSetting(config.SocketServerName) { serverName, err = settings.Setting(config.SocketServerName) if err != nil { - return + return nil, err } } @@ -45,17 +47,20 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) if settings.HasSetting(config.SocketInsecureSkipVerify) { insecureSkipVerify, err = settings.BoolSetting(config.SocketInsecureSkipVerify) if err != nil { - return + return nil, err } } - if !settings.HasSetting(config.SocketPrivateKeyFile) && !settings.HasSetting(config.SocketCertificateFile) { + if !settings.HasSetting(config.SocketPrivateKeyFile) && + !settings.HasSetting(config.SocketCertificateFile) && + !settings.HasSetting(config.SocketPrivateKeyBytes) && + !settings.HasSetting(config.SocketCertificateBytes) { if !allowSkipClientCerts { - return + return nil, nil } } - tlsConfig = defaultTLSConfig() + tlsConfig := defaultTLSConfig() tlsConfig.ServerName = serverName tlsConfig.InsecureSkipVerify = insecureSkipVerify setMinVersionExplicit(settings, tlsConfig) @@ -67,49 +72,80 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) privateKeyFile, err = settings.Setting(config.SocketPrivateKeyFile) if err != nil { - return + return nil, err } certificateFile, err = settings.Setting(config.SocketCertificateFile) if err != nil { - return + return nil, err } tlsConfig.Certificates = make([]tls.Certificate, 1) if tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(certificateFile, privateKeyFile); err != nil { - return + return nil, fmt.Errorf("failed to load key pair: %w", err) + } + } else if settings.HasSetting(config.SocketPrivateKeyBytes) || settings.HasSetting(config.SocketCertificateBytes) { + privateKeyBytes, err := settings.RawSetting(config.SocketPrivateKeyBytes) + if err != nil { + return nil, err + } + + certificateBytes, err := settings.RawSetting(config.SocketCertificateBytes) + if err != nil { + return nil, err } + + tlsConfig.Certificates = make([]tls.Certificate, 1) + + certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse key pair: %w", err) + } + + tlsConfig.Certificates[0] = certificate } if !allowSkipClientCerts { tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert } - if !settings.HasSetting(config.SocketCAFile) { - return + if !settings.HasSetting(config.SocketCAFile) && !settings.HasSetting(config.SocketCABytes) { + return tlsConfig, nil } - caFile, err := settings.Setting(config.SocketCAFile) - if err != nil { - return - } + certPool := x509.NewCertPool() + if settings.HasSetting(config.SocketCAFile) { + caFile, err := settings.Setting(config.SocketCAFile) + if err != nil { + return nil, err + } - pem, err := os.ReadFile(caFile) - if err != nil { - return - } + pem, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA bundle: %w", err) + } - certPool := x509.NewCertPool() - if !certPool.AppendCertsFromPEM(pem) { - err = fmt.Errorf("Failed to parse %v", caFile) - return + if !certPool.AppendCertsFromPEM(pem) { + err = fmt.Errorf("failed to parse %v", caFile) + return nil, err + } + } else { + caBytes, err := settings.RawSetting(config.SocketCABytes) + if err != nil { + return nil, err + } + + if !certPool.AppendCertsFromPEM(caBytes) { + err = errors.New("failed to parse CA bundle from raw bytes") + return nil, err + } } tlsConfig.RootCAs = certPool tlsConfig.ClientCAs = certPool - return + return tlsConfig, nil } // defaultTLSConfig brought to you by https://github.com/gtank/cryptopasta/ diff --git a/tls_test.go b/tls_test.go index 0bae6cd54..ea9ea93a6 100644 --- a/tls_test.go +++ b/tls_test.go @@ -17,6 +17,7 @@ package quickfix import ( "crypto/tls" + "os" "testing" "github.com/stretchr/testify/suite" @@ -26,8 +27,9 @@ import ( type TLSTestSuite struct { suite.Suite - settings *Settings - PrivateKeyFile, CertificateFile, CAFile string + settings *Settings + PrivateKeyFile, CertificateFile, CAFile string + PrivateKeyBytes, CertificateBytes, CABytes []byte } func TestTLSTestSuite(t *testing.T) { @@ -39,6 +41,18 @@ func (s *TLSTestSuite) SetupTest() { s.PrivateKeyFile = "_test_data/localhost.key" s.CertificateFile = "_test_data/localhost.crt" s.CAFile = "_test_data/ca.crt" + + privateKeyBytes, err := os.ReadFile(s.PrivateKeyFile) + s.Require().NoError(err) + s.PrivateKeyBytes = privateKeyBytes + + certificateBytes, err := os.ReadFile(s.CertificateFile) + s.Require().NoError(err) + s.CertificateBytes = certificateBytes + + caBytes, err := os.ReadFile(s.CAFile) + s.Require().NoError(err) + s.CABytes = caBytes } func (s *TLSTestSuite) TestLoadTLSNoSettings() { @@ -51,11 +65,13 @@ func (s *TLSTestSuite) TestLoadTLSMissingKeyOrCert() { s.settings.GlobalSettings().Set(config.SocketPrivateKeyFile, s.PrivateKeyFile) _, err := loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketCertificateFile") s.SetupTest() s.settings.GlobalSettings().Set(config.SocketCertificateFile, s.CertificateFile) _, err = loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketPrivateKeyFile") } func (s *TLSTestSuite) TestLoadTLSInvalidKeyOrCert() { @@ -63,6 +79,7 @@ func (s *TLSTestSuite) TestLoadTLSInvalidKeyOrCert() { s.settings.GlobalSettings().Set(config.SocketCertificateFile, "foo") _, err := loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "failed to load key pair: open foo: no such file or directory") } func (s *TLSTestSuite) TestLoadTLSNoCA() { @@ -86,6 +103,7 @@ func (s *TLSTestSuite) TestLoadTLSWithBadCA() { _, err := loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "failed to read CA bundle: open bar: no such file or directory") } func (s *TLSTestSuite) TestLoadTLSWithCA() { @@ -223,3 +241,93 @@ func (s *TLSTestSuite) TestMinimumTLSVersion() { s.NotNil(tlsConfig) s.Equal(tlsConfig.MinVersion, uint16(tls.VersionTLS12)) } + +func (s *TLSTestSuite) TestLoadTLSBytesMissingKeyOrCert() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + _, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketCertificateBytes") + + s.SetupTest() + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + _, err = loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketPrivateKeyBytes") +} + +func (s *TLSTestSuite) TestLoadTLSBytesInvalidKeyOrCert() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, []byte("not a cert")) + _, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "failed to parse key pair: tls: failed to find any PEM data in certificate input") + + s.SetupTest() + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, []byte("not a key")) + _, err = loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "failed to parse key pair: tls: failed to find any PEM data in key input") +} + +func (s *TLSTestSuite) TestLoadTLSBytesNoCA() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NoError(err) + s.NotNil(tlsConfig) + + s.Len(tlsConfig.Certificates, 1) + s.Nil(tlsConfig.RootCAs) + s.Nil(tlsConfig.ClientCAs) + s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) +} + +func (s *TLSTestSuite) TestLoadTLSBytesWithBadCA() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCABytes, []byte("bar")) + + _, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "failed to parse CA bundle from raw bytes") +} + +func (s *TLSTestSuite) TestLoadTLSBytesWithCA() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCABytes, s.CABytes) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.Len(tlsConfig.Certificates, 1) + s.NotNil(tlsConfig.RootCAs) + s.NotNil(tlsConfig.ClientCAs) + s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) +} + +func (s *TLSTestSuite) TestLoadTLSBytesWithOnlyCA() { + s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") + s.settings.GlobalSettings().SetRaw(config.SocketCABytes, s.CABytes) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.NotNil(tlsConfig.RootCAs) + s.NotNil(tlsConfig.ClientCAs) +} + +func (s *TLSTestSuite) TestServerNameWithCertsFromBytes() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().Set(config.SocketServerName, "DummyServerNameWithCerts") + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + s.Equal("DummyServerNameWithCerts", tlsConfig.ServerName) +}