diff --git a/.gitignore b/.gitignore index 9464d4b..ef29e5a 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,5 @@ Local* *~ *.swp /core +gen gen-* diff --git a/tests/Makefile b/tests/Makefile index 2d3987d..6667360 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -1,17 +1,17 @@ all: gen-cpp gen-python gen-cpp: - cd tests/protos && mkdir -p ../gen-cpp && \ - protoc --cpp_out ../gen-cpp --grpc_out ../gen-cpp \ + cd tests/protos && mkdir -p ../gen/cpp && \ + protoc --cpp_out ../gen/cpp --grpc_out ../gen/cpp \ --plugin=protoc-gen-grpc=/usr/local/bin/grpc_cpp_plugin \ *.proto gen-python: - mkdir -p tests/gen-py && \ + mkdir -p tests/gen/python && \ python3 -m grpc_tools.protoc -I ./tests/protos \ - --python_out=tests/gen-py \ - --grpc_python_out=tests/gen-py \ + --python_out=tests/gen/python \ + --grpc_python_out=tests/gen/python \ ./tests/protos/*.proto clean: - rm -rf ./tests/gen-* + rm -rf ./tests/gen diff --git a/tests/cases/AsyncSchedule001_client.py b/tests/cases/AsyncSchedule001_client.py index 39eab54..78901b0 100644 --- a/tests/cases/AsyncSchedule001_client.py +++ b/tests/cases/AsyncSchedule001_client.py @@ -5,9 +5,9 @@ import timeit DIR = os.path.dirname(os.path.abspath(os.path.join(__file__, ".."))) -sys.path.insert(0, os.path.join(DIR, "gen-py")) +sys.path.insert(0, os.path.join(DIR, "gen/python")) -import AsyncSchedule_pb2 as P +import msg_pb2 as P import AsyncSchedule_pb2_grpc as G @@ -48,6 +48,7 @@ async def reqs(): for _ in range(2): req = P.Request(msg="hi") yield req + _call = stub.BidiStream(reqs()) count = 0 async for r in _call: diff --git a/tests/cases/AsyncSchedule001_server.hs b/tests/cases/AsyncSchedule001_server.hs index 2dd3944..7f9c36a 100644 --- a/tests/cases/AsyncSchedule001_server.hs +++ b/tests/cases/AsyncSchedule001_server.hs @@ -11,6 +11,8 @@ import HsGrpc.Common.Log import HsGrpc.Server import Proto.AsyncSchedule as P import Proto.AsyncSchedule_Fields as P +import Proto.Msg as P +import Proto.Msg_Fields as P main :: IO () main = do @@ -21,7 +23,7 @@ main = do , serverOnStarted = Just onStarted } --gprSetLogVerbosity GprLogSeverityDebug - runServer opts $ handlers + runServer opts handlers onStarted :: IO () onStarted = putStrLn "Server listening on 0.0.0.0:50051" @@ -29,9 +31,9 @@ onStarted = putStrLn "Server listening on 0.0.0.0:50051" handlers :: [ServiceHandler] handlers = -- With using 'shortUnary', the test case should not pass. - [ unary (GRPC :: GRPC P.Service "slowUnary") handleSlowUnary - , unary (GRPC :: GRPC P.Service "depUnary") handleDepUnary - , bidiStream (GRPC :: GRPC P.Service "bidiStream") handleBidiStream + [ unary (GRPC :: GRPC P.AsyncScheduleService "slowUnary") handleSlowUnary + , unary (GRPC :: GRPC P.AsyncScheduleService "depUnary") handleDepUnary + , bidiStream (GRPC :: GRPC P.AsyncScheduleService "bidiStream") handleBidiStream ] handleSlowUnary :: UnaryHandler P.Request P.Reply @@ -43,11 +45,11 @@ handleSlowUnary _ctx _req = do -- NOTE: not thread-safe notifyMVar :: MVar () -notifyMVar = unsafePerformIO $ newEmptyMVar +notifyMVar = unsafePerformIO newEmptyMVar {-# NOINLINE notifyMVar #-} exitedMVar :: MVar () -exitedMVar = unsafePerformIO $ newEmptyMVar +exitedMVar = unsafePerformIO newEmptyMVar {-# NOINLINE exitedMVar #-} handleDepUnary :: UnaryHandler P.Request P.Reply diff --git a/tests/cases/credentials/README.md b/tests/cases/credentials/README.md new file mode 100644 index 0000000..be0009f --- /dev/null +++ b/tests/cases/credentials/README.md @@ -0,0 +1 @@ +Cert from diff --git a/tests/cases/credentials/root.crt b/tests/cases/credentials/root.crt new file mode 100644 index 0000000..0fa644d --- /dev/null +++ b/tests/cases/credentials/root.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDWTCCAkGgAwIBAgIJAPOConZMwykwMA0GCSqGSIb3DQEBCwUAMEIxCzAJBgNV +BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMQ8wDQYDVQQKDAZHb29nbGUxDTAL +BgNVBAsMBGdSUEMwIBcNMTkwNjI0MjIyMDA3WhgPMjExOTA1MzEyMjIwMDdaMEIx +CzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMQ8wDQYDVQQKDAZHb29n +bGUxDTALBgNVBAsMBGdSUEMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB +AQCwqei3TfyLidnQNDJ2lierMYo229K92DuORni7nSjJQ59Jc3dNMsmqGQJjCD8o +6mTlKM/oCbs27Wpx+OxcOLvT95j2kiDGca1fCvaMdguIod09SWiyMpv/hp0trLv7 +NJIKHznath6rHYX2Ii3fZ1yCPzyQbEPSAA+GNpoNm1v1ZWmWKke9v7vLlS3inNlW +Mt9jepK7DrtbNZnVDjeItnppBSbVYRMxIyNHkepFbqXx5TpkCvl4M4XQZw9bfSxQ +i3WZ3q+T1Tw//OUdPNc+OfMhu0MA0QoMwikskP0NaIC3dbJZ5Ogx0RcnaB4E+9C6 +O/znUEh3WuKVl5HXBF+UwWoFAgMBAAGjUDBOMB0GA1UdDgQWBBRm3JIgzgK4G97J +fbMGatWMZc7V3jAfBgNVHSMEGDAWgBRm3JIgzgK4G97JfbMGatWMZc7V3jAMBgNV +HRMEBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCNiV8x41if094ry2srS0YucpiN +3rTPk08FOLsENTMYai524TGXJti1P6ofGr5KXCL0uxTByHE3fEiMMud2TIY5iHQo +Y4mzDTTcb+Q7yKHwYZMlcp6nO8W+NeY5t+S0JPHhb8deKWepcN2UpXBUYQLw7AiE +l96T9Gi+vC9h/XE5IVwHFQXTxf5UYzXtW1nfapvrOONg/ms41dgmrRKIi+knWfiJ +FdHpHX2sfDAoJtnpEISX+nxRGNVTLY64utXWm4yxaZJshvy2s8zWJgRg7rtwAhTT +Np9E9MnihXLEmDI4Co9XlLPJyZFmqImsbmVuKFeQOCiLAoPJaMI2lbi7fiTo +-----END CERTIFICATE----- diff --git a/tests/cases/test_auth_client.py b/tests/cases/test_auth_client.py new file mode 100644 index 0000000..c39b54f --- /dev/null +++ b/tests/cases/test_auth_client.py @@ -0,0 +1,109 @@ +import os +import asyncio +import grpc +import sys + +DIR = os.path.dirname(os.path.abspath(os.path.join(__file__, ".."))) +sys.path.insert(0, os.path.join(DIR, "gen/python")) + +import msg_pb2 as P +import auth_pb2_grpc as G + +host = "localhost" +port = 50051 +TOKEN = "Basic dXNlcjpwYXNzd2Q=" + +# TODO + + +def _load_credential_from_file(filepath): + real_path = os.path.join(os.path.dirname(__file__), filepath) + with open(real_path, "rb") as f: + return f.read() + + +SERVER_CERTIFICATE = _load_credential_from_file("credentials/localhost.crt") +SERVER_CERTIFICATE_KEY = _load_credential_from_file("credentials/localhost.key") +ROOT_CERTIFICATE = _load_credential_from_file("credentials/root.crt") + + +class BasicAuth(grpc.AuthMetadataPlugin): + def __init__(self, token): + self.token = token + + def __call__( + self, + context: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ) -> None: + callback((("authorization", self.token),), None) + + +def create_secure_channel(addr, token=None) -> grpc.aio.Channel: + # Channel credential will be valid for the entire channel + channel_credential = grpc.ssl_channel_credentials( + root_certificates=ROOT_CERTIFICATE, + # private_key=SERVER_CERTIFICATE_KEY, + # certificate_chain=SERVER_CERTIFICATE, + ) + # Call credential object will be invoked for every single RPC + if token: + call_credentials = grpc.metadata_call_credentials( + BasicAuth(token), name="basic auth" + ) + # Combining channel credentials and call credentials together + composite_credentials = grpc.composite_channel_credentials( + channel_credential, + call_credentials, + ) + channel = grpc.aio.secure_channel(addr, composite_credentials) + else: + channel = grpc.aio.secure_channel(addr, channel_credential) + return channel + + +async def unary(): + channel = grpc.aio.insecure_channel(f"{host}:{port}") + stub = G.AuthServiceStub(channel) + req = P.Request(msg="x") + r = await stub.Unary(req) + assert r.msg == "x" + + +async def secure_unary(): + channel = create_secure_channel(f"{host}:{port}") + stub = G.AuthServiceStub(channel) + req = P.Request(msg="x") + r = await stub.Unary(req) + assert r.msg == "x" + + +async def secure_token_unary(): + channel = create_secure_channel(f"{host}:{port}", token=TOKEN) + stub = G.AuthServiceStub(channel) + req = P.Request(msg="x") + r = await stub.Unary(req) + assert r.msg == "x" + + +async def token_bidi(): + channel = grpc.aio.insecure_channel(f"{host}:{port}") + stub = G.AuthServiceStub(channel) + + async def reqs(): + for _ in range(2): + req = P.Request(msg="hi") + yield req + + _call = stub.BidiStream(reqs(), metadata=[("authorization", TOKEN)]) + count = 0 + async for r in _call: + count += 1 + print("=> ", count) + + +def test(): + # asyncio.run(unary()) + # asyncio.run(secure_unary()) + # asyncio.run(secure_token_unary()) + asyncio.run(token_bidi()) diff --git a/tests/cases/test_auth_server.hs b/tests/cases/test_auth_server.hs new file mode 100644 index 0000000..3575de8 --- /dev/null +++ b/tests/cases/test_auth_server.hs @@ -0,0 +1,84 @@ +module Main (main) where + +import Control.Monad +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS +import Data.Either +import Data.ProtoLens (defMessage) +import Lens.Micro + +import HsGrpc.Common.Log +import HsGrpc.Server +import HsGrpc.Server.Types +import Proto.Auth as P +import Proto.Auth_Fields as P +import Proto.Msg as P +import Proto.Msg_Fields as P + +main :: IO () +main = do + let tokens = [ "dXNlcjpwYXNzd2Q=" -- echo -n "user:passwd" | base64 + , "dXNlcjE6cGFzc3dkMQ==" -- echo -n "user1:passwd1" | base64 + ] + --sslOptions <- Just <$>readTlsPemFile + -- "tests/cases/credentials/localhost.key" + -- "tests/cases/credentials/localhost.crt" + -- Nothing + -- (Just "tests/cases/credentials/root.crt") + sslOptions <- pure Nothing + + let opts = defaultServerOpts + { serverHost = "127.0.0.1" + , serverPort = 50051 + , serverSslOptions = sslOptions + , serverAuthTokens = tokens + , serverParallelism = 1 + , serverOnStarted = Just onStarted + } + runServer opts handlers + +readTlsPemFile + :: String -> String -> Maybe String + -> IO SslServerCredentialsOptions +readTlsPemFile keyPath certPath caPath = do + key <- BS.readFile keyPath + cert <- BS.readFile certPath + ca <- mapM BS.readFile caPath + let authType = maybe GrpcSslDontRequestClientCertificate + (const GrpcSslRequestAndRequireClientCertificateAndVerify) + caPath + pure $ SslServerCredentialsOptions{ pemKeyCertPairs = [(key, cert)] + , pemRootCerts = ca + , clientAuthType = authType + } + + +onStarted :: IO () +onStarted = putStrLn "Server listening on 0.0.0.0:50051" + + +handlers :: [ServiceHandler] +handlers = + [ unary (GRPC :: GRPC P.AuthService "unary") handleUnary + , clientStream (GRPC :: GRPC P.AuthService "clientStream") handleClientStream + , serverStream (GRPC :: GRPC P.AuthService "serverStream") handleServerStream + , bidiStream (GRPC :: GRPC P.AuthService "bidiStream") handleBidiStream + ] + +handleUnary :: UnaryHandler P.Request P.Reply +handleUnary _ctx req = pure $ defMessage & P.msg .~ (req ^. P.msg) + +handleClientStream :: ClientStreamHandler P.Request P.Reply +handleClientStream = undefined + +handleServerStream :: ServerStreamHandler P.Request P.Reply () +handleServerStream = undefined + +handleBidiStream :: BidiStreamHandler P.Request P.Reply () +handleBidiStream _ctx stream = whileM $ do + m_req <- streamRead stream + case m_req of + Just req -> do + let reply = defMessage & P.msg .~ ("hi, " <> req ^. P.msg) + isRight <$> streamWrite stream (Just reply) + Nothing -> putStrLn "Client closed" >> pure False diff --git a/tests/hs-grpc-tests.cabal b/tests/hs-grpc-tests.cabal index 2a1a62b..e46acde 100644 --- a/tests/hs-grpc-tests.cabal +++ b/tests/hs-grpc-tests.cabal @@ -1,4 +1,4 @@ -cabal-version: 2.4 +cabal-version: 3.4 name: hs-grpc-tests version: 0.1.0.0 synopsis: Tests for hs-grpc @@ -19,6 +19,11 @@ custom-setup , Cabal >=2.4 && <4 , proto-lens-setup ^>=0.4 +-- FIXME: Support to build with template-haskell +flag hsgrpc_enable_asan + default: False + description: Enable AddressSanitizer. + library hs-source-dirs: . build-depends: @@ -28,10 +33,18 @@ library exposed-modules: Proto.AsyncSchedule Proto.AsyncSchedule_Fields + Proto.Auth + Proto.Auth_Fields + Proto.Msg + Proto.Msg_Fields autogen-modules: + Proto.Msg + Proto.Msg_Fields Proto.AsyncSchedule Proto.AsyncSchedule_Fields + Proto.Auth + Proto.Auth_Fields default-language: GHC2021 @@ -52,6 +65,16 @@ common common-exe DataKinds OverloadedStrings + -- XXX: Tricky options to link static archive, see: https://github.com/haskell/cabal/issues/4677 + if (flag(hsgrpc_enable_asan) && os(osx)) + ghc-options: "-optl-Wl,-lasan" + + if (flag(hsgrpc_enable_asan) && !os(osx)) + ghc-options: + -pgml g++ "-optl-Wl,--allow-multiple-definition" + "-optl-Wl,--whole-archive" "-optl-Wl,-Bstatic" "-optl-Wl,-lasan" + "-optl-Wl,-Bdynamic" "-optl-Wl,--no-whole-archive" + ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wpartial-fields -Wredundant-constraints @@ -60,3 +83,7 @@ common common-exe executable AsyncSchedule001_server import: common-exe main-is: AsyncSchedule001_server.hs + +executable test_auth_server + import: common-exe + main-is: test_auth_server.hs diff --git a/tests/protos/AsyncSchedule.proto b/tests/protos/AsyncSchedule.proto index d9d1f26..9881b1c 100644 --- a/tests/protos/AsyncSchedule.proto +++ b/tests/protos/AsyncSchedule.proto @@ -1,17 +1,10 @@ syntax = "proto3"; - package test; -service Service { +import "msg.proto"; + +service AsyncScheduleService { rpc SlowUnary(Request) returns (Reply) {} rpc DepUnary(Request) returns (Reply) {} rpc BidiStream(stream Request) returns (stream Reply) {} } - -message Request { - string msg = 1; -} - -message Reply { - string msg = 1; -} diff --git a/tests/protos/auth.proto b/tests/protos/auth.proto new file mode 100644 index 0000000..69eeb4d --- /dev/null +++ b/tests/protos/auth.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; +package test; + +import "msg.proto"; + +service AuthService { + rpc Unary(Request) returns (Reply) {} + rpc ClientStream(stream Request) returns (Reply) {} + rpc ServerStream(Request) returns (stream Reply) {} + rpc BidiStream(stream Request) returns (stream Reply) {} +} diff --git a/tests/protos/msg.proto b/tests/protos/msg.proto new file mode 100644 index 0000000..14d0970 --- /dev/null +++ b/tests/protos/msg.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; +package test; + +message Request { + string msg = 1; +} + +message Reply { + string msg = 1; +}