Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Commit

Permalink
Merge pull request #95 from maralla/msg_corrupt
Browse files Browse the repository at this point in the history
fix read_struct
  • Loading branch information
lxyu committed Mar 8, 2015
2 parents 50fa59f + e363d07 commit 8c61f7d
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 10 deletions.
114 changes: 114 additions & 0 deletions tests/test_type_mismatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from unittest import TestCase

from thriftpy.thrift import TType, TPayload

from thriftpy.transport.memory import TMemoryBuffer
from thriftpy.protocol.binary import TBinaryProtocol

from thriftpy._compat import CYTHON


class Struct(TPayload):
thrift_spec = {
1: (TType.I32, 'a', False),
2: (TType.STRING, 'b', False),
3: (TType.DOUBLE, 'c', False)
}
default_spec = [('a', None), ('b', None), ('c', None)]


class TItem(TPayload):
thrift_spec = {
1: (TType.I32, "id", False),
2: (TType.LIST, "phones", TType.STRING, False),
3: (TType.MAP, "addr", (TType.I32, TType.STRING), False),
4: (TType.LIST, "data", (TType.STRUCT, Struct), False)
}
default_spec = [("id", None), ("phones", None), ("addr", None),
("data", None)]


class MismatchTestCase(TestCase):
BUFFER = TMemoryBuffer
PROTO = TBinaryProtocol

def test_list_type_mismatch(self):
class TMismatchItem(TPayload):
thrift_spec = {
1: (TType.I32, "id", False),
2: (TType.LIST, "phones", (TType.I32, False), False),
}
default_spec = [("id", None), ("phones", None)]

t = self.BUFFER()
p = self.PROTO(t)

item = TItem(id=37, phones=["23424", "235125"])
p.write_struct(item)
p.write_message_end()

item2 = TMismatchItem()
p.read_struct(item2)

assert item2.phones == []

def test_map_type_mismatch(self):
class TMismatchItem(TPayload):
thrift_spec = {
1: (TType.I32, "id", False),
3: (TType.MAP, "addr", (TType.STRING, TType.STRING), False)
}
default_spec = [("id", None), ("addr", None)]

t = self.BUFFER()
p = self.PROTO(t)

item = TItem(id=37, addr={1: "hello", 2: "world"})
p.write_struct(item)
p.write_message_end()

item2 = TMismatchItem()
p.read_struct(item2)

assert item2.addr == {}

def test_struct_mismatch(self):
class MismatchStruct(TPayload):
thrift_spec = {
1: (TType.STRING, 'a', False),
2: (TType.STRING, 'b', False)
}
default_spec = [('a', None), ('b', None)]

class TMismatchItem(TPayload):
thrift_spec = {
1: (TType.I32, "id", False),
2: (TType.LIST, "phones", TType.STRING, False),
3: (TType.MAP, "addr", (TType.I32, TType.STRING), False),
4: (TType.LIST, "data", (TType.STRUCT, MismatchStruct), False)
}
default_spec = [("id", None), ("phones", None), ("addr", None)]

t = self.BUFFER()
p = self.PROTO(t)

item = TItem(id=37, data=[Struct(a=1, b="hello", c=0.123),
Struct(a=2, b="world", c=34.342346),
Struct(a=3, b="when", c=25235.14)])
p.write_struct(item)
p.write_message_end()

item2 = TMismatchItem()
p.read_struct(item2)

assert len(item2.data) == 3
assert all([i.b for i in item2.data])


if CYTHON:
from thriftpy.transport.memory import TCyMemoryBuffer
from thriftpy.protocol.cybin import TCyBinaryProtocol

class CyMismatchTestCase(MismatchTestCase):
BUFFER = TCyMemoryBuffer
PROTO = TCyBinaryProtocol
15 changes: 10 additions & 5 deletions thriftpy/protocol/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ def read_val(inbuf, ttype, spec=None):
r_type, sz = read_list_begin(inbuf)
# the v_type is useless here since we already get it from spec
if r_type != v_type:
raise Exception("Message Corrupt")
for _ in range(sz):
skip(inbuf, r_type)
return []

for i in range(sz):
result.append(read_val(inbuf, v_type, v_spec))
Expand All @@ -265,7 +267,10 @@ def read_val(inbuf, ttype, spec=None):
result = {}
sk_type, sv_type, sz = read_map_begin(inbuf)
if sk_type != k_type or sv_type != v_type:
raise Exception("Message Corrupt")
for _ in range(sz):
skip(inbuf, sk_type)
skip(inbuf, sv_type)
return {}

for i in range(sz):
k_val = read_val(inbuf, k_type, k_spec)
Expand All @@ -281,8 +286,7 @@ def read_val(inbuf, ttype, spec=None):


def read_struct(inbuf, obj):
# The max loop count equals field count + a final stop byte.
for i in range(len(obj.thrift_spec) + 1):
while True:
f_type, fid = read_field_begin(inbuf)
if f_type == TType.STOP:
break
Expand All @@ -300,7 +304,8 @@ def read_struct(inbuf, obj):
# it really should equal here. but since we already wasted
# space storing the duplicate info, let's check it.
if f_type != sf_type:
raise Exception("Message Corrupt")
skip(inbuf, f_type)
continue

setattr(obj, f_name, read_val(inbuf, f_type, f_container_spec))

Expand Down
16 changes: 11 additions & 5 deletions thriftpy/protocol/cybin/cybin.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ cdef inline int write_double(CyTransportBase buf, double val) except -1:

cdef inline read_struct(CyTransportBase buf, obj):
cdef dict field_specs = obj.thrift_spec
cdef int fid, i
cdef int fid
cdef TType field_type, ttype
cdef tuple field_spec
cdef str name

for i in range(len(field_specs) + 1):
while True:
field_type = <TType>read_i08(buf)
if field_type == T_STOP:
break
Expand All @@ -114,7 +114,8 @@ cdef inline read_struct(CyTransportBase buf, obj):
field_spec = field_specs[fid]
ttype = field_spec[0]
if field_type != ttype:
raise ProtocolError("Message Corrupt")
skip(buf, field_type)
continue

name = field_spec[1]
if len(field_spec) == 2:
Expand Down Expand Up @@ -211,7 +212,9 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None):
size = read_i32(buf)

if orig_type != v_type:
raise ProtocolError("Message Corrupt")
for _ in range(size):
skip(buf, orig_type)
return []

return [c_read_val(buf, v_type, v_spec) for _ in range(size)]

Expand All @@ -237,7 +240,10 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None):
size = read_i32(buf)

if orig_key_type != k_type or orig_type != v_type:
raise ProtocolError("Message Corrupt")
for _ in range(size):
skip(buf, orig_key_type)
skip(buf, orig_type)
return {}

return {c_read_val(buf, k_type, k_spec): c_read_val(buf, v_type, v_spec)
for _ in range(size)}
Expand Down

0 comments on commit 8c61f7d

Please sign in to comment.