Skip to content

Commit

Permalink
update model's compile depth (#676)
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalrajkannan78 authored Nov 13, 2024
1 parent bb22c1b commit 33d674b
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def test_blazepose_detector_pytorch(test_device):
# Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Load BlazePose Detector
pose_detector = BlazePose()
Expand All @@ -46,7 +46,7 @@ def test_blazepose_detector_pytorch(test_device):
def test_blazepose_regressor_pytorch(test_device):
# Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Load BlazePose Landmark Regressor
pose_regressor = BlazePoseLandmark()
Expand All @@ -61,7 +61,7 @@ def test_blaze_palm_pytorch(test_device):

# Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Load BlazePalm Detector
palm_detector = BlazePalm()
Expand All @@ -85,7 +85,7 @@ def test_blaze_hand_pytorch(test_device):

# Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Load BlazePalm Detector
hand_regressor = BlazeHandLandmark()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_mobilenetv1_basic(test_device):
def generate_model_mobilenetv1_imgcls_hf_pytorch(test_device, variant):
# Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config() # load global compiler config object
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Create Forge module from PyTorch model
preprocessor = download_model(AutoImageProcessor.from_pretrained, variant)
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_mobilenetv1_192(test_device):
def generate_model_mobilenetV1I224_imgcls_hf_pytorch(test_device, variant):
# Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Create Forge module from PyTorch model
preprocessor = download_model(AutoImageProcessor.from_pretrained, variant)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_mobilenetv2_basic(test_device):
def generate_model_mobilenetV2I96_imgcls_hf_pytorch(test_device, variant):
# Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

preprocessor = download_model(AutoImageProcessor.from_pretrained, variant)
model = download_model(AutoModelForImageClassification.from_pretrained, variant)
Expand All @@ -73,7 +73,7 @@ def test_mobilenetv2_96(test_device):
def generate_model_mobilenetV2I160_imgcls_hf_pytorch(test_device, variant):
# Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

preprocessor = download_model(AutoImageProcessor.from_pretrained, variant)
model = download_model(AutoModelForImageClassification.from_pretrained, variant)
Expand All @@ -99,7 +99,7 @@ def test_mobilenetv2_160(test_device):
def generate_model_mobilenetV2I244_imgcls_hf_pytorch(test_device, variant):
# Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Create Forge module from PyTorch model
preprocessor = download_model(AutoImageProcessor.from_pretrained, variant)
Expand Down Expand Up @@ -169,7 +169,7 @@ def generate_model_mobilenetV2_semseg_hf_pytorch(test_device, variant):

# Configurations
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Load model
framework_model = download_model(MobileNetV2ForSemanticSegmentation.from_pretrained, variant)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_tri_basic_2_sematic_segmentation_pytorch(test_device):

# Set PyBuda configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.GENERATE_INITIAL_GRAPH
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Sample Input
image_w = 800
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def generate_model_yolotinyV3_imgcls_holli_pytorch(test_device, variant):
# STEP 1: Set Forge configuration parameters
compiler_cfg = forge.config._get_global_compiler_config() # load global compiler config object
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

model = Yolov3Tiny(num_classes=80, use_wrong_previous_anchors=True)
model.load_state_dict(torch.load("weights/yolov3_tiny_coco_01.h5"))
Expand Down

0 comments on commit 33d674b

Please sign in to comment.