Skip to content

Commit

Permalink
update socket stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
thatstoasty committed Sep 3, 2024
1 parent e724031 commit a4ab45e
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 233 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,6 @@ cython_debug/

# Rattler
output

# Mojo
**/*.mojopkg
40 changes: 40 additions & 0 deletions scripts/publish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import sys
from pathlib import Path
import hashlib
import os
import requests

channel = "https://prefix.dev/api/v1/upload/mojo-community"
token = os.environ.get("PREFIX_API_KEY")
if not token:
print("Please set PREFIX_API_KEY to your Prefix API key.")
sys.exit(1)

def upload(fn):
data = fn.read_bytes()

# skip if larger than 100Mb
if len(data) > 100 * 1024 * 1024:
print("Skipping", fn, "because it is too large")
return

name = fn.name
sha256 = hashlib.sha256(data).hexdigest()
headers = {
"X-File-Name": name,
"X-File-SHA256": sha256,
"Authorization": f"Bearer {token}",
"Content-Length": str(len(data) + 1),
"Content-Type": "application/octet-stream",
}

r = requests.post(channel, data=data, headers=headers)
print(f"Uploaded package {name} with status {r.status_code}")


if __name__ == "__main__":
if len(sys.argv) > 1:
upload(Path(sys.argv[1]))
else:
print("Usage: upload.py <package>")
sys.exit(1)
2 changes: 1 addition & 1 deletion src/gojo/net/address.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct HostPort(Stringable):
self.port = port

fn __str__(self) -> String:
return join_host_port(self.host, str(self.port))
return self.host + ":" + str(self.port)


fn join_host_port(host: String, port: String) -> String:
Expand Down
37 changes: 19 additions & 18 deletions src/gojo/net/ip.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ fn get_addr_info(host: String) raises -> AddrInfo:
var status = getaddrinfo(
host.unsafe_ptr(),
UnsafePointer[UInt8](),
UnsafePointer.address_of(hints),
UnsafePointer.address_of(servinfo),
Reference(hints),
Reference(servinfo),
)
if status != 0:
print("getaddrinfo failed to execute with status:", status)
Expand All @@ -64,8 +64,8 @@ fn get_addr_info(host: String) raises -> AddrInfo:
var status = getaddrinfo_unix(
host.unsafe_ptr(),
UnsafePointer[UInt8](),
UnsafePointer.address_of(hints),
UnsafePointer.address_of(servinfo),
Reference(hints),
Reference(servinfo),
)
if status != 0:
print("getaddrinfo failed to execute with status:", status)
Expand Down Expand Up @@ -116,12 +116,12 @@ fn convert_binary_port_to_int(port: UInt16) -> Int:


fn convert_ip_to_binary(ip_address: String, address_family: Int) -> UInt32:
var ip_buffer = UnsafePointer[UInt8].alloc(4)
var status = inet_pton(address_family, ip_address.unsafe_ptr(), ip_buffer)
var ip = List[UInt8, True](0, 0, 0, 0)
var status = inet_pton(address_family, ip_address.unsafe_ptr(), ip.unsafe_ptr())
if status == -1:
print("Failed to convert IP address to binary")

return ip_buffer.bitcast[c_uint]().take_pointee()
return ip.steal_data().bitcast[c_uint]().take_pointee()


fn convert_binary_ip_to_string(owned ip_address: UInt32, address_family: Int32, address_length: UInt32) -> String:
Expand All @@ -137,20 +137,21 @@ fn convert_binary_ip_to_string(owned ip_address: UInt32, address_family: Int32,
"""
# It seems like the len of the buffer depends on the length of the string IP.
# Allocating 10 works for localhost (127.0.0.1) which I suspect is 9 bytes + 1 null terminator byte. So max should be 16 (15 + 1).
var ip_buffer = UnsafePointer[c_void].alloc(16)
var ip_address_ptr = UnsafePointer.address_of(ip_address).bitcast[c_void]()
_ = inet_ntop(address_family, ip_address_ptr, ip_buffer, 16)
var ip = String(List[UInt8, True](0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0))
_ = inet_ntop(address_family, UnsafePointer.address_of(ip_address).bitcast[UInt8](), ip.unsafe_ptr(), 16)

var index = 0
while True:
if ip_buffer[index] == 0:
break
index += 1
if ip._buffer[index] == 0:
break

return StringRef(ip_buffer, index)
ip._buffer.size = index
ip._buffer.append(0)
return ip


fn build_sockaddr_pointer(ip_address: String, port: Int, address_family: Int) -> UnsafePointer[sockaddr]:
fn build_sockaddr(ip_address: String, port: Int, address_family: Int) -> sockaddr:
"""Build a sockaddr pointer from an IP address and port number.
https://learn.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
https://learn.microsoft.com/en-us/windows/win32/api/ws2def/ns-ws2def-sockaddr_in.
Expand All @@ -159,7 +160,7 @@ fn build_sockaddr_pointer(ip_address: String, port: Int, address_family: Int) ->
var bin_ip = convert_ip_to_binary(ip_address, address_family)

var ai = sockaddr_in(address_family, bin_port, bin_ip, StaticTuple[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0))
return UnsafePointer.address_of(ai).bitcast[sockaddr]()
return UnsafePointer.address_of(ai).bitcast[sockaddr]().take_pointee()


fn build_sockaddr_in(ip_address: String, port: Int, address_family: Int) -> sockaddr_in:
Expand All @@ -173,7 +174,7 @@ fn build_sockaddr_in(ip_address: String, port: Int, address_family: Int) -> sock
return sockaddr_in(address_family, bin_port, bin_ip, StaticTuple[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0))


fn convert_sockaddr_to_host_port(sockaddr: UnsafePointer[sockaddr]) -> (HostPort, Error):
fn convert_sockaddr_to_host_port(owned sockaddr: sockaddr) -> (HostPort, Error):
"""Casts a sockaddr pointer to a sockaddr_in pointer and converts the binary IP and port to a string and int respectively.
Args:
Expand All @@ -182,11 +183,11 @@ fn convert_sockaddr_to_host_port(sockaddr: UnsafePointer[sockaddr]) -> (HostPort
Returns:
A tuple containing the HostPort and an Error if any occurred,.
"""
if not sockaddr:
if not UnsafePointer.address_of(sockaddr):
return HostPort(), Error("sockaddr is null, nothing to convert.")

# Cast sockaddr struct to sockaddr_in to convert binary IP to string.
var addr_in = sockaddr.bitcast[sockaddr_in]().take_pointee()
var addr_in = UnsafePointer.address_of(sockaddr).bitcast[sockaddr_in]().take_pointee()

return (
HostPort(
Expand Down
78 changes: 36 additions & 42 deletions src/gojo/net/socket.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ from ..syscall import (
from .fd import FileDescriptor, FileDescriptorBase
from .ip import (
convert_binary_ip_to_string,
build_sockaddr_pointer,
build_sockaddr,
build_sockaddr_in,
convert_binary_port_to_int,
convert_sockaddr_to_host_port,
Expand Down Expand Up @@ -182,16 +182,17 @@ struct Socket(FileDescriptorBase):
var remote_address = sockaddr()
var new_fd = accept(
self.fd.fd,
UnsafePointer.address_of(remote_address),
UnsafePointer.address_of(socklen_t(sizeof[socklen_t]())),
Reference(remote_address),
Reference(socklen_t(sizeof[socklen_t]())),
)
if new_fd == -1:
_ = external_call["perror", c_void, UnsafePointer[UInt8]](String("accept").unsafe_ptr())
raise Error("Failed to accept connection")

# TODO: Switch to reference here
var remote: HostPort
var err: Error
remote, err = convert_sockaddr_to_host_port(UnsafePointer.address_of(remote_address))
remote, err = convert_sockaddr_to_host_port(remote_address)
if err:
raise err
_ = remote_address
Expand Down Expand Up @@ -232,14 +233,12 @@ struct Socket(FileDescriptorBase):
address: String - The IP address to bind the socket to.
port: The port number to bind the socket to.
"""
# var sockaddr_pointer = build_sockaddr_pointer(address, port, self.address_family)
var sa_in = build_sockaddr_in(address, port, self.address_family)
# var sa_in = build_sockaddr_in(address, port, self.address_family)
if bind(self.fd.fd, UnsafePointer.address_of(sa_in), sizeof[sockaddr_in]()) == -1:
var local_address = build_sockaddr_in(address, port, self.address_family)
if bind(self.fd.fd, Reference(local_address), sizeof[sockaddr_in]()) == -1:
_ = external_call["perror", c_void, UnsafePointer[UInt8]](String("bind").unsafe_ptr())
_ = shutdown(self.fd.fd, SHUT_RDWR)
raise Error("Binding socket failed. Wait a few seconds and try again?")
_ = sa_in
_ = local_address

var local = self.get_sock_name()
self.local_address = BaseAddr(local.host, local.port)
Expand All @@ -254,26 +253,20 @@ struct Socket(FileDescriptorBase):
raise SocketClosedError

# TODO: Add check to see if the socket is bound and error if not.
var sa = sockaddr()
# print(sa.sa_family)
var local_address = sockaddr()
var local_address_size = socklen_t(sizeof[sockaddr]())
var status = getsockname(
self.fd.fd,
UnsafePointer.address_of(sa),
UnsafePointer.address_of(socklen_t(sizeof[sockaddr]())),
Reference(local_address),
Reference(local_address_size),
)
if status == -1:
_ = external_call["perror", c_void, UnsafePointer[UInt8]](String("getsockname").unsafe_ptr())
_ = external_call["perror", c_void, UnsafePointer[UInt8]]("getsockname".unsafe_ptr())
raise Error("Socket.get_sock_name: Failed to get address of local socket.")
# print(sa.sa_family)
var addr_in = UnsafePointer.address_of(sa).bitcast[sockaddr_in]()
# print(sa.sa_family, addr_in.sin_addr.s_addr, addr_in.sin_port)
# var addr_in = local_address_ptr.bitcast[sockaddr_in]().take_pointee()
# print(convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AddressFamily.AF_INET, 16), convert_binary_port_to_int(addr_in.sin_port))
# _ = sa
_ = addr_in
var addr_in = UnsafePointer.address_of(local_address).bitcast[sockaddr_in]().take_pointee()
return HostPort(
host=convert_binary_ip_to_string(addr_in[].sin_addr.s_addr, AddressFamily.AF_INET, 16),
port=convert_binary_port_to_int(addr_in[].sin_port),
host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AddressFamily.AF_INET, 16),
port=convert_binary_port_to_int(addr_in.sin_port),
)

fn get_peer_name(self) -> (HostPort, Error):
Expand All @@ -282,22 +275,21 @@ struct Socket(FileDescriptorBase):
return HostPort(), SocketClosedError

# TODO: Add check to see if the socket is bound and error if not.
var remote_address_ptr = UnsafePointer[sockaddr].alloc(1)
var remote_address_ptr_size = socklen_t(sizeof[sockaddr]())
var remote_address = sockaddr()
var remote_address_size = socklen_t(sizeof[sockaddr]())
var status = getpeername(
self.fd.fd,
remote_address_ptr,
UnsafePointer[socklen_t].address_of(remote_address_ptr_size),
Reference(remote_address),
Reference(remote_address_size),
)
if status == -1:
return HostPort(), Error("Socket.get_peer_name: Failed to get address of remote socket.")

var remote: HostPort
var err: Error
remote, err = convert_sockaddr_to_host_port(remote_address_ptr)
remote, err = convert_sockaddr_to_host_port(remote_address)
if err:
return HostPort(), err

return remote, Error()

fn get_socket_option(self, option_name: Int) raises -> Int:
Expand All @@ -307,14 +299,13 @@ struct Socket(FileDescriptorBase):
option_name: The socket option to get.
"""
var option_value_pointer = UnsafePointer[c_void].alloc(1)
var option_len = socklen_t(sizeof[socklen_t]())
var option_len_pointer = UnsafePointer.address_of(option_len)
var option_len = socklen_t(sizeof[c_void]())
var status = getsockopt(
self.fd.fd,
SOL_SOCKET,
option_name,
option_value_pointer,
option_len_pointer,
Reference(option_len),
)
if status == -1:
raise Error("Socket.get_sock_opt failed with status: " + str(status))
Expand Down Expand Up @@ -348,7 +339,7 @@ struct Socket(FileDescriptorBase):
port: The port number to connect to.
"""
var sa_in = build_sockaddr_in(address, port, self.address_family)
if connect(self.fd.fd, UnsafePointer.address_of(sa_in), sizeof[sockaddr_in]()) == -1:
if connect(self.fd.fd, Reference(sa_in), sizeof[sockaddr_in]()) == -1:
_ = external_call["perror", c_void, UnsafePointer[UInt8]](String("connect").unsafe_ptr())
self.shutdown()
return Error("Socket.connect: Failed to connect to the remote socket at: " + address + ":" + str(port))
Expand Down Expand Up @@ -412,12 +403,13 @@ struct Socket(FileDescriptorBase):
address: The IP address to connect to.
port: The port number to connect to.
"""
var sa = build_sockaddr(address, port, self.address_family)
var bytes_sent = sendto(
self.fd.fd,
src.unsafe_ptr(),
len(src),
0,
build_sockaddr_pointer(address, port, self.address_family),
Reference(sa),
sizeof[sockaddr_in](),
)

Expand Down Expand Up @@ -494,24 +486,25 @@ struct Socket(FileDescriptorBase):
Returns:
The number of bytes read, the remote address, and an error if one occurred.
"""
var remote_address_ptr = UnsafePointer[sockaddr].alloc(1)
var remote_address = sockaddr()
# var remote_address_ptr = UnsafePointer[sockaddr].alloc(1)
var remote_address_ptr_size = socklen_t(sizeof[sockaddr]())
var buffer = UnsafePointer[UInt8].alloc(size)
var bytes_received = recvfrom(
self.fd.fd,
buffer,
size,
0,
remote_address_ptr,
UnsafePointer[socklen_t].address_of(remote_address_ptr_size),
Reference(remote_address),
Reference(remote_address_ptr_size),
)

if bytes_received == -1:
return List[UInt8, True](), HostPort(), Error("Failed to read from socket, received a -1 response.")

var remote: HostPort
var err: Error
remote, err = convert_sockaddr_to_host_port(remote_address_ptr)
remote, err = convert_sockaddr_to_host_port(remote_address)
if err:
return List[UInt8, True](), HostPort(), err

Expand All @@ -523,15 +516,16 @@ struct Socket(FileDescriptorBase):

fn receive_from_into(inout self, inout dest: List[UInt8, True]) -> (Int, HostPort, Error):
"""Receive data from the socket into the buffer dest."""
var remote_address_ptr = UnsafePointer[sockaddr].alloc(1)
var remote_address = sockaddr()
# var remote_address_ptr = UnsafePointer[sockaddr].alloc(1)
var remote_address_ptr_size = socklen_t(sizeof[sockaddr]())
var bytes_read = recvfrom(
self.fd.fd,
dest.unsafe_ptr() + dest.size,
dest.capacity - dest.size,
0,
remote_address_ptr,
UnsafePointer[socklen_t].address_of(remote_address_ptr_size),
Reference(remote_address),
Reference(remote_address_ptr_size),
)
dest.size += bytes_read

Expand All @@ -540,7 +534,7 @@ struct Socket(FileDescriptorBase):

var remote: HostPort
var err: Error
remote, err = convert_sockaddr_to_host_port(remote_address_ptr)
remote, err = convert_sockaddr_to_host_port(remote_address)
if err:
return 0, HostPort(), err

Expand Down
Loading

0 comments on commit a4ab45e

Please sign in to comment.