diff --git a/apduWrapper.go b/apduWrapper.go index 3fc21fb..e9103fb 100644 --- a/apduWrapper.go +++ b/apduWrapper.go @@ -126,7 +126,7 @@ func WrapCommandAPDU( } // UnwrapResponseAPDU parses a response of 64 byte packets into the real data -func UnwrapResponseAPDU(channel uint16, pipe <- chan []byte, packetSize int) ([]byte, error) { +func UnwrapResponseAPDU(channel uint16, pipe <-chan []byte, packetSize int) ([]byte, error) { var sequenceIdx uint16 var totalResult []byte @@ -135,7 +135,7 @@ func UnwrapResponseAPDU(channel uint16, pipe <- chan []byte, packetSize int) ([] for !done { // Read next packet from the channel - buffer := <- pipe + buffer := <-pipe result, responseSize, err := DeserializePacket(channel, buffer, sequenceIdx) if err != nil { @@ -157,4 +157,4 @@ func UnwrapResponseAPDU(channel uint16, pipe <- chan []byte, packetSize int) ([] // Remove trailing zeros totalResult = totalResult[:totalSize] return totalResult, nil -} \ No newline at end of file +} diff --git a/apduWrapper_test.go b/apduWrapper_test.go index 9f82bd3..29d2b1c 100644 --- a/apduWrapper_test.go +++ b/apduWrapper_test.go @@ -26,7 +26,7 @@ import ( ) func Test_SerializePacket_EmptyCommand(t *testing.T) { - var command= make([]byte, 1) + var command = make([]byte, 1) _, _, err := SerializePacket(0x0101, command, 64, 0) assert.Nil(t, err, "Commands smaller than 3 bytes should return error") @@ -42,9 +42,9 @@ func Test_SerializePacket_PacketSize(t *testing.T) { commandLen uint16 } - h := header{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 32} + h := header{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 32} - var command= make([]byte, h.commandLen) + var command = make([]byte, h.commandLen) result, _, _ := SerializePacket( h.channel, @@ -65,9 +65,9 @@ func Test_SerializePacket_Header(t *testing.T) { commandLen uint16 } - h := header{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 32} + h := header{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 32} - var command= make([]byte, h.commandLen) + var command = make([]byte, h.commandLen) result, _, _ := SerializePacket( h.channel, @@ -91,9 +91,9 @@ func Test_SerializePacket_Offset(t *testing.T) { commandLen uint16 } - h := header{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 100} + h := header{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 100} - var command= make([]byte, h.commandLen) + var command = make([]byte, h.commandLen) _, offset, _ := SerializePacket( h.channel, @@ -101,7 +101,7 @@ func Test_SerializePacket_Offset(t *testing.T) { packetSize, h.sequenceIdx) - assert.Equal(t, packetSize - int(unsafe.Sizeof(h))+1, offset, "Wrong offset returned. Offset must point to the next comamnd byte that needs to be packet-ized.") + assert.Equal(t, packetSize-int(unsafe.Sizeof(h))+1, offset, "Wrong offset returned. Offset must point to the next comamnd byte that needs to be packet-ized.") } func Test_WrapCommandAPDU_NumberOfPackets(t *testing.T) { @@ -119,9 +119,9 @@ func Test_WrapCommandAPDU_NumberOfPackets(t *testing.T) { tag uint8 } - h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 100} + h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 100} - var command= make([]byte, h1.commandLen) + var command = make([]byte, h1.commandLen) result, _ := WrapCommandAPDU( h1.channel, @@ -146,9 +146,9 @@ func Test_WrapCommandAPDU_CheckHeaders(t *testing.T) { tag uint8 } - h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 100} + h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 100} - var command= make([]byte, h1.commandLen) + var command = make([]byte, h1.commandLen) result, _ := WrapCommandAPDU( h1.channel, @@ -181,9 +181,9 @@ func Test_WrapCommandAPDU_CheckData(t *testing.T) { tag uint8 } - h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 200} + h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 200} - var command= make([]byte, h1.commandLen) + var command = make([]byte, h1.commandLen) for i := range command { command[i] = byte(i % 256) @@ -228,9 +228,9 @@ func Test_DeserializePacket_FirstPacket(t *testing.T) { output, totalSize, err := DeserializePacket(0x0101, packet, 0) - assert.Nil(t,err, "Simple deserialize should not have errors") + assert.Nil(t, err, "Simple deserialize should not have errors") assert.Equal(t, len(sampleCommand), int(totalSize), "TotalSize is incorrect") - assert.Equal(t, packetSize - firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong") + assert.Equal(t, packetSize-firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong") assert.True(t, bytes.Compare(output[:len(sampleCommand)], sampleCommand) == 0, "Deserialized message does not match the original") } @@ -243,9 +243,9 @@ func Test_DeserializePacket_SecondMessage(t *testing.T) { output, totalSize, err := DeserializePacket(0x0101, packet, 1) - assert.Nil(t,err, "Simple deserialize should not have errors") + assert.Nil(t, err, "Simple deserialize should not have errors") assert.Equal(t, 0, int(totalSize), "TotalSize should not be returned from deserialization of non-first packet") - assert.Equal(t, packetSize - firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong") + assert.Equal(t, packetSize-firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong") assert.True(t, bytes.Compare(output[:len(sampleCommand)], sampleCommand) == 0, "Deserialized message does not match the original") } @@ -256,7 +256,7 @@ func Test_UnwrapApdu_SmokeTest(t *testing.T) { var packetSize int = 64 // Initialize some dummy input - var input= make([]byte, inputSize) + var input = make([]byte, inputSize) for i := range input { input[i] = byte(i % 256) } @@ -264,7 +264,7 @@ func Test_UnwrapApdu_SmokeTest(t *testing.T) { serialized, _ := WrapCommandAPDU(channel, input, packetSize) // Allocate enough buffers to keep all the packets - pipe := make(chan []byte, int(math.Ceil(float64(inputSize) / float64(packetSize)))) + pipe := make(chan []byte, int(math.Ceil(float64(inputSize)/float64(packetSize)))) // Send all the packets to the pipe for len(serialized) > 0 { pipe <- serialized[:packetSize] diff --git a/ledger.go b/ledger.go index 0642993..c3073ec 100644 --- a/ledger.go +++ b/ledger.go @@ -19,8 +19,9 @@ package ledger_go import ( "errors" "fmt" - "github.com/zondax/hid" "sync" + + "github.com/zondax/hid" ) const ( @@ -34,7 +35,7 @@ const ( type Ledger struct { device hid.Device readCo sync.Once - readChannel chan [] byte + readChannel chan []byte Logging bool } @@ -70,23 +71,17 @@ func FindLedger() (*Ledger, error) { devices := hid.Enumerate(VendorLedger, 0) for _, d := range devices { - if d.VendorID == VendorLedger && d.UsagePage == UsagePageLedger { - device, err := d.Open() - if err != nil { - return nil, err - } - return NewLedger(device), nil - } + deviceFound := d.UsagePage == UsagePageLedger + deviceFound = deviceFound || (d.Product == "Nano S" && d.Interface == 0) - // Linux discovery - if d.VendorID == VendorLedger && d.Product == "Nano S" && d.Interface == 0 { + if deviceFound { device, err := d.Open() - if err != nil { - return nil, err + if err == nil { + return NewLedger(device), nil } - return NewLedger(device), nil } } + return nil, errors.New("no ledger connected") } @@ -126,6 +121,10 @@ func ErrorMessage(errorCode uint16) string { } } +func (ledger *Ledger) Close() error { + return ledger.device.Close() +} + func (ledger *Ledger) Write(buffer []byte) (int, error) { totalBytes := len(buffer) totalWrittenBytes := 0 @@ -150,7 +149,7 @@ func (ledger *Ledger) Read() <-chan []byte { return ledger.readChannel } -func (ledger *Ledger) initReadChannel(){ +func (ledger *Ledger) initReadChannel() { ledger.readChannel = make(chan []byte, 30) go ledger.readThread() } diff --git a/ledger_test.go b/ledger_test.go index 298d7b8..01b83ba 100644 --- a/ledger_test.go +++ b/ledger_test.go @@ -21,8 +21,8 @@ package ledger_go import ( "encoding/hex" "fmt" - "github.com/zondax/hid" "github.com/stretchr/testify/assert" + "github.com/zondax/hid" "testing" ) @@ -41,7 +41,7 @@ func Test_FindLedger(t *testing.T) { fmt.Println("\n*********************************") fmt.Println("Did you enter the password??") fmt.Println("*********************************") - t.Fatalf( "Error: %s", err.Error()) + t.Fatalf("Error: %s", err.Error()) } assert.NotNil(t, ledger) } @@ -52,7 +52,7 @@ func Test_BasicExchange(t *testing.T) { fmt.Println("\n*********************************") fmt.Println("Did you enter the password??") fmt.Println("*********************************") - t.Fatalf( "Error: %s", err.Error()) + t.Fatalf("Error: %s", err.Error()) } assert.NotNil(t, ledger) @@ -63,7 +63,7 @@ func Test_BasicExchange(t *testing.T) { if err != nil { fmt.Printf("iteration %d\n", i) - t.Fatalf( "Error: %s", err.Error()) + t.Fatalf("Error: %s", err.Error()) } assert.Equal(t, 4, len(response)) @@ -76,23 +76,23 @@ func Test_LongExchange(t *testing.T) { fmt.Println("\n*********************************") fmt.Println("Did you enter the password??") fmt.Println("*********************************") - t.Fatalf( "Error: %s", err.Error()) + t.Fatalf("Error: %s", err.Error()) } assert.NotNil(t, ledger) - path := "052c000080760000800000008000000000000000000000000000000000000000000000000000000000"; + path := "052c000080760000800000008000000000000000000000000000000000000000000000000000000000" pathBytes, err := hex.DecodeString(path) if err != nil { t.Fatalf("invalid path in test") } - header := []byte { 0x55, 1, 0, 0, byte(len(pathBytes))} + header := []byte{0x55, 1, 0, 0, byte(len(pathBytes))} message := append(header, pathBytes...) response, err := ledger.Exchange(message) if err != nil { - t.Fatalf( "Error: %s", err.Error()) + t.Fatalf("Error: %s", err.Error()) } assert.Equal(t, 65, len(response))