Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove as_tensor argument of set_tensors_from_ndarray_1d #2615

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions botorch/optim/closures/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]],
parameters: dict[str, Tensor],
as_array: Callable[[Tensor], npt.NDArray] = None, # pyre-ignore [9]
as_tensor: Callable[[npt.NDArray], Tensor] = torch.as_tensor,
get_state: Callable[[], npt.NDArray] = None, # pyre-ignore [9]
set_state: Callable[[npt.NDArray], None] = None, # pyre-ignore [9]
fill_value: float = 0.0,
Expand All @@ -99,14 +98,13 @@ def __init__(
Expected to correspond with the first `len(parameters)` optional
gradient tensors returned by `closure`.
as_array: Callable used to convert tensors to ndarrays.
as_tensor: Callable used to convert ndarrays to tensors.
get_state: Callable that returns the closure's state as an ndarray. When
passed as `None`, defaults to calling `get_tensors_as_ndarray_1d`
on `closure.parameters` while passing `as_array` (if given by the user).
set_state: Callable that takes a 1-dimensional ndarray and sets the
closure's state. When passed as `None`, `set_state` defaults to
calling `set_tensors_from_ndarray_1d` with `closure.parameters` and
a given ndarray while passing `as_tensor`.
a given ndarray.
fill_value: Fill value for parameters whose gradients are None. In most
cases, `fill_value` should either be zero or NaN.
persistent: Boolean specifying whether an ndarray should be retained
Expand All @@ -128,15 +126,12 @@ def __init__(
as_array = partial(as_ndarray, dtype=np_float64)

if set_state is None:
set_state = partial(
set_tensors_from_ndarray_1d, parameters, as_tensor=as_tensor
)
set_state = partial(set_tensors_from_ndarray_1d, parameters)

self.closure = closure
self.parameters = parameters

self.as_array = as_ndarray
self.as_tensor = as_tensor
self._get_state = get_state
self._set_state = set_state

Expand Down
7 changes: 5 additions & 2 deletions botorch/optim/utils/numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def get_tensors_as_ndarray_1d(
def set_tensors_from_ndarray_1d(
tensors: Iterator[Tensor] | dict[str, Tensor],
array: npt.NDArray,
as_tensor: Callable[[npt.NDArray], Tensor] = torch.as_tensor,
) -> None:
r"""Sets the values of one more tensors based off of a vector of assignments."""
named_tensors_iter = (
Expand All @@ -125,7 +124,11 @@ def set_tensors_from_ndarray_1d(
try:
size = tnsr.numel()
vals = array[index : index + size] if tnsr.ndim else array[index]
tnsr.copy_(as_tensor(vals).to(tnsr).view(tnsr.shape).to(tnsr))
tnsr.copy_(
torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype).view(
tnsr.shape
)
saitcakmak marked this conversation as resolved.
Show resolved Hide resolved
)
index += size
except Exception as e:
raise RuntimeError(
Expand Down
Loading