Skip to content

Commit

Permalink
Made sure that there is a sensible label to use in the path for savin…
Browse files Browse the repository at this point in the history
…g directions
  • Loading branch information
ojh31 committed Oct 12, 2023
1 parent 3691693 commit 49fe73b
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 17 deletions.
4 changes: 3 additions & 1 deletion fit_directions.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,15 @@ def get_placeholder(prompt_type: PromptType) -> Optional[str]:
epochs = 1 if "treebank" in train_type.value else 64,
batch_size=BATCH_SIZES[model_name],
d_das=method.get_dimension(),
train_label=train_type.value,
)
print(f"Saving DAS direction to {das_path}")
torch.cuda.empty_cache()
print("Emptied CUDA cache")
else:
trainset = ResidualStreamDataset.get_dataset(
model, device, prompt_type=train_types, scaffold=SCAFFOLD
model, device, prompt_type=train_types, scaffold=SCAFFOLD,
label=train_type.value,
)
testset = ResidualStreamDataset.get_dataset(
model, device, prompt_type=test_type, scaffold=SCAFFOLD
Expand Down
29 changes: 19 additions & 10 deletions utils/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _fit(
method: ClassificationMethod = ClassificationMethod.KMEANS,
):
assert train_data is not None
assert train_layer is not None
if test_data is None:
test_data = train_data
if test_layer is None:
Expand Down Expand Up @@ -273,8 +274,12 @@ def _fit_logistic_regression(


def train_classifying_direction(
train_data: ResidualStreamDataset, train_pos: Union[str, None], train_layer: int,
test_data: Union[ResidualStreamDataset, None], test_pos: Union[str, None], test_layer: int,
train_data: ResidualStreamDataset,
train_pos: Union[str, None],
train_layer: int,
test_data: Union[ResidualStreamDataset, None],
test_pos: Union[str, None],
test_layer: int,
method: ClassificationMethod,
**kwargs,
):
Expand All @@ -298,28 +303,32 @@ def train_classifying_direction(
warnings.simplefilter("error", ConvergenceWarning) # Turn the warning into an error
try:
train_line, correct, total, accuracy = fitting_method(
train_data, train_pos, train_layer,
test_data, test_pos, test_layer,
train_data,
train_layer,
test_data,
test_layer,
**kwargs,
)
test_line, _, _, _ = fitting_method(
test_data, test_pos, test_layer,
test_data, test_pos, test_layer,
test_data,
test_layer,
test_data,
test_layer,
**kwargs,
)
except ConvergenceWarning:
print(
f"Convergence warning for {method.value}; "
f"train type:{train_data.prompt_type.value}, pos: {train_pos}, layer:{train_layer}, "
f"test type:{test_data.prompt_type.value}, pos: {test_pos}, layer:{test_layer}, "
f"train type:{train_data.prompt_type}, pos: {train_pos}, layer:{train_layer}, "
f"test type:{test_data.prompt_type}, pos: {test_pos}, layer:{test_layer}, "
f"kwargs: {kwargs}\n"
f"train str_labels:{train_data.str_labels}\n"
f"test str_labels:{test_data.str_labels}\n"
)
return
# write line to file
train_pos_str = train_pos if train_pos is not None else "ALL"
array_path = f"{method.value}_{train_data.prompt_type.value}_{train_pos_str}_layer{train_layer}"
train_pos_str = f"_{train_pos}" if train_pos is not None else ""
array_path = f"{method.value}_{train_data.label}{train_pos_str}_layer{train_layer}"
save_array(train_line, array_path, model)

cosine_sim = safe_cosine_sim(
Expand Down
10 changes: 6 additions & 4 deletions utils/das.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def get_das_dataset(
device: Optional[torch.device] = None,
requires_grad: bool = True,
verbose: bool = False,
label: Optional[str] = None,
):
"""
Wrapper for utils.prompts.get_dataset that returns a dataset in a useful form for DAS
Expand All @@ -437,7 +438,7 @@ def get_das_dataset(
return DataLoader([]), None, None
clean_corrupt_data = get_dataset(
model, device, prompt_type=prompt_type, scaffold=scaffold,
position=position
position=position, label=label,
)
if max_dataset_size is not None:
clean_corrupt_data = clean_corrupt_data.get_subset(
Expand Down Expand Up @@ -492,6 +493,7 @@ def train_das_subspace(
data_requires_grad: bool = False,
verbose: bool = False,
d_das: int = 1,
train_label: Optional[str] = None,
**config_arg,
) -> Tuple[Float[Tensor, "batch d_model"], str]:
"""
Expand All @@ -505,7 +507,7 @@ def train_das_subspace(
train_type, position=train_pos, layer=train_layer, model=model,
batch_size=batch_size, max_dataset_size=max_dataset_size,
scaffold=scaffold, device=device, requires_grad=data_requires_grad,
verbose=verbose,
verbose=verbose, label=train_label
)
if test_type != train_type or test_pos != train_pos or test_layer != train_layer:
testloader, loss_fn_val, test_position = get_das_dataset(
Expand Down Expand Up @@ -539,9 +541,9 @@ def train_das_subspace(
verbose=verbose,
**config,
)
train_pos = train_pos if train_pos is not None else 'ALL'
d_das_str = f'{d_das}d' if d_das > 1 else ''
save_path = f'das{d_das_str}_{train_type.value}_{train_pos}_layer{train_layer}'
train_pos_str = f'_{train_pos}' if train_pos is not None else ''
save_path = f'das{d_das_str}_{train_label}{train_pos_str}_layer{train_layer}'
save_array(
directions.detach().cpu().squeeze(1).numpy(),
save_path,
Expand Down
14 changes: 13 additions & 1 deletion utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def __init__(
position: Union[None, str, Int[Tensor, "batch"]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
prompt_type: Optional[PromptType] = None,
label: Optional[str] = None,
):
assert len(clean_tokens) == len(corrupted_tokens)
assert len(clean_tokens) == len(answer_tokens)
Expand Down Expand Up @@ -541,11 +542,13 @@ def __init__(
self.prompt_type = prompt_type
self.is_positive = is_positive
self.position = position
self.label = label

def __add__(self, other: "CleanCorruptedDataset"):
assert isinstance(other, CleanCorruptedDataset)
assert self.tokenizer is not None
assert self.tokenizer == other.tokenizer
assert self.label == other.label
seq_len1 = self.clean_tokens.shape[1]
seq_len2 = other.clean_tokens.shape[1]
max_seq_len = max(seq_len1, seq_len2)
Expand All @@ -568,6 +571,7 @@ def __add__(self, other: "CleanCorruptedDataset"):
torch.cat([self.position + offset1, other.position + offset2]),
self.tokenizer,
None,
self.label,
)

def get_subset(self, indices: List[int]):
Expand All @@ -580,6 +584,7 @@ def get_subset(self, indices: List[int]):
self.position[indices],
self.tokenizer,
self.prompt_type,
self.label,
)

def get_num_pad_tokens(self) -> Int[Tensor, "batch"]:
Expand Down Expand Up @@ -868,6 +873,7 @@ def get_dataset(
comparison: Tuple[str, str] = ("positive", "negative"),
scaffold: Optional[ReviewScaffold] = None,
position: Optional[Union[str, List[str], List[None]]] = None,
label: Optional[str] = None,
) -> CleanCorruptedDataset:
if isinstance(prompt_type, PromptType):
assert position is None or isinstance(position, str)
Expand All @@ -879,6 +885,7 @@ def get_dataset(
comparison=comparison,
scaffold=scaffold,
position=position,
label=label,
)
assert len(prompt_type) > 0
if position is None:
Expand All @@ -892,6 +899,7 @@ def get_dataset(
comparison=comparison,
scaffold=scaffold,
position=position[0],
label=label,
)
for pt, pos in zip(prompt_type[1:], position[1:]):
assert isinstance(pt, PromptType)
Expand All @@ -903,6 +911,7 @@ def get_dataset(
comparison=comparison,
scaffold=scaffold,
position=pos,
label=label,
)
return out

Expand All @@ -915,8 +924,10 @@ def _get_dataset(
comparison: Tuple[str, str] = ("positive", "negative"),
scaffold: Optional[ReviewScaffold] = None,
position: Optional[str] = None,
label: Optional[str] = None,
) -> CleanCorruptedDataset:
prompt_type = PromptType(prompt_type)
if label is None and isinstance(prompt_type, PromptType):
label = prompt_type.value
if prompt_type in (
PromptType.TREEBANK_TRAIN, PromptType.TREEBANK_TEST, PromptType.TREEBANK_DEV
):
Expand Down Expand Up @@ -976,6 +987,7 @@ def _get_dataset(
tokenizer=model.tokenizer,
position=position,
prompt_type=prompt_type,
label=label,
)


Expand Down
6 changes: 5 additions & 1 deletion utils/residual_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
position: Int[Tensor, "batch"],
model: HookedTransformer,
prompt_type: Union[PromptType, List[PromptType]],
label: str,
) -> None:
assert len(prompt_strings) == len(prompt_tokens)
assert len(prompt_strings) == len(is_positive)
Expand All @@ -46,6 +47,7 @@ def __init__(
self.position = position
self.model = model
self.prompt_type = prompt_type
self.label = label
label_tensor = self.prompt_tokens[
torch.arange(len(self.prompt_tokens)), self.position
].cpu().detach()
Expand Down Expand Up @@ -169,6 +171,7 @@ def get_dataset(
prompt_type: Union[PromptType, List[PromptType]] = PromptType.SIMPLE_TRAIN,
scaffold: Optional[ReviewScaffold] = None,
position: Optional[Union[str, List[str]]] = None,
label: Optional[str] = None,

) -> Union[None, 'ResidualStreamDataset']:
"""
Expand All @@ -178,7 +181,7 @@ def get_dataset(
return None
clean_corrupt_data = get_dataset(
model, device, prompt_type=prompt_type, scaffold=scaffold,
position=position,
position=position, label=label,
)

assert len(clean_corrupt_data.all_prompts) == len(clean_corrupt_data.answer_tokens)
Expand All @@ -190,6 +193,7 @@ def get_dataset(
clean_corrupt_data.position,
model,
prompt_type,
label,
)


Expand Down

0 comments on commit 49fe73b

Please sign in to comment.