Skip to content

Commit

Permalink
Merge branch 'main' into behavior-flags
Browse files Browse the repository at this point in the history
  • Loading branch information
mikealfare authored Aug 24, 2024
2 parents 2d7655a + 35af654 commit 60a004a
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 94 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Security-20240808-154439.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Security
body: Fix arbitrary file write during tarfile extraction
time: 2024-08-08T15:44:39.601346-05:00
custom:
Author: aranke
PR: "182"
4 changes: 2 additions & 2 deletions dbt_common/clients/agate_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from codecs import BOM_UTF8

import agate # type: ignore
import agate
import datetime
import isodate
import json
Expand Down Expand Up @@ -149,7 +149,7 @@ def as_matrix(table):
return [r.values() for r in table.rows.values()]


def from_csv(abspath, text_columns, delimiter=","):
def from_csv(abspath, text_columns, delimiter=",") -> agate.Table:
type_tester = build_type_tester(text_columns=text_columns)
with open(abspath, encoding="utf-8") as fp:
if fp.read(1) != BOM:
Expand Down
14 changes: 8 additions & 6 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type
from typing_extensions import Protocol

import jinja2 # type: ignore
import jinja2.ext # type: ignore
import jinja2.nativetypes # type: ignore
import jinja2.nodes # type: ignore
import jinja2.parser # type: ignore
import jinja2.sandbox # type: ignore
import jinja2
import jinja2.ext
import jinja2.nativetypes
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox

from dbt_common.tests import test_caching_enabled
from dbt_common.utils.jinja import (
Expand Down Expand Up @@ -124,6 +124,7 @@ def new_context(
"shared or locals parameters."
)

vars = {} if vars is None else vars
parent = ChainMap(vars, self.globals) if self.globals else vars

return self.environment.context_class(self.environment, parent, self.name, self.blocks)
Expand Down Expand Up @@ -544,6 +545,7 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str:

def _get_blocks_hash(text: str, allowed_blocks: Optional[Set[str]], collect_raw_data: bool) -> int:
"""Provides a hash function over the arguments to extract_toplevel_blocks, in order to support caching."""
allowed_blocks = allowed_blocks or set()
allowed_tuple = tuple(sorted(allowed_blocks) or [])
return text.__hash__() + allowed_tuple.__hash__() + collect_raw_data.__hash__()

Expand Down
28 changes: 27 additions & 1 deletion dbt_common/clients/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class FindMatchingParams:
root_path: str
relative_paths_to_search: List[str]
file_pattern: str

# ignore_spec: Optional[PathSpec] = None

def __init__(
Expand Down Expand Up @@ -608,11 +609,36 @@ def rename(from_path: str, to_path: str, force: bool = False) -> None:
shutil.move(from_path, to_path)


def safe_extract(tarball: tarfile.TarFile, path: str = ".") -> None:
"""
Fix for CWE-22: Improper Limitation of a Pathname to a Restricted Directory ('Path Traversal')
Solution copied from https://github.com/mindsdb/mindsdb/blob/main/mindsdb/utilities/fs.py
"""

def _is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

# for py >= 3.12
if hasattr(tarball, "data_filter"):
tarball.extractall(path, filter="data")
else:
members = tarball.getmembers()
for member in members:
member_path = os.path.join(path, member.name)
if not _is_within_directory(path, member_path):
raise tarfile.OutsideDestinationError(member, path)

tarball.extractall(path, members=members)


def untar_package(tar_path: str, dest_dir: str, rename_to: Optional[str] = None) -> None:
tar_path = convert_path(tar_path)
tar_dir_name = None
with tarfile.open(tar_path, "r:gz") as tarball:
tarball.extractall(dest_dir)
safe_extract(tarball, dest_dir)
tar_dir_name = os.path.commonprefix(tarball.getnames())
if rename_to:
downloaded_path = os.path.join(dest_dir, tar_dir_name)
Expand Down
11 changes: 8 additions & 3 deletions dbt_common/contracts/util.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import dataclasses
from typing import Any
from typing import Any, TypeVar

_R = TypeVar("_R", bound="Replaceable")


# TODO: remove from dbt_common.contracts.util:: Replaceable + references
class Replaceable:
def replace(self, **kwargs: Any):
def replace(self: _R, **kwargs: Any) -> _R:
return dataclasses.replace(self, **kwargs) # type: ignore


_M = TypeVar("_M", bound="Mergeable")


class Mergeable(Replaceable):
def merged(self, *args):
def merged(self: _M, *args: Any) -> _M:
"""Perform a shallow merge, where the last non-None write wins. This is
intended to merge dataclasses that are a collection of optional values.
"""
Expand Down
7 changes: 3 additions & 4 deletions dbt_common/dataclass_schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple
from typing import Any, ClassVar, Dict, get_type_hints, List, Optional, Tuple, Union
import re
import jsonschema
from dataclasses import fields, Field
from enum import Enum
from datetime import datetime
from dateutil.parser import parse

# type: ignore
from mashumaro.config import (
TO_DICT_ADD_OMIT_NONE_FLAG,
ADD_SERIALIZATION_CONTEXT,
Expand All @@ -33,8 +32,8 @@ def serialize(self, value: datetime) -> str:
out += "Z"
return out

def deserialize(self, value) -> datetime:
return value if isinstance(value, datetime) else parse(cast(str, value))
def deserialize(self, value: Union[datetime, str]) -> datetime:
return value if isinstance(value, datetime) else parse(value)


class dbtMashConfig(MashBaseConfig):
Expand Down
4 changes: 2 additions & 2 deletions dbt_common/exceptions/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class WorkingDirectoryError(CommandError):
def __init__(self, cwd: str, cmd: List[str], msg: str) -> None:
super().__init__(cwd, cmd, msg)

def __str__(self):
def __str__(self, prefix: str = "! ") -> str:
return f'{self.msg}: "{self.cwd}"'


Expand All @@ -46,5 +46,5 @@ def __init__(
self.stderr = scrub_secrets(stderr.decode("utf-8"), env_secrets())
self.args = (cwd, self.cmd, returncode, self.stdout, self.stderr, msg)

def __str__(self):
def __str__(self, prefix: str = "! ") -> str:
return f"{self.msg} running: {self.cmd}"
4 changes: 2 additions & 2 deletions dbt_common/semver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
import re
from typing import List, Iterable
from typing import Iterable, List, Union

import dbt_common.exceptions.base
from dbt_common.exceptions import VersionsNotCompatibleError
Expand Down Expand Up @@ -378,7 +378,7 @@ def is_exact(self) -> bool:
return False


def reduce_versions(*args):
def reduce_versions(*args: Union[VersionSpecifier, VersionRange, str]) -> VersionRange:
version_specifiers = []

for version in args:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_connection_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from dbt_common.utils.connection import connection_exception_retry


def no_retry_fn():
def no_retry_fn() -> str:
return "success"


class TestNoRetries:
def test_no_retry(self):
def test_no_retry(self) -> None:
fn_to_retry = functools.partial(no_retry_fn)
result = connection_exception_retry(fn_to_retry, 3)

Expand Down
36 changes: 19 additions & 17 deletions tests/unit/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
from typing import Any, Dict
from typing import Any, Dict, List

import pytest
from dbt_common.record import Diff

Case = List[Dict[str, Any]]


@pytest.fixture
def current_query():
def current_query() -> Case:
return [
{
"params": {
Expand All @@ -21,7 +23,7 @@ def current_query():


@pytest.fixture
def query_modified_order():
def query_modified_order() -> Case:
return [
{
"params": {
Expand All @@ -36,7 +38,7 @@ def query_modified_order():


@pytest.fixture
def query_modified_value():
def query_modified_value() -> Case:
return [
{
"params": {
Expand All @@ -51,7 +53,7 @@ def query_modified_value():


@pytest.fixture
def current_simple():
def current_simple() -> Case:
return [
{
"params": {
Expand All @@ -65,7 +67,7 @@ def current_simple():


@pytest.fixture
def current_simple_modified():
def current_simple_modified() -> Case:
return [
{
"params": {
Expand All @@ -79,7 +81,7 @@ def current_simple_modified():


@pytest.fixture
def env_record():
def env_record() -> Case:
return [
{
"params": {},
Expand All @@ -94,7 +96,7 @@ def env_record():


@pytest.fixture
def modified_env_record():
def modified_env_record() -> Case:
return [
{
"params": {},
Expand All @@ -108,30 +110,30 @@ def modified_env_record():
]


def test_diff_query_records_no_diff(current_query, query_modified_order):
def test_diff_query_records_no_diff(current_query: Case, query_modified_order: Case) -> None:
# Setup: Create an instance of Diff
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
result = diff_instance.diff_query_records(current_query, query_modified_order)
# the order changed but the diff should be empty
expected_result = {}
expected_result: Dict[str, Any] = {}
assert result == expected_result # Replace expected_result with what you actually expect


def test_diff_query_records_with_diff(current_query, query_modified_value):
def test_diff_query_records_with_diff(current_query: Case, query_modified_value: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
result = diff_instance.diff_query_records(current_query, query_modified_value)
# the values changed this time
expected_result = {
expected_result: Dict[str, Any] = {
"values_changed": {"root[0]['result']['table'][1]['b']": {"new_value": 7, "old_value": 10}}
}
assert result == expected_result


def test_diff_env_records(env_record, modified_env_record):
def test_diff_env_records(env_record: Case, modified_env_record: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
Expand All @@ -147,17 +149,17 @@ def test_diff_env_records(env_record, modified_env_record):
assert result == expected_result


def test_diff_default_no_diff(current_simple):
def test_diff_default_no_diff(current_simple: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
# use the same list to ensure no diff
result = diff_instance.diff_default(current_simple, current_simple)
expected_result = {}
expected_result: Dict[str, Any] = {}
assert result == expected_result


def test_diff_default_with_diff(current_simple, current_simple_modified):
def test_diff_default_with_diff(current_simple: Case, current_simple_modified: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
Expand All @@ -170,7 +172,7 @@ def test_diff_default_with_diff(current_simple, current_simple_modified):

# Mock out reading the files so we don't have to
class MockFile:
def __init__(self, json_data):
def __init__(self, json_data) -> None:
self.json_data = json_data

def __enter__(self):
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def valid_error_names() -> Set[str]:


class TestWarnOrError:
def test_fires_error(self, valid_error_names: Set[str]):
def test_fires_error(self, valid_error_names: Set[str]) -> None:
functions.WARN_ERROR_OPTIONS = WarnErrorOptions(
include="*", valid_error_names=valid_error_names
)
Expand All @@ -49,8 +49,8 @@ def test_fires_warning(
self,
valid_error_names: Set[str],
event_catcher: EventCatcher,
set_event_manager_with_catcher,
):
set_event_manager_with_catcher: None,
) -> None:
functions.WARN_ERROR_OPTIONS = WarnErrorOptions(
include="*", exclude=list(valid_error_names), valid_error_names=valid_error_names
)
Expand All @@ -62,8 +62,8 @@ def test_silenced(
self,
valid_error_names: Set[str],
event_catcher: EventCatcher,
set_event_manager_with_catcher,
):
set_event_manager_with_catcher: None,
) -> None:
functions.WARN_ERROR_OPTIONS = WarnErrorOptions(
include="*", silence=list(valid_error_names), valid_error_names=valid_error_names
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_incomplete_block_failure(self) -> None:
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body, allowed_blocks={"myblock"})

def test_wrong_end_failure(self):
def test_wrong_end_failure(self) -> None:
body = "{% myblock foo %} {% endotherblock %}"
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"})
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_system_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,18 @@ def test_untar_package_empty(self) -> None:
with self.assertRaises(tarfile.ReadError) as exc:
dbt_common.clients.system.untar_package(named_file.name, self.tempdest)
self.assertEqual("empty file", str(exc.exception))

def test_untar_package_outside_directory(self) -> None:
with NamedTemporaryFile(
prefix="my-package.2", suffix=".tar.gz", dir=self.tempdir, delete=False
) as named_tar_file:
tar_file_full_path = named_tar_file.name
with NamedTemporaryFile(prefix="a", suffix=".txt", dir=self.tempdir) as file_a:
file_a.write(b"some text in the text file")
relative_file_a = "/../" + os.path.basename(file_a.name)
with tarfile.open(fileobj=named_tar_file, mode="w:gz") as tar:
tar.addfile(tarfile.TarInfo(relative_file_a), open(file_a.name))

assert tarfile.is_tarfile(tar.name)
with self.assertRaises(tarfile.OutsideDestinationError):
dbt_common.clients.system.untar_package(tar_file_full_path, self.tempdest)
Loading

0 comments on commit 60a004a

Please sign in to comment.