diff --git a/include/devices/cuda/fastllm-cuda.cuh b/include/devices/cuda/fastllm-cuda.cuh index 576de200..e8b0b901 100644 --- a/include/devices/cuda/fastllm-cuda.cuh +++ b/include/devices/cuda/fastllm-cuda.cuh @@ -15,6 +15,7 @@ void FastllmCudaDirectFree(void *ret); void FastllmCudaCopyFromHostToDevice(void *dst, void *src, size_t size); void FastllmCudaCopyFromDeviceToHost(void *dst, void *src, size_t size); void FastllmCudaCopyFromDeviceToDevice(void *dst, void *src, size_t size); +void FastllmCudaMemcpyBetweenDevices(int dstId, void *dst, int srcId, void *src, size_t size); void FastllmCudaMemcpy2DDeviceToDevice(void * dst, size_t dpitch, const void * src, size_t spitch, size_t width, size_t height); diff --git a/include/fastllm.h b/include/fastllm.h index 2010e728..27e90f47 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -249,7 +249,7 @@ namespace fastllm { std::string fileName; long long filePos; - std::shared_ptr m_file; + std::shared_ptr mapFile; bool directMemory = false; // 直接分配/释放Memory,不经过缓存 @@ -303,8 +303,8 @@ namespace fastllm { void ToDevice(void *device); - void set_file(std::shared_ptr file) { - m_file = file; + void SetMapFile(std::shared_ptr file) { + mapFile = file; } }; diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index f11a0d4e..a2e77705 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -10,6 +10,17 @@ #include "fastllm-cuda.cuh" #include "fastllm.h" + +#define checkCudaErrors(message, val) showError(val, message, __FILE__, __LINE__) + +void showError(cudaError_t result, char const* const message, const char* const file, + int const line) { + if (cudaSuccess != result) { + printf("%s\n CUDA error = %d, %s at %s:%d\n '%s'\n", + message, result, cudaGetErrorName(result), file, line, cudaGetErrorString(result)); + } +} + static std::map s_fastllmCublasHandleMap; cublasHandle_t getFastllmCublasHandle() { int id = -1; @@ -1064,7 +1075,7 @@ void *FastllmCudaPrepareInput(const fastllm::Data &input) { ret = (void*)(input.expansionBytes); auto state = cudaMemcpy(ret, input.cpuData, input.expansionBytes, cudaMemcpyHostToDevice); if (cudaSuccess != state) { - printf("Error: CUDA error when copy from memory to GPU! state %d", state); + checkCudaErrors("Error: CUDA error when copy from memory to GPU!", state); return nullptr; } } @@ -1090,8 +1101,7 @@ void *FastllmCudaPrepareOutput(fastllm::Data &output) { void FastllmCudaFinishOutput(fastllm::Data &output, void *data) { if (output.dataDevice != fastllm::DataDevice::CUDA) { auto state = cudaMemcpy(output.cpuData, data, output.expansionBytes, cudaMemcpyDeviceToHost); - if (cudaSuccess != state) - printf("Error: CUDA error when copy from GPU to memory! state %d", state); + checkCudaErrors("Error: CUDA error when copy from GPU to memory!", state); FastllmCudaFree(data); } @@ -1123,8 +1133,7 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh } else { state = cudaMemset(cudaBiasData, 0, k * sizeof(float)); } - if (cudaSuccess != state) - printf("Error: CUDA error when moving bias to device! state %d\n", state); + checkCudaErrors("Error: CUDA error when moving bias to device!", state); weight.extraCudaData.push_back((void*)cudaBiasData); } @@ -1218,8 +1227,7 @@ bool FastllmCudaMatMulFloatInt4(const fastllm::Data &input, fastllm::Data &weigh } else { state = cudaMemset(cudaBiasData, 0, k * sizeof(float)); } - if (cudaSuccess != state) - printf("Error: CUDA error when moving bias to device! state %d\n", state); + checkCudaErrors("Error: CUDA error when moving bias to device!", state); weight.extraCudaData.push_back((void*)cudaBiasData); } @@ -1269,8 +1277,7 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data } else { state = cudaMemset(cudaBiasData, 0, k * sizeof(float)); } - if (cudaSuccess != state) - printf("Error: CUDA error when moving bias to device! state %d\n", state); + checkCudaErrors("Error: CUDA error when moving bias to device!", state); weight.extraCudaData.push_back((void*)cudaBiasData); } @@ -1351,8 +1358,7 @@ bool FastllmCudaMatMulFloat32(const fastllm::Data &input, fastllm::Data &weight, } else { state = cudaMemset(cudaBiasData, 0, k * sizeof(float)); } - if (cudaSuccess != state) - printf("Error: CUDA error when moving bias to device! state %d\n", state); + checkCudaErrors("Error: CUDA error when moving bias to device!", state); weight.extraCudaData.push_back((void*)cudaBiasData); } @@ -1403,8 +1409,7 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, } else { state = cudaMemset(cudaBiasData, 0, k * sizeof(float)); } - if (cudaSuccess != state) - printf("Error: CUDA error when moving bias to device! state %d\n", state); + checkCudaErrors("Error: CUDA error when moving bias to device!", state); weight.extraCudaData.push_back((void*)cudaBiasData); } float *cudaBiasData = (float*)weight.extraCudaData[0]; @@ -1476,7 +1481,8 @@ void * FastllmCudaDirectMalloc(size_t size) { void * ret; cudaError_t state = cudaMalloc(&ret, size); if (cudaSuccess != state) { - printf("Error: CUDA error when allocating %d kB memory! state %d, maybe there's no enough memory left on device.\n", size >> 10, state); + printf("Error: CUDA error when allocating %d kB memory! maybe there's no enough memory left on device.", size >> 10); + checkCudaErrors("", state); return nullptr; } return ret; @@ -1484,19 +1490,14 @@ void * FastllmCudaDirectMalloc(size_t size) { void FastllmCudaDirectFree(void *ret) { cudaError_t state = cudaFree(ret); - if (cudaSuccess != state) { - printf("Error: CUDA error when release memory! state %d.\n", state); - } + checkCudaErrors("Error: CUDA error when release memory!", state); } void * FastllmCudaMalloc(size_t size) { int id = -1; cudaError_t state = cudaSuccess; state = cudaGetDevice(&id); - if (cudaSuccess != state) { - printf("Error: CUDA error when find device! state %d", state); - return nullptr; - } + checkCudaErrors("Error: CUDA error when find device!", state); if (size > 1024 * 1024) { auto &bigBuffers = bigBuffersMap[id]; int selId = -1; @@ -1516,7 +1517,8 @@ void * FastllmCudaMalloc(size_t size) { void * ret; state = cudaMalloc(&ret, size); if (cudaSuccess != state) { - printf("Error: CUDA error when allocating %d MB memory! state %d, maybe there's no enough memory left on device.\n", size >> 20, state); + printf("Error: CUDA error when allocating %d MB memory! maybe there's no enough memory left on device.", size >> 20); + checkCudaErrors("", state); return nullptr; } bigBuffers.push_back(CudaMemoryBuffer(ret, size, true)); @@ -1533,7 +1535,8 @@ void * FastllmCudaMalloc(size_t size) { void * ret; state = cudaMalloc(&ret, size); if (cudaSuccess != state) { - printf("Error: CUDA error when allocating %d KB memory! state %d, maybe there's no enough memory left on device.\n", size >> 10, state); + printf("Error: CUDA error when allocating %d KB memory! maybe there's no enough memory left on device.", size >> 10); + checkCudaErrors("", state); return nullptr; } cudaBuffers.push_back(CudaMemoryBuffer(ret, size, true)); @@ -1556,7 +1559,8 @@ void FastllmCudaFree(void *ret) { state = cudaSetDevice(it.first); state = cudaFree(cudaBuffers[i].data); if (cudaSuccess != state) - printf("Error: CUDA error when release memory on device %d! state %d.\n", it.first, state); + printf("Error: CUDA error when release memory on device %d!", it.first); + checkCudaErrors("", state); } else { temp.push_back(cudaBuffers[i]); } @@ -1585,10 +1589,7 @@ void FastllmCudaFree(void *ret) { } } state = cudaFree(ret); - if (cudaSuccess != state) { - printf("CUDA error when release memory! state %d.\n", state); - return; - } + checkCudaErrors("CUDA error when release memory!", state); } void FastllmCudaMallocBigBuffer(size_t size) { @@ -1598,9 +1599,9 @@ void FastllmCudaMallocBigBuffer(size_t size) { auto &bigBuffers = bigBuffersMap[id]; cudaMalloc(&ret, size); auto state = cudaMalloc(&ret, size); - if (cudaSuccess != state) { - printf("Error: CUDA error when allocating %d MB memory! state %d. maybe there's no enough memory left on device.\n", size >> 20, state); - } + if (cudaSuccess != state) + printf("Error: CUDA error when allocating %d MB memory! maybe there's no enough memory left on device.", size >> 20); + checkCudaErrors("", state); bigBuffers.push_back(CudaMemoryBuffer(ret, size, false)); } @@ -1618,7 +1619,8 @@ void FastllmCudaClearBigBuffer() { state = cudaSetDevice(it.first); state = cudaFree(bigBuffers[i].data); if (cudaSuccess != state) - printf("Error: CUDA error when release memory on device %d! state %d.\n", it.first, state); + printf("Error: CUDA error when release memory on device %d!", it.first); + checkCudaErrors("", state); } else { temp.push_back(bigBuffers[i]); } @@ -1630,23 +1632,38 @@ void FastllmCudaClearBigBuffer() { } void FastllmCudaCopyFromHostToDevice(void *dst, void *src, size_t size) { - auto state = cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); - if (cudaSuccess != state) - printf("Error: CUDA error when copy from memory to GPU! state %d.\n", state); + cudaError_t state = cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); + checkCudaErrors("Error: CUDA error when copy from memory to GPU!", state); //cudaDeviceSynchronize(); } void FastllmCudaCopyFromDeviceToHost(void *dst, void *src, size_t size) { - auto state = cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); - if (cudaSuccess != state) - printf("Error: CUDA error when copy from GPU to memory! state %d.\n", state); + cudaError_t state = cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); + checkCudaErrors("Error: CUDA error when copy from GPU to memory!", state); //cudaDeviceSynchronize(); } void FastllmCudaCopyFromDeviceToDevice(void *dst, void *src, size_t size) { - auto state = cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); - if (cudaSuccess != state) - printf("Error: CUDA error when copy on GPU! state %d.\n", state); + cudaError_t state = cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); + checkCudaErrors("Error: CUDA error when copy on GPU!", state); + //cudaDeviceSynchronize(); +} + +void FastllmCudaMemcpyBetweenDevices(int dstId, void *dst, int srcId, void *src, size_t size) { + int canPeerAccess = 0; + cudaError_t state = cudaDeviceCanAccessPeer(&canPeerAccess, srcId, dstId); + if (canPeerAccess) { + state = cudaMemcpyPeer(dst, dstId, src, srcId, size); + } else { + uint8_t *cpuData = new uint8_t[size]; + state = cudaSetDevice(srcId); + state = cudaMemcpy(cpuData, src, size, cudaMemcpyDeviceToHost); + + state = cudaSetDevice(dstId); + state = cudaMemcpy(dst, cpuData, size, cudaMemcpyHostToDevice); + delete[] cpuData; + } + checkCudaErrors("Error: CUDA error when copy Between GPUs!", state); //cudaDeviceSynchronize(); } diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 908fc29c..5e64a162 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -684,11 +684,20 @@ namespace fastllm { #ifdef USE_CUDA if (this->dataDevice == DataDevice::CPU) { if (device == DataDevice::CUDA) { + uint8_t *cpuData = this->cpuData; +#ifdef USE_MMAP + cpuData = new uint8_t[expansionBytes]; + memcpy(cpuData, this->cpuData, expansionBytes); +#endif FastllmCudaSetDevice(deviceIds.size() == 0 ? 0 : deviceIds[0]); this->cudaData = FastllmCudaMalloc(expansionBytes); - FastllmCudaCopyFromHostToDevice(this->cudaData, this->cpuData, expansionBytes); + FastllmCudaCopyFromHostToDevice(this->cudaData, cpuData, expansionBytes); +#ifdef USE_MMAP + delete[] cpuData; +#else delete[] this->cpuData; this->cpuData = nullptr; +#endif } } else if (this->dataDevice == DataDevice::CUDA) { if (device == DataDevice::CPU) { @@ -697,16 +706,16 @@ namespace fastllm { FastllmCudaFree(this->cudaData); this->cudaData = nullptr; } else if (device == DataDevice::CUDA) { - FastllmCudaSetDevice(this->dataDeviceIds.size() == 0 ? 0 : this->dataDeviceIds[0]); - uint8_t *cpuData = new uint8_t[expansionBytes]; - FastllmCudaCopyFromDeviceToHost(cpuData, this->cudaData, expansionBytes); - FastllmCudaFree(this->cudaData); - - FastllmCudaSetDevice(deviceIds.size() == 0 ? 0 : deviceIds[0]); - this->cudaData = FastllmCudaMalloc(expansionBytes); + int sourceDevice = this->dataDeviceIds.size() == 0 ? 0 : this->dataDeviceIds[0]; + int destDevice = deviceIds.size() == 0 ? 0 : deviceIds[0]; + FastllmCudaSetDevice(destDevice); + void *newCudaData = FastllmCudaMalloc(expansionBytes); - FastllmCudaCopyFromHostToDevice(this->cudaData, cpuData, expansionBytes); - delete[] cpuData; + FastllmCudaMemcpyBetweenDevices(destDevice, newCudaData, sourceDevice, this->cudaData, expansionBytes); + FastllmCudaSetDevice(sourceDevice); + FastllmCudaFree(this->cudaData); + this->cudaData = newCudaData; + FastllmCudaSetDevice(destDevice); } } #endif @@ -1374,7 +1383,8 @@ namespace fastllm { } } else { #ifdef USE_MMAP - weight[name].set_file(mapped_file); + weight[name].SetMapFile(mapped_file); + weight[name].expansionBytes = (weight[name].Count(0) * weight[name].unitSize - 1) / weight[name].unitSizeDiv + 1; #else weight[name].Allocate(); #endif