Skip to content

Commit

Permalink
Fix dialect selector (#391)
Browse files Browse the repository at this point in the history
  • Loading branch information
nolanbconaway authored Sep 30, 2024
1 parent 2dfa0f1 commit b7ad7f7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def inject_vars():
"""Inject arbitrary data into all templates."""
return dict(
all_rules=config.VALID_RULES,
all_dialects=config.VALID_DIALECTS,
all_dialects=list(config.VALID_DIALECTS.values()),
sqlfluff_version=config.SQLFLUFF_VERSION,
)

Expand Down
2 changes: 1 addition & 1 deletion src/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

SQLFLUFF_VERSION = sqlfluff.__version__

VALID_DIALECTS = tuple(d.name for d in sqlfluff.list_dialects())
VALID_DIALECTS = {d.label: d.name for d in sqlfluff.list_dialects()}

# dict mapping string rule names to descriptions
VALID_RULES = {r.code: r.description for r in sqlfluff.list_rules()}
22 changes: 19 additions & 3 deletions src/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from flask import Blueprint, redirect, render_template, request, url_for
from sqlfluff.api import fix, lint
from .config import VALID_DIALECTS

bp = Blueprint("routes", __name__)

Expand Down Expand Up @@ -43,10 +44,25 @@ def fluff_results():
sql = sql_decode(request.args["sql"]).strip()
sql = "\n".join(sql.splitlines()) + "\n"

# dialect must be a dialect label for `load_raw_dialect`. VALID_DIALECTS is a
# dictionary of dialect labels to dialect names. If we have a name, we need to
# get the label.
#
# However, the frontend logic runs on dialect names, so we need to convert the
# label back to a name for the frontend.
dialect = request.args["dialect"]
if dialect in VALID_DIALECTS.values():
dialect_name = dialect
dialect_label = next(
label for label, name in VALID_DIALECTS.items() if name == dialect
)
else:
dialect_label = dialect
dialect_name = VALID_DIALECTS[dialect]

try:
linted = lint(sql, dialect=dialect)
fixed_sql = fix(sql, dialect=dialect)
linted = lint(sql, dialect=dialect_label)
fixed_sql = fix(sql, dialect=dialect_label)
except RuntimeError as e:
linted = [
{
Expand All @@ -61,7 +77,7 @@ def fluff_results():
"index.html",
results=True,
sql=sql,
dialect=dialect,
dialect=dialect_name,
lint_errors=linted,
fixed_sql=fixed_sql,
)
19 changes: 16 additions & 3 deletions test/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,28 @@ def test_post_redirect(client):
assert rv.status_code == 302 and "/fluffed?sql" in rv.headers["location"]


def test_results_no_errors(client):
"""Test that the results is good to go when there is no error."""
@pytest.mark.parametrize("dialect", ["sparksql", "Apache Spark SQL"])
def test_results_no_errors(client, dialect):
"""Test that the results is good to go when there is no error.
Parameterized dialect asserts that either the formatted name or label can be used
as the dialect parameter.
"""
sql_encoded = sql_encode("select * from table")
rv = client.get("/fluffed", query_string=f"""dialect=ansi&sql={sql_encoded}""")
rv = client.get("/fluffed", query_string=f"""dialect={dialect}&sql={sql_encoded}""")
html = rv.data.decode().lower()
assert "sqlfluff online" in html
assert "fixed sql" in html
assert "select * from table" in html

# Test that the dialect is correctly selected in the results page.
selected_dialect = (
BeautifulSoup(html, "html.parser")
.find("select", {"id": "sql_dialect"})
.find("option", {"selected": "selected"})
)
assert selected_dialect.text.strip() == "apache spark sql"


def test_results_some_errors(client):
"""Test that the results is good to go with one obvious error."""
Expand Down

0 comments on commit b7ad7f7

Please sign in to comment.