Skip to content

Commit

Permalink
Fixed a couple bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva authored and drdavella committed Nov 29, 2023
1 parent b77acd6 commit f097f4f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def _parse_dependencies_from_toml(self, toml_data: dict):
# 1. no dependencies
return self._parse_dependencies(toml_data["project"]["dependencies"])

def _parse_py_versions(self, toml_data: dict):
def _parse_py_versions(self, toml_data: dict) -> list:
# todo: handle cases for
# 1. no requires-python
# 2. multiple requires-python such as "">3.5.2"", ">=3.11.1,<3.11.2"
return [toml_data["project"]["requires-python"]]
# 1. multiple requires-python such as "">3.5.2"", ">=3.11.1,<3.11.2"
maybe_project = toml_data.get("project")
maybe_python = maybe_project.get("requires-python") if maybe_project else None
return [maybe_python] if maybe_python else []

def _parse_file(self, file: Path):
data = toml.load(file)
Expand Down
9 changes: 7 additions & 2 deletions src/core_codemods/secure_flask_session_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,20 @@ def call_node_with_secure_configs(
self, original_node: cst.Call, updated_node: cst.Call
) -> cst.Call:
new_args = []
changed = False
for arg in updated_node.args:
if (key := arg.keyword.value) in self.SECURE_SESSION_CONFIGS:
if (
arg.keyword
and (key := arg.keyword.value) in self.SECURE_SESSION_CONFIGS
):
# self._remove_config(key)
if true_value(arg.value) not in self.SECURE_SESSION_CONFIGS[key]: # type: ignore
safe_value = self._get_secure_config_val(key)
arg = arg.with_changes(value=safe_value)
changed = True
new_args.append(arg)

if updated_node.args != new_args:
if changed:
self.report_change(original_node)
return updated_node.with_changes(args=new_args)

Expand Down
14 changes: 14 additions & 0 deletions tests/codemods/test_secure_flask_session_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ def test_app_accessed_config_not_called(self, tmpdir):
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

def test_app_update_no_keyword(self, tmpdir):
input_code = """\
import flask
from flask import Flask
def foo(test_config=None):
app = Flask(__name__)
app.secret_key = "dev"
app.config.update(test_config)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
print(self.file_context.codemod_changes)
assert len(self.file_context.codemod_changes) == 0

def test_from_import(self, tmpdir):
input_code = """\
from flask import Flask
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ def pkg_with_pyproject_toml(tmp_path_factory):
return base_dir


@pytest.fixture(scope="module")
def pkg_with_pyproject_toml_no_python(tmp_path_factory):
base_dir = tmp_path_factory.mktemp("foo")
toml_file = base_dir / "pyproject.toml"
toml = """\
[project]
name = "js_example"
version = "1.1.0"
description = "Demonstrates making AJAX requests to Flask."
readme = "README.rst"
license = {file = "LICENSE.rst"}
maintainers = [{name = "Pallets", email = "[email protected]"}]
dependencies = ["flask"]
"""
toml_file.write_text(toml)
return base_dir


class TestPyprojectTomlParser:
def test_parse(self, pkg_with_pyproject_toml):
parser = PyprojectTomlParser(pkg_with_pyproject_toml)
Expand All @@ -42,6 +60,16 @@ def test_parse(self, pkg_with_pyproject_toml):
assert store.py_versions == [">=3.10.0"]
assert len(store.dependencies) == 6

def test_parse_no_python(self, pkg_with_pyproject_toml_no_python):
parser = PyprojectTomlParser(pkg_with_pyproject_toml_no_python)
found = parser.parse()
assert len(found) == 1
store = found[0]
assert store.type == "pyproject.toml"
assert store.file == str(pkg_with_pyproject_toml_no_python / parser.file_name)
assert store.py_versions == []
assert len(store.dependencies) == 1

def test_parse_no_file(self, pkg_with_pyproject_toml):
parser = PyprojectTomlParser(pkg_with_pyproject_toml / "foo")
found = parser.parse()
Expand Down

0 comments on commit f097f4f

Please sign in to comment.