From 2aadd70c992d2457d4a70c04d6ea6552a4880b01 Mon Sep 17 00:00:00 2001 From: mu <59917266+4eUeP@users.noreply.github.com> Date: Thu, 14 Sep 2023 17:23:28 +0800 Subject: [PATCH] Support all header versions --- hstream-kafka/src/Kafka/Protocol/Encoding.hs | 21 ++ hstream-kafka/src/Kafka/Protocol/Message.hs | 68 +++- .../src/Kafka/Protocol/Message/Struct.hs | 25 ++ hstream-kafka/src/Kafka/Server.hs | 17 +- script/kafka_gen.py | 304 +++++++++++++----- 5 files changed, 344 insertions(+), 91 deletions(-) diff --git a/hstream-kafka/src/Kafka/Protocol/Encoding.hs b/hstream-kafka/src/Kafka/Protocol/Encoding.hs index 885bfaec7..8cd8a5b02 100644 --- a/hstream-kafka/src/Kafka/Protocol/Encoding.hs +++ b/hstream-kafka/src/Kafka/Protocol/Encoding.hs @@ -7,6 +7,8 @@ module Kafka.Protocol.Encoding ( Serializable (..) + , putEither + , getEither , runGet , runGet' , runPut @@ -108,6 +110,25 @@ instance (Serializable a) => GSerializable (K1 i a) where gget = K1 <$> get gput (K1 x) = put x +-- There is no easy way to support Sum types for Generic instance. +-- +-- So here we give a special case for Either +putEither :: (Serializable a, Serializable b) => Either a b -> Builder +putEither (Left x) = put x +putEither (Right x) = put x +{-# INLINE putEither #-} + +-- There is no way to support Sum types for Generic instance. +-- +-- So here we give a special case for Either +getEither + :: (Serializable a, Serializable b) + => Bool -- ^ True for Right, False for Left + -> Parser (Either a b) +getEither True = Right <$> get +getEither False = Left <$> get +{-# INLINE getEither #-} + ------------------------------------------------------------------------------- newtype DecodeError = DecodeError String diff --git a/hstream-kafka/src/Kafka/Protocol/Message.hs b/hstream-kafka/src/Kafka/Protocol/Message.hs index 0bf9913be..2c085f116 100644 --- a/hstream-kafka/src/Kafka/Protocol/Message.hs +++ b/hstream-kafka/src/Kafka/Protocol/Message.hs @@ -1,29 +1,73 @@ +{-# LANGUAGE DuplicateRecordFields #-} + module Kafka.Protocol.Message ( RequestHeader (..) , ResponseHeader (..) + , putResponseHeader + , runPutResponseHeaderLazy + , Unsupported (..) , module Kafka.Protocol.Message.Struct ) where +import qualified Data.ByteString.Lazy as BL import Data.Int -import Data.Text (Text) import GHC.Generics import Kafka.Protocol.Encoding import Kafka.Protocol.Message.Struct --- TODO: Support Optional Tagged Fields +data Unsupported = Unsupported + deriving (Show, Eq, Generic) + +instance Serializable Unsupported + data RequestHeader = RequestHeader - { requestApiKey :: !ApiKey - , requestApiVersion :: !Int16 - , requestCorrelationId :: !Int32 - , requestClientId :: !(Maybe Text) - } deriving (Show, Eq, Generic) + { requestApiKey :: {-# UNPACK #-} !ApiKey + , requestApiVersion :: {-# UNPACK #-} !Int16 + , requestCorrelationId :: {-# UNPACK #-} !Int32 + , requestClientId :: !(Either Unsupported NullableString) + , requesteTaggedFields :: !(Either Unsupported TaggedFields) + } deriving (Show, Eq) + +instance Serializable RequestHeader where + get = do + requestApiKey <- get + requestApiVersion <- get + requestCorrelationId <- get + let (reqHeaderVer, _) = getHeaderVersion requestApiKey requestApiVersion + case reqHeaderVer of + 2 -> do requestClientId <- getEither True + requesteTaggedFields <- getEither True + pure RequestHeader{..} + 1 -> do requestClientId <- getEither True + let requesteTaggedFields = Left Unsupported + in pure RequestHeader{..} + 0 -> let requestClientId = Left Unsupported + requesteTaggedFields = Left Unsupported + in pure RequestHeader{..} + v -> error $ "Unknown request header version" <> show v + {-# INLINE get #-} + + put RequestHeader{..} = + put requestApiKey + <> put requestApiVersion + <> put requestCorrelationId + <> putEither requestClientId + <> putEither requesteTaggedFields + {-# INLINE put #-} -instance Serializable RequestHeader +data ResponseHeader = ResponseHeader + { responseCorrelationId :: {-# UNPACK #-} !Int32 + , responseTaggedFields :: !(Either Unsupported TaggedFields) + } deriving (Show, Eq) -newtype ResponseHeader = ResponseHeader - { responseCorrelationId :: Int32 - } deriving (Show, Eq, Generic) +putResponseHeader :: ResponseHeader -> Builder +putResponseHeader ResponseHeader{..} = + put responseCorrelationId + <> putEither responseTaggedFields +{-# INLINE putResponseHeader #-} -instance Serializable ResponseHeader +runPutResponseHeaderLazy :: ResponseHeader -> BL.ByteString +runPutResponseHeaderLazy = toLazyByteString . putResponseHeader +{-# INLINE runPutResponseHeaderLazy #-} diff --git a/hstream-kafka/src/Kafka/Protocol/Message/Struct.hs b/hstream-kafka/src/Kafka/Protocol/Message/Struct.hs index f2db3dab3..ad6a48c3e 100644 --- a/hstream-kafka/src/Kafka/Protocol/Message/Struct.hs +++ b/hstream-kafka/src/Kafka/Protocol/Message/Struct.hs @@ -951,3 +951,28 @@ supportedApiVersions = , ApiVersionV0 (ApiKey 19) 0 0 , ApiVersionV0 (ApiKey 20) 0 0 ] + +getHeaderVersion :: ApiKey -> Int16 -> (Int16, Int16) +getHeaderVersion (ApiKey 0) 0 = (1, 0) +getHeaderVersion (ApiKey 0) 1 = (1, 0) +getHeaderVersion (ApiKey 0) 2 = (1, 0) +getHeaderVersion (ApiKey 1) 0 = (1, 0) +getHeaderVersion (ApiKey 2) 0 = (1, 0) +getHeaderVersion (ApiKey 3) 0 = (1, 0) +getHeaderVersion (ApiKey 3) 1 = (1, 0) +getHeaderVersion (ApiKey 8) 0 = (1, 0) +getHeaderVersion (ApiKey 9) 0 = (1, 0) +getHeaderVersion (ApiKey 10) 0 = (1, 0) +getHeaderVersion (ApiKey 11) 0 = (1, 0) +getHeaderVersion (ApiKey 12) 0 = (1, 0) +getHeaderVersion (ApiKey 13) 0 = (1, 0) +getHeaderVersion (ApiKey 14) 0 = (1, 0) +getHeaderVersion (ApiKey 15) 0 = (1, 0) +getHeaderVersion (ApiKey 16) 0 = (1, 0) +getHeaderVersion (ApiKey 18) 0 = (1, 0) +getHeaderVersion (ApiKey 18) 1 = (1, 0) +getHeaderVersion (ApiKey 18) 2 = (1, 0) +getHeaderVersion (ApiKey 19) 0 = (1, 0) +getHeaderVersion (ApiKey 20) 0 = (1, 0) +getHeaderVersion k v = error $ "Unknown " <> show k <> " v" <> show v +{-# INLINE getHeaderVersion #-} diff --git a/hstream-kafka/src/Kafka/Server.hs b/hstream-kafka/src/Kafka/Server.hs index 65a82d360..4cbc2468f 100644 --- a/hstream-kafka/src/Kafka/Server.hs +++ b/hstream-kafka/src/Kafka/Server.hs @@ -63,24 +63,29 @@ runServer ServerOptions{..} handlers = runHandler reqBs = do headerResult <- runParser @RequestHeader get reqBs case headerResult of - Done l requestHeader -> do - let ServiceHandler{..} = findHandler requestHeader + Done l RequestHeader{..} -> do + let ServiceHandler{..} = findHandler requestApiKey requestApiVersion case rpcHandler of UnaryHandler rpcHandler' -> do req <- runGet l resp <- rpcHandler' RequestContext req let respBs = runPutLazy resp - respHeaderBs = runPutLazy $ ResponseHeader (requestCorrelationId requestHeader) + (_, respHeaderVer) = getHeaderVersion requestApiKey requestApiVersion + respHeaderBs = + case respHeaderVer of + 0 -> runPutResponseHeaderLazy $ ResponseHeader requestCorrelationId (Left Unsupported) + 1 -> runPutResponseHeaderLazy $ ResponseHeader requestCorrelationId (Right EmptyTaggedFields) + _ -> error $ "Unknown response header version" <> show respHeaderVer let len = BSL.length (respHeaderBs <> respBs) lenBs = runPutLazy @Int32 (fromIntegral len) pure $ lenBs <> respHeaderBs <> respBs Fail _ err -> E.throwIO $ DecodeError $ "Fail, " <> err More _ -> E.throwIO $ DecodeError $ "More" - findHandler RequestHeader{..} = do + findHandler (ApiKey key) version = do let m_handler = find (\ServiceHandler{..} -> - rpcMethod == (fromIntegral requestApiKey, requestApiVersion)) handlers - errmsg = "NotImplemented: " <> show requestApiKey <> ":v" <> show requestApiVersion + rpcMethod == (key, version)) handlers + errmsg = "NotImplemented: " <> show key <> ":v" <> show version fromMaybe (error errmsg) m_handler -- from the "network-run" package. diff --git a/script/kafka_gen.py b/script/kafka_gen.py index 5107e733a..6caad6838 100755 --- a/script/kafka_gen.py +++ b/script/kafka_gen.py @@ -34,9 +34,8 @@ # ----------------------------------------------------------------------------- - - # Constants + RENAMES = {"Records": "RecordBytes"} TYPE_MAPS = { @@ -57,7 +56,18 @@ "string": "!NullableString", "bytes": "!NullableBytes", "records": "!NullableBytes", - "array": "!(KaArray {})", +} +COMPACT_TYPE_MAPS = { + "string": "!CompactString", + "bytes": "!CompactBytes", + "records": "!CompactBytes", + "array": "!(CompactKaArray {})", +} +COMPACT_NULLABLE_TYPE_MAPS = { + "string": "!CompactNullableString", + "bytes": "!CompactNullableBytes", + "records": "!CompactNullableBytes", + "array": "!(CompactKaArray {})", } GLOBAL_API_VERSION_PATCH = (0, 0) @@ -67,11 +77,27 @@ "Produce": (0, 2), } +# ----------------------------------------------------------------------------- # Variables + + +# Since order is True, the fields order is important +@dataclass(eq=True, frozen=True, order=True) +class ApiVersion: + api_key: int + api_name: str + min_version: int + max_version: int + min_flex_version: int + max_flex_version: int + + DATA_TYPES = [] SUB_DATA_TYPES = [] -# {(api_key, min_version, max_version, api_name)} + +# Set of ApiVersion API_VERSIONS = set() + DATA_TYPE_RENAMES = {} @@ -115,6 +141,7 @@ class HsDataField: name: str ty: str doc: Optional[str] = None + is_tagged: bool = False class HsData: @@ -123,21 +150,47 @@ def __init__( name, fields: List[HsDataField] | int, version, + is_flexible=False, cons=None, doc=None, ): self.name = name - self.fields = fields + self.is_flexible = is_flexible + self._init_fields(fields) self.version = version self.doc = doc self._name = name + f"V{version}" self._cons = cons + f"V{version} " if cons else self._name + def _init_fields(self, fields): + self.fields = [] + self.tagged_fields = [] + + if isinstance(fields, int): + self.fields = fields + else: + for f in fields: + if f.is_tagged: + self.tagged_fields.append(f) + else: + self.fields.append(f) + def format(self): if isinstance(self.fields, int): # Maybe unused return f"type {self._name} = {self.name}V{self.fields}" + # TODO: tagged_fields + # + # FIXME: + # + # 1. We assume that tagged_fields is always the last field + # 2. We assume that flexible message always has tagged_fields + if self.tagged_fields or self.is_flexible: + self.fields.append( + HsDataField("taggedFields", "!TaggedFields", is_tagged=True) + ) + if len(self.fields) == 0: data_type = f"data {self._name} = {self._cons}" data_fields = " " @@ -173,7 +226,11 @@ def append_hs_datas(datas: List[HsData], data: HsData): assert data.version != data_.version # Use the first same type is OK - if not same_found and data.fields == data_.fields: + if ( + not same_found + and data.is_flexible == data_.is_flexible # noqa: W503 + and data.fields == data_.fields # noqa: W503 + ): # An int mean use this fileds instead DATA_TYPE_RENAMES[ f"{data.name}V{data.version}" @@ -216,46 +273,78 @@ def lines(): def in_version_range(version, min_version, max_version): - if min_version is None: - raise Exception("min version is None") - if min_version > version: + if min_version is None or min_version > version: return False if max_version is not None and max_version < version: return False return True +# https://github.com/apache/kafka/blob/3.5.1/generator/src/main/java/org/apache/kafka/message/ApiMessageTypeGenerator.java#L329 +def get_header_version(v, api): + resp_version = None + req_version = None + + in_flex = in_version_range(v, api.min_flex_version, api.max_flex_version) + if in_flex: + req_version, resp_version = 2, 1 + else: + req_version, resp_version = 1, 0 + + # Hardcoded Exception: ApiVersionsResponse always includes a v0 header + if api.api_key == 18: + resp_version = 0 + + # Hardcoded Exception: + # + # Version 0 of ControlledShutdownRequest has a non-standard request header + # which does not include clientId. Version 1 of ControlledShutdownRequest + # and later use the standard request header. + if api.api_key == 7: + if v == 0: + req_version = 0 + + assert req_version is not None + assert resp_version is not None + return (req_version, resp_version) + + # ----------------------------------------------------------------------------- # Parsers def parse_version(spec): + # e.g. 2 -> (2, 2) match = re.match(r"^(?P\d+)$", spec) if match: min_ = int(match.group("min")) return min_, min_ + # e.g. 2+ -> (2, None) match = re.match(r"^(?P\d+)\+$", spec) if match: min_ = int(match.group("min")) return min_, None + # e.g. 1-2 -> (1, 2) match = re.match(r"^(?P\d+)\-(?P\d+)$", spec) if match: min_ = int(match.group("min")) max_ = int(match.group("max")) return min_, max_ + # invlid return None, None -def parse_field(field, api_version=0): +def parse_field(field, api_version=0, flexible=False): about = field.get("about") # TODO name = RENAMES.get(field["name"], field["name"]) type_type = field["type"] type_name = None type_maps = TYPE_MAPS with_extra_version_suffix = False + is_tagged = False # Versions min_field_version, max_field_version = parse_version( @@ -264,28 +353,48 @@ def parse_field(field, api_version=0): min_tagged_version, max_tagged_version = parse_version( field.get("taggedVersions", "") ) + min_null_version, max_null_version = parse_version( + field.get("nullableVersions", "") + ) + # field has no "versions" if min_field_version is None: - if in_version_range( - api_version, min_tagged_version, max_tagged_version - ): - raise NotImplementedError("taggedVersions") - return + # a "taggedVersions" must be present + assert min_tagged_version is not None - if not in_version_range(api_version, min_field_version, max_field_version): - return - if nullableVersions := field.get("nullableVersions"): - min_null_version, max_null_version = parse_version(nullableVersions) - if in_version_range(api_version, min_null_version, max_null_version): - type_maps = NULLABLE_TYPE_MAPS + in_api_version = in_version_range( + api_version, min_field_version, max_field_version + ) + in_null_version = in_version_range( + api_version, min_null_version, max_null_version + ) + in_tagged_version = in_version_range( + api_version, min_tagged_version, max_tagged_version + ) - # TODO + if (in_api_version, in_tagged_version) == (False, False): + return + elif (in_api_version, in_tagged_version) == (True, False): + pass + elif (in_api_version, in_tagged_version) == (False, True): + raise NotImplementedError("Only taggedVersions") + elif (in_api_version, in_tagged_version) == (True, True): + is_tagged = True + + # Note that tagged fields can only be added to "flexible" message versions. + if min_tagged_version is not None: + assert flexible + + if (flexible, in_null_version) == (True, True): + type_maps = {**type_maps, **COMPACT_NULLABLE_TYPE_MAPS} + elif (flexible, in_null_version) == (True, False): + type_maps = {**type_maps, **COMPACT_TYPE_MAPS} + elif (flexible, in_null_version) == (False, True): + type_maps = {**type_maps, **NULLABLE_TYPE_MAPS} + + # XXX, maybe unused since flexibleVersions should not in field level (?) if "flexibleVersions" in field: - raise NotImplementedError("flexibleVersions") - - # TODO - if "taggedVersions" in field: - raise NotImplementedError("taggedVersions") + raise NotImplementedError("flexibleVersions in field!") # Error code if name == "ErrorCode" and type_type == "int16": @@ -316,10 +425,15 @@ def parse_field(field, api_version=0): data_sub_fields = list( filter( None, - (parse_field(f, api_version=api_version) for f in sub_fields), + ( + parse_field(f, api_version=api_version, flexible=flexible) + for f in sub_fields + ), ) ) - hs_data = HsData(type_name, data_sub_fields, api_version) + hs_data = HsData( + type_name, data_sub_fields, api_version, is_flexible=flexible + ) append_hs_datas(SUB_DATA_TYPES, hs_data) data_name = lower_fst(name) @@ -328,19 +442,26 @@ def parse_field(field, api_version=0): _type_name = f"{type_name}V{api_version}" type_name = DATA_TYPE_RENAMES.get(_type_name, _type_name) data_type = type_maps[type_type].format(type_name) - data_field = HsDataField(data_name, data_type, doc=about) + data_field = HsDataField( + data_name, + data_type, + doc=about, + is_tagged=is_tagged, + ) return data_field def parse(msg): api_key = msg["apiKey"] min_api_version, max_api_version = parse_version(msg["validVersions"]) - flex_min_version, flex_max_version = parse_version(msg["flexibleVersions"]) + min_flex_version, max_flex_version = parse_version(msg["flexibleVersions"]) name = msg["name"] fields = msg["fields"] api_type = msg["type"] api_name = name.removesuffix(upper_fst(api_type)) + assert api_type in ["request", "response"] + # Get api_version (glo_min_api_version, glo_max_api_version) = GLOBAL_API_VERSION_PATCH if glo_min_api_version > min_api_version: @@ -351,29 +472,41 @@ def parse(msg): min_api_version = api_version_patch[0] max_api_version = api_version_patch[1] - # TODO - if flex_min_version in range( - min_api_version, max_api_version + 1 - ) or flex_max_version in range(min_api_version, max_api_version + 1): - raise NotImplementedError("flexibleVersions") - - api = (api_key, min_api_version, max_api_version, api_name) - for k, minv, maxv, _ in API_VERSIONS: - if api_key == k: - assert min_api_version == minv - assert max_api_version == maxv - API_VERSIONS.add(api) + for api in API_VERSIONS: + if api_key == api.api_key: + assert min_api_version == api.min_version + assert max_api_version == api.max_version + + API_VERSIONS.add( + ApiVersion( + api_key=api_key, + api_name=api_name, + min_version=min_api_version, + max_version=max_api_version, + min_flex_version=min_flex_version, + max_flex_version=max_flex_version, + ) + ) for v in range(min_api_version, max_api_version + 1): - fs = list(filter(None, (parse_field(f, api_version=v) for f in fields))) - hs_data = HsData(name, fs, v) + flexible = in_version_range(v, min_flex_version, max_flex_version) + fs = list( + filter( + None, + ( + parse_field(f, api_version=v, flexible=flexible) + for f in fields + ), + ) + ) + hs_data = HsData(name, fs, v, is_flexible=flexible) append_hs_datas(DATA_TYPES, hs_data) # ----------------------------------------------------------------------------- -def gen_header(): +def gen_haskell_header(): return """ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -416,8 +549,10 @@ def gen_api_keys(): """ ) api_keys.append("instance Show ApiKey where") - for k, _, _, n in sorted(API_VERSIONS): - api_keys.append(f' show (ApiKey {k}) = "{n}({k})"') + for api in sorted(API_VERSIONS): + api_keys.append( + f' show (ApiKey {api.api_key}) = "{api.api_name}({api.api_key})"' + ) api_keys.append(' show (ApiKey n) = "Unknown " <> show n') return "\n".join(api_keys) @@ -427,8 +562,8 @@ def gen_supported_api_versions(): result += "supportedApiVersions =\n" result += format_hs_list( ( - f"ApiVersionV0 (ApiKey {k}) {minv} {maxv}" - for (k, minv, maxv, _) in sorted(API_VERSIONS) + f"ApiVersionV0 (ApiKey {api.api_key}) {api.min_version} {api.max_version}" + for api in sorted(API_VERSIONS) ), indent=2, ) @@ -436,28 +571,28 @@ def gen_supported_api_versions(): def gen_services(): - max_api_version = max(x[2] for x in API_VERSIONS) + max_api_version = max(x.max_version for x in API_VERSIONS) services = [] srv_methods = lambda v: format_hs_list( ( - '"' + lower_fst(n) + '"' - for (k, _, max_v, n) in sorted(API_VERSIONS) - if v <= max_v + '"' + lower_fst(api.api_name) + '"' + for api in sorted(API_VERSIONS) + if v <= api.max_version ), indent=4, prefix="'", ) method_impl_ins = lambda v: "\n".join( f"""\ -instance HasMethodImpl {srv_name} "{lower_fst(n)}" where - type MethodName {srv_name} "{lower_fst(n)}" = "{lower_fst(n)}" - type MethodKey {srv_name} "{lower_fst(n)}" = {k} - type MethodVersion {srv_name} "{lower_fst(n)}" = {v} - type MethodInput {srv_name} "{lower_fst(n)}" = {n}RequestV{v} - type MethodOutput {srv_name} "{lower_fst(n)}" = {n}ResponseV{v} +instance HasMethodImpl {srv_name} "{lower_fst(api.api_name)}" where + type MethodName {srv_name} "{lower_fst(api.api_name)}" = "{lower_fst(api.api_name)}" + type MethodKey {srv_name} "{lower_fst(api.api_name)}" = {api.api_key} + type MethodVersion {srv_name} "{lower_fst(api.api_name)}" = {v} + type MethodInput {srv_name} "{lower_fst(api.api_name)}" = {api.api_name}RequestV{v} + type MethodOutput {srv_name} "{lower_fst(api.api_name)}" = {api.api_name}ResponseV{v} """ - for (k, _, max_v, n) in sorted(API_VERSIONS) - if v <= max_v + for api in sorted(API_VERSIONS) + if v <= api.max_version ) # for all supported api_version @@ -477,9 +612,23 @@ def gen_services(): return "".join(services).strip() +def gen_api_header_version(): + hs_type = "getHeaderVersion :: ApiKey -> Int16 -> (Int16, Int16)" + hs_impl = "\n".join( + f"getHeaderVersion (ApiKey {api.api_key}) {v} = {get_header_version(v, api)}" + for api in sorted(API_VERSIONS) + for v in range(api.min_version, api.max_version + 1) + ) + hs_math_other = ( + 'getHeaderVersion k v = error $ "Unknown " <> show k <> " v" <> show v' + ) + hs_inline = "{-# INLINE getHeaderVersion #-}" + return f"{hs_type}\n{hs_impl}\n{hs_math_other}\n{hs_inline}" + + def gen_struct(): return f""" -{gen_header()} +{gen_haskell_header()} \ {gen_splitter()} \ @@ -498,6 +647,8 @@ def gen_struct(): {gen_api_keys()} {gen_supported_api_versions()} + +{gen_api_header_version()} """.strip() @@ -550,19 +701,26 @@ def cli_get_json(path): action="store_true", help="Don't run stylish-haskell to format the result", ) + parser_run.add_argument( + "--dry-run", + action="store_true", + help="perform a trial run with no outputs", + ) argcomplete.autocomplete(parser) args = parser.parse_args() if args.sub_command == "run": run_parse(args.files) - if not args.no_format: - result = subprocess.run( - "stylish-haskell", - input=gen_struct().encode(), - stdout=subprocess.PIPE, - ) - if result and result.stdout: - print(result.stdout.decode().strip()) - else: - print(gen_struct()) + outputs = gen_struct() + if not args.dry_run: + if not args.no_format: + result = subprocess.run( + "stylish-haskell", + input=outputs.encode(), + stdout=subprocess.PIPE, + ) + if result and result.stdout: + print(result.stdout.decode().strip()) + else: + print(outputs)