diff --git a/servant-client-core/servant-client-core.cabal b/servant-client-core/servant-client-core.cabal index b93fc8103..9d1560e08 100644 --- a/servant-client-core/servant-client-core.cabal +++ b/servant-client-core/servant-client-core.cabal @@ -91,6 +91,7 @@ library Servant.Client.Core.Reexport Servant.Client.Core.Request Servant.Client.Core.Response + Servant.Client.Core.ResponseUnrender Servant.Client.Core.RunClient Servant.Client.Free Servant.Client.Generic diff --git a/servant-client-core/src/Servant/Client/Core/HasClient.hs b/servant-client-core/src/Servant/Client/Core/HasClient.hs index 5b00b1fc0..50ec5c0cb 100644 --- a/servant-client-core/src/Servant/Client/Core/HasClient.hs +++ b/servant-client-core/src/Servant/Client/Core/HasClient.hs @@ -1,6 +1,6 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE ApplicativeDo #-} {-# OPTIONS_GHC -Wno-missing-methods #-} -{-# LANGUAGE EmptyCase #-} module Servant.Client.Core.HasClient ( clientIn, HasClient (..), @@ -9,7 +9,8 @@ module Servant.Client.Core.HasClient ( (//), (/:), foldMapUnion, - matchUnion + matchUnion, + fromSomeClientResponse ) where import Prelude () @@ -17,9 +18,10 @@ import Prelude.Compat import Control.Arrow (left, (+++)) +import qualified Data.Text as Text import Control.Monad (unless) -import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy as BSL import Data.Either (partitionEithers) import Data.Constraint (Dict(..)) @@ -43,13 +45,11 @@ import Data.SOP.Constraint import Data.SOP.NP (NP (..), cpure_NP) import Data.SOP.NS - (NS (S)) + (NS (..)) import Data.String (fromString) import Data.Text (Text, pack) -import Data.Proxy - (Proxy (Proxy)) import GHC.TypeLits (KnownNat, KnownSymbol, TypeError, symbolVal) import Network.HTTP.Types @@ -71,7 +71,7 @@ import Servant.API.Generic (GenericMode(..), ToServant, ToServantApi , GenericServant, toServant, fromServant) import Servant.API.ContentTypes - (contentTypes, AllMime (allMime), AllMimeUnrender (allMimeUnrender), AcceptHeader) + (contentTypes, AllMime (allMime), AllMimeUnrender (allMimeUnrender)) import Servant.API.QueryString (ToDeepQuery(..), generateDeepParam) import Servant.API.Status (statusFromNat) @@ -87,9 +87,12 @@ import Servant.Client.Core.BasicAuth import Servant.Client.Core.ClientError import Servant.Client.Core.Request import Servant.Client.Core.Response +import Servant.Client.Core.ResponseUnrender +import qualified Servant.Client.Core.Response as Response import Servant.Client.Core.RunClient -import Servant.API.MultiVerb +import Servant.API.MultiVerb import qualified Network.HTTP.Media as M +import Data.Typeable -- * Accessing APIs as a Client @@ -325,7 +328,7 @@ data ClientParseError = ClientParseError MediaType String | ClientStatusMismatch deriving (Eq, Show) class UnrenderResponse (cts :: [Type]) (a :: Type) where - unrenderResponse :: Seq.Seq H.Header -> BL.ByteString -> Proxy cts + unrenderResponse :: Seq.Seq H.Header -> BSL.ByteString -> Proxy cts -> [Either (MediaType, String) a] instance {-# OVERLAPPABLE #-} AllMimeUnrender cts a => UnrenderResponse cts a where @@ -367,15 +370,13 @@ instance {-# OVERLAPPING #-} method = reflectMethod $ Proxy @method acceptStatus = statuses (Proxy @as) - response <- runRequestAcceptStatus (Just acceptStatus) request {requestMethod = method, requestAccept = accept} + response@Response{responseBody=body, responseStatusCode=status, responseHeaders=headers} + <- runRequestAcceptStatus (Just acceptStatus) (request {requestMethod = method, requestAccept = accept}) responseContentType <- checkContentTypeHeader response unless (any (matches responseContentType) accept) $ do throwClientError $ UnsupportedContentType responseContentType response - let status = responseStatusCode response - body = responseBody response - headers = responseHeaders response - res = tryParsers status $ mimeUnrenders (Proxy @contentTypes) headers body + let res = tryParsers status $ mimeUnrenders (Proxy @contentTypes) headers body case res of Left errors -> throwClientError $ DecodeFailure (T.pack (show errors)) response Right x -> return x @@ -399,7 +400,7 @@ instance {-# OVERLAPPING #-} All (UnrenderResponse cts) xs => Proxy cts -> Seq.Seq H.Header -> - BL.ByteString -> + BSL.ByteString -> NP ([] :.: Either (MediaType, String)) xs mimeUnrenders ctp headers body = cpure_NP (Proxy @(UnrenderResponse cts)) @@ -416,10 +417,10 @@ instance {-# OVERLAPPABLE #-} hoistClientMonad _ _ f ma = f ma - clientWithRoute _pm Proxy req = withStreamingRequest req' $ \gres -> do - let mimeUnrender' = mimeUnrender (Proxy :: Proxy ct) :: BL.ByteString -> Either String chunk + clientWithRoute _pm Proxy req = withStreamingRequest req' $ \Response{responseBody=body} -> do + let mimeUnrender' = mimeUnrender (Proxy :: Proxy ct) :: BSL.ByteString -> Either String chunk framingUnrender' = framingUnrender (Proxy :: Proxy framing) mimeUnrender' - fromSourceIO $ framingUnrender' $ responseBody gres + fromSourceIO $ framingUnrender' body where req' = req { requestAccept = fromList [contentType (Proxy :: Proxy ct)] @@ -436,13 +437,14 @@ instance {-# OVERLAPPING #-} hoistClientMonad _ _ f ma = f ma - clientWithRoute _pm Proxy req = withStreamingRequest req' $ \gres -> do - let mimeUnrender' = mimeUnrender (Proxy :: Proxy ct) :: BL.ByteString -> Either String chunk + clientWithRoute _pm Proxy req = withStreamingRequest req' $ + \Response{responseBody=body, responseHeaders=headers} -> do + let mimeUnrender' = mimeUnrender (Proxy :: Proxy ct) :: BSL.ByteString -> Either String chunk framingUnrender' = framingUnrender (Proxy :: Proxy framing) mimeUnrender' - val <- fromSourceIO $ framingUnrender' $ responseBody gres + val <- fromSourceIO $ framingUnrender' body return $ Headers { getResponse = val - , getHeadersHList = buildHeadersTo . toList $ responseHeaders gres + , getHeadersHList = buildHeadersTo $ toList headers } where @@ -760,7 +762,7 @@ instance sourceIO = framingRender framingP - (mimeRender ctypeP :: chunk -> BL.ByteString) + (mimeRender ctypeP :: chunk -> BSL.ByteString) (toSourceIO body) -- | Make the querying function append @path@ to the request path. @@ -975,19 +977,9 @@ x // f = f x (/:) :: (a -> b -> c) -> b -> a -> c (/:) = flip -class IsResponseList cs as where - responseListRender :: AcceptHeader -> Union (ResponseTypes as) -> Maybe InternalResponse - responseListUnrender :: M.MediaType -> InternalResponse -> UnrenderResult (Union (ResponseTypes as)) - - responseListStatuses :: [Status] - -instance IsResponseList cs '[] where - responseListRender _ x = case x of {} - responseListUnrender _ _ = empty - responseListStatuses = [] instance - ( IsResponseList cs as, + ( ResponseListUnrender cs as, AllMime cs, ReflectMethod method, AsUnion as r, @@ -998,7 +990,7 @@ instance type Client m (MultiVerb method cs as r) = m r clientWithRoute _ _ req = do - response <- + response@Response{responseBody=body} <- runRequestAcceptStatus (Just (responseListStatuses @cs @as)) req @@ -1012,9 +1004,9 @@ instance -- FUTUREWORK: support streaming let sresp = - if LBS.null (responseBody response) - then SomeResponse response {responseBody = ()} - else SomeResponse response + if BSL.null body + then SomeClientResponse $ response {Response.responseBody = ()} + else SomeClientResponse response case responseListUnrender @cs @as c sresp of StatusMismatch -> throwClientError (DecodeFailure "Status mismatch" response) UnrenderError e -> throwClientError (DecodeFailure (Text.pack e) response) @@ -1064,11 +1056,11 @@ checkContentTypeHeader response = decodedAs :: forall ct a m. (MimeUnrender ct a, RunClient m) => Response -> Proxy ct -> m a -decodedAs response ct = do +decodedAs response@Response{responseBody=body} ct = do responseContentType <- checkContentTypeHeader response unless (any (matches responseContentType) accept) $ throwClientError $ UnsupportedContentType responseContentType response - case mimeUnrender ct $ responseBody response of + case mimeUnrender ct body of Left err -> throwClientError $ DecodeFailure (T.pack err) response Right val -> return val where diff --git a/servant-client-core/src/Servant/Client/Core/Response.hs b/servant-client-core/src/Servant/Client/Core/Response.hs index 16ca0667a..59aaaf38b 100644 --- a/servant-client-core/src/Servant/Client/Core/Response.hs +++ b/servant-client-core/src/Servant/Client/Core/Response.hs @@ -1,17 +1,17 @@ {-# LANGUAGE DeriveDataTypeable #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE NamedFieldPuns #-} module Servant.Client.Core.Response ( Response, StreamingResponse, ResponseF (..), + responseToInternalResponse, ) where import Prelude () @@ -31,6 +31,7 @@ import Network.HTTP.Types import Servant.API.Stream (SourceIO) +import Servant.Types.ResponseList data ResponseF a = Response { responseStatusCode :: Status @@ -51,3 +52,7 @@ instance NFData a => NFData (ResponseF a) where type Response = ResponseF LBS.ByteString type StreamingResponse = ResponseF (SourceIO BS.ByteString) + +responseToInternalResponse :: ResponseF a -> InternalResponse a +responseToInternalResponse Response{responseStatusCode, responseHeaders,responseBody} = + InternalResponse responseStatusCode responseHeaders responseBody diff --git a/servant-client-core/src/Servant/Client/Core/ResponseUnrender.hs b/servant-client-core/src/Servant/Client/Core/ResponseUnrender.hs new file mode 100644 index 000000000..3d13daab8 --- /dev/null +++ b/servant-client-core/src/Servant/Client/Core/ResponseUnrender.hs @@ -0,0 +1,134 @@ +{-# LANGUAGE ApplicativeDo #-} +module Servant.Client.Core.ResponseUnrender where + +import Control.Applicative +import Control.Monad +import Data.Kind (Type) +import Data.SOP +import Data.Typeable +import GHC.TypeLits +import Network.HTTP.Types.Status (Status) +import qualified Data.ByteString.Lazy as BSL +import qualified Network.HTTP.Media as M + +import Servant.API.ContentTypes +import Servant.API.MultiVerb +import Servant.API.Status +import Servant.API.UVerb.Union (Union) +import Servant.Client.Core.Response (ResponseF(..)) +import qualified Servant.Client.Core.Response as Response +import Servant.API.Stream (SourceIO) +import Data.ByteString (ByteString) + +data SomeClientResponse = forall a. Typeable a => SomeClientResponse (ResponseF a) + +fromSomeClientResponse + :: forall a m. (Alternative m, Typeable a) + => SomeClientResponse + -> m (ResponseF a) +fromSomeClientResponse (SomeClientResponse Response {..}) = do + body <- maybe empty pure $ cast @_ @a responseBody + pure $ + Response + { responseBody = body, + .. + } + + +class ResponseUnrender cs a where + type ResponseBody a :: Type + type ResponseStatus a :: Nat + responseUnrender + :: M.MediaType + -> ResponseF (ResponseBody a) + -> UnrenderResult (ResponseType a) + +-- +-- FIXME: Move this to the client in its own module +class (Typeable as) => ResponseListUnrender cs as where + responseListUnrender + :: M.MediaType + -> SomeClientResponse + -> UnrenderResult (Union (ResponseTypes as)) + + responseListStatuses :: [Status] + +instance ResponseListUnrender cs '[] where + responseListUnrender _ _ = StatusMismatch + responseListStatuses = [] + +instance + ( Typeable a, + Typeable (ResponseBody a), + ResponseUnrender cs a, + ResponseListUnrender cs as, + KnownStatus (ResponseStatus a) + ) => + ResponseListUnrender cs (a ': as) + where + responseListUnrender c output = + Z . I <$> (responseUnrender @cs @a c =<< fromSomeClientResponse output) + <|> S <$> responseListUnrender @cs @as c output + + responseListStatuses = statusVal (Proxy @(ResponseStatus a)) : responseListStatuses @cs @as + +instance + ( KnownStatus s, + MimeUnrender ct a + ) => + ResponseUnrender cs (RespondAs (ct :: Type) s desc a) + where + type ResponseStatus (RespondAs ct s desc a) = s + type ResponseBody (RespondAs ct s desc a) = BSL.ByteString + + responseUnrender _ output = do + guard (responseStatusCode output == statusVal (Proxy @s)) + either UnrenderError UnrenderSuccess $ + mimeUnrender (Proxy @ct) (Response.responseBody output) + +instance (KnownStatus s) => ResponseUnrender cs (RespondAs '() s desc ()) where + type ResponseStatus (RespondAs '() s desc ()) = s + type ResponseBody (RespondAs '() s desc ()) = () + + responseUnrender _ output = + guard (responseStatusCode output == statusVal (Proxy @s)) + +instance + (KnownStatus s) + => ResponseUnrender cs (RespondStreaming s desc framing ct) + where + type ResponseStatus (RespondStreaming s desc framing ct) = s + type ResponseBody (RespondStreaming s desc framing ct) = SourceIO ByteString + + responseUnrender _ resp = do + guard (Response.responseStatusCode resp == statusVal (Proxy @s)) + pure $ Response.responseBody resp + +instance + (AllMimeUnrender cs a, KnownStatus s) + => ResponseUnrender cs (Respond s desc a) where + type ResponseStatus (Respond s desc a) = s + type ResponseBody (Respond s desc a) = BSL.ByteString + + responseUnrender c output = do + guard (responseStatusCode output == statusVal (Proxy @s)) + let results = allMimeUnrender (Proxy @cs) + case lookup c results of + Nothing -> empty + Just f -> either UnrenderError UnrenderSuccess (f (responseBody output)) + +instance + ( AsHeaders xs (ResponseType r) a, + ServantHeaders hs xs, + ResponseUnrender cs r + ) => + ResponseUnrender cs (WithHeaders hs a r) + where + type ResponseStatus (WithHeaders hs a r) = ResponseStatus r + type ResponseBody (WithHeaders hs a r) = ResponseBody r + + responseUnrender c output = do + x <- responseUnrender @cs @r c output + case extractHeaders @hs (responseHeaders output) of + Nothing -> UnrenderError "Failed to parse headers" + Just hs -> pure $ fromHeaders @xs (hs, x) diff --git a/servant-client/test/Servant/ClientTestUtils.hs b/servant-client/test/Servant/ClientTestUtils.hs index 1d6b57b19..e2579312f 100644 --- a/servant-client/test/Servant/ClientTestUtils.hs +++ b/servant-client/test/Servant/ClientTestUtils.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} @@ -75,6 +74,7 @@ import qualified Servant.Client.Core.Auth as Auth import Servant.Server import Servant.Server.Experimental.Auth import Servant.Test.ComprehensiveAPI +import Servant.API.MultiVerb -- This declaration simply checks that all instances are in place. _ = client comprehensiveAPIWithoutStreaming @@ -145,6 +145,23 @@ instance ToDeepQuery Filter where , (["name"], Just (Text.pack name')) ] +-- MultiVerb test endpoint + +data ErrorResponse a + +data UserNotFound + +type instance ResponseType (ErrorResponse a ) = a + +type GetUserVerb = + MultiVerb + 'GET + '[JSON] + '[ ErrorResponse UserNotFound, + Respond 200 "User found" Person + ] + (Maybe Person) + type Api = Get '[JSON] Person :<|> "get" :> Get '[JSON] Person @@ -221,6 +238,7 @@ uverbGetSuccessOrRedirect :: Bool WithStatus 301 Text]) uverbGetCreated :: ClientM (Union '[WithStatus 201 Person]) recordRoutes :: RecordRoutes (AsClientT ClientM) +captureVerbatim :: Verbatim -> ClientM Text getRoot :<|> getGet @@ -282,15 +300,15 @@ server = serve api ( } ) :<|> return alice - :<|> (Tagged $ \ _request respond -> respond $ Wai.responseLBS HTTP.ok200 [] "rawSuccess") - :<|> (Tagged $ \ request respond -> (respond $ Wai.responseLBS HTTP.ok200 (Wai.requestHeaders $ request) "rawSuccess")) - :<|> (Tagged $ \ _request respond -> respond $ Wai.responseLBS HTTP.badRequest400 [] "rawFailure") + :<|> Tagged (\ _request respond -> respond $ Wai.responseLBS HTTP.ok200 [] "rawSuccess") + :<|> Tagged (\ request respond -> respond $ Wai.responseLBS HTTP.ok200 (Wai.requestHeaders request) "rawSuccess") + :<|> Tagged (\ _request respond -> respond $ Wai.responseLBS HTTP.badRequest400 [] "rawFailure") :<|> (\ a b c d -> return (a, b, c, d)) - :<|> (return $ addHeader 1729 $ addHeader "eg2" True) + :<|> return (addHeader 1729 $ addHeader "eg2" True) :<|> (pure . Z . I . WithStatus $ addHeader 1729 $ addHeader "eg2" True) - :<|> (return $ addHeader "cookie1" $ addHeader "cookie2" True) + :<|> return (addHeader "cookie1" $ addHeader "cookie2" True) :<|> return NoContent - :<|> (Tagged $ \ _request respond -> respond $ Wai.responseLBS HTTP.found302 [("Location", "testlocation"), ("Set-Cookie", "testcookie=test")] "") + :<|> Tagged (\ _request respond -> respond $ Wai.responseLBS HTTP.found302 [("Location", "testlocation"), ("Set-Cookie", "testcookie=test")] "") :<|> emptyServer :<|> (\shouldRedirect -> if shouldRedirect then respond (WithStatus @301 ("redirecting" :: Text)) @@ -318,10 +336,10 @@ failApi = Proxy failServer :: Application failServer = serve failApi ( - (Tagged $ \ _request respond -> respond $ Wai.responseLBS HTTP.ok200 [] "") + Tagged (\ _request respond -> respond $ Wai.responseLBS HTTP.ok200 [] "") :<|> (\ _capture -> Tagged $ \_request respond -> respond $ Wai.responseLBS HTTP.ok200 [("content-type", "application/json")] "") - :<|> (Tagged $ \_request respond -> respond $ Wai.responseLBS HTTP.ok200 [("content-type", "fooooo")] "") - :<|> (Tagged $ \_request respond -> respond $ Wai.responseLBS HTTP.ok200 [("content-type", "application/x-www-form-urlencoded"), ("X-Example1", "1"), ("X-Example2", "foo")] "") + :<|> Tagged (\_request respond -> respond $ Wai.responseLBS HTTP.ok200 [("content-type", "fooooo")] "") + :<|> Tagged (\_request respond -> respond $ Wai.responseLBS HTTP.ok200 [("content-type", "application/x-www-form-urlencoded"), ("X-Example1", "1"), ("X-Example2", "foo")] "") ) -- * basic auth stuff diff --git a/servant-client/test/Servant/MiddlewareSpec.hs b/servant-client/test/Servant/MiddlewareSpec.hs index 648ca1311..9b7c2a943 100644 --- a/servant-client/test/Servant/MiddlewareSpec.hs +++ b/servant-client/test/Servant/MiddlewareSpec.hs @@ -16,9 +16,7 @@ module Servant.MiddlewareSpec (spec) where -import Control.Arrow - ( left, - ) +import Control.Arrow (left) import Control.Concurrent (newEmptyMVar, putMVar, takeMVar) import Control.Exception (Exception, throwIO, try) import Control.Monad.IO.Class @@ -114,4 +112,4 @@ spec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do left show <$> runClientWithMiddleware getGet mid baseUrl `shouldReturn` Right alice ref <- readIORef ref - ref `shouldBe` ["req1", "req2", "req3", "resp3", "resp2", "resp1"] \ No newline at end of file + ref `shouldBe` ["req1", "req2", "req3", "resp3", "resp2", "resp1"] diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index 31da0e164..cdab2b9bd 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -103,8 +103,9 @@ library Servant.Server.Internal.DelayedIO Servant.Server.Internal.ErrorFormatter Servant.Server.Internal.Handler - Servant.Server.Internal.Router + Servant.Server.Internal.ResponseRender Servant.Server.Internal.RouteResult + Servant.Server.Internal.Router Servant.Server.Internal.RoutingApplication Servant.Server.Internal.ServerError Servant.Server.StaticFiles diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index 5933e0d2c..a8e0e5834 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -1,5 +1,5 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE EmptyCase #-} module Servant.Server.Internal ( module Servant.Server.Internal @@ -40,7 +40,7 @@ import Data.Tagged import qualified Data.Text as T import Data.Typeable import GHC.Generics -import GHC.TypeLits (KnownNat, KnownSymbol, TypeError, ErrorMessage (..), symbolVal, Nat) +import GHC.TypeLits (KnownNat, KnownSymbol, TypeError, ErrorMessage (..), symbolVal) import qualified Network.HTTP.Media as NHM import Network.HTTP.Types hiding (statusCode, Header, ResponseHeaders) @@ -63,7 +63,7 @@ import Servant.API.Generic (GenericMode(..), ToServant, ToServantApi, import Servant.API.ContentTypes (AcceptHeader (..), AllCTRender (..), AllCTUnrender (..), AllMime, MimeRender (..), MimeUnrender (..), NoContent, - canHandleAcceptH, AllMimeRender, AllMimeUnrender) + canHandleAcceptH) import Servant.API.Modifiers (FoldLenient, FoldRequired, RequestArgument, unfoldRequestArgument) @@ -71,18 +71,12 @@ import Servant.API.QueryString (FromDeepQuery(..)) import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders, getResponse) import Servant.API.Status - (statusFromNat, KnownStatus) + (statusFromNat) import qualified Servant.Types.SourceT as S import Servant.API.TypeErrors import Web.HttpApiData (FromHttpApiData, parseHeader, parseQueryParam, parseUrlPiece, parseUrlPieces) -import Network.HTTP.Types (Header) -import Data.Sequence (Seq) -import qualified Network.Wai as Wai -import Data.ByteString (ByteString) -import qualified Network.HTTP.Media as M -import Control.Applicative (Alternative) import Servant.Server.Internal.BasicAuth import Servant.Server.Internal.Context @@ -94,9 +88,9 @@ import Servant.Server.Internal.Router import Servant.Server.Internal.RouteResult import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.ServerError +import Servant.Server.Internal.ResponseRender import Servant.API.MultiVerb import Servant.API.TypeLevel (AtMostOneFragment, FragmentUnique) -import Data.SOP class HasServer api context where -- | The type of a server for this API, given a monad to run effects in. @@ -1130,159 +1124,10 @@ instance servantSrvN :: ServerT (ToServantApi api) n = hoistServerWithContext (Proxy @(ToServantApi api)) pctx nat servantSrvM -data InternalResponse a = InternalResponse - { statusCode :: Status - , headers :: Seq Header - , responseBody :: a - } deriving stock (Eq, Show, Generic, Typeable, Functor, Foldable, Traversable) - -class (Typeable a) => IsWaiBody a where - responseToWai :: InternalResponse a -> Wai.Response - -instance IsWaiBody BSL.ByteString where - responseToWai r = - Wai.responseLBS - (statusCode r) - (toList (headers r)) - (responseBody r) - -instance IsWaiBody () where - responseToWai r = - Wai.responseLBS - (statusCode r) - (toList (headers r)) - mempty - -instance IsWaiBody (SourceIO ByteString) where - responseToWai r = - Wai.responseStream - (statusCode r) - (toList (headers r)) - $ \output flush -> do - foreach - (const (pure ())) - (\chunk -> output (byteString chunk) *> flush) - (responseBody r) - - -class (IsWaiBody (ResponseBody a)) => IsResponse cs a where - type ResponseStatus a :: Nat - type ResponseBody a :: Type - - responseRender :: AcceptHeader -> ResponseType a -> Maybe (InternalResponse (ResponseBody a)) - responseUnrender :: M.MediaType -> InternalResponse (ResponseBody a) -> UnrenderResult (ResponseType a) - -data SomeResponse = forall a. (IsWaiBody a) => SomeResponse (InternalResponse a) - -addContentType :: forall ct a. (Accept ct) => InternalResponse a -> InternalResponse a -addContentType = addContentType' (contentType (Proxy @ct)) - -addContentType' :: M.MediaType -> InternalResponse a -> InternalResponse a -addContentType' c r = r {headers = (hContentType, M.renderHeader c) <| headers r} - -setEmptyBody :: SomeResponse -> SomeResponse -setEmptyBody (SomeResponse r) = SomeResponse (go r) - where - go :: InternalResponse a -> InternalResponse BSL.ByteString - go InternalResponse {..} = InternalResponse {responseBody = mempty, ..} - -someResponseToWai :: SomeResponse -> Wai.Response -someResponseToWai (SomeResponse r) = responseToWai r - -fromSomeResponse :: (Alternative m, Typeable a) => SomeResponse -> m (InternalResponse a) -fromSomeResponse (SomeResponse InternalResponse {..}) = do - body <- maybe empty pure $ cast responseBody - pure $ - InternalResponse - { responseBody = body, - .. - } - -instance - ( KnownStatus s, - MimeRender ct a, - MimeUnrender ct a - ) => - IsResponse cs (RespondAs (ct :: Type) s desc a) - where - type ResponseStatus (RespondAs ct s desc a) = s - type ResponseBody (RespondAs ct s desc a) = BSL.ByteString - - responseRender _ x = - pure . addContentType @ct $ - InternalResponse - { statusCode = statusVal (Proxy @s), - responseBody = mimeRender (Proxy @ct) x, - headers = mempty - } - - responseUnrender _ output = do - guard (statusCode output == statusVal (Proxy @s)) - either UnrenderError UnrenderSuccess $ - mimeUnrender (Proxy @ct) (responseBody output) - -instance (KnownStatus s) => IsResponse cs (RespondAs '() s desc ()) where - type ResponseStatus (RespondAs '() s desc ()) = s - type ResponseBody (RespondAs '() s desc ()) = () - - responseRender _ _ = - pure $ - InternalResponse - { statusCode = statusVal (Proxy @s), - responseBody = (), - headers = mempty - } - - responseUnrender _ output = - guard (statusCode output == statusVal (Proxy @s)) - -instance - (Accept ct, KnownStatus s) => - IsResponse cs (RespondStreaming s desc framing ct) - where - type ResponseStatus (RespondStreaming s desc framing ct) = s - type ResponseBody (RespondStreaming s desc framing ct) = SourceIO ByteString - responseRender _ x = - pure . addContentType @ct $ - InternalResponse - { statusCode = statusVal (Proxy @s), - responseBody = x - } - - responseUnrender _ resp = do - guard (statusCode resp == statusVal (Proxy @s)) - pure $ responseBody resp - -instance (AllMimeRender cs a, AllMimeUnrender cs a, KnownStatus s) => IsResponse cs (Respond s desc a) where - type ResponseStatus (Respond s desc a) = s - type ResponseBody (Respond s desc a) = BSL.ByteString - - -- Note: here it seems like we are rendering for all possible content types, - -- only to choose the correct one afterwards. However, render results besides the - -- one picked by 'M.mapAcceptMedia' are not evaluated, and therefore nor are the - -- corresponding rendering functions. - responseRender (AcceptHeader acc) x = - M.mapAcceptMedia (map (uncurry mkRenderOutput) (allMimeRender (Proxy @cs) x)) acc - where - mkRenderOutput :: M.MediaType -> BSL.ByteString -> (M.MediaType, Response) - mkRenderOutput c body = - (c,) . addContentType' c $ - InternalResponse - { statusCode = statusVal (Proxy @s), - responseBody = body, - headers = mempty - } - - responseUnrender c output = do - guard (statusCode output == statusVal (Proxy @s)) - let results = allMimeUnrender (Proxy @cs) - case lookup c results of - Nothing -> empty - Just f -> either UnrenderError UnrenderSuccess (f (responseBody output)) instance ( HasAcceptCheck cs, - IsResponseList cs as, + ResponseListRender cs as, AsUnion as r, ReflectMethod method ) => @@ -1314,45 +1159,6 @@ instance where method = reflectMethod (Proxy @method) -instance - ( AsHeaders xs (ResponseType r) a, - ServantHeaders hs xs, - IsResponse cs r - ) => - IsResponse cs (WithHeaders hs a r) - where - type ResponseStatus (WithHeaders hs a r) = ResponseStatus r - type ResponseBody (WithHeaders hs a r) = ResponseBody r - - responseRender acc x = addHeaders <$> responseRender @cs @r acc y - where - (hs, y) = toHeaders @xs x - addHeaders r = - r - { headers = headers r <> Seq.fromList (constructHeaders @hs hs) - } - - responseUnrender c output = do - x <- responseUnrender @cs @r c output - case extractHeaders @hs (headers output) of - Nothing -> UnrenderError "Failed to parse headers" - Just hs -> pure $ fromHeaders @xs (hs, x) - -instance - ( IsResponse cs a, - KnownStatus (ResponseStatus a) - ) => - IsResponseList cs (a ': as) - where - responseListRender acc (Z (I x)) = fmap SomeResponse (responseRender @cs @a acc x) - responseListRender acc (S x) = responseListRender @cs @as acc x - - responseListUnrender c output = - Z . I <$> (responseUnrender @cs @a c =<< fromSomeResponse output) - <|> S <$> responseListUnrender @cs @as c output - - responseListStatuses = statusVal (Proxy @(ResponseStatus a)) : responseListStatuses @cs @as - class HasAcceptCheck cs where acceptCheck' :: Proxy cs -> AcceptHeader -> DelayedIO () @@ -1361,4 +1167,3 @@ instance (AllMime cs) => HasAcceptCheck cs where instance HasAcceptCheck '() where acceptCheck' _ _ = pure () - diff --git a/servant-server/src/Servant/Server/Internal/ResponseRender.hs b/servant-server/src/Servant/Server/Internal/ResponseRender.hs new file mode 100644 index 000000000..3806ede49 --- /dev/null +++ b/servant-server/src/Servant/Server/Internal/ResponseRender.hs @@ -0,0 +1,186 @@ +{-# LANGUAGE EmptyCase #-} + +module Servant.Server.Internal.ResponseRender where + +import Data.ByteString (ByteString) +import Data.Kind (Type) +import Data.Typeable +import GHC.TypeLits +import qualified Data.ByteString.Lazy as BSL +import qualified Network.Wai as Wai +import Network.HTTP.Types (Status, hContentType) +import Data.SOP +import qualified Servant.Types.SourceT as S +import qualified Data.ByteString.Builder as BB +import qualified Data.Sequence as Seq + +import Servant.API.ContentTypes (AcceptHeader (..), AllMimeRender, MimeRender, Accept, allMimeRender, mimeRender, contentType) +import Servant.API.MultiVerb +import Servant.API.Status +import Servant.API.Stream (SourceIO) +import Servant.API.UVerb.Union +import Servant.Types.ResponseList +import qualified Network.HTTP.Media as M +import Data.Foldable (toList) +import Data.Sequence ((<|)) + +class (Typeable a) => IsWaiBody a where + responseToWai :: InternalResponse a -> Wai.Response + +instance IsWaiBody BSL.ByteString where + responseToWai r = + Wai.responseLBS + (statusCode r) + (toList (headers r)) + (responseBody r) + +instance IsWaiBody () where + responseToWai r = + Wai.responseLBS + (statusCode r) + (toList (headers r)) + mempty + +instance IsWaiBody (SourceIO ByteString) where + responseToWai r = + Wai.responseStream + (statusCode r) + (toList (headers r)) + $ \output flush -> do + S.foreach + (const (pure ())) + (\chunk -> output (BB.byteString chunk) *> flush) + (responseBody r) + +data SomeResponse = forall a. (IsWaiBody a) => SomeResponse (InternalResponse a) + + +class ResponseListRender cs as where + responseListRender + :: AcceptHeader + -> Union (ResponseTypes as) + -> Maybe SomeResponse + responseListStatuses :: [Status] + +instance ResponseListRender cs '[] where + responseListRender _ x = case x of {} + responseListStatuses = [] + + +class (IsWaiBody (ResponseBody a)) => ResponseRender cs a where + type ResponseStatus a :: Nat + type ResponseBody a :: Type + responseRender + :: AcceptHeader + -> ResponseType a + -> Maybe (InternalResponse (ResponseBody a)) + +instance + ( ResponseRender cs a, + ResponseListRender cs as, + KnownStatus (ResponseStatus a) + ) => + ResponseListRender cs (a ': as) + where + responseListRender acc (Z (I x)) = fmap SomeResponse (responseRender @cs @a acc x) + responseListRender acc (S x) = responseListRender @cs @as acc x + + responseListStatuses = statusVal (Proxy @(ResponseStatus a)) : responseListStatuses @cs @as + +instance + ( AsHeaders xs (ResponseType r) a, + ServantHeaders hs xs, + ResponseRender cs r + ) => + ResponseRender cs (WithHeaders hs a r) + where + type ResponseStatus (WithHeaders hs a r) = ResponseStatus r + type ResponseBody (WithHeaders hs a r) = ResponseBody r + + responseRender acc x = addHeaders <$> responseRender @cs @r acc y + where + (hs, y) = toHeaders @xs x + addHeaders r = + r + { headers = headers r <> Seq.fromList (constructHeaders @hs hs) + } + +instance + ( KnownStatus s, + MimeRender ct a + ) => + ResponseRender cs (RespondAs (ct :: Type) s desc a) + where + type ResponseStatus (RespondAs ct s desc a) = s + type ResponseBody (RespondAs ct s desc a) = BSL.ByteString + + responseRender _ x = + pure . addContentType @ct $ + InternalResponse + { statusCode = statusVal (Proxy @s), + responseBody = mimeRender (Proxy @ct) x, + headers = mempty + } + +instance (KnownStatus s) => ResponseRender cs (RespondAs '() s desc ()) where + type ResponseStatus (RespondAs '() s desc ()) = s + type ResponseBody (RespondAs '() s desc ()) = () + + responseRender _ _ = + pure $ + InternalResponse + { statusCode = statusVal (Proxy @s), + responseBody = (), + headers = mempty + } + +instance + (Accept ct, KnownStatus s) + => ResponseRender cs (RespondStreaming s desc framing ct) + where + type ResponseStatus (RespondStreaming s desc framing ct) = s + type ResponseBody (RespondStreaming s desc framing ct) = SourceIO ByteString + responseRender _ x = + pure . addContentType @ct $ + InternalResponse + { statusCode = statusVal (Proxy @s), + responseBody = x, + headers = mempty + } + +instance + (AllMimeRender cs a, KnownStatus s) + => ResponseRender cs (Respond s desc a) where + type ResponseStatus (Respond s desc a) = s + type ResponseBody (Respond s desc a) = BSL.ByteString + + -- Note: here it seems like we are rendering for all possible content types, + -- only to choose the correct one afterwards. However, render results besides the + -- one picked by 'M.mapAcceptMedia' are not evaluated, and therefore nor are the + -- corresponding rendering functions. + responseRender (AcceptHeader acc) x = + M.mapAcceptMedia (map (uncurry mkRenderOutput) (allMimeRender (Proxy @cs) x)) acc + where + mkRenderOutput :: M.MediaType -> BSL.ByteString -> (M.MediaType, InternalResponse BSL.ByteString) + mkRenderOutput c body = + (c,) . addContentType' c $ + InternalResponse + { statusCode = statusVal (Proxy @s), + responseBody = body, + headers = mempty + } + +addContentType :: forall ct a. (Accept ct) => InternalResponse a -> InternalResponse a +addContentType = addContentType' (contentType (Proxy @ct)) + +addContentType' :: M.MediaType -> InternalResponse a -> InternalResponse a +addContentType' c r = r {headers = (hContentType, M.renderHeader c) <| headers r} + +setEmptyBody :: SomeResponse -> SomeResponse +setEmptyBody (SomeResponse r) = SomeResponse (go r) + where + go :: InternalResponse a -> InternalResponse BSL.ByteString + go InternalResponse {..} = InternalResponse {responseBody = mempty, ..} + +someResponseToWai :: SomeResponse -> Wai.Response +someResponseToWai (SomeResponse r) = responseToWai r diff --git a/servant/servant.cabal b/servant/servant.cabal index 1bd5e7303..aa5fb89a2 100644 --- a/servant/servant.cabal +++ b/servant/servant.cabal @@ -116,7 +116,9 @@ library Servant.API.WithResource -- Types - exposed-modules: Servant.Types.SourceT + exposed-modules: + Servant.Types.SourceT + Servant.Types.ResponseList -- Test stuff exposed-modules: Servant.Test.ComprehensiveAPI diff --git a/servant/src/Servant/API/MultiVerb.hs b/servant/src/Servant/API/MultiVerb.hs index aee2d855c..66be34a6c 100644 --- a/servant/src/Servant/API/MultiVerb.hs +++ b/servant/src/Servant/API/MultiVerb.hs @@ -23,7 +23,8 @@ module Servant.API.MultiVerb GenericAsUnion (..), ResponseType, ResponseTypes, - UnrenderResult(..) + UnrenderResult(..), + ServantHeaders(..) ) where diff --git a/servant/src/Servant/Types/ResponseList.hs b/servant/src/Servant/Types/ResponseList.hs new file mode 100644 index 000000000..6e7b9af49 --- /dev/null +++ b/servant/src/Servant/Types/ResponseList.hs @@ -0,0 +1,14 @@ +{-# LANGUAGE DeriveTraversable #-} + +module Servant.Types.ResponseList where + +import Network.HTTP.Types (Status, Header) +import Data.Sequence (Seq) +import GHC.Generics (Generic) +import Data.Data (Typeable) + +data InternalResponse a = InternalResponse + { statusCode :: Status + , headers :: Seq Header + , responseBody :: a + } deriving stock (Eq, Show, Generic, Typeable, Functor, Foldable, Traversable)