Skip to content

Commit

Permalink
Tests, docs and bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva committed Oct 23, 2023
1 parent db0a008 commit d394942
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 17 deletions.
23 changes: 23 additions & 0 deletions src/core_codemods/docs/pixee_python_sql-parameterization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
This codemod refactors SQL statements to be parameterized, rather than built by hand.

Without parameterization, developers must remember to escape string inputs using the rules for that column type and database. This usually results in bugs -- and sometimes vulnerability. Although it's not clear if this code is exploitable today, this change will make the code more robust in case the conditions which prevent exploitation today ever go away.

Our changes look something like this:

```diff
import sqlite3

name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
- cursor.execute("SELECT * from USERS WHERE name ='" + name + "'")
+ cursor.execute("SELECT * from USERS WHERE name =?", (name, ))
```

If you have feedback on this codemod, [please let us know](mailto:[email protected])!

## F.A.Q.

### Why is this codemod marked as Merge With Cursory Review

Python has a wealth of database drivers that all use the same interface. Different drivers may require different string tokens used for parameterization, and Python's dynamic typing makes it quite hard, and sometimes impossible, to detect which driver is being used just by looking at the code.
39 changes: 26 additions & 13 deletions src/core_codemods/sql_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def on_leave(self, original_node, updated_node):
changes_dict[key] = [
*getattr(updated_node, key)
] + value.sequence
print(changes_dict)
case _:
changes_dict[key] = value
return updated_node.with_changes(**changes_dict)
Expand All @@ -82,7 +81,16 @@ class SQLQueryParameterization(BaseCodemod, Codemod):
DESCRIPTION=("Parameterize SQL queries."),
NAME="sql-parameterization",
REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW,
REFERENCES=[],
REFERENCES=[
{
"url": "https://cwe.mitre.org/data/definitions/89.html",
"description": "",
},
{
"url": "https://owasp.org/www-community/attacks/SQL_Injection",
"description": "",
},
],
)
SUMMARY = "Parameterize SQL queries."
CHANGE_DESCRIPTION = ""
Expand Down Expand Up @@ -145,9 +153,10 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:

# TODO research if named parameters are widely supported
# it could solve for the case of existing parameters
tuple_arg = cst.Arg(cst.Tuple(elements=params_elements))
# self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg])
self.changed_nodes[call] = {"args": Append([tuple_arg])}
if params_elements:
tuple_arg = cst.Arg(cst.Tuple(elements=params_elements))
# self.changed_nodes[call] = call.with_changes(args=[*call.args, tuple_arg])
self.changed_nodes[call] = {"args": Append([tuple_arg])}

# made changes
if self.changed_nodes:
Expand All @@ -173,10 +182,12 @@ def _fix_injection(
self.changed_nodes[expr] = cst.parse_expression('""')
# remove quote literal from end
match end:
# TODO test with escaped strings here...
case cst.SimpleString():
current_end = self.changed_nodes.get(end) or end
new_raw_value = current_end.raw_value[1:]
if current_end.raw_value.startswith("\\'"):
new_raw_value = current_end.raw_value[2:]
else:
new_raw_value = current_end.raw_value[1:]
new_value = (
current_end.prefix
+ current_end.quote
Expand All @@ -192,7 +203,10 @@ def _fix_injection(
match start:
case cst.SimpleString():
current_start = self.changed_nodes.get(start) or start
new_raw_value = current_start.raw_value[:-1] + parameter_token
if current_start.raw_value.endswith("\\'"):
new_raw_value = current_start.raw_value[:-2] + parameter_token
else:
new_raw_value = current_start.raw_value[:-1] + parameter_token
new_value = (
current_start.prefix
+ current_start.quote
Expand Down Expand Up @@ -279,6 +293,7 @@ def recurse_Name(self, node: cst.Name) -> list[cst.CSTNode]:

def recurse_Attribute(self, node: cst.Attribute) -> list[cst.CSTNode]:
# TODO attributes may have been assigned, should those be modified?
# research how to detect attribute assigns in libcst
return [node]

def _find_gparent(self, n: cst.CSTNode) -> Optional[cst.CSTNode]:
Expand Down Expand Up @@ -328,16 +343,14 @@ def leave_Module(self, original_node: cst.Module):
self.injection_patterns.append((start, middle, end))
# end may contain the start of anothe literal, put it back
# should not be a single quote

# TODO think of a better solution here
if self._is_literal_start(end, 0) and self._is_not_a_single_quote(end):
modulo_2 = 0
leaves.append(end)
else:
modulo_2 = 1

# TODO use changed nodes to detect if start has already been modified before
# this can happen if start = end of another expression

def _is_not_a_single_quote(self, expression: cst.CSTNode) -> bool:
match expression:
case cst.SimpleString():
Expand All @@ -358,11 +371,11 @@ def _is_injectable(self, expression: cst.CSTNode) -> bool:
match expression:
case cst.Integer() | cst.Float() | cst.Imaginary() | cst.SimpleString():
return False
case cst.Call(func=cst.Name(value="str"), args=[arg, *_]):
case cst.Call(func=cst.Name(value="str"), args=[cst.Arg(value=arg), *_]):
# TODO
# treat str(encoding = 'utf-8', object=obj)
# ensure this is the built-in
if matchers.matches(arg, literal_number): # type: ignore
if matchers.matches(arg, literal): # type: ignore
return False
case cst.FormattedStringExpression() if matchers.matches(
expression, literal
Expand Down
207 changes: 203 additions & 4 deletions tests/codemods/test_sql_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,218 @@ def test_name(self):
def test_simple(self, tmpdir):
input_code = """\
import sqlite3
from a import name
name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS where name ='" + name + "'")
cursor.execute("SELECT * from USERS WHERE name ='" + name + "'")
"""
expected = """\
import sqlite3
from a import name
name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS where name =?", (name, ))
cursor.execute("SELECT * from USERS WHERE name =?", (name, ))
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_multiple(self, tmpdir):
input_code = """\
import sqlite3
name = input()
phone = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS WHERE name ='" + name + "' AND phone ='" + phone + "'" )
"""
expected = """\
import sqlite3
name = input()
phone = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS WHERE name =?" + " AND phone =?", (name, phone, ))
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_can_deal_with_multiple_variables(self, tmpdir):
input_code = """\
import sqlite3
def foo(self, cursor, name, phone):
a = "SELECT * from USERS "
b = "WHERE name = '" + name
c = "' AND phone = '" + phone + "'"
return cursor.execute(a + b + c)
"""

expected = """\
import sqlite3
def foo(self, cursor, name, phone):
a = "SELECT * from USERS "
b = "WHERE name = ?"
c = " AND phone = ?"
return cursor.execute(a + b + c, (name, phone, ))
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_simple_if(self, tmpdir):
input_code = """\
import sqlite3
name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else name) + "'")
"""
expected = """\
import sqlite3
name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS WHERE name =?", (('Jenny' if True else name), ))
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_multiple_escaped_quote(self, tmpdir):
input_code = """\
import sqlite3
name = input()
phone = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute('SELECT * from USERS WHERE name =\\'' + name + '\\' AND phone =\\'' + phone + '\\'' )
"""
expected = """\
import sqlite3
name = input()
phone = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute('SELECT * from USERS WHERE name =?' + ' AND phone =?', (name, phone, ))
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

# negative tests below

def test_no_sql_keyword(self, tmpdir):
input_code = """\
import sqlite3
def foo(self, cursor, name, phone):
a = "COLLECT * from USERS "
b = "WHERE name = '" + name
c = "' AND phone = '" + phone + "'"
return cursor.execute(a + b + c)
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

def test_multiple_expressions_injection(self, tmpdir):
input_code = """\
import sqlite3
name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS WHERE name ='" + name + "_username" + "'")
"""
expected = """\
import sqlite3
name = input()
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS WHERE name =?", (name + "_username", ))
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_wont_parameterize_literals(self, tmpdir):
input_code = """\
import sqlite3
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS WHERE name ='" + str(1234) + "'")
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

def test_wont_parameterize_literals_if(self, tmpdir):
input_code = """\
import sqlite3
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS WHERE name ='" + ('Jenny' if True else 123) + "'")
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

def test_will_ignore_escaped_quote(self, tmpdir):
input_code = """\
import sqlite3
connection = sqlite3.connect("my_db.db")
cursor = connection.cursor()
cursor.execute("SELECT * from USERS WHERE name ='Jenny\'s username'")
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

def test_already_has_parameters(self, tmpdir):
input_code = """\
import sqlite3
def foo(self, cursor, name, phone):
a = "SELECT * from USERS "
b = "WHERE name = '" + name
c = "' AND phone = ?"
return cursor.execute(a + b + c, (phone,))
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

def test_wont_change_class_attribute(self, tmpdir):
input_code = """\
import sqlite3
class A():
query = "SELECT * from USERS WHERE name ='"
def foo(self, name, cursor):
return cursor.execute(query + name + "'")
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

def test_wont_change_module_variable(self, tmpdir):
input_code = """\
import sqlite3
query = "SELECT * from USERS WHERE name ='"
def foo(name, cursor):
return cursor.execute(query + name + "'")
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

0 comments on commit d394942

Please sign in to comment.