Skip to content

Commit

Permalink
[lora] Update load option to preload (#2623)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Dec 4, 2024
1 parent c38d659 commit 44b096a
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 37 deletions.
24 changes: 12 additions & 12 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,8 @@ def register_adapter(inputs: Input):
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_load = inputs.get_as_string(
"load").lower() == "true" if inputs.contains_key("load") else True
adapter_preload = inputs.get_as_string("preload").lower(
) == "true" if inputs.contains_key("preload") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

Expand All @@ -543,10 +543,10 @@ def register_adapter(inputs: Input):
f"Only local LoRA models are supported. {adapter_path} is not a valid path"
)

if not adapter_load and adapter_pin:
raise ValueError("Can not set load to false and pin to true")
if not adapter_preload and adapter_pin:
raise ValueError("Can not set preload to false and pin to true")

if adapter_load:
if adapter_preload:
loaded = _service.add_lora(adapter_name, adapter_alias,
adapter_path)

Expand Down Expand Up @@ -578,16 +578,16 @@ def update_adapter(inputs: Input):
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_load = inputs.get_as_string(
"load").lower() == "true" if inputs.contains_key("load") else True
adapter_preload = inputs.get_as_string("preload").lower(
) == "true" if inputs.contains_key("preload") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

if adapter_name not in _service.adapter_registry:
raise ValueError(f"Adapter {adapter_alias} not registered.")

try:
if not adapter_load and adapter_pin:
if not adapter_preload and adapter_pin:
raise ValueError("Can not set load to false and pin to true")

old_adapter = _service.adapter_registry[adapter_name]
Expand All @@ -596,10 +596,10 @@ def update_adapter(inputs: Input):
raise NotImplementedError(
f"Updating adapter path is not supported.")

old_adapter_load = old_adapter.get_as_string("load").lower(
) == "true" if old_adapter.contains_key("load") else True
if adapter_load != old_adapter_load:
if adapter_load:
old_adapter_preload = old_adapter.get_as_string("preload").lower(
) == "true" if old_adapter.contains_key("preload") else True
if adapter_preload != old_adapter_preload:
if adapter_preload:
_service.add_lora(adapter_name, adapter_alias, adapter_path)
else:
_service.remove_lora(adapter_name, adapter_alias)
Expand Down
6 changes: 3 additions & 3 deletions serving/docs/adapters_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ This is an extension of the [Management API](management_api.md) and can be acces

* name: The adapter name.
* src: The adapter src. It currently requires a file, but eventually an id or URL can be supported depending on the model handler.
* load (optional): Whether to load the adapter weights, defaults to `true`. If this option is enabled, adapter weights will be loaded in GPU memory during registration.
* pin (optional): Whether to pin the adapter, defaults to `false`. If this option is enabled, adapter weights will be loaded, and the adapter is pinned during registration. This helps certain latency sensitive adapters to be present in GPU memory without being evicted.
* preload (optional): Whether to preload the adapter during initialization, defaults to `true`.
* pin (optional): Whether to pin the adapter, defaults to `false`. If this option is enabled, adapter will be preloaded, and the adapter is pinned during initialization. This helps certain latency sensitive adapters to be present in GPU memory without being evicted.
* All additional arguments will be treated as additional model-specific options and will be passed to the model during adapter registration

```bash
Expand All @@ -36,7 +36,7 @@ curl -X POST "http://localhost:8080/models/adaptecho/adapters?name=a1&src=/opt/m

`POST /models/{model_name}/adapters/{adapter_name}/update`

* load (optional): Whether to load the adapter weights.
* preload (optional): Whether to preload the adapter during initialization.
* pin (optional): Whether to pin the adapter. LoRA adapters can be pinned in GPU without being evicted from LRUCache. This helps certain latency sensitive adapters to be present in GPU memory without being evicted.
* All additional arguments will be treated as additional model-specific options and will be passed to the model during adapter registration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ private void handleListAdapters(
for (int i = pagination.getPageToken(); i < pagination.getLast(); ++i) {
String adapterName = keys.get(i);
Adapter<Input, Output> adapter = modelInfo.getAdapter(adapterName);
list.addAdapter(adapter.getName(), adapter.getSrc(), adapter.isLoad(), adapter.isPin());
list.addAdapter(
adapter.getName(), adapter.getSrc(), adapter.isPreload(), adapter.isPin());
}

NettyUtils.sendJsonResponse(ctx, list);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
public class DescribeAdapterResponse {
private String name;
private String src;
private boolean load;
private boolean preload;
private boolean pin;

/**
Expand All @@ -31,7 +31,7 @@ public class DescribeAdapterResponse {
public DescribeAdapterResponse(Adapter<Input, Output> adapter) {
this.name = adapter.getName();
this.src = adapter.getSrc();
this.load = adapter.isLoad();
this.preload = adapter.isPreload();
this.pin = adapter.isPin();
}

Expand All @@ -54,12 +54,12 @@ public String getSrc() {
}

/**
* Returns whether to load the adapter weights.
* Returns whether to preload the adapter during initialization.
*
* @return whether to load the adapter weights
* @return whether to preload the adapter during initialization
*/
public boolean isLoad() {
return load;
public boolean isPreload() {
return preload;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,24 @@ public List<AdapterItem> getAdapters() {
*
* @param name the adapter name
* @param src the adapter source
* @param load whether to load the adapter weights
* @param preload whether to preload the adapter during initialization
* @param pin whether to pin the adapter
*/
public void addAdapter(String name, String src, boolean load, boolean pin) {
adapters.add(new AdapterItem(name, src, load, pin));
public void addAdapter(String name, String src, boolean preload, boolean pin) {
adapters.add(new AdapterItem(name, src, preload, pin));
}

/** A class that holds the adapter response. */
public static final class AdapterItem {
private String name;
private String src;
private boolean load;
private boolean preload;
private boolean pin;

private AdapterItem(String name, String src, boolean load, boolean pin) {
private AdapterItem(String name, String src, boolean preload, boolean pin) {
this.name = name;
this.src = src;
this.load = load;
this.preload = preload;
this.pin = pin;
}

Expand Down
10 changes: 5 additions & 5 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ private void testRegisterAdapter(Channel channel, boolean registerModel, boolean
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertTrue(resp.isPreload());
assertFalse(resp.isPin());
}

Expand Down Expand Up @@ -1130,7 +1130,7 @@ private void testUpdateAdapter(Channel channel, boolean modelPrefix)
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertTrue(resp.isPreload());
assertTrue(resp.isPin());
}

Expand Down Expand Up @@ -1190,7 +1190,7 @@ private void testUpdateAdapterHandlerError() throws InterruptedException {
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertTrue(resp.isPreload());
assertFalse(resp.isPin());
}

Expand Down Expand Up @@ -1224,7 +1224,7 @@ private void testUpdateAdapterOom() throws InterruptedException {
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertTrue(resp.isPreload());
assertFalse(resp.isPin());
}

Expand Down Expand Up @@ -1381,7 +1381,7 @@ private void testDescribeAdapter(Channel channel, boolean modelPrefix)
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), "adaptable");
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertTrue(resp.isPreload());
assertTrue(resp.isPin());
}

Expand Down
8 changes: 4 additions & 4 deletions wlm/src/main/java/ai/djl/serving/wlm/Adapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ public void setOptions(Map<String, String> options) {
}

/**
* Returns whether to load the adapter weights.
* Returns whether to preload the adapter during initialization.
*
* @return whether to load the adapter weights
* @return whether to preload the adapter during initialization
*/
public boolean isLoad() {
return Boolean.parseBoolean(options.getOrDefault("load", "true"));
public boolean isPreload() {
return Boolean.parseBoolean(options.getOrDefault("preload", "true"));
}

/**
Expand Down

0 comments on commit 44b096a

Please sign in to comment.