diff --git a/config/configuration.go b/config/configuration.go index bf3044dba..b0ff01347 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -19,6 +19,9 @@ const ( SocketPrivateKeyFile string = "SocketPrivateKeyFile" SocketCertificateFile string = "SocketCertificateFile" SocketCAFile string = "SocketCAFile" + SocketPrivateKeyBytes string = "SocketPrivateKeyBytes" + SocketCertificateBytes string = "SocketCertificateBytes" + SocketCABytes string = "SocketCABytes" SocketInsecureSkipVerify string = "SocketInsecureSkipVerify" SocketServerName string = "SocketServerName" SocketMinimumTLSVersion string = "SocketMinimumTLSVersion" diff --git a/config/doc.go b/config/doc.go index 911a4458a..7e0edb028 100644 --- a/config/doc.go +++ b/config/doc.go @@ -317,6 +317,18 @@ Certificate to use for secure TLS connections. Must be used with SocketPrivateKe Optional root CA to use for secure TLS connections. For acceptors, client certificates will be verified against this CA. For initiators, clients will use the CA to verify the server certificate. If not configurated, initiators will verify the server certificate using the host's root CA set. +# SocketPrivateKeyBytes + +Raw bytes of PEM encoded private to use for secure TLS connections. Must be used with SocketCertificateBytes + +# SocketCertificateBytes + +Raw bytes of PEM encoded certificate to use for secure TLS connections. Must be used with SocketPrivateKeyBytes + +# SocketCABytes + +Optional root CA to use for secure TLS connections as raw bytes. For acceptors, client certificates will be verified against this CA. For initiators, clients will use the CA to verify the server certificate. If not configurated, initiators will verify the server certificate using the host's root CA set. + # SocketServerName The expected server name on a returned certificate, unless SocketInsecureSkipVerify is true. This is for the TLS Server Name Indication extension. Initiator only. diff --git a/session_factory.go b/session_factory.go index 37818deb9..d875762de 100644 --- a/session_factory.go +++ b/session_factory.go @@ -284,7 +284,7 @@ func (f sessionFactory) newSession( for _, dayStr := range dayStrs { day, ok := dayLookup[dayStr] if !ok { - err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: weekdaysStr} + err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: []byte(weekdaysStr)} return } weekdays = append(weekdays, day) @@ -315,7 +315,7 @@ func (f sessionFactory) newSession( parseDay := func(setting, dayStr string) (day time.Weekday, err error) { day, ok := dayLookup[dayStr] if !ok { - return day, IncorrectFormatForSetting{Setting: setting, Value: dayStr} + return day, IncorrectFormatForSetting{Setting: setting, Value: []byte(dayStr)} } return } @@ -355,7 +355,7 @@ func (f sessionFactory) newSession( s.timestampPrecision = Nanos default: - err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: precisionStr} + err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: []byte(precisionStr)} return } } diff --git a/session_settings.go b/session_settings.go index 323ac6234..74d6cdd27 100644 --- a/session_settings.go +++ b/session_settings.go @@ -23,7 +23,7 @@ import ( // SessionSettings maps session settings to values with typed accessors. type SessionSettings struct { - settings map[string]string + settings map[string][]byte } // ConditionallyRequiredSetting indicates a missing setting. @@ -37,8 +37,9 @@ func (e ConditionallyRequiredSetting) Error() string { // IncorrectFormatForSetting indicates a setting that is incorrectly formatted. type IncorrectFormatForSetting struct { - Setting, Value string - Err error + Setting string + Value []byte + Err error } func (e IncorrectFormatForSetting) Error() string { @@ -47,7 +48,7 @@ func (e IncorrectFormatForSetting) Error() string { // Init initializes or resets SessionSettings. func (s *SessionSettings) Init() { - s.settings = make(map[string]string) + s.settings = make(map[string][]byte) } // NewSessionSettings returns a newly initialized SessionSettings instance. @@ -58,8 +59,8 @@ func NewSessionSettings() *SessionSettings { return s } -// Set assigns a value to a setting on SessionSettings. -func (s *SessionSettings) Set(setting string, val string) { +// SetRaw assigns a value to a setting on SessionSettings. +func (s *SessionSettings) SetRaw(setting string, val []byte) { // Lazy init. if s.settings == nil { s.Init() @@ -68,69 +69,87 @@ func (s *SessionSettings) Set(setting string, val string) { s.settings[setting] = val } +// Set assigns a string value to a setting on SessionSettings. +func (s *SessionSettings) Set(setting string, val string) { + // Lazy init + if s.settings == nil { + s.Init() + } + + s.settings[setting] = []byte(val) +} + // HasSetting returns true if a setting is set, false if not. func (s *SessionSettings) HasSetting(setting string) bool { _, ok := s.settings[setting] return ok } -// Setting is a settings string accessor. Returns an error if the setting is missing. -func (s *SessionSettings) Setting(setting string) (string, error) { +// RawSetting is a settings accessor that returns the raw byte slice value of +// the setting. Returns an error if the setting is missing. +func (s *SessionSettings) RawSetting(setting string) ([]byte, error) { val, ok := s.settings[setting] if !ok { - return val, ConditionallyRequiredSetting{setting} + return nil, ConditionallyRequiredSetting{Setting: setting} } return val, nil } -// IntSetting returns the requested setting parsed as an int. Returns an errror if the setting is not set or cannot be parsed as an int. -func (s *SessionSettings) IntSetting(setting string) (val int, err error) { - stringVal, err := s.Setting(setting) +// Setting is a settings string accessor. Returns an error if the setting is missing. +func (s *SessionSettings) Setting(setting string) (string, error) { + val, err := s.RawSetting(setting) + if err != nil { + return "", err + } + return string(val), nil +} + +// IntSetting returns the requested setting parsed as an int. Returns an errror if the setting is not set or cannot be parsed as an int. +func (s *SessionSettings) IntSetting(setting string) (int, error) { + rawVal, err := s.RawSetting(setting) if err != nil { - return + return 0, err } - if val, err = strconv.Atoi(stringVal); err != nil { - return val, IncorrectFormatForSetting{Setting: setting, Value: stringVal, Err: err} + if val, err := strconv.Atoi(string(rawVal)); err == nil { + return val, nil } - return + return 0, IncorrectFormatForSetting{Setting: setting, Value: rawVal, Err: err} } // DurationSetting returns the requested setting parsed as a time.Duration. // Returns an error if the setting is not set or cannot be parsed as a time.Duration. -func (s *SessionSettings) DurationSetting(setting string) (val time.Duration, err error) { - stringVal, err := s.Setting(setting) - +func (s *SessionSettings) DurationSetting(setting string) (time.Duration, error) { + rawVal, err := s.RawSetting(setting) if err != nil { - return + return 0, err } - if val, err = time.ParseDuration(stringVal); err != nil { - return val, IncorrectFormatForSetting{Setting: setting, Value: stringVal, Err: err} + if val, err := time.ParseDuration(string(rawVal)); err == nil { + return val, nil } - return + return 0, IncorrectFormatForSetting{Setting: setting, Value: rawVal, Err: err} } // BoolSetting returns the requested setting parsed as a boolean. Returns an error if the setting is not set or cannot be parsed as a bool. func (s SessionSettings) BoolSetting(setting string) (bool, error) { - stringVal, err := s.Setting(setting) - + rawVal, err := s.RawSetting(setting) if err != nil { return false, err } - switch stringVal { + switch string(rawVal) { case "Y", "y": return true, nil case "N", "n": return false, nil } - return false, IncorrectFormatForSetting{Setting: setting, Value: stringVal} + return false, IncorrectFormatForSetting{Setting: setting, Value: rawVal} } func (s *SessionSettings) overlay(overlay *SessionSettings) { diff --git a/session_settings_test.go b/session_settings_test.go index 7fa9a94ad..d413f6866 100644 --- a/session_settings_test.go +++ b/session_settings_test.go @@ -16,7 +16,9 @@ package quickfix import ( + "bytes" "testing" + "time" "github.com/quickfixgo/quickfix/config" ) @@ -55,10 +57,15 @@ func TestSessionSettings_IntSettings(t *testing.T) { } s.Set(config.SocketAcceptPort, "notanint") - if _, err := s.IntSetting(config.SocketAcceptPort); err == nil { + _, err := s.IntSetting(config.SocketAcceptPort) + if err == nil { t.Error("Expected error for unparsable value") } + if err.Error() != `"notanint" is invalid for SocketAcceptPort` { + t.Errorf("Expected %s, got %s", `"notanint" is invalid for SocketAcceptPort`, err) + } + s.Set(config.SocketAcceptPort, "1005") val, err := s.IntSetting(config.SocketAcceptPort) if err != nil { @@ -77,10 +84,15 @@ func TestSessionSettings_BoolSettings(t *testing.T) { } s.Set(config.ResetOnLogon, "notabool") - if _, err := s.BoolSetting(config.ResetOnLogon); err == nil { + _, err := s.BoolSetting(config.ResetOnLogon) + if err == nil { t.Error("Expected error for unparsable value") } + if err.Error() != `"notabool" is invalid for ResetOnLogon` { + t.Errorf("Expected %s, got %s", `"notabool" is invalid for ResetOnLogon`, err) + } + var boolTests = []struct { input string expected bool @@ -105,6 +117,55 @@ func TestSessionSettings_BoolSettings(t *testing.T) { } } +func TestSessionSettings_DurationSettings(t *testing.T) { + s := NewSessionSettings() + if _, err := s.BoolSetting(config.ReconnectInterval); err == nil { + t.Error("Expected error for unknown setting") + } + + s.Set(config.ReconnectInterval, "not duration") + + _, err := s.DurationSetting(config.ReconnectInterval) + if err == nil { + t.Error("Expected error for unparsable value") + } + + if err.Error() != `"not duration" is invalid for ReconnectInterval` { + t.Errorf("Expected %s, got %s", `"not duration" is invalid for ReconnectInterval`, err) + } + + s.Set(config.ReconnectInterval, "10s") + + got, err := s.DurationSetting(config.ReconnectInterval) + if err != nil { + t.Error("Unexpected err", err) + } + + expected, _ := time.ParseDuration("10s") + + if got != expected { + t.Errorf("Expected %v, got %v", expected, got) + } +} + +func TestSessionSettings_ByteSettings(t *testing.T) { + s := NewSessionSettings() + if _, err := s.RawSetting(config.SocketPrivateKeyBytes); err == nil { + t.Error("Expected error for unknown setting") + } + + s.SetRaw(config.SocketPrivateKeyBytes, []byte("pembytes")) + + got, err := s.RawSetting(config.SocketPrivateKeyBytes) + if err != nil { + t.Error("Unexpected err", err) + } + + if !bytes.Equal([]byte("pembytes"), got) { + t.Errorf("Expected %v, got %v", []byte("pembytes"), got) + } +} + func TestSessionSettings_Clone(t *testing.T) { s := NewSessionSettings() diff --git a/tls.go b/tls.go index b9978a727..335d8e5d7 100644 --- a/tls.go +++ b/tls.go @@ -18,18 +18,20 @@ package quickfix import ( "crypto/tls" "crypto/x509" + "errors" "fmt" "os" "github.com/quickfixgo/quickfix/config" ) -func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) { +func loadTLSConfig(settings *SessionSettings) (*tls.Config, error) { + var err error allowSkipClientCerts := false if settings.HasSetting(config.SocketUseSSL) { allowSkipClientCerts, err = settings.BoolSetting(config.SocketUseSSL) if err != nil { - return + return nil, err } } @@ -37,7 +39,7 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) if settings.HasSetting(config.SocketServerName) { serverName, err = settings.Setting(config.SocketServerName) if err != nil { - return + return nil, err } } @@ -45,17 +47,20 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) if settings.HasSetting(config.SocketInsecureSkipVerify) { insecureSkipVerify, err = settings.BoolSetting(config.SocketInsecureSkipVerify) if err != nil { - return + return nil, err } } - if !settings.HasSetting(config.SocketPrivateKeyFile) && !settings.HasSetting(config.SocketCertificateFile) { + if !settings.HasSetting(config.SocketPrivateKeyFile) && + !settings.HasSetting(config.SocketCertificateFile) && + !settings.HasSetting(config.SocketPrivateKeyBytes) && + !settings.HasSetting(config.SocketCertificateBytes) { if !allowSkipClientCerts { - return + return nil, nil } } - tlsConfig = defaultTLSConfig() + tlsConfig := defaultTLSConfig() tlsConfig.ServerName = serverName tlsConfig.InsecureSkipVerify = insecureSkipVerify setMinVersionExplicit(settings, tlsConfig) @@ -67,49 +72,80 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) privateKeyFile, err = settings.Setting(config.SocketPrivateKeyFile) if err != nil { - return + return nil, err } certificateFile, err = settings.Setting(config.SocketCertificateFile) if err != nil { - return + return nil, err } tlsConfig.Certificates = make([]tls.Certificate, 1) if tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(certificateFile, privateKeyFile); err != nil { - return + return nil, fmt.Errorf("failed to load key pair: %w", err) + } + } else if settings.HasSetting(config.SocketPrivateKeyBytes) || settings.HasSetting(config.SocketCertificateBytes) { + privateKeyBytes, err := settings.RawSetting(config.SocketPrivateKeyBytes) + if err != nil { + return nil, err + } + + certificateBytes, err := settings.RawSetting(config.SocketCertificateBytes) + if err != nil { + return nil, err } + + tlsConfig.Certificates = make([]tls.Certificate, 1) + + certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse key pair: %w", err) + } + + tlsConfig.Certificates[0] = certificate } if !allowSkipClientCerts { tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert } - if !settings.HasSetting(config.SocketCAFile) { - return + if !settings.HasSetting(config.SocketCAFile) && !settings.HasSetting(config.SocketCABytes) { + return tlsConfig, nil } - caFile, err := settings.Setting(config.SocketCAFile) - if err != nil { - return - } + certPool := x509.NewCertPool() + if settings.HasSetting(config.SocketCAFile) { + caFile, err := settings.Setting(config.SocketCAFile) + if err != nil { + return nil, err + } - pem, err := os.ReadFile(caFile) - if err != nil { - return - } + pem, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA bundle: %w", err) + } - certPool := x509.NewCertPool() - if !certPool.AppendCertsFromPEM(pem) { - err = fmt.Errorf("Failed to parse %v", caFile) - return + if !certPool.AppendCertsFromPEM(pem) { + err = fmt.Errorf("failed to parse %v", caFile) + return nil, err + } + } else { + caBytes, err := settings.RawSetting(config.SocketCABytes) + if err != nil { + return nil, err + } + + if !certPool.AppendCertsFromPEM(caBytes) { + err = errors.New("failed to parse CA bundle from raw bytes") + return nil, err + } } tlsConfig.RootCAs = certPool tlsConfig.ClientCAs = certPool - return + return tlsConfig, nil } // defaultTLSConfig brought to you by https://github.com/gtank/cryptopasta/ diff --git a/tls_test.go b/tls_test.go index 0bae6cd54..ea9ea93a6 100644 --- a/tls_test.go +++ b/tls_test.go @@ -17,6 +17,7 @@ package quickfix import ( "crypto/tls" + "os" "testing" "github.com/stretchr/testify/suite" @@ -26,8 +27,9 @@ import ( type TLSTestSuite struct { suite.Suite - settings *Settings - PrivateKeyFile, CertificateFile, CAFile string + settings *Settings + PrivateKeyFile, CertificateFile, CAFile string + PrivateKeyBytes, CertificateBytes, CABytes []byte } func TestTLSTestSuite(t *testing.T) { @@ -39,6 +41,18 @@ func (s *TLSTestSuite) SetupTest() { s.PrivateKeyFile = "_test_data/localhost.key" s.CertificateFile = "_test_data/localhost.crt" s.CAFile = "_test_data/ca.crt" + + privateKeyBytes, err := os.ReadFile(s.PrivateKeyFile) + s.Require().NoError(err) + s.PrivateKeyBytes = privateKeyBytes + + certificateBytes, err := os.ReadFile(s.CertificateFile) + s.Require().NoError(err) + s.CertificateBytes = certificateBytes + + caBytes, err := os.ReadFile(s.CAFile) + s.Require().NoError(err) + s.CABytes = caBytes } func (s *TLSTestSuite) TestLoadTLSNoSettings() { @@ -51,11 +65,13 @@ func (s *TLSTestSuite) TestLoadTLSMissingKeyOrCert() { s.settings.GlobalSettings().Set(config.SocketPrivateKeyFile, s.PrivateKeyFile) _, err := loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketCertificateFile") s.SetupTest() s.settings.GlobalSettings().Set(config.SocketCertificateFile, s.CertificateFile) _, err = loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketPrivateKeyFile") } func (s *TLSTestSuite) TestLoadTLSInvalidKeyOrCert() { @@ -63,6 +79,7 @@ func (s *TLSTestSuite) TestLoadTLSInvalidKeyOrCert() { s.settings.GlobalSettings().Set(config.SocketCertificateFile, "foo") _, err := loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "failed to load key pair: open foo: no such file or directory") } func (s *TLSTestSuite) TestLoadTLSNoCA() { @@ -86,6 +103,7 @@ func (s *TLSTestSuite) TestLoadTLSWithBadCA() { _, err := loadTLSConfig(s.settings.GlobalSettings()) s.NotNil(err) + s.EqualError(err, "failed to read CA bundle: open bar: no such file or directory") } func (s *TLSTestSuite) TestLoadTLSWithCA() { @@ -223,3 +241,93 @@ func (s *TLSTestSuite) TestMinimumTLSVersion() { s.NotNil(tlsConfig) s.Equal(tlsConfig.MinVersion, uint16(tls.VersionTLS12)) } + +func (s *TLSTestSuite) TestLoadTLSBytesMissingKeyOrCert() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + _, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketCertificateBytes") + + s.SetupTest() + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + _, err = loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "Conditionally Required Setting: SocketPrivateKeyBytes") +} + +func (s *TLSTestSuite) TestLoadTLSBytesInvalidKeyOrCert() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, []byte("not a cert")) + _, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "failed to parse key pair: tls: failed to find any PEM data in certificate input") + + s.SetupTest() + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, []byte("not a key")) + _, err = loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "failed to parse key pair: tls: failed to find any PEM data in key input") +} + +func (s *TLSTestSuite) TestLoadTLSBytesNoCA() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NoError(err) + s.NotNil(tlsConfig) + + s.Len(tlsConfig.Certificates, 1) + s.Nil(tlsConfig.RootCAs) + s.Nil(tlsConfig.ClientCAs) + s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) +} + +func (s *TLSTestSuite) TestLoadTLSBytesWithBadCA() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCABytes, []byte("bar")) + + _, err := loadTLSConfig(s.settings.GlobalSettings()) + s.NotNil(err) + s.EqualError(err, "failed to parse CA bundle from raw bytes") +} + +func (s *TLSTestSuite) TestLoadTLSBytesWithCA() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCABytes, s.CABytes) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.Len(tlsConfig.Certificates, 1) + s.NotNil(tlsConfig.RootCAs) + s.NotNil(tlsConfig.ClientCAs) + s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) +} + +func (s *TLSTestSuite) TestLoadTLSBytesWithOnlyCA() { + s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") + s.settings.GlobalSettings().SetRaw(config.SocketCABytes, s.CABytes) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.NotNil(tlsConfig.RootCAs) + s.NotNil(tlsConfig.ClientCAs) +} + +func (s *TLSTestSuite) TestServerNameWithCertsFromBytes() { + s.settings.GlobalSettings().SetRaw(config.SocketPrivateKeyBytes, s.PrivateKeyBytes) + s.settings.GlobalSettings().SetRaw(config.SocketCertificateBytes, s.CertificateBytes) + s.settings.GlobalSettings().Set(config.SocketServerName, "DummyServerNameWithCerts") + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + s.Equal("DummyServerNameWithCerts", tlsConfig.ServerName) +}