diff --git a/include/fastllm.h b/include/fastllm.h index a1ea663f..22fde50f 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -415,6 +415,8 @@ namespace fastllm { void AddWeight(const std::string &key, const std::vector &dims, DataType dataType, WeightType weightType, DataType oriDataType, uint8_t *oriData); // 插入一个权重 + void ReleaseWeight(); // 释放所有权重占用的空间 + void AddQLinearWeight(const std::string &key, const std::vector &dims, int bit, float *scales, uint8_t *oriData); // 插入一个Qlinear层的权重,量化规则为float value = scales * oriData diff --git a/include/models/basellm.h b/include/models/basellm.h index dc12b994..161aa976 100644 --- a/include/models/basellm.h +++ b/include/models/basellm.h @@ -46,7 +46,9 @@ namespace fastllm { public: basellm() {}; - ~basellm() {}; + ~basellm() { + this->weight.ReleaseWeight(); + }; virtual void LoadFromFile(const std::string &fileName); // 从文件读取 diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 4f0db830..908fc29c 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -1722,6 +1722,21 @@ namespace fastllm { } } + void WeightMap::ReleaseWeight() { + for (auto &w : this->weight) { +#ifndef USE_MMAP + delete[] w.second.cpuData; + w.second.cpuData = nullptr; +#endif +#ifdef USE_CUDA + if (w.second.cudaData != nullptr) { + FastllmCudaDirectFree(w.second.cudaData); + w.second.cudaData = nullptr; + } +#endif + } + } + Data &WeightMap::operator[](const std::string &key) { return weight[key]; } diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index 7ebdc297..ca42e516 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -345,3 +345,6 @@ def set_adapter(self, name: str): def disable_adapter(self): fastllm_lib.disable_adapter(self.model) + + def release_memory(self): + fastllm_lib.release_memory(self.model) diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp index c490e94f..933c0bf6 100644 --- a/tools/src/pytools.cpp +++ b/tools/src/pytools.cpp @@ -169,6 +169,12 @@ extern "C" { return; } + DLL_EXPORT void release_memory(int modelId) { + auto model = models.GetModel(modelId); + model->weight.ReleaseWeight(); + return; + } + DLL_EXPORT void init_params_llm_model(int modelId) { auto model = models.GetModel(modelId); model->InitParams();