diff --git a/Repro.hs b/Repro.hs new file mode 100644 index 000000000..eae925d03 --- /dev/null +++ b/Repro.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Main where + +import Network.Socket +import Network.Wai +import Network.Wai.Handler.Warp +import Network.HTTP.Types (status200) +import Network.Wai.Handler.Warp.Internal +import Data.ByteString.Builder (byteString) +import Debug.Trace +import Control.Concurrent +import qualified Control.Exception as E + +main :: IO () +main = do + let settings = + defaultSettings { + settingsOnClose = \_ -> msg "closed!", + settingsOnException = \_ e -> msg ("Exception: " ++ show e) >> E.throw e + } + runSettings settings app + +msg :: String -> IO () +msg s = traceEventIO s + + +app :: Application +app req respond = E.handle onErr $ do + connectionIsInactive req + msg "starting handler" + threadDelay $ 10*1000*1000 + msg "handler responding..." + x <- respond $ responseBuilder status200 [("Content-Type", "text/plain")] (byteString "Hello, world!") + msg "handler done" + return x + where + onErr e = + msg ("Handler exception: " ++ show @E.SomeException e) >> E.throw e + diff --git a/Test.hs b/Test.hs new file mode 100644 index 000000000..22d3f84b9 --- /dev/null +++ b/Test.hs @@ -0,0 +1,20 @@ +import Control.Concurrent +import System.IO +import Network.Socket as N + +main :: IO () +main = do + addr:_ <- N.getAddrInfo (Just N.defaultHints) (Just "127.0.0.1") (Just "3000") + s <- N.openSocket addr + N.connect s (addrAddress addr) + putStrLn "Client connected" + hdl <- N.socketToHandle s ReadWriteMode + hPutStr hdl $ unlines + [ "GET / HTTP/1.1" + , "" + , "" + , "" + ] + threadDelay (100*1000) + putStrLn "Client closing" + N.close s diff --git a/run.sh b/run.sh new file mode 100644 index 000000000..351ee035f --- /dev/null +++ b/run.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +set -e + +GHC="$HOME/ghc/ghc-compare-3/_build/stage1/bin/ghc" +#GHC="$HOME/ghcs-nix/ghcs/9.4.5/bin/ghc" + +cabal build -w $GHC warp --write-ghc-environment-file=always +$GHC Repro.hs -threaded -debug +$GHC Test.hs -threaded -debug + +run() { + echo "Starting server..." + ./Repro +RTS -N2 -v-au 2>&1 & + sleep 1 + + echo "Starting client..." + ./Test + echo "Client done" + + sleep 15 + echo "Killing server..." + kill -INT %1 + + echo "Done" + #nix run nixpkgs#haskellPackages.ghc-events -- show Repro.eventlog +} + +run | nix shell nixpkgs#moreutils -c ts -i "%.S" diff --git a/warp/Network/Wai/Handler/Warp.hs b/warp/Network/Wai/Handler/Warp.hs index 223cea1c5..21ba122d9 100644 --- a/warp/Network/Wai/Handler/Warp.hs +++ b/warp/Network/Wai/Handler/Warp.hs @@ -111,6 +111,8 @@ module Network.Wai.Handler.Warp ( , openFreePort -- * Version , warpVersion + -- * Handling premature connection closure + , connectionIsInactive -- * HTTP/2 -- ** HTTP2 data , HTTP2Data @@ -141,6 +143,7 @@ import Network.Wai (Request, Response, vault) import System.TimeManager import Network.Wai.Handler.Warp.FileInfoCache +import Network.Wai.Handler.Warp.HTTP1 (connectionIsInactive) import Network.Wai.Handler.Warp.HTTP2.Request (getHTTP2Data, setHTTP2Data, modifyHTTP2Data) import Network.Wai.Handler.Warp.HTTP2.Types import Network.Wai.Handler.Warp.Imports diff --git a/warp/Network/Wai/Handler/Warp/HTTP1.hs b/warp/Network/Wai/Handler/Warp/HTTP1.hs index c6b43cac7..76bb82286 100644 --- a/warp/Network/Wai/Handler/Warp/HTTP1.hs +++ b/warp/Network/Wai/Handler/Warp/HTTP1.hs @@ -5,16 +5,18 @@ {-# LANGUAGE ScopedTypeVariables #-} module Network.Wai.Handler.Warp.HTTP1 ( - http1 + http1, + connectionIsInactive ) where import "iproute" Data.IP (toHostAddress, toHostAddress6) -import qualified Control.Concurrent as Conc (yield) +import qualified Control.Concurrent as Conc import qualified UnliftIO import UnliftIO (SomeException, fromException, throwIO) import qualified Data.ByteString as BS import Data.Char (chr) import Data.IORef (IORef, newIORef, readIORef, writeIORef) +import qualified Data.Vault.Lazy as Vault import Network.Socket (SockAddr(SockAddrInet, SockAddrInet6)) import Network.Wai import Network.Wai.Internal (ResponseReceived (ResponseReceived)) @@ -30,11 +32,21 @@ import Network.Wai.Handler.Warp.Types http1 :: Settings -> InternalInfo -> Connection -> Transport -> Application -> SockAddr -> T.Handle -> ByteString -> IO () http1 settings ii conn transport app origAddr th bs0 = do + connActive <- mkConnActiveFlag + case connRegisterPeerClosedCb conn of + -- TODO Ignore only operation-not-supported exceptions + Just registerCb -> void $ UnliftIO.tryIO $ do + tid <- Conc.myThreadId + registerCb $ do + waitUntilConnInactive connActive + UnliftIO.throwTo tid PeerClosedException + Nothing -> return () + istatus <- newIORef True src <- mkSource (wrappedRecv conn istatus (settingsSlowlorisSize settings)) leftoverSource src bs0 addr <- getProxyProtocolAddr src - http1server settings ii conn transport app addr th istatus src + http1server settings ii conn transport connActive app addr th istatus src where wrappedRecv Connection { connRecv = recv } istatus slowlorisSize = do bs <- recv @@ -83,8 +95,8 @@ http1 settings ii conn transport app origAddr th bs0 = do decodeAscii = map (chr . fromEnum) . BS.unpack -http1server :: Settings -> InternalInfo -> Connection -> Transport -> Application -> SockAddr -> T.Handle -> IORef Bool -> Source -> IO () -http1server settings ii conn transport app addr th istatus src = +http1server :: Settings -> InternalInfo -> Connection -> Transport -> ConnActiveFlag -> Application -> SockAddr -> T.Handle -> IORef Bool -> Source -> IO () +http1server settings ii conn transport connActive app addr th istatus src = loop True `UnliftIO.catchAny` handler where handler e @@ -98,7 +110,8 @@ http1server settings ii conn transport app addr th istatus src = throwIO e loop firstRequest = do - (req, mremainingRef, idxhdr, nextBodyFlush) <- recvRequest firstRequest settings conn ii th addr src transport + setConnActiveFlag connActive True + (req, mremainingRef, idxhdr, nextBodyFlush) <- recvRequest firstRequest settings conn ii th addr src transport connActive keepAlive <- processRequest settings ii conn app th istatus src req mremainingRef idxhdr nextBodyFlush `UnliftIO.catchAny` \e -> do settingsOnException settings (Just req) e @@ -219,3 +232,12 @@ flushBody src = loop | BS.null bs -> return True | toRead' >= 0 -> loop toRead' | otherwise -> return False + +-- | Used by a handler to indicate that its current computation can be safely +-- killed if the requesting connection is shutdown. +connectionIsInactive :: Request -> IO () +connectionIsInactive req = do + case Vault.lookup connActiveFlagKey (vault req) of + Just flag -> setConnActiveFlag flag False + Nothing -> return () + diff --git a/warp/Network/Wai/Handler/Warp/Request.hs b/warp/Network/Wai/Handler/Warp/Request.hs index 0e358a0bf..d5c805783 100644 --- a/warp/Network/Wai/Handler/Warp/Request.hs +++ b/warp/Network/Wai/Handler/Warp/Request.hs @@ -12,6 +12,7 @@ module Network.Wai.Handler.Warp.Request ( #ifdef MIN_VERSION_crypton_x509 , getClientCertificateKey #endif + , connActiveFlagKey , NoKeepAliveRequest (..) ) where @@ -56,6 +57,7 @@ recvRequest :: Bool -- ^ first request on this connection? -> SockAddr -- ^ Peer's address. -> Source -- ^ Where HTTP request comes from. -> Transport + -> ConnActiveFlag -> IO (Request ,Maybe (I.IORef Int) ,IndexedHeader @@ -65,7 +67,7 @@ recvRequest :: Bool -- ^ first request on this connection? -- 'IndexedHeader' of HTTP request for internal use, -- Body producing action used for flushing the request body -recvRequest firstRequest settings conn ii th addr src transport = do +recvRequest firstRequest settings conn ii th addr src transport connActive = do hdrlines <- headerLines (settingsMaxTotalHeaderLength settings) firstRequest src (method, unparsedPath, path, query, httpversion, hdr) <- parseHeaderLines hdrlines let idxhdr = indexRequestHeader hdr @@ -76,6 +78,7 @@ recvRequest firstRequest settings conn ii th addr src transport = do rawPath = if settingsNoParsePath settings then unparsedPath else path vaultValue = Vault.insert pauseTimeoutKey (Timeout.pause th) $ Vault.insert getFileInfoKey (getFileInfo ii) + $ Vault.insert connActiveFlagKey connActive #ifdef MIN_VERSION_crypton_x509 $ Vault.insert getClientCertificateKey (getTransportClientCertificate transport) #endif @@ -328,3 +331,7 @@ getClientCertificateKey :: Vault.Key (Maybe CertificateChain) getClientCertificateKey = unsafePerformIO Vault.newKey {-# NOINLINE getClientCertificateKey #-} #endif + +connActiveFlagKey :: Vault.Key ConnActiveFlag +connActiveFlagKey = unsafePerformIO Vault.newKey +{-# NOINLINE connActiveFlagKey #-} diff --git a/warp/Network/Wai/Handler/Warp/Run.hs b/warp/Network/Wai/Handler/Warp/Run.hs index d273462be..cb4024056 100644 --- a/warp/Network/Wai/Handler/Warp/Run.hs +++ b/warp/Network/Wai/Handler/Warp/Run.hs @@ -8,6 +8,7 @@ module Network.Wai.Handler.Warp.Run where import Control.Arrow (first) +import Control.Concurrent import qualified Control.Exception import Control.Exception (allowInterrupt) import qualified Data.ByteString as S @@ -21,6 +22,12 @@ import Network.Socket (gracefulClose) #endif import Network.Socket.BufferPool import qualified Network.Socket.ByteString as Sock +#if MIN_VERSION_base(4,18,0) +-- For evtPeerClosed +import Network.Socket (withFdSocket) +import GHC.Event +import System.Posix.Types (Fd(Fd)) +#endif import Network.Wai import System.Environment (lookupEnv) import System.IO.Error (ioeGetErrorType) @@ -59,6 +66,14 @@ socketConnection _ s = do bufferPool <- newBufferPool 2048 16384 writeBuffer <- createWriteBuffer 16384 writeBufferRef <- newIORef writeBuffer +#if MIN_VERSION_base(4,18,0) + let registerPeerClosedCb = Just $ \cb -> withFdSocket s $ \fd -> do + Just mgr <- getSystemEventManager + _ <- registerFd mgr (\ _ _ -> cb) (Fd fd) evtPeerClosed OneShot + return () +#else + let registerPeerClosedCb = Nothing +#endif isH2 <- newIORef False -- HTTP/1.x return Connection { connSendMany = Sock.sendMany s @@ -80,6 +95,7 @@ socketConnection _ s = do , connRecvBuf = \_ _ -> return True -- obsoleted , connWriteBuffer = writeBufferRef , connHTTP2 = isH2 + , connRegisterPeerClosedCb = registerPeerClosedCb } where receive' sock pool = UnliftIO.handleIO handler $ receive sock pool diff --git a/warp/Network/Wai/Handler/Warp/Types.hs b/warp/Network/Wai/Handler/Warp/Types.hs index 9e1366fdd..0c668ec22 100644 --- a/warp/Network/Wai/Handler/Warp/Types.hs +++ b/warp/Network/Wai/Handler/Warp/Types.hs @@ -73,6 +73,15 @@ instance UnliftIO.Exception ExceptionInsideResponseBody ---------------------------------------------------------------- +-- | Exception thrown when the iniating client of a connection being handled by +-- a worker closes its end of the connection. +data PeerClosedException = PeerClosedException + deriving (Show) + +instance UnliftIO.Exception PeerClosedException + +---------------------------------------------------------------- + -- | Data type to abstract file identifiers. -- On Unix, a file descriptor would be specified to make use of -- the file descriptor cache. @@ -125,6 +134,7 @@ data Connection = Connection { , connWriteBuffer :: IORef WriteBuffer -- | Is this connection HTTP/2? , connHTTP2 :: IORef Bool + , connRegisterPeerClosedCb :: Maybe (IO () -> IO ()) } getConnHTTP2 :: Connection -> IO Bool @@ -144,6 +154,28 @@ data InternalInfo = InternalInfo { ---------------------------------------------------------------- +-- | In some HTTP/1 applications (e.g. those where requests are pure queries +-- which imply no "effects") it can make sense to abort running handlers when +-- the write-side of the client's connection closed (via @shutdown(2)@) before +-- a response has been sent. To facilitate this use-case, each handler thread +-- carries a 'ConnActiveFlag' which dictates whether the handler's current +-- computation can be safely aborted if the connection is shutdown. +newtype ConnActiveFlag = ConnActiveFlag (UnliftIO.TVar Bool) + +mkConnActiveFlag :: IO ConnActiveFlag +mkConnActiveFlag = ConnActiveFlag <$> UnliftIO.newTVarIO True + +setConnActiveFlag :: ConnActiveFlag -> Bool -> IO () +setConnActiveFlag (ConnActiveFlag v) active = UnliftIO.atomically $ + UnliftIO.writeTVar v active + +waitUntilConnInactive :: ConnActiveFlag -> IO () +waitUntilConnInactive (ConnActiveFlag v) = UnliftIO.atomically $ do + active <- UnliftIO.readTVar v + when active UnliftIO.retrySTM + +---------------------------------------------------------------- + -- | Type for input streaming. data Source = Source !(IORef ByteString) !(IO ByteString)