Skip to content

Commit

Permalink
Merge branch 'quickfixgo:main' into dynamic_session_2
Browse files Browse the repository at this point in the history
  • Loading branch information
haoyang1994 authored Aug 14, 2024
2 parents 80a93d4 + 9191a58 commit a5f5bfd
Show file tree
Hide file tree
Showing 11 changed files with 459 additions and 81 deletions.
41 changes: 41 additions & 0 deletions config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,47 @@ const (
// - A filepath to a file with read access.
SocketCAFile string = "SocketCAFile"

// SocketPrivateKeyBytes is an optional value containing raw bytes of a PEM
// encoded private key to use for secure TLS communications.
// Must be used with SocketCertificateBytes.
// Must contain PEM encoded data.
//
// Required: No
//
// Default: N/A
//
// Valid Values:
// - Raw bytes containing a valid PEM encoded private key.
SocketPrivateKeyBytes string = "SocketPrivateKeyBytes"

// SocketCertificateBytes is an optional value containing raw bytes of a PEM
// encoded certificate to use for secure TLS communications.
// Must be used with SocketPrivateKeyBytes.
// Must contain PEM encoded data.
//
// Required: No
//
// Default: N/A
//
// Valid Values:
// - Raw bytes containing a valid PEM encoded certificate.
SocketCertificateBytes string = "SocketCertificateBytes"

// SocketCABytes is an optional value containing raw bytes of a PEM encoded
// root CA to use for secure TLS communications. For acceptors, client
// certificates will be verified against this CA. For initiators, clients
// will use the CA to verify the server certificate. If not configured,
// initiators will verify the server certificates using the host's root CA
// set.
//
// Required: No
//
// Default: N/A
//
// Valid Values:
// - Raw bytes containing a valid PEM encoded CA.
SocketCABytes string = "SocketCABytes"

// SocketInsecureSkipVerify controls whether a client verifies the server's certificate chain and host name.
// If SocketInsecureSkipVerify is set to Y, crypto/tls accepts any certificate presented by the server and any host name in that certificate.
// In this mode, TLS is susceptible to machine-in-the-middle attacks unless custom verification is used.
Expand Down
18 changes: 16 additions & 2 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"github.com/quickfixgo/quickfix/config"
)

func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error) {
func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, err error) {
stdDialer := &net.Dialer{}
if settings.HasSetting(config.SocketTimeout) {
timeout, err := settings.DurationSetting(config.SocketTimeout)
Expand Down Expand Up @@ -73,9 +73,23 @@ func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error
}
}

dialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, dialer)
var proxyDialer proxy.Dialer

proxyDialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, stdDialer)
if err != nil {
return
}

if contextDialer, ok := proxyDialer.(proxy.ContextDialer); ok {
dialer = contextDialer
} else {
err = fmt.Errorf("proxy does not support context dialer")
return
}

default:
err = fmt.Errorf("unsupported proxy type %s", proxyType)
}

return
}
58 changes: 55 additions & 3 deletions field_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ func (m FieldMap) GetField(tag Tag, parser FieldValueReader) MessageRejectError
return nil
}

// GetField parses of a field with Tag tag. Returned reject may indicate the field is not present, or the field value is invalid.
func (m FieldMap) getFieldNoLock(tag Tag, parser FieldValueReader) MessageRejectError {
f, ok := m.tagLookup[tag]
if !ok {
return ConditionallyRequiredFieldMissing(tag)
}

if err := parser.Read(f[0].value); err != nil {
return IncorrectDataFormatForValue(tag)
}

return nil
}

// GetBytes is a zero-copy GetField wrapper for []bytes fields.
func (m FieldMap) GetBytes(tag Tag) ([]byte, MessageRejectError) {
m.rwLock.RLock()
Expand All @@ -128,6 +142,16 @@ func (m FieldMap) GetBytes(tag Tag) ([]byte, MessageRejectError) {
return f[0].value, nil
}

// getBytesNoLock is a lock free zero-copy GetField wrapper for []bytes fields.
func (m FieldMap) getBytesNoLock(tag Tag) ([]byte, MessageRejectError) {
f, ok := m.tagLookup[tag]
if !ok {
return nil, ConditionallyRequiredFieldMissing(tag)
}

return f[0].value, nil
}

// GetBool is a GetField wrapper for bool fields.
func (m FieldMap) GetBool(tag Tag) (bool, MessageRejectError) {
var val FIXBoolean
Expand All @@ -152,6 +176,21 @@ func (m FieldMap) GetInt(tag Tag) (int, MessageRejectError) {
return int(val), err
}

// GetInt is a lock free GetField wrapper for int fields.
func (m FieldMap) getIntNoLock(tag Tag) (int, MessageRejectError) {
bytes, err := m.getBytesNoLock(tag)
if err != nil {
return 0, err
}

var val FIXInt
if val.Read(bytes) != nil {
err = IncorrectDataFormatForValue(tag)
}

return int(val), err
}

// GetTime is a GetField wrapper for utc timestamp fields.
func (m FieldMap) GetTime(tag Tag) (t time.Time, err MessageRejectError) {
m.rwLock.RLock()
Expand Down Expand Up @@ -179,6 +218,15 @@ func (m FieldMap) GetString(tag Tag) (string, MessageRejectError) {
return string(val), nil
}

// GetString is a GetField wrapper for string fields.
func (m FieldMap) getStringNoLock(tag Tag) (string, MessageRejectError) {
var val FIXString
if err := m.getFieldNoLock(tag, &val); err != nil {
return "", err
}
return string(val), nil
}

// GetGroup is a Get function specific to Group Fields.
func (m FieldMap) GetGroup(parser FieldGroupReader) MessageRejectError {
m.rwLock.RLock()
Expand Down Expand Up @@ -246,6 +294,13 @@ func (m *FieldMap) Clear() {
}
}

func (m *FieldMap) clearNoLock() {
m.tags = m.tags[0:0]
for k := range m.tagLookup {
delete(m.tagLookup, k)
}
}

// CopyInto overwrites the given FieldMap with this one.
func (m *FieldMap) CopyInto(to *FieldMap) {
m.rwLock.RLock()
Expand All @@ -263,9 +318,6 @@ func (m *FieldMap) CopyInto(to *FieldMap) {
}

func (m *FieldMap) add(f field) {
m.rwLock.Lock()
defer m.rwLock.Unlock()

t := fieldTag(f)
if _, ok := m.tagLookup[t]; !ok {
m.tags = append(m.tags, t)
Expand Down
26 changes: 23 additions & 3 deletions initiator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package quickfix

import (
"bufio"
"context"
"crypto/tls"
"strings"
"sync"
Expand Down Expand Up @@ -50,7 +51,7 @@ func (i *Initiator) Start() (err error) {
return
}

var dialer proxy.Dialer
var dialer proxy.ContextDialer
if dialer, err = loadDialerConfig(settings); err != nil {
return
}
Expand Down Expand Up @@ -142,7 +143,7 @@ func (i *Initiator) waitForReconnectInterval(reconnectInterval time.Duration) bo
return true
}

func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.Dialer) {
func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.ContextDialer) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
Expand All @@ -162,14 +163,27 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
return
}

ctx, cancel := context.WithCancel(context.Background())

// We start a goroutine in order to be able to cancel the dialer mid-connection
// on receiving a stop signal to stop the initiator.
go func() {
select {
case <-i.stopChan:
cancel()
case <-ctx.Done():
return
}
}()

var disconnected chan interface{}
var msgIn chan fixIn
var msgOut chan []byte

address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)]
session.log.OnEventf("Connecting to: %v", address)

netConn, err := dialer.Dial("tcp", address)
netConn, err := dialer.DialContext(ctx, "tcp", address)
if err != nil {
session.log.OnEventf("Failed to connect: %v", err)
goto reconnect
Expand Down Expand Up @@ -208,13 +222,19 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
close(disconnected)
}()

// This ensures we properly cleanup the goroutine and context used for
// dial cancelation after successful connection.
cancel()

select {
case <-disconnected:
case <-i.stopChan:
return
}

reconnect:
cancel()

connectionAttempt++
session.log.OnEventf("Reconnecting in %v", session.ReconnectInterval)
if !i.waitForReconnectInterval(session.ReconnectInterval) {
Expand Down
32 changes: 19 additions & 13 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,20 @@ func ParseMessageWithDataDictionary(

// doParsing executes the message parsing process.
func doParsing(mp *msgParser) (err error) {
mp.msg.Header.rwLock.Lock()
defer mp.msg.Header.rwLock.Unlock()
mp.msg.Body.rwLock.Lock()
defer mp.msg.Body.rwLock.Unlock()
mp.msg.Trailer.rwLock.Lock()
defer mp.msg.Trailer.rwLock.Unlock()

// Initialize for parsing.
mp.msg.Header.Clear()
mp.msg.Body.Clear()
mp.msg.Trailer.Clear()
mp.msg.Header.clearNoLock()
mp.msg.Body.clearNoLock()
mp.msg.Trailer.clearNoLock()

// Allocate expected message fields in one chunk.
fieldCount := 0
for _, b := range mp.rawBytes {
if b == '\001' {
fieldCount++
}
}
fieldCount := bytes.Count(mp.rawBytes, []byte{'\001'})
if fieldCount == 0 {
return parseError{OrigError: fmt.Sprintf("No Fields detected in %s", string(mp.rawBytes))}
}
Expand Down Expand Up @@ -267,7 +269,7 @@ func doParsing(mp *msgParser) (err error) {
}

if mp.parsedFieldBytes.tag == tagXMLDataLen {
xmlDataLen, _ = mp.msg.Header.GetInt(tagXMLDataLen)
xmlDataLen, _ = mp.msg.Header.getIntNoLock(tagXMLDataLen)
}
mp.fieldIndex++
}
Expand All @@ -292,7 +294,7 @@ func doParsing(mp *msgParser) (err error) {
}
}

bodyLength, err := mp.msg.Header.GetInt(tagBodyLength)
bodyLength, err := mp.msg.Header.getIntNoLock(tagBodyLength)
if err != nil {
err = parseError{OrigError: err.Error()}
} else if length != bodyLength && !xmlDataMsg {
Expand Down Expand Up @@ -373,7 +375,7 @@ func parseGroup(mp *msgParser, tags []Tag) {
// tags slice will contain multiple tags if the tag in question is found while processing a group already.
func isNumInGroupField(msg *Message, tags []Tag, appDataDictionary *datadictionary.DataDictionary) bool {
if appDataDictionary != nil {
msgt, err := msg.MsgType()
msgt, err := msg.msgTypeNoLock()
if err != nil {
return false
}
Expand Down Expand Up @@ -406,7 +408,7 @@ func isNumInGroupField(msg *Message, tags []Tag, appDataDictionary *datadictiona
// tags slice will contain multiple tags if the tag in question is found while processing a group already.
func getGroupFields(msg *Message, tags []Tag, appDataDictionary *datadictionary.DataDictionary) (fields []*datadictionary.FieldDef) {
if appDataDictionary != nil {
msgt, err := msg.MsgType()
msgt, err := msg.msgTypeNoLock()
if err != nil {
return
}
Expand Down Expand Up @@ -476,6 +478,10 @@ func (m *Message) MsgType() (string, MessageRejectError) {
return m.Header.GetString(tagMsgType)
}

func (m *Message) msgTypeNoLock() (string, MessageRejectError) {
return m.Header.getStringNoLock(tagMsgType)
}

// IsMsgTypeOf returns true if the Header contains MsgType (tag 35) field and its value is the specified one.
func (m *Message) IsMsgTypeOf(msgType string) bool {
if v, err := m.MsgType(); err == nil {
Expand Down
6 changes: 3 additions & 3 deletions session_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ func (f sessionFactory) newSession(
for _, dayStr := range dayStrs {
day, ok := dayLookup[dayStr]
if !ok {
err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: weekdaysStr}
err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: []byte(weekdaysStr)}
return
}
weekdays = append(weekdays, day)
Expand Down Expand Up @@ -315,7 +315,7 @@ func (f sessionFactory) newSession(
parseDay := func(setting, dayStr string) (day time.Weekday, err error) {
day, ok := dayLookup[dayStr]
if !ok {
return day, IncorrectFormatForSetting{Setting: setting, Value: dayStr}
return day, IncorrectFormatForSetting{Setting: setting, Value: []byte(dayStr)}
}
return
}
Expand Down Expand Up @@ -355,7 +355,7 @@ func (f sessionFactory) newSession(
s.timestampPrecision = Nanos

default:
err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: precisionStr}
err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: []byte(precisionStr)}
return
}
}
Expand Down
Loading

0 comments on commit a5f5bfd

Please sign in to comment.