Skip to content

Commit

Permalink
Merge branch 'main' into vmilosevic/nightly_ci
Browse files Browse the repository at this point in the history
  • Loading branch information
vmilosevic authored Nov 7, 2024
2 parents fabdd6c + e089fdc commit 58ce1a7
Show file tree
Hide file tree
Showing 129 changed files with 2,148 additions and 109 deletions.
72 changes: 72 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

# SPDX-License-Identifier: Apache-2.0

import time
import pytest
import psutil
import threading
from loguru import logger
from datetime import datetime


Expand All @@ -13,3 +17,71 @@ def record_test_timestamp(record_property):
yield
end_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z")
record_property("end_timestamp", end_timestamp)


@pytest.fixture(autouse=True)
def memory_usage_tracker():
"""
A pytest fixture that tracks memory usage during the execution of a test.
This fixture automatically tracks the memory usage of the process running the tests.
It starts tracking before the test runs, continues tracking in a background thread during the test,
and stops tracking after the test completes. It logs the memory usage statistics including the
minimum, maximum, average, and total memory usage by the test.
The memory usage is measured in megabytes (MB).
Note:
- This fixture is automatically used for all tests due to the `autouse=True` parameter.
- The interval for memory readings can be adjusted by changing the sleep duration in the `track_memory` function.
- Min, max, and avg memory usage are calculated based on the recorded memory readings from system memory.
"""
process = psutil.Process()

# Initialize memory tracking variables
start_mem = process.memory_info().rss / (1024 * 1024) # MB
min_mem = start_mem
max_mem = start_mem
total_mem = start_mem
count = 1

# Start a background thread or loop to collect memory usage over time
tracking = True

def track_memory():
nonlocal min_mem, max_mem, total_mem, count
while tracking:
current_mem = process.memory_info().rss / (1024 * 1024)
min_mem = min(min_mem, current_mem)
max_mem = max(max_mem, current_mem)
total_mem += current_mem
count += 1
time.sleep(0.1) # Adjust the interval as needed

# Start tracking in a background thread
import threading

tracker_thread = threading.Thread(target=track_memory)
tracker_thread.start()

# Run the test
yield

# Stop tracking and wait for the thread to finish
tracking = False
tracker_thread.join()

# Calculate end memory and memory usage stats
end_mem = process.memory_info().rss / (1024 * 1024) # MB
min_mem = min(min_mem, end_mem)
max_mem = max(max_mem, end_mem)
total_mem += end_mem
count += 1
avg_mem = total_mem / count

# Log memory usage statistics
logger.info(f"Test memory usage:")
logger.info(f" By test: {end_mem - start_mem:.2f} MB")
logger.info(f" Minimum: {min_mem:.2f} MB")
logger.info(f" Maximum: {max_mem:.2f} MB")
logger.info(f" Average: {avg_mem:.2f} MB")
28 changes: 15 additions & 13 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,37 +543,39 @@ class MLIRGenerator
return string_value;
}

/// Initialize lowering handler map
/// Initialize lowering handler map, keep in lexicographical order
void init_lowering_handler_map()
{
lowering_handler_map["abs"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::AbsOp>;
lowering_handler_map["add"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::AddOp>;
lowering_handler_map["cast"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TypecastOp>;
lowering_handler_map["concatenate"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ConcatOp>;
lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Conv2dOp>;
lowering_handler_map["cosine"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::CosOp>;
lowering_handler_map["embedding"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingOp>;
lowering_handler_map["exp"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ExpOp>;
lowering_handler_map["greater_equal"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::GreaterEqualOp>;
lowering_handler_map["greater"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::GreaterThanOp>;
lowering_handler_map["less"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::LessEqualOp>;
lowering_handler_map["matmul"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MatmulOp>;
lowering_handler_map["max_pool2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MaxPool2dOp>;
lowering_handler_map["maximum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MaximumOp>;
lowering_handler_map["multiply"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MultiplyOp>;
lowering_handler_map["not_equal"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::NotEqualOp>;
lowering_handler_map["reciprocal"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReciprocalOp>;
lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MeanOp>;
lowering_handler_map["reduce_max"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MaxOp>;
lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SumOp>;
lowering_handler_map["relu"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReluOp>;
lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReshapeOp>;
lowering_handler_map["sigmoid"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SigmoidOp>;
lowering_handler_map["sine"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SinOp>;
lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SoftmaxOp>;
lowering_handler_map["sqrt"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SqrtOp>;
lowering_handler_map["squeeze"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SqueezeOp>;
lowering_handler_map["subtract"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SubtractOp>;
lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TransposeOp>;
lowering_handler_map["greater_equal"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::GreaterEqualOp>;
lowering_handler_map["unsqueeze"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::UnsqueezeOp>;
lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Conv2dOp>;
lowering_handler_map["concatenate"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ConcatOp>;
lowering_handler_map["sigmoid"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SigmoidOp>;
lowering_handler_map["max_pool2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MaxPool2dOp>;
lowering_handler_map["abs"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::AbsOp>;
lowering_handler_map["exp"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ExpOp>;
lowering_handler_map["maximum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MaximumOp>;
lowering_handler_map["less"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::LessEqualOp>;
lowering_handler_map["greater"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::GreaterThanOp>;
lowering_handler_map["not_equal"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::NotEqualOp>;
lowering_handler_map["cast"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TypecastOp>;
}
};
} // namespace
Expand Down
3 changes: 3 additions & 0 deletions forge/csrc/passes/mlir_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma clang diagnostic pop

// MLIR headers
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/IR/BuiltinOps.h"
#include "utils/logger.hpp"

Expand Down Expand Up @@ -50,6 +51,8 @@ runtime::Binary run_mlir_compiler(tt::ForgeGraphModule& module)
mlir::ml_program::MLProgramDialect,
mlir::tensor::TensorDialect>();

mlir::func::registerInlinerExtension(registry);

// Create a context with all registered dialects.
mlir::MLIRContext context(registry);

Expand Down
3 changes: 3 additions & 0 deletions forge/forge/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ class CompilerConfig:
# Number of patterns to match for each module
tvm_module_to_num_patterns: Dict[str, int] = field(default_factory=lambda: dict())

# If enabled, for given test, it generates Forge Modules in form of PyTest for each op that exists in given module
tvm_generate_op_tests: bool = False

# Enables a transform for conv that directly reads input, such that it goes from stride > 1 to stride = 1
# This usually translates to lower DRAM BW and less math as the input better populates tiles
enable_conv_prestride: bool = True
Expand Down
12 changes: 0 additions & 12 deletions forge/forge/op/eltwise_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,6 @@ def Sine(name: str, operandA: Tensor) -> Tensor:
operandA: Tensor
First operand
min: float
Minimum value
max: float
Maximum value
Returns
-------
Tensor
Expand All @@ -407,12 +401,6 @@ def Cosine(name: str, operandA: Tensor) -> Tensor:
operandA: Tensor
First operand
min: float
Minimum value
max: float
Maximum value
Returns
-------
Tensor
Expand Down
54 changes: 49 additions & 5 deletions forge/forge/python_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def write_header(self):
self.wl("from loguru import logger")

self.wl("import torch")
self.wl("from forge import Tensor, compile")
self.wl("from forge.op.eval.common import compare_with_golden_pcc, compare_with_golden")
if self.framework == "tensorflow":
self.wl("import tensorflow as tf")
self.wl("from forge.tvm_utils import map_tf_dtype_to_pt")
Expand Down Expand Up @@ -257,17 +259,19 @@ def write_forward(self, ops, inputs, outputs):
self.indent = 0
self.wl("")

def write_param_parser(self, param_names, param_file_name):
def write_param_parser(
self, param_names, param_file_name, names_params_file_name=None, named_buffers_file_name=None
):
self.indent = 1

if self.framework == "pytorch":
self.wl(f"def process_framework_parameters(self, model):")
self.wl(f"def process_framework_parameters(self):")
self.indent += 1
self.wl(f"named_parameters = dict(model.state_dict().items())")
self.wl(f"named_parameters = torch.load('{names_params_file_name}')")
if param_file_name is not None:
self.wl(f'serialized_params = torch.load("{param_file_name}")')
self.wl(f"named_parameters.update(serialized_params)")
self.wl("named_buffers = dict(model.named_buffers())")
self.wl(f"named_buffers = torch.load('{named_buffers_file_name}')")
self.wl("named_parameters.update(named_buffers)")

if len(param_names):
Expand Down Expand Up @@ -949,6 +953,46 @@ def write_param_parser(self, param_names, param_file_name):
else:
assert False, "TODO: Add other framework param parsers"

def write_pytest_function(self, module_name, input_shapes):
"""
Generates a pytest function to test a module with given input shapes.
This function writes a pytest function that:
1. Creates input tensors based on the provided shapes.
2. Initializes the framework model with the specified module name.
3. Processes the framework parameters.
4. Runs the framework model with the created inputs.
5. Compiles the framework model.
6. Runs the compiled model with the same inputs.
7. Asserts that the outputs of the framework model and the compiled model are similar within a specified tolerance.
Args:
module_name (str): The name of the module to be tested.
input_shapes (list): A list of shapes for the input tensors.
"""
self.wl("")
self.wl("")
self.wl("def test_module():")
self.indent += 1
self.wl("inputs = [")
self.indent += 1
for shape in input_shapes:
self.wl(f"Tensor.create_from_torch(torch.rand({shape})),")
self.indent -= 1
self.wl("]")
self.wl("")
self.wl(f"framework_model = {self.class_name}('{module_name}')")
self.wl("framework_model.process_framework_parameters()")
self.wl("fw_out = framework_model(*inputs)")
self.wl("")
self.wl("compiled_model = compile(framework_model, sample_inputs=inputs)")
self.wl("co_out = compiled_model(*inputs)")
self.wl("")
self.wl(
"assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])"
)
self.indent -= 1


class PyTorchWriter(PythonWriter):
incompatible_np_float_types = [
Expand Down Expand Up @@ -1207,7 +1251,7 @@ def write_param_parser(self, param_names, param_file_name):
self.indent = 1

if self.framework == "pytorch":
self.wl(f"def process_framework_parameters(self, model):")
self.wl(f"def process_framework_parameters(self):")
self.indent += 1

self.wl("named_parameters = dict(model.named_parameters())")
Expand Down
Loading

0 comments on commit 58ce1a7

Please sign in to comment.