Skip to content

Commit

Permalink
Merge pull request #179 from nbuilding/bug-fix
Browse files Browse the repository at this point in the history
Fixes circular imports
  • Loading branch information
SheepTester authored Jun 8, 2021
2 parents c67f89c + b4495b5 commit e981c04
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
10 changes: 1 addition & 9 deletions python/run.n
Original file line number Diff line number Diff line change
@@ -1,9 +1 @@
let i: () -> () -> int = [] -> () -> int {
return [] -> int {
return 3
}
}

let ii: () -> int = i()

print(ii())
let _ = imp "./runner.n"
3 changes: 1 addition & 2 deletions python/runner.n
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
let test:str = 0
let test = "ree"
let _ = imp "./run.n"
34 changes: 26 additions & 8 deletions python/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
syntaxpath = os.path.join(basepath, "syntax.lark")


def parse_file(file_path, base_path):
import_scope = Scope(base_path=base_path, file_path=file_path)
def parse_file(file_path, base_path, parent_imports):
import_scope = Scope(base_path=base_path, file_path=file_path, parent_imports=parent_imports)
native_functions.add_funcs(import_scope)

with open(syntaxpath, "r") as f:
Expand All @@ -70,8 +70,8 @@ def parse_file(file_path, base_path):
return import_scope, tree, file


async def eval_file(file_path, base_path):
import_scope, tree, _ = parse_file(file_path, base_path)
async def eval_file(file_path, base_path, parent_imports):
import_scope, tree, _ = parse_file(file_path, base_path, parent_imports)

import_scope.variables = {
**import_scope.variables,
Expand All @@ -80,8 +80,8 @@ async def eval_file(file_path, base_path):
return import_scope


def type_check_file(file_path, base_path):
import_scope, tree, text_file = parse_file(file_path, base_path)
def type_check_file(file_path, base_path, parent_imports):
import_scope, tree, text_file = parse_file(file_path, base_path, parent_imports)

scope = type_check(tree, import_scope)
import_scope.variables = {**import_scope.variables, **scope.variables}
Expand Down Expand Up @@ -173,6 +173,7 @@ def __init__(
warnings=None,
base_path="",
file_path="",
parent_imports=None
):
self.parent = parent
self.parent_function = parent_function
Expand All @@ -186,6 +187,8 @@ def __init__(
self.base_path = base_path
# The path of the file the Scope is associated with.
self.file_path = file_path
# The other files it has been imported from to prevent circular imports
self.parent_imports = parent_imports if parent_imports is not None else []

def new_scope(self, parent_function=None, inherit_errors=True):
return Scope(
Expand All @@ -195,6 +198,7 @@ def new_scope(self, parent_function=None, inherit_errors=True):
warnings=self.warnings if inherit_errors else [],
base_path=self.base_path,
file_path=self.file_path,
parent_imports=self.parent_imports,
)

def get_variable(self, name, err=True):
Expand Down Expand Up @@ -923,7 +927,7 @@ async def eval_expr(self, expr):
# Support old syntax
rel_file_path = expr.children[0].value + ".n"
file_path = os.path.join(os.path.dirname(self.file_path), rel_file_path)
val = await eval_file(file_path, self.base_path)
val = await eval_file(file_path, self.base_path, self.parent_imports + [os.path.normpath(self.file_path)])
holder = {}
for key in val.variables.keys():
if val.variables[key].public:
Expand Down Expand Up @@ -1623,8 +1627,22 @@ def type_check_expr(self, expr):
# Support old syntax
rel_file_path = expr.children[0].value + ".n"
file_path = os.path.join(os.path.dirname(self.file_path), rel_file_path)
if os.path.normpath(file_path) == os.path.normpath(self.file_path):
self.errors.append(
TypeCheckError(
expr.children[0], "You cannot import the file that is running"
)
)
return None
if os.path.isfile(file_path):
impn, f = type_check_file(file_path, self.base_path)
if os.path.normpath(file_path) in self.parent_imports:
self.errors.append(
TypeCheckError(
expr.children[0], "Circular imports are not allowed"
)
)
return None
impn, f = type_check_file(file_path, self.base_path, self.parent_imports + [os.path.normpath(self.file_path)])
if len(impn.errors) != 0:
self.errors.append(ImportedError(impn.errors[:], f))
if len(impn.warnings) != 0:
Expand Down

0 comments on commit e981c04

Please sign in to comment.