From 8742a071cb9f09e110cf0f0d09f0aabfc29033cd Mon Sep 17 00:00:00 2001 From: siemon Date: Tue, 17 Oct 2023 11:56:06 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E9=87=8A=E6=94=BE=E6=9D=83?= =?UTF-8?q?=E9=87=8D=E6=89=80=E5=8D=A0=E5=86=85=E5=AD=98=E7=A9=BA=E9=97=B4?= =?UTF-8?q?=E7=9A=84=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/fastllm.h | 2 ++ include/models/basellm.h | 4 +++- src/fastllm.cpp | 15 +++++++++++++++ tools/fastllm_pytools/llm.py | 3 +++ tools/src/pytools.cpp | 6 ++++++ 5 files changed, 29 insertions(+), 1 deletion(-) diff --git a/include/fastllm.h b/include/fastllm.h index a1ea663f..22fde50f 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -415,6 +415,8 @@ namespace fastllm { void AddWeight(const std::string &key, const std::vector &dims, DataType dataType, WeightType weightType, DataType oriDataType, uint8_t *oriData); // 插入一个权重 + void ReleaseWeight(); // 释放所有权重占用的空间 + void AddQLinearWeight(const std::string &key, const std::vector &dims, int bit, float *scales, uint8_t *oriData); // 插入一个Qlinear层的权重,量化规则为float value = scales * oriData diff --git a/include/models/basellm.h b/include/models/basellm.h index dc12b994..161aa976 100644 --- a/include/models/basellm.h +++ b/include/models/basellm.h @@ -46,7 +46,9 @@ namespace fastllm { public: basellm() {}; - ~basellm() {}; + ~basellm() { + this->weight.ReleaseWeight(); + }; virtual void LoadFromFile(const std::string &fileName); // 从文件读取 diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 4f0db830..908fc29c 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -1722,6 +1722,21 @@ namespace fastllm { } } + void WeightMap::ReleaseWeight() { + for (auto &w : this->weight) { +#ifndef USE_MMAP + delete[] w.second.cpuData; + w.second.cpuData = nullptr; +#endif +#ifdef USE_CUDA + if (w.second.cudaData != nullptr) { + FastllmCudaDirectFree(w.second.cudaData); + w.second.cudaData = nullptr; + } +#endif + } + } + Data &WeightMap::operator[](const std::string &key) { return weight[key]; } diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index 7ebdc297..ca42e516 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -345,3 +345,6 @@ def set_adapter(self, name: str): def disable_adapter(self): fastllm_lib.disable_adapter(self.model) + + def release_memory(self): + fastllm_lib.release_memory(self.model) diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp index c490e94f..933c0bf6 100644 --- a/tools/src/pytools.cpp +++ b/tools/src/pytools.cpp @@ -169,6 +169,12 @@ extern "C" { return; } + DLL_EXPORT void release_memory(int modelId) { + auto model = models.GetModel(modelId); + model->weight.ReleaseWeight(); + return; + } + DLL_EXPORT void init_params_llm_model(int modelId) { auto model = models.GetModel(modelId); model->InitParams();