Skip to content

Commit

Permalink
better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Nov 21, 2023
1 parent ec2e0cc commit 3bbbfee
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 11 deletions.
17 changes: 17 additions & 0 deletions src/codemodder/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,20 @@ def true_value(node: cst.Name | cst.SimpleString) -> str | int | bool:
return False
return val
return ""


def extract_targets_of_assignment(
assignment: cst.AnnAssign | cst.Assign | cst.WithItem | cst.NamedExpr,
) -> list[cst.BaseExpression]:
match assignment:
case cst.AnnAssign():
if assignment.target:
return [assignment.target]
case cst.Assign():
return [t.target for t in assignment.targets]
case cst.NamedExpr():
return [assignment.target]
case cst.WithItem():
if assignment.asname:
return [assignment.asname.name]
return []
22 changes: 13 additions & 9 deletions src/core_codemods/secure_flask_session_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import BaseCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.utils.utils import true_value
from codemodder.utils.utils import extract_targets_of_assignment, true_value
from codemodder.codemods.base_visitor import BaseTransformer
from codemodder.change import Change
from codemodder.file_context import FileContext
Expand Down Expand Up @@ -103,8 +103,12 @@ def _store_flask_app(self, original_node) -> None:
flask_app_parent = self.get_metadata(ParentNodeProvider, original_node)
match flask_app_parent:
case cst.AnnAssign() | cst.Assign():
flask_app_attr = flask_app_parent.targets[0].target
self.flask_app_name = flask_app_attr.value
targets = extract_targets_of_assignment(flask_app_parent)
# TODO: handle other assignments ex. l[0] = Flask(...) , a.b = Flask(...)
if targets and matchers.matches(
first_target := targets[0], matchers.Name()
):
self.flask_app_name = first_target.value

# def _remove_config(self, key):
# try:
Expand Down Expand Up @@ -163,18 +167,18 @@ def assign_node_with_secure_config(
return updated_node

def _is_config_update_call(self, original_node: cst.Call):
config = cst.Name(value="config")
app_name = cst.Name(value=self.flask_app_name)
app_config_node = cst.Attribute(value=app_name, attr=config)
config = matchers.Name(value="config")
app_name = matchers.Name(value=self.flask_app_name)
app_config_node = matchers.Attribute(value=app_name, attr=config)
update = cst.Name(value="update")
return matchers.matches(
original_node.func, matchers.Attribute(value=app_config_node, attr=update)
)

def _is_config_subscript(self, original_node: cst.Assign):
config = cst.Name(value="config")
app_name = cst.Name(value=self.flask_app_name)
app_config_node = cst.Attribute(value=app_name, attr=config)
config = matchers.Name(value="config")
app_name = matchers.Name(value=self.flask_app_name)
app_config_node = matchers.Attribute(value=app_name, attr=config)
return matchers.matches(
original_node.targets[0].target, matchers.Subscript(value=app_config_node)
)
Expand Down
36 changes: 34 additions & 2 deletions tests/codemods/test_secure_flask_session_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def test_from_import(self, tmpdir):
input_code = """\
from flask import Flask
app = flask.Flask(__name__)
app = Flask(__name__)
app.secret_key = "dev"
app.config.update(SESSION_COOKIE_SECURE=False)
"""
expexted_output = """\
from flask import Flask
app = flask.Flask(__name__)
app = Flask(__name__)
app.secret_key = "dev"
app.config.update(SESSION_COOKIE_SECURE=True)
"""
Expand All @@ -88,6 +88,38 @@ def test_import_alias(self, tmpdir):
self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output))
assert len(self.file_context.codemod_changes) == 1

def test_annotated_assign(self, tmpdir):
input_code = """\
import flask
app: flask.Flask = flask.Flask(__name__)
app.secret_key = "dev"
# more code
app.config.update(SESSION_COOKIE_SECURE=False)
"""
expexted_output = """\
import flask
app: flask.Flask = flask.Flask(__name__)
app.secret_key = "dev"
# more code
app.config.update(SESSION_COOKIE_SECURE=True)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expexted_output))
assert len(self.file_context.codemod_changes) == 1

def test_other_assignment_type(self, tmpdir):
input_code = """\
import flask
class AppStore:
pass
store = AppStore()
store.app = flask.Flask(__name__)
store.app.secret_key = "dev"
# more code
store.app.config.update(SESSION_COOKIE_SECURE=False)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

@pytest.mark.parametrize(
"config_lines,expected_config_lines",
[
Expand Down

0 comments on commit 3bbbfee

Please sign in to comment.