diff --git a/example/Win32Demo/fastllm.vcxproj b/example/Win32Demo/fastllm.vcxproj index e5e33c3a..0e57dfdc 100644 --- a/example/Win32Demo/fastllm.vcxproj +++ b/example/Win32Demo/fastllm.vcxproj @@ -163,7 +163,6 @@ /arch:AVX /source-charset:utf-8 %(AdditionalOptions) - cudart.lib;cublas.lib;%(AdditionalDependencies) Windows true true @@ -181,7 +180,7 @@ - + diff --git a/include/models/basellm.h b/include/models/basellm.h index 156a1d6e..61302542 100644 --- a/include/models/basellm.h +++ b/include/models/basellm.h @@ -1,4 +1,7 @@ -#pragma once + +#ifndef FASTLLM_BASELLM_H +#define FASTLLM_BASELLM_H + #include "fastllm.h" #include @@ -50,9 +53,9 @@ namespace fastllm { this->weight.ReleaseWeight(); }; - virtual void LoadFromFile(const std::string &fileName); // 从文件读取 + virtual void LoadFromFile(const std::string &fileName); // 从文件读取 - virtual void InitParams(); // 初始化参数信息 + virtual void InitParams(); // 初始化参数信息 // 推理 virtual int Forward( @@ -85,12 +88,12 @@ namespace fastllm { const LastTokensManager &lastTokens = LastTokensManager(), std::vector *> *logits = nullptr); - // 根据输入的tokens生成LLM推理的输入 + // 根据输入的tokens生成LLM推理的输入 virtual void FillLLMInputs(std::vector > &inputTokens, const std::map ¶ms, Data &inputIds, Data &attentionMask, Data &positionIds); - // 根据输入的tokens生成LLM推理的输入 + // 根据输入的tokens生成LLM推理的输入 virtual void FillLLMInputsBatch(std::vector > &inputTokens, const std::vector > ¶ms, Data &inputIds, Data &attentionMask, Data &positionIds); @@ -102,16 +105,16 @@ namespace fastllm { virtual void ResponseBatch(const std::vector &inputs, std::vector &outputs, RuntimeResultBatch retCb = nullptr, - const GenerationConfig &generationConfig = GenerationConfig()); // 批量根据给出的内容回复 + const GenerationConfig &generationConfig = GenerationConfig()); // 批量根据给出的内容回复 virtual int LaunchResponseTokens(const std::vector &inputTokens, const GenerationConfig &generationConfig = GenerationConfig()); // 启动一个response任务,返回分配的handleId - virtual int FetchResponseTokens(int handleId); // 获取指定handle的输出, -1代表输出结束了 + virtual int FetchResponseTokens(int handleId); // 获取指定handle的输出, -1代表输出结束了 virtual int FetchResponseLogits(int handleId, std::vector &logits); // 获取指定handle的输出Logits - virtual void SaveLowBitModel(const std::string &fileName, int bit); // 存储成量化模型 + virtual void SaveLowBitModel(const std::string &fileName, int bit); // 存储成量化模型 virtual void SaveModel(const std::string &fileName); // 直接导出 @@ -158,3 +161,5 @@ namespace fastllm { int tokensLimit = -1; }; } + +#endif //FASTLLM_BASELLM_H \ No newline at end of file