Skip to content

Commit

Permalink
Merge branch 'main' of github.com:KastnerRG/fkeras
Browse files Browse the repository at this point in the history
  • Loading branch information
oliviaweng committed Nov 30, 2024
2 parents 6a624fe + 0204293 commit f472cae
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions fkeras/fmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@


class FModel:
def __init__(self, model, model_param_ber=0):
def __init__(self, model, model_param_ber=0, verbose=0):
self.model = model
self.model_param_ber = model_param_ber
self.verbose = verbose
self._set_layer_bit_ranges()
self._set_model_param_ber()
# self.layer_bit_ranges = {}
Expand Down Expand Up @@ -72,7 +73,8 @@ def _set_layer_bit_ranges(self):
raise NotImplementedError(
f"Injecting faults in {layer.__class__.__name__} layer not yet supported."
)
print(f"[fkeras.fmodel._set_layer_bit_ranges] {self.num_model_param_bits}")
if self.verbose > 0:
print(f"[fkeras.fmodel._set_layer_bit_ranges] {self.num_model_param_bits}")

def explicit_select_model_param_bitflip(self, bits_to_flip):
"""
Expand All @@ -87,9 +89,11 @@ def explicit_select_model_param_bitflip(self, bits_to_flip):
"""
bits_to_flip_per_layer = defaultdict(int)
# num_faults = int(self.num_model_param_bits * self.model_param_ber)
print(
f"[fkeras.fmodel.explicit_select_model_param_bitflip] num_faults = {len(bits_to_flip)}"
)

if self.verbose > 0:
print(
f"[fkeras.fmodel.explicit_select_model_param_bitflip] num_faults = {len(bits_to_flip)}"
)
# bits_to_flip = random.sample(list(range(self.num_model_param_bits)), num_faults)
bits_to_flip.sort()
# print(f"[fkeras.fmodel.explicit_select_model_param_bitflip] {bits_to_flip}")
Expand Down Expand Up @@ -127,7 +131,8 @@ def random_select_model_param_bitflip(self):
"""
bits_to_flip_per_layer = defaultdict(int)
num_faults = int(self.num_model_param_bits * self.model_param_ber)
print(f"num_faults = {num_faults}")
if self.verbose > 0:
print(f"num_faults = {num_faults}")
bits_to_flip = random.sample(list(range(self.num_model_param_bits)), num_faults)
bits_to_flip.sort()
for bit in bits_to_flip:
Expand Down

0 comments on commit f472cae

Please sign in to comment.