diff --git a/tests/test_cli.py b/tests/test_cli.py index e3fe49ae8..de715403a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -149,9 +149,12 @@ def test_resolve_source(self): file = hello_path.joinpath("hello.xsd") url = "http://www.xsdata/schema.xsd" - self.assertEqual([file.as_uri()], list(resolve_source(str(file)))) - self.assertEqual([url], list(resolve_source(url))) - self.assertEqual(5, len(list(resolve_source(str(hello_path))))) + self.assertEqual([file.as_uri()], list(resolve_source(str(file), False))) + self.assertEqual([url], list(resolve_source(url, False))) + self.assertEqual(5, len(list(resolve_source(str(hello_path), False)))) def_xml_path = fixtures_dir.joinpath("calculator") - self.assertEqual(3, len(list(resolve_source(str(def_xml_path))))) + self.assertEqual(3, len(list(resolve_source(str(def_xml_path), False)))) + + actual = list(resolve_source(str(fixtures_dir), True)) + self.assertEqual(32, len(actual)) diff --git a/xsdata/cli.py b/xsdata/cli.py index 1f94a8ce6..2d27f6862 100644 --- a/xsdata/cli.py +++ b/xsdata/cli.py @@ -101,6 +101,13 @@ def download(source: str, output: str): @cli.command("generate") @click.argument("source", required=True) +@click.option( + "-r", + "--recursive", + is_flag=True, + default=False, + help="Search files recursively in the source directory", +) @click.option("-c", "--config", default=".xsdata.xml", help="Project configuration") @click.option("-pp", "--print", is_flag=True, default=False, help="Print output") @model_options(GeneratorOutput) @@ -114,6 +121,7 @@ def generate(**kwargs: Any): """ source = kwargs.pop("source") stdout = kwargs.pop("print") + recursive = kwargs.pop("recursive") config_file = Path(kwargs.pop("config")).resolve() params = {k.replace("__", "."): v for k, v in kwargs.items() if v is not None} @@ -121,21 +129,20 @@ def generate(**kwargs: Any): config.output.update(**params) transformer = SchemaTransformer(config=config, print=stdout) - transformer.process(list(resolve_source(source))) + transformer.process(list(resolve_source(source, recursive=recursive))) handler.emit_warnings() -def resolve_source(source: str) -> Iterator[str]: +def resolve_source(source: str, recursive: bool) -> Iterator[str]: if source.find("://") > -1 and not source.startswith("file://"): yield source else: path = Path(source).resolve() + match = "**/*" if recursive else "*" if path.is_dir(): - yield from (x.as_uri() for x in path.glob("*.wsdl")) - yield from (x.as_uri() for x in path.glob("*.xsd")) - yield from (x.as_uri() for x in path.glob("*.xml")) - yield from (x.as_uri() for x in path.glob("*.json")) + for ext in ["wsdl", "xsd", "xml", "json"]: + yield from (x.as_uri() for x in path.glob(f"{match}.{ext}")) else: # is file yield path.as_uri()