Skip to content

Commit

Permalink
allows for built-in ast unparse if present
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Nov 25, 2024
1 parent e6b5b02 commit b30e549
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 28 deletions.
13 changes: 6 additions & 7 deletions dlt/cli/deploy_command_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from yaml import Dumper
from itertools import chain
from typing import List, Optional, Sequence, Tuple, Any, Dict
from astunparse import unparse

# optional dependencies
import pipdeptree
Expand All @@ -23,7 +22,7 @@
from dlt.common.git import get_origin, get_repo, Repo
from dlt.common.configuration.specs.runtime_configuration import get_default_pipeline_name
from dlt.common.typing import StrAny
from dlt.common.reflection.utils import evaluate_node_literal
from dlt.common.reflection.utils import evaluate_node_literal, ast_unparse
from dlt.common.pipeline import LoadInfo, TPipelineState, get_dlt_repos_dir
from dlt.common.storages import FileStorage
from dlt.common.utils import set_working_dir
Expand Down Expand Up @@ -313,7 +312,7 @@ def parse_pipeline_info(visitor: PipelineScriptVisitor) -> List[Tuple[str, Optio
if f_r_value is None:
fmt.warning(
"The value of `dev_mode` in call to `dlt.pipeline` cannot be"
f" determined from {unparse(f_r_node).strip()}. We assume that you know"
f" determined from {ast_unparse(f_r_node).strip()}. We assume that you know"
" what you are doing :)"
)
if f_r_value is True:
Expand All @@ -331,8 +330,8 @@ def parse_pipeline_info(visitor: PipelineScriptVisitor) -> List[Tuple[str, Optio
raise CliCommandInnerException(
"deploy",
"The value of 'pipelines_dir' argument in call to `dlt_pipeline` cannot be"
f" determined from {unparse(p_d_node).strip()}. Pipeline working dir will"
" be found. Pass it directly with --pipelines-dir option.",
f" determined from {ast_unparse(p_d_node).strip()}. Pipeline working dir"
" will be found. Pass it directly with --pipelines-dir option.",
)

p_n_node = call_args.arguments.get("pipeline_name")
Expand All @@ -342,8 +341,8 @@ def parse_pipeline_info(visitor: PipelineScriptVisitor) -> List[Tuple[str, Optio
raise CliCommandInnerException(
"deploy",
"The value of 'pipeline_name' argument in call to `dlt_pipeline` cannot be"
f" determined from {unparse(p_d_node).strip()}. Pipeline working dir will"
" be found. Pass it directly with --pipeline-name option.",
f" determined from {ast_unparse(p_d_node).strip()}. Pipeline working dir"
" will be found. Pass it directly with --pipeline-name option.",
)
pipelines.append((pipeline_name, pipelines_dir))

Expand Down
5 changes: 2 additions & 3 deletions dlt/cli/source_detection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import ast
import inspect
from astunparse import unparse
from typing import Dict, Tuple, Set, List

from dlt.common.configuration import is_secret_hint
from dlt.common.configuration.specs import BaseConfiguration
from dlt.common.reflection.utils import creates_func_def_name_node
from dlt.common.reflection.utils import creates_func_def_name_node, ast_unparse
from dlt.common.typing import is_optional_type

from dlt.sources import SourceReference
Expand Down Expand Up @@ -65,7 +64,7 @@ def find_source_calls_to_replace(
for calls in visitor.known_sources_resources_calls.values():
for call in calls:
transformed_nodes.append(
(call.func, ast.Name(id=pipeline_name + "_" + unparse(call.func)))
(call.func, ast.Name(id=pipeline_name + "_" + ast_unparse(call.func)))
)

return transformed_nodes
Expand Down
14 changes: 10 additions & 4 deletions dlt/common/reflection/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import ast
import inspect
import astunparse
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Callable

try:
import astunparse

ast_unparse: Callable[[ast.AST], str] = astunparse.unparse
except ImportError:
ast_unparse = ast.unparse # type: ignore[attr-defined, unused-ignore]

from dlt.common.typing import AnyFun

Expand All @@ -25,7 +31,7 @@ def get_literal_defaults(node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) ->
literal_defaults: Dict[str, str] = {}
for arg, default in zip(reversed(args), reversed(defaults)):
if default:
literal_defaults[str(arg.arg)] = astunparse.unparse(default).strip()
literal_defaults[str(arg.arg)] = ast_unparse(default).strip()

return literal_defaults

Expand Down Expand Up @@ -99,7 +105,7 @@ def rewrite_python_script(
script_lines.append(source_script_lines[last_line][last_offset : node.col_offset])

# replace node value
script_lines.append(astunparse.unparse(t_value).strip())
script_lines.append(ast_unparse(t_value).strip())
last_line = node.end_lineno - 1
last_offset = node.end_col_offset

Expand Down
4 changes: 0 additions & 4 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
from contextlib import contextmanager

from dlt import version

from dlt.common.json import json

from dlt.common.normalizers.naming.naming import NamingConvention
from dlt.common.exceptions import MissingDependencyException

from dlt.common.destination import AnyDestination
from dlt.common.destination.reference import (
SupportsReadableRelation,
Expand Down
3 changes: 1 addition & 2 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TPipelineStateDoc,
load_package as current_load_package,
)
from dlt.destinations.sql_client import DBApiCursor, WithSqlClient, SqlClientBase
from dlt.destinations.sql_client import WithSqlClient, SqlClientBase
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import (
FollowupJobRequest,
Expand All @@ -63,7 +63,6 @@
from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration
from dlt.destinations import path_utils
from dlt.destinations.fs_client import FSClientBase
from dlt.destinations.dataset import ReadableDBAPIDataset
from dlt.destinations.utils import verify_schema_merge_disposition

INIT_FILE_NAME = "init"
Expand Down
9 changes: 4 additions & 5 deletions dlt/reflection/script_visitor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import inspect
import ast
import astunparse
from ast import NodeVisitor
from typing import Any, Dict, List
from dlt.common.reflection.utils import find_outer_func_def

from dlt.common.reflection.utils import find_outer_func_def, ast_unparse

import dlt.reflection.names as n

Expand Down Expand Up @@ -68,9 +67,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
for deco in node.decorator_list:
# decorators can be function calls, attributes or names
if isinstance(deco, (ast.Name, ast.Attribute)):
alias_name = astunparse.unparse(deco).strip()
alias_name = ast_unparse(deco).strip()
elif isinstance(deco, ast.Call):
alias_name = astunparse.unparse(deco.func).strip()
alias_name = ast_unparse(deco.func).strip()
else:
raise ValueError(
self.source_segment(deco), type(deco), "Unknown decorator form"
Expand All @@ -87,7 +86,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
def visit_Call(self, node: ast.Call) -> Any:
if self._curr_pass == 2:
# check if this is a call to any of known functions
alias_name = astunparse.unparse(node.func).strip()
alias_name = ast_unparse(node.func).strip()
fn = self.func_aliases.get(alias_name)
if not fn:
# try a fallback to "run" function that may be called on pipeline or source
Expand Down
5 changes: 2 additions & 3 deletions dlt/sources/sql_database/arrow_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

from dlt.common.configuration import with_config
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.libs.pyarrow import (
row_tuples_to_arrow as _row_tuples_to_arrow,
)


@with_config
Expand All @@ -20,6 +17,8 @@ def row_tuples_to_arrow(
is always the case if run within the pipeline. This will generate arrow schema compatible with the destination.
Otherwise generic capabilities are used
"""
from dlt.common.libs.pyarrow import row_tuples_to_arrow as _row_tuples_to_arrow

return _row_tuples_to_arrow(
rows, caps or DestinationCapabilitiesContext.generic_capabilities(), columns, tz
)

0 comments on commit b30e549

Please sign in to comment.