Skip to content

Commit

Permalink
Clean up + report exception message in call graph tracer
Browse files Browse the repository at this point in the history
  • Loading branch information
Domiii committed Oct 8, 2024
1 parent 6f41f08 commit cba291c
Showing 1 changed file with 58 additions and 46 deletions.
104 changes: 58 additions & 46 deletions prediction_assets/call_graph_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,43 @@
import traceback
from json.decoder import JSONDecodeError
from types import FrameType, TracebackType
from typing import Callable, Dict, List, Optional, Union # noqa: UP035
from typing import Callable, Dict, List, Optional, Union, Any # noqa: UP035

TargetConfig = Dict[str, Union[Optional[str], Optional[int]]]


def parse_json(json_string):
try:
return json.loads(json_string)
except JSONDecodeError as e:
# Get the position of the error
pos = e.pos

# Get the line and column of the error
lineno = json_string.count('\n', 0, pos) + 1
colno = pos - json_string.rfind('\n', 0, pos)
lineno = json_string.count("\n", 0, pos) + 1
colno = pos - json_string.rfind("\n", 0, pos)

# Get the problematic lines (including context)
lines = json_string.splitlines()
context_range = 2 # Number of lines to show before and after the error
start = max(0, lineno - context_range - 1)
end = min(len(lines), lineno + context_range)
context_lines = lines[start:end]

# Create the context string with line numbers
context = ""
for i, line in enumerate(context_lines, start=start+1):
for i, line in enumerate(context_lines, start=start + 1):
if i == lineno:
context += f"{i:4d} > {line}\n"
context += " " + " " * (colno - 1) + "^\n"
else:
context += f"{i:4d} {line}\n"

# Construct and raise a new error with more information
error_msg = f"JSON parsing failed at line {lineno}, column {colno}:\n\n{context.rstrip()}\nError: {str(e)}"
raise ValueError(error_msg) from e


# Parse config.
TRACE_TARGET_CONFIG_STR = os.environ.get("TDD_TRACE_TARGET_CONFIG")
TRACE_TARGET_CONFIG: Optional[TargetConfig] = None
Expand All @@ -55,7 +57,9 @@ def parse_json(json_string):
if "target_file" not in TRACE_TARGET_CONFIG:
raise ValueError("TDD_TRACE_TARGET_CONFIG must provide 'target_file'.")
if "target_function_name" not in TRACE_TARGET_CONFIG:
raise ValueError("TDD_TRACE_TARGET_CONFIG must provide 'target_function_name' if 'target_file' is provided.")
raise ValueError(
"TDD_TRACE_TARGET_CONFIG must provide 'target_function_name' if 'target_file' is provided."
)

# Record parameter values and return values, only if target region is sufficiently scoped.
RECORD_VALUES = not not TRACE_TARGET_CONFIG
Expand Down Expand Up @@ -90,7 +94,6 @@ def __init__(
self.code_context = code_context
self.index = index


@classmethod
def from_frame(cls, frame: FrameType) -> "FrameInfo":
return cls(frame=frame)
Expand All @@ -104,7 +107,6 @@ def from_frame_summary(cls, summary: traceback.FrameSummary) -> "FrameInfo":
return cls(
call_filename=summary.filename,
call_lineno=summary.lineno,

function=summary.name,
code_context=summary.line and [summary.line] or None,
)
Expand All @@ -120,7 +122,6 @@ def get_name(self) -> str:
code = frame.f_code
return code.co_name


def get_locals(self) -> Dict[str, any]:
return self.frame.f_locals

Expand Down Expand Up @@ -154,6 +155,7 @@ def get_relative_filename(filename: str) -> str:
except Exception:
return filename


class BaseNode:
def __init__(self):
self.children: List[BaseNode] = []
Expand All @@ -175,6 +177,7 @@ def __str__(self, level=0, visited=None):
def to_dict(self):
return {"name": self.name, "type": "OmittedNode"}


class CallGraphNode(BaseNode):
is_partial: bool = False

Expand Down Expand Up @@ -231,6 +234,7 @@ def __str__(self, level=0, visited=None):
result += child.__str__(level + 1, visited)
return result


class CallGraph:
def __init__(self):
self.call_stack: List[CallGraphNode] = ContextVar("call_stack", default=[])
Expand All @@ -256,7 +260,7 @@ def should_trace(self, frame: FrameType) -> bool:
return abs_filename.startswith(self.cwd)
else:
return True

def access_call_stack(self) -> List[CallGraphNode]:
call_stack = self.call_stack.get()
res: List[CallGraphNode] = call_stack.copy()
Expand Down Expand Up @@ -291,24 +295,18 @@ def trace_calls(self, event_frame: FrameType, event: str, arg: any) -> Optional[
call_stack.pop()
self.call_stack.set(call_stack)
elif event == "exception":
exc_type, _, _ = arg
exc_type, exc_str, _ = arg
if exc_type is GeneratorExit:
return None
call_stack = self.access_call_stack()
if call_stack:
call_stack[-1].set_exception(exc_type.__name__)
test_node = (
next(
(
node
for node in reversed(call_stack)
if node.name.startswith("test_")
),
None,
)
test_node = next(
(node for node in reversed(call_stack) if node.name.startswith("test_")),
None,
)
if test_node:
self.print_graph_on_exception("EXCEPTION", test_node)
self.print_graph_on_exception("EXCEPTION", test_node, exc_str)
return self.trace_calls
except Exception as err:
if not MUTE_EXCEPTIONS:
Expand All @@ -321,10 +319,13 @@ def find_node(self, target_config: TargetConfig) -> Optional[CallGraphNode]:
stack = [root]
while stack:
node = stack.pop()
if (node.name == target_config.get('target_function_name') and
node.decl_filename == target_config.get('target_file') and
(not target_config.get('decl_lineno') or
node.frame_info.decl_lineno == target_config['decl_lineno'])):
if (
node.name == target_config.get("target_function_name")
and node.decl_filename == target_config.get("target_file")
and (
not target_config.get("decl_lineno") or node.frame_info.decl_lineno == target_config["decl_lineno"]
)
):
return node
stack.extend(reversed(node.children))
return None
Expand Down Expand Up @@ -372,25 +373,36 @@ def create_partial_node(node: CallGraphNode) -> CallGraphNode:

return root

def print_graph_on_exception(self, cause: str, node: BaseNode):
result: str = None
if TRACE_TARGET_CONFIG:
partial_graph = self.get_partial_graph(TRACE_TARGET_CONFIG)
partial_info = f" PARTIAL='{str(TRACE_TARGET_CONFIG)}'"
if partial_graph:
result = str(partial_graph)
def print_graph_on_exception(self, cause: str, node: BaseNode, exception_details: Optional[Any]):
try:
result: str = None
if TRACE_TARGET_CONFIG:
partial_graph = self.get_partial_graph(TRACE_TARGET_CONFIG)
partial_info = f" PARTIAL='{str(TRACE_TARGET_CONFIG)}'"
if partial_graph:
result = str(partial_graph)
else:
# Hackfix: Stringify without values, if we could not target the function.
global RECORD_VALUES
RECORD_VALUES = False
result = (
"(❌ ERROR: Could not find target function. Providing high-level call graph instead. ❌)\n"
+ str(node)
)
RECORD_VALUES = True
else:
# Hackfix: Stringify without values, if we could not target the function.
global RECORD_VALUES
RECORD_VALUES = False
result = "(❌ ERROR: Could not find target function. Providing high-level call graph instead. ❌)\n" + str(node)
RECORD_VALUES = True
else:
partial_info = ""
result = str(node)
print("\n\n" + f"<CALL_GRAPH_ON_EXCEPTION cause='{cause}'{partial_info}>", file=sys.stderr)
print(result, file=sys.stderr)
print("\n</CALL_GRAPH_ON_EXCEPTION>", file=sys.stderr)
partial_info = ""
result = str(node)

print("\n\n" + f"<EXCEPTION_EVENT origin='{cause}'>", file=sys.stderr)
if exception_details:
print("<EXCEPTION_DETAILS>\n" + str(exception_details) + "\n</EXCEPTION_DETAILS>", file=sys.stderr)
print(f"<CALL_GRAPH_ON_EXCEPTION{partial_info}>", file=sys.stderr)
print(result, file=sys.stderr)
print("</CALL_GRAPH_ON_EXCEPTION>", file=sys.stderr)
print("</EXCEPTION_EVENT>", file=sys.stderr)
except Exception as err:
print(f"INTERNAL ERROR when printing EXCEPTION_EVENT: {err}")

if python_version >= (3, 7):

Expand Down

0 comments on commit cba291c

Please sign in to comment.