Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making getAddrInfo polymorphic #587

Merged
merged 12 commits into from
Sep 11, 2024
6 changes: 4 additions & 2 deletions Network/Socket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
-- > import qualified Control.Exception as E
-- > import Control.Monad (unless, forever, void)
-- > import qualified Data.ByteString as S
-- > import qualified Data.List.NonEmpty as NE
-- > import Network.Socket
-- > import Network.Socket.ByteString (recv, sendAll)
-- >
Expand All @@ -56,7 +57,7 @@
-- > addrFlags = [AI_PASSIVE]
-- > , addrSocketType = Stream
-- > }
-- > head <$> getAddrInfo (Just hints) mhost (Just port)
-- > NE.head <$> getAddrInfo (Just hints) mhost (Just port)
-- > open addr = E.bracketOnError (openSocket addr) close $ \sock -> do
-- > setSocketOption sock ReuseAddr 1
-- > withFdSocket sock setCloseOnExecIfNeeded
Expand All @@ -77,6 +78,7 @@
-- >
-- > import qualified Control.Exception as E
-- > import qualified Data.ByteString.Char8 as C
-- > import qualified Data.List.NonEmpty as NE
-- > import Network.Socket
-- > import Network.Socket.ByteString (recv, sendAll)
-- >
Expand All @@ -95,7 +97,7 @@
-- > where
-- > resolve = do
-- > let hints = defaultHints { addrSocketType = Stream }
-- > head <$> getAddrInfo (Just hints) (Just host) (Just port)
-- > NE.head <$> getAddrInfo (Just hints) (Just host) (Just port)
-- > open addr = E.bracketOnError (openSocket addr) close $ \sock -> do
-- > connect sock $ addrAddress addr
-- > return sock
Expand Down
138 changes: 84 additions & 54 deletions Network/Socket/Info.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

module Network.Socket.Info where

import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import Foreign.Marshal.Alloc (alloca, allocaBytes)
import Foreign.Marshal.Utils (maybeWith, with)
import GHC.IO.Exception (IOErrorType(NoSuchThing))
Expand Down Expand Up @@ -200,53 +202,66 @@ defaultHints = AddrInfo {
, addrCanonName = Nothing
}

-----------------------------------------------------------------------------
-- | Resolve a host or service name to one or more addresses.
-- The 'AddrInfo' values that this function returns contain 'SockAddr'
-- values that you can pass directly to 'connect' or
-- 'bind'.
--
-- This function is protocol independent. It can return both IPv4 and
-- IPv6 address information.
--
-- The 'AddrInfo' argument specifies the preferred query behaviour,
-- socket options, or protocol. You can override these conveniently
-- using Haskell's record update syntax on 'defaultHints', for example
-- as follows:
--
-- >>> let hints = defaultHints { addrFlags = [AI_NUMERICHOST], addrSocketType = Stream }
--
-- You must provide a 'Just' value for at least one of the 'HostName'
-- or 'ServiceName' arguments. 'HostName' can be either a numeric
-- network address (dotted quad for IPv4, colon-separated hex for
-- IPv6) or a hostname. In the latter case, its addresses will be
-- looked up unless 'AI_NUMERICHOST' is specified as a hint. If you
-- do not provide a 'HostName' value /and/ do not set 'AI_PASSIVE' as
-- a hint, network addresses in the result will contain the address of
-- the loopback interface.
--
-- If the query fails, this function throws an IO exception instead of
-- returning an empty list. Otherwise, it returns a non-empty list
-- of 'AddrInfo' values.
--
-- There are several reasons why a query might result in several
-- values. For example, the queried-for host could be multihomed, or
-- the service might be available via several protocols.
--
-- Note: the order of arguments is slightly different to that defined
-- for @getaddrinfo@ in RFC 2553. The 'AddrInfo' parameter comes first
-- to make partial application easier.
--
-- >>> addr:_ <- getAddrInfo (Just hints) (Just "127.0.0.1") (Just "http")
-- >>> addrAddress addr
-- 127.0.0.1:80

getAddrInfo
class GetAddrInfo t where
-----------------------------------------------------------------------------
-- | Resolve a host or service name to one or more addresses.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rendered the haddocks, and it doesn't give the reader a way to see what instances of GetAddrInfo exist. I think we should document them here.

Alternatively, we could use the IsList class? But I think that will be more confusing and should only be for -XOverloadedLists support.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that IsList is confusing.
1411634 improves the doc.

-- The 'AddrInfo' values that this function returns contain 'SockAddr'
-- values that you can pass directly to 'connect' or
-- 'bind'.
--
-- This function is protocol independent. It can return both IPv4 and
-- IPv6 address information.
--
-- The 'AddrInfo' argument specifies the preferred query behaviour,
-- socket options, or protocol. You can override these conveniently
-- using Haskell's record update syntax on 'defaultHints', for example
-- as follows:
--
-- >>> let hints = defaultHints { addrFlags = [AI_NUMERICHOST], addrSocketType = Stream }
--
-- You must provide a 'Just' value for at least one of the 'HostName'
-- or 'ServiceName' arguments. 'HostName' can be either a numeric
-- network address (dotted quad for IPv4, colon-separated hex for
-- IPv6) or a hostname. In the latter case, its addresses will be
-- looked up unless 'AI_NUMERICHOST' is specified as a hint. If you
-- do not provide a 'HostName' value /and/ do not set 'AI_PASSIVE' as
-- a hint, network addresses in the result will contain the address of
-- the loopback interface.
--
-- If the query fails, this function throws an IO exception instead of
-- returning an empty list. Otherwise, it returns a non-empty list
-- of 'AddrInfo' values.
--
-- There are several reasons why a query might result in several
-- values. For example, the queried-for host could be multihomed, or
-- the service might be available via several protocols.
--
-- Note: the order of arguments is slightly different to that defined
-- for @getaddrinfo@ in RFC 2553. The 'AddrInfo' parameter comes first
-- to make partial application easier.
--
-- >>> import qualified Data.List.NonEmpty as NE
-- >>> addr <- NE.head <$> getAddrInfo (Just hints) (Just "127.0.0.1") (Just "http")
-- >>> addrAddress addr
-- 127.0.0.1:80
getAddrInfo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A @since line might also be useful for future readers, to know when a NonEmpty is available.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

:: Maybe AddrInfo -- ^ preferred socket type or protocol
-> Maybe HostName -- ^ host name to look up
-> Maybe ServiceName -- ^ service name to look up
-> IO (t AddrInfo) -- ^ resolved addresses, with "best" first

instance GetAddrInfo [] where
getAddrInfo = getAddrInfoList

instance GetAddrInfo NE.NonEmpty where
getAddrInfo = getAddrInfoNE

getAddrInfoNE
:: Maybe AddrInfo -- ^ preferred socket type or protocol
-> Maybe HostName -- ^ host name to look up
-> Maybe ServiceName -- ^ service name to look up
-> IO [AddrInfo] -- ^ resolved addresses, with "best" first
getAddrInfo hints node service = alloc getaddrinfo
-> IO (NonEmpty AddrInfo) -- ^ resolved addresses, with "best" first
getAddrInfoNE hints node service = alloc getaddrinfo
where
alloc body = withSocketsDo $ maybeWith withCString node $ \c_node ->
maybeWith withCString service $ \c_service ->
Expand All @@ -257,13 +272,10 @@ getAddrInfo hints node service = alloc getaddrinfo
ret <- c_getaddrinfo c_node c_service c_hints ptr_ptr_addrs
if ret == 0 then do
ptr_addrs <- peek ptr_ptr_addrs
ais <- followAddrInfo ptr_addrs
c_freeaddrinfo ptr_addrs
-- POSIX requires that getaddrinfo(3) returns at least one addrinfo.
-- See: http://pubs.opengroup.org/onlinepubs/9699919799/functions/getaddrinfo.html
kazu-yamamoto marked this conversation as resolved.
Show resolved Hide resolved
case ais of
[] -> ioError $ mkIOError NoSuchThing message Nothing Nothing
_ -> return ais
ais <- followAddrInfo ptr_addrs
return ais
else do
err <- gai_strerror ret
ioError $ ioeSetErrorString
Expand All @@ -290,13 +302,31 @@ getAddrInfo hints node service = alloc getaddrinfo
filteredHints = hints
#endif

followAddrInfo :: Ptr AddrInfo -> IO [AddrInfo]
getAddrInfoList
:: Maybe AddrInfo
-> Maybe HostName
-> Maybe ServiceName
-> IO [AddrInfo]
getAddrInfoList hints node service =
-- getAddrInfo never returns an empty list.
NE.toList <$> getAddrInfoNE hints node service

followAddrInfo :: Ptr AddrInfo -> IO (NonEmpty AddrInfo)
followAddrInfo ptr_ai
| ptr_ai == nullPtr = return []
| ptr_ai == nullPtr = ioError $ mkIOError NoSuchThing "getaddrinfo must retuan at least one addrinfo" Nothing Nothing
kazu-yamamoto marked this conversation as resolved.
Show resolved Hide resolved
| otherwise = do
a <- peek ptr_ai
as <- (# peek struct addrinfo, ai_next) ptr_ai >>= followAddrInfo
return (a : as)
a <- peek ptr_ai
ptr <- (# peek struct addrinfo, ai_next) ptr_ai
(a :|) <$> go ptr
where
go :: Ptr AddrInfo -> IO [AddrInfo]
go ptr
| ptr == nullPtr = return []
| otherwise = do
a' <- peek ptr
ptr' <- (# peek struct addrinfo, ai_next) ptr
as' <- go ptr'
return (a':as')

foreign import ccall safe "hsnet_getaddrinfo"
c_getaddrinfo :: CString -> CString -> Ptr AddrInfo -> Ptr (Ptr AddrInfo)
Expand Down
3 changes: 2 additions & 1 deletion Network/Socket/Syscall.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#endif

#if defined(mingw32_HOST_OS)
import Foreign (FunPtr)

Check warning on line 14 in Network/Socket/Syscall.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.4)

The import of ‘Foreign’ is redundant

Check warning on line 14 in Network/Socket/Syscall.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.4)

The import of ‘Foreign’ is redundant

Check warning on line 14 in Network/Socket/Syscall.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.4)

The import of ‘Foreign’ is redundant

Check warning on line 14 in Network/Socket/Syscall.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.6)

The import of ‘Foreign’ is redundant

Check warning on line 14 in Network/Socket/Syscall.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.6)

The import of ‘Foreign’ is redundant

Check warning on line 14 in Network/Socket/Syscall.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.6)

The import of ‘Foreign’ is redundant

Check warning on line 14 in Network/Socket/Syscall.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.8)

The import of ‘Foreign’ is redundant

Check warning on line 14 in Network/Socket/Syscall.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.8)

The import of ‘Foreign’ is redundant

Check warning on line 14 in Network/Socket/Syscall.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.8)

The import of ‘Foreign’ is redundant
import GHC.Conc (asyncDoProc)
#else
import Foreign.C.Error (getErrno, eINTR, eINPROGRESS)
Expand Down Expand Up @@ -63,8 +63,9 @@
-- can be handled with one socket.
--
-- >>> import Network.Socket
-- >>> import qualified Data.List.NonEmpty as NE
-- >>> let hints = defaultHints { addrFlags = [AI_NUMERICHOST, AI_NUMERICSERV], addrSocketType = Stream }
-- >>> addr:_ <- getAddrInfo (Just hints) (Just "127.0.0.1") (Just "5000")
-- >>> addr <- NE.head <$> getAddrInfo (Just hints) (Just "127.0.0.1") (Just "5000")
-- >>> sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
-- >>> Network.Socket.bind sock (addrAddress addr)
-- >>> getSocketName sock
Expand Down
3 changes: 2 additions & 1 deletion examples/EchoClient.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module Main (main) where

import qualified Control.Exception as E
import qualified Data.ByteString.Char8 as C
import qualified Data.List.NonEmpty as NE
import Network.Socket
import Network.Socket.ByteString (recv, sendAll)

Expand All @@ -23,7 +24,7 @@ runTCPClient host port client = withSocketsDo $ do
where
resolve = do
let hints = defaultHints{addrSocketType = Stream}
head <$> getAddrInfo (Just hints) (Just host) (Just port)
NE.head <$> getAddrInfo (Just hints) (Just host) (Just port)
open addr = E.bracketOnError (openSocket addr) close $ \sock -> do
connect sock $ addrAddress addr
return sock
3 changes: 2 additions & 1 deletion examples/EchoServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Control.Concurrent (forkFinally)
import qualified Control.Exception as E
import Control.Monad (forever, unless, void)
import qualified Data.ByteString as S
import qualified Data.List.NonEmpty as NE
import Network.Socket
import Network.Socket.ByteString (recv, sendAll)

Expand All @@ -29,7 +30,7 @@ runTCPServer mhost port server = withSocketsDo $ do
{ addrFlags = [AI_PASSIVE]
, addrSocketType = Stream
}
head <$> getAddrInfo (Just hints) mhost (Just port)
NE.head <$> getAddrInfo (Just hints) mhost (Just port)
open addr = E.bracketOnError (openSocket addr) close $ \sock -> do
setSocketOption sock ReuseAddr 1
withFdSocket sock setCloseOnExecIfNeeded
Expand Down
2 changes: 1 addition & 1 deletion tests/Network/SocketSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@

it "does not cause segfault on macOS 10.8.2 due to AI_NUMERICSERV" $ do
let hints = defaultHints { addrFlags = [AI_NUMERICSERV] }
void $ getAddrInfo (Just hints) (Just "localhost") Nothing
void (getAddrInfo (Just hints) (Just "localhost") Nothing :: IO [AddrInfo])

#if defined(mingw32_HOST_OS)
let lpdevname = "loopback_0"
Expand Down Expand Up @@ -423,7 +423,7 @@
cmsgidGen = biasedGen (\g -> CmsgId <$> g <*> g) cmsgidPatterns arbitrary

genFds :: Gen [Fd]
genFds = listOf (Fd <$> arbitrary)

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 8.4)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 8.4)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.4)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.4)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.6)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.6)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.8)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.8)

Defined but not used: ‘genFds’

-- pruned lists of pattern synonym values for each type to generate values from

Expand Down
5 changes: 3 additions & 2 deletions tests/Network/Test/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as L
import qualified Data.List.NonEmpty as NE
import Network.Socket
import System.Directory
import System.Timeout (timeout)
import Test.Hspec

Check warning on line 42 in tests/Network/Test/Common.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 8.4)

The import of ‘Test.Hspec’ is redundant

Check warning on line 42 in tests/Network/Test/Common.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 8.4)

The import of ‘Test.Hspec’ is redundant

Check warning on line 42 in tests/Network/Test/Common.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.4)

The import of ‘Test.Hspec’ is redundant

Check warning on line 42 in tests/Network/Test/Common.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.4)

The import of ‘Test.Hspec’ is redundant

Check warning on line 42 in tests/Network/Test/Common.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.6)

The import of ‘Test.Hspec’ is redundant

Check warning on line 42 in tests/Network/Test/Common.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.6)

The import of ‘Test.Hspec’ is redundant

Check warning on line 42 in tests/Network/Test/Common.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.8)

The import of ‘Test.Hspec’ is redundant

Check warning on line 42 in tests/Network/Test/Common.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.8)

The import of ‘Test.Hspec’ is redundant

serverAddr :: String
serverAddr = "127.0.0.1"
Expand Down Expand Up @@ -244,7 +245,7 @@

resolveClient :: SocketType -> HostName -> PortNumber -> IO AddrInfo
resolveClient socketType host port =
head <$> getAddrInfo (Just hints) (Just host) (Just $ show port)
NE.head <$> getAddrInfo (Just hints) (Just host) (Just $ show port)
where
hints = defaultHints {
addrSocketType = socketType
Expand All @@ -253,7 +254,7 @@

resolveServer :: SocketType -> HostName -> IO AddrInfo
resolveServer socketType host =
head <$> getAddrInfo (Just hints) (Just host) Nothing
NE.head <$> getAddrInfo (Just hints) (Just host) Nothing
where
hints = defaultHints {
addrSocketType = socketType
Expand Down
Loading