Skip to content

Commit

Permalink
update writer trait to use span
Browse files Browse the repository at this point in the history
  • Loading branch information
thatstoasty committed Jul 9, 2024
1 parent 577c302 commit 87430e6
Show file tree
Hide file tree
Showing 11 changed files with 21 additions and 123 deletions.
17 changes: 1 addition & 16 deletions gojo/bufio/bufio.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ struct Writer[W: io.Writer, size: Int = io.BUFFER_SIZE](
"""
return self.bytes_written

fn _write(inout self, src: Span[UInt8]) -> (Int, Error):
fn write(inout self, src: Span[UInt8]) -> (Int, Error):
"""Writes the contents of src into the buffer.
It returns the number of bytes written.
If nn < len(src), it also returns an error explaining
Expand Down Expand Up @@ -817,21 +817,6 @@ struct Writer[W: io.Writer, size: Int = io.BUFFER_SIZE](
total_bytes_written += n
return total_bytes_written, err

fn write(inout self, src: List[UInt8]) -> (Int, Error):
"""
Appends a byte List to the builder buffer.
Args:
src: The byte array to append.
"""
var span = Span(src)

var bytes_read: Int
var err: Error
bytes_read, err = self._write(span)

return bytes_read, err

fn write_byte(inout self, src: UInt8) -> (Int, Error):
"""Writes a single byte to the internal buffer.
Expand Down
31 changes: 8 additions & 23 deletions gojo/bytes/buffer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ alias MIN_READ: Int = 512
# ERR_TOO_LARGE is passed to panic if memory cannot be allocated to store data in a buffer.
alias ERR_TOO_LARGE = "buffer.Buffer: too large"
alias ERR_NEGATIVE_READ = "buffer.Buffer: reader returned negative count from read"
alias ERR_SHORT_WRITE = "short write"
alias ERR_SHORTwrite = "short write"


struct Buffer(
Expand Down Expand Up @@ -159,7 +159,7 @@ struct Buffer(
"""
return self.as_string_slice()

fn _write(inout self, src: Span[UInt8]) -> (Int, Error):
fn write(inout self, src: Span[UInt8]) -> (Int, Error):
"""
Appends a byte Span to the builder buffer.
Expand All @@ -173,29 +173,14 @@ struct Buffer(

return len(src), Error()

fn write(inout self, src: List[UInt8]) -> (Int, Error):
"""
Appends a byte List to the builder buffer.
Args:
src: The byte array to append.
"""
var span = Span(src)

var bytes_read: Int
var err: Error
bytes_read, err = self._write(span)

return bytes_read, err

fn write_string(inout self, src: String) -> (Int, Error):
"""
Appends a string to the builder buffer.
Args:
src: The string to append.
"""
return self._write(src.as_bytes_slice())
return self.write(src.as_bytes_slice())

fn write_byte(inout self, byte: UInt8) -> (Int, Error):
"""Appends the byte c to the buffer, growing the buffer as needed.
Expand Down Expand Up @@ -421,14 +406,14 @@ struct Buffer(
# The number of bytes written to the writer.
# """
# self.last_read = OP_INVALID
# var bytes_to_write = len(self)
# var bytes_towrite = len(self)
# var total_bytes_written: Int = 0

# if bytes_to_write > 0:
# if bytes_towrite > 0:
# var bytes_written: Int
# var err: Error
# bytes_written, err = writer.write(self.as_bytes_slice()[self.offset :])
# if bytes_written > bytes_to_write:
# if bytes_written > bytes_towrite:
# panic("bytes.Buffer.write_to: invalid write count")

# self.offset += bytes_written
Expand All @@ -437,8 +422,8 @@ struct Buffer(
# return total_bytes_written, err

# # all bytes should have been written, by definition of write method in io.Writer
# if bytes_written != bytes_to_write:
# return total_bytes_written, Error(ERR_SHORT_WRITE)
# if bytes_written != bytes_towrite:
# return total_bytes_written, Error(ERR_SHORTwrite)

# # Buffer is now empty; reset.
# self.reset()
Expand Down
5 changes: 1 addition & 4 deletions gojo/io/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ trait Writer(Movable):
Implementations must not retain p.
"""

fn _write(inout self, src: Span[UInt8]) -> (Int, Error):
...

fn write(inout self, src: List[UInt8]) -> (Int, Error):
fn write(inout self, src: Span[UInt8, _]) -> (Int, Error):
...


Expand Down
5 changes: 1 addition & 4 deletions gojo/io/file.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ struct FileWrapper(io.ReadWriteCloser, io.ByteReader):
except e:
return 0, e

fn _write(inout self, src: Span[UInt8]) -> (Int, Error):
fn write(inout self, src: Span[UInt8]) -> (Int, Error):
if len(src) == 0:
return 0, Error("No data to write")

Expand All @@ -134,6 +134,3 @@ struct FileWrapper(io.ReadWriteCloser, io.ByteReader):
return len(src), io.EOF
except e:
return 0, Error(str(e))

fn write(inout self, src: List[UInt8]) -> (Int, Error):
return self._write(Span(src))
2 changes: 1 addition & 1 deletion gojo/io/io.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn write_string[W: Writer](inout writer: W, string: String) -> (Int, Error):
Returns:
The number of bytes written and an error, if any.
"""
return writer.write(string.as_bytes())
return writer.write(string.as_bytes_slice())


fn write_string[W: StringWriter](inout writer: W, string: String) -> (Int, Error):
Expand Down
22 changes: 3 additions & 19 deletions gojo/io/std.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ import ..io
struct STDWriter[file_descriptor: Int](Copyable, io.Writer, io.StringWriter):
"""A writer for POSIX file descriptors."""

@always_inline
fn __init__(inout self):
constrained[
file_descriptor == 1 or file_descriptor == 2,
"The STDWriter Struct is meant to write to STDOUT and STDERR. file_descriptor must be 1 or 2.",
]()

@always_inline
fn _write(inout self, src: Span[UInt8]) -> (Int, Error):
fn write(inout self, src: Span[UInt8]) -> (Int, Error):
"""Writes the given bytes to the file descriptor.
Args:
Expand All @@ -31,19 +29,6 @@ struct STDWriter[file_descriptor: Int](Copyable, io.Writer, io.StringWriter):

return write_count, Error()

@always_inline
fn write(inout self, src: List[UInt8]) -> (Int, Error):
"""Writes the given bytes to the file descriptor.
Args:
src: The bytes to write to the file descriptor.
Returns:
The number of bytes written to the file descriptor.
"""
return self._write(Span(src))

@always_inline
fn write_string(inout self, src: String) -> (Int, Error):
"""Writes the given string to the file descriptor.
Expand All @@ -53,9 +38,8 @@ struct STDWriter[file_descriptor: Int](Copyable, io.Writer, io.StringWriter):
Returns:
The number of bytes written to the file descriptor.
"""
return self._write(src.as_bytes_slice())
return self.write(src.as_bytes_slice())

@always_inline
fn read_from[R: io.Reader](inout self, inout reader: R) -> (Int, Error):
"""Reads from the given reader to a temporary buffer and writes to the file descriptor.
Expand All @@ -67,4 +51,4 @@ struct STDWriter[file_descriptor: Int](Copyable, io.Writer, io.StringWriter):
"""
var buffer = List[UInt8](capacity=io.BUFFER_SIZE)
_ = reader.read(buffer)
return self._write(Span(buffer))
return self.write(Span(buffer))
6 changes: 1 addition & 5 deletions gojo/net/fd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@ struct FileDescriptor(FileDescriptorBase):

return bytes_read, err

fn write(inout self, src: List[UInt8]) -> (Int, Error):
"""Write data from the buffer to the file descriptor."""
return self._write(Span(src))

fn _write(inout self, src: Span[UInt8]) -> (Int, Error):
fn write(inout self, src: Span[UInt8]) -> (Int, Error):
"""Write data from the buffer to the file descriptor."""
var bytes_sent = send(self.fd, src.unsafe_ptr(), len(src), 0)
if bytes_sent == -1:
Expand Down
13 changes: 1 addition & 12 deletions gojo/net/socket.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -350,18 +350,7 @@ struct Socket(FileDescriptorBase):
self.remote_address = BaseAddr(remote.host, remote.port)
return Error()

fn _write(inout self: Self, src: Span[UInt8]) -> (Int, Error):
"""Send data to the socket. The socket must be connected to a remote socket.
Args:
src: The data to send.
Returns:
The number of bytes sent.
"""
return self.fd._write(src)

fn write(inout self: Self, src: List[UInt8]) -> (Int, Error):
fn write(inout self: Self, src: Span[UInt8]) -> (Int, Error):
"""Send data to the socket. The socket must be connected to a remote socket.
Args:
Expand Down
22 changes: 1 addition & 21 deletions gojo/net/tcp.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,12 @@ struct TCPConnection(Movable):

var socket: Socket

@always_inline
fn __init__(inout self, owned socket: Socket):
self.socket = socket^

@always_inline
fn __moveinit__(inout self, owned existing: Self):
self.socket = existing.socket^

@always_inline
fn _read(inout self, inout dest: Span[UInt8], capacity: Int) -> (Int, Error):
"""Reads data from the underlying file descriptor.
Expand All @@ -74,7 +71,6 @@ struct TCPConnection(Movable):

return bytes_read, err

@always_inline
fn read(inout self, inout dest: List[UInt8]) -> (Int, Error):
"""Reads data from the underlying file descriptor.
Expand All @@ -93,20 +89,7 @@ struct TCPConnection(Movable):

return bytes_read, err

@always_inline
fn _write(inout self, src: Span[UInt8]) -> (Int, Error):
"""Writes data to the underlying file descriptor.
Args:
src: The buffer to read data into.
Returns:
The number of bytes written, or an error if one occurred.
"""
return self.socket._write(src)

@always_inline
fn write(inout self, src: List[UInt8]) -> (Int, Error):
fn write(inout self, src: Span[UInt8]) -> (Int, Error):
"""Writes data to the underlying file descriptor.
Args:
Expand All @@ -117,7 +100,6 @@ struct TCPConnection(Movable):
"""
return self.socket.write(src)

@always_inline
fn close(inout self) -> Error:
"""Closes the underlying file descriptor.
Expand All @@ -126,7 +108,6 @@ struct TCPConnection(Movable):
"""
return self.socket.close()

@always_inline
fn local_address(self) -> TCPAddr:
"""Returns the local network address.
The Addr returned is shared by all invocations of local_address, so do not modify it.
Expand All @@ -136,7 +117,6 @@ struct TCPConnection(Movable):
"""
return self.socket.local_address_as_tcp()

@always_inline
fn remote_address(self) -> TCPAddr:
"""Returns the remote network address.
The Addr returned is shared by all invocations of remote_address, so do not modify it.
Expand Down
19 changes: 2 additions & 17 deletions gojo/strings/builder.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct StringBuilder[growth_factor: Float32 = 2](
new_capacity = self._capacity + bytes_to_add
self._resize(new_capacity)

fn _write(inout self, src: Span[UInt8]) -> (Int, Error):
fn write(inout self, src: Span[UInt8]) -> (Int, Error):
"""
Appends a byte Span to the builder buffer.
Expand All @@ -121,29 +121,14 @@ struct StringBuilder[growth_factor: Float32 = 2](

return len(src), Error()

fn write(inout self, src: List[UInt8]) -> (Int, Error):
"""
Appends a byte List to the builder buffer.
Args:
src: The byte array to append.
"""
var span = Span(src)

var bytes_read: Int
var err: Error
bytes_read, err = self._write(span)

return bytes_read, err

fn write_string(inout self, src: String) -> (Int, Error):
"""
Appends a string to the builder buffer.
Args:
src: The string to append.
"""
return self._write(src.as_bytes_slice())
return self.write(src.as_bytes_slice())

fn write_byte(inout self, byte: UInt8) -> (Int, Error):
self._resize_if_needed(1)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bufio.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ fn test_big_write():

# When writing, it should bypass the Bufio struct's buffer and write directly to the underlying bytes buffer writer. So, no need to flush.
var text = str(builder)
_ = writer.write(text.as_bytes())
_ = writer.write(text.as_bytes_slice())
test.assert_equal(len(writer.writer), 5000)
test.assert_equal(text[0], "0")
test.assert_equal(text[4999], "9")
Expand Down

0 comments on commit 87430e6

Please sign in to comment.