diff --git a/.devops/llama-cli-cann.Dockerfile b/.devops/llama-cli-cann.Dockerfile new file mode 100644 index 0000000000000..db5ba2f25ea67 --- /dev/null +++ b/.devops/llama-cli-cann.Dockerfile @@ -0,0 +1,44 @@ +ARG ASCEND_VERSION=8.0.rc2.alpha003-910b-openeuler22.03-py3.8 + +FROM cosdt/cann:$ASCEND_VERSION AS build + +WORKDIR /app + +COPY . . + +RUN yum install -y gcc g++ cmake make +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH} +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH} +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + +# find libascend_hal.so, because the drive hasn`t been mounted. +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH + +RUN echo "Building with static libs" && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \ + cmake -B build -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF && \ + cmake --build build --config Release --target llama-cli + +# TODO: use image with NNRT +FROM cosdt/cann:$ASCEND_VERSION AS runtime +COPY --from=build /app/build/bin/llama-cli /llama-cli + +ENV LC_ALL=C.utf8 + +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH} +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH} +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + +ENTRYPOINT ["/llama-cli" ] diff --git a/.devops/llama-server-cuda.Dockerfile b/.devops/llama-server-cuda.Dockerfile index 67328cf1c1788..1842489841f8c 100644 --- a/.devops/llama-server-cuda.Dockerfile +++ b/.devops/llama-server-cuda.Dockerfile @@ -24,6 +24,8 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} ENV GGML_CUDA=1 # Enable cURL ENV LLAMA_CURL=1 +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 RUN make -j$(nproc) llama-server diff --git a/.devops/llama-server-intel.Dockerfile b/.devops/llama-server-intel.Dockerfile index f525658dddfe5..9c355b664f15e 100644 --- a/.devops/llama-server-intel.Dockerfile +++ b/.devops/llama-server-intel.Dockerfile @@ -26,6 +26,8 @@ RUN apt-get update && \ COPY --from=build /app/build/bin/llama-server /llama-server ENV LC_ALL=C.utf8 +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] diff --git a/.devops/llama-server-rocm.Dockerfile b/.devops/llama-server-rocm.Dockerfile index 763b4cd3f1c2e..fd0e19ad6e49c 100644 --- a/.devops/llama-server-rocm.Dockerfile +++ b/.devops/llama-server-rocm.Dockerfile @@ -39,6 +39,8 @@ ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} ENV GGML_HIPBLAS=1 ENV CC=/opt/rocm/llvm/bin/clang ENV CXX=/opt/rocm/llvm/bin/clang++ +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 # Enable cURL ENV LLAMA_CURL=1 diff --git a/.devops/llama-server-vulkan.Dockerfile b/.devops/llama-server-vulkan.Dockerfile index 13a61ffd8454b..93c5e0c26e691 100644 --- a/.devops/llama-server-vulkan.Dockerfile +++ b/.devops/llama-server-vulkan.Dockerfile @@ -23,6 +23,8 @@ RUN cp /app/build/bin/llama-server /llama-server && \ rm -rf /app ENV LC_ALL=C.utf8 +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] diff --git a/.devops/llama-server.Dockerfile b/.devops/llama-server.Dockerfile index ff558604ebde2..02accc85e1368 100644 --- a/.devops/llama-server.Dockerfile +++ b/.devops/llama-server.Dockerfile @@ -21,6 +21,8 @@ RUN apt-get update && \ COPY --from=build /app/llama-server /llama-server ENV LC_ALL=C.utf8 +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] diff --git a/.ecrc b/.ecrc index a3351f4e6442d..c68877ec211f1 100644 --- a/.ecrc +++ b/.ecrc @@ -1,5 +1,5 @@ { - "Exclude": ["^\\.gitmodules$"], + "Exclude": ["^\\.gitmodules$", "stb_image\\.h"], "Disable": { "IndentSize": true } diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml.disabled similarity index 97% rename from .github/workflows/bench.yml rename to .github/workflows/bench.yml.disabled index eb69b82c47e64..bfdbb4ef5e385 100644 --- a/.github/workflows/bench.yml +++ b/.github/workflows/bench.yml.disabled @@ -1,3 +1,6 @@ +# TODO: there have been some issues with the workflow, so disabling for now +# https://github.com/ggerganov/llama.cpp/issues/7893 +# # Benchmark name: Benchmark @@ -129,6 +132,8 @@ jobs: - name: Server bench id: server_bench + env: + HEAD_REF: ${{ github.head_ref || github.ref_name }} run: | set -eux @@ -137,7 +142,7 @@ jobs: python bench.py \ --runner-label ${{ env.RUNNER_LABEL }} \ --name ${{ github.job }} \ - --branch ${{ github.head_ref || github.ref_name }} \ + --branch $HEAD_REF \ --commit ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha }} \ --scenario script.js \ --duration ${{ github.event.inputs.duration || env.DURATION }} \ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b9246659a6ef0..74b5d4f69d790 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -47,7 +47,7 @@ jobs: sysctl -a mkdir build cd build - cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL_EMBED_LIBRARY=ON -DLLAMA_CURL=ON -DBUILD_SHARED_LIBS=OFF .. + cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL_EMBED_LIBRARY=ON -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF .. cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) - name: Test @@ -105,7 +105,7 @@ jobs: sysctl -a # Metal is disabled due to intermittent failures with Github runners not having a GPU: # https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 - cmake -B build -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF -DLLAMA_CURL=ON -DBUILD_SHARED_LIBS=OFF + cmake -B build -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) - name: Test @@ -222,7 +222,7 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DBUILD_SHARED_LIBS=OFF + cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF cmake --build . --config Release -j $(nproc) - name: Test @@ -696,22 +696,20 @@ jobs: strategy: matrix: include: - - build: 'rpc-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=ON' - build: 'noavx-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DBUILD_SHARED_LIBS=ON' - build: 'avx2-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=ON' - build: 'avx-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_AVX2=OFF -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX2=OFF -DBUILD_SHARED_LIBS=ON' - build: 'avx512-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_AVX512=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX512=ON -DBUILD_SHARED_LIBS=ON' - build: 'openblas-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_BLAS=ON -DBUILD_SHARED_LIBS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BLAS=ON -DBUILD_SHARED_LIBS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' - build: 'kompute-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON -DBUILD_SHARED_LIBS=ON' - build: 'vulkan-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_VULKAN=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_VULKAN=ON -DBUILD_SHARED_LIBS=ON' - build: 'llvm-arm64' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON' - build: 'msvc-arm64' diff --git a/.github/workflows/python-check-requirements.yml b/.github/workflows/python-check-requirements.yml index 4e0374fc63d95..46e80aecd0a0c 100644 --- a/.github/workflows/python-check-requirements.yml +++ b/.github/workflows/python-check-requirements.yml @@ -6,15 +6,13 @@ on: - '.github/workflows/python-check-requirements.yml' - 'scripts/check-requirements.sh' - 'convert*.py' - - 'requirements.txt' - - 'requirements/*.txt' + - '**/requirements*.txt' pull_request: paths: - '.github/workflows/python-check-requirements.yml' - 'scripts/check-requirements.sh' - 'convert*.py' - - 'requirements.txt' - - 'requirements/*.txt' + - '**/requirements*.txt' concurrency: group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} diff --git a/.gitignore b/.gitignore index 04bbc037e672c..2880d19cac0e2 100644 --- a/.gitignore +++ b/.gitignore @@ -128,3 +128,6 @@ poetry.toml # Scripts !/scripts/install-oneapi.bat + +# Test models for lora adapters +/lora-tests diff --git a/CMakePresets.json b/CMakePresets.json index bdad38952d3cb..ce627b4d39e0c 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -28,6 +28,7 @@ { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, + { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, { "name": "arm64-windows-msvc", "hidden": true, @@ -60,6 +61,8 @@ { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, { "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] }, - { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] } + { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, + { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, + { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] } ] } diff --git a/Makefile b/Makefile index 649671ed6a72e..332496cfc39c1 100644 --- a/Makefile +++ b/Makefile @@ -763,6 +763,10 @@ ifdef GGML_VULKAN_MEMORY_DEBUG MK_CPPFLAGS += -DGGML_VULKAN_MEMORY_DEBUG endif +ifdef GGML_VULKAN_PERF + MK_CPPFLAGS += -DGGML_VULKAN_PERF +endif + ifdef GGML_VULKAN_VALIDATE MK_CPPFLAGS += -DGGML_VULKAN_VALIDATE endif diff --git a/README.md b/README.md index 1283f6805874e..bb2b93a35021f 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,8 @@ Typically finetunes of the base models below are supported as well. - [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca) - [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) - [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) +- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) +- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) (instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md)) @@ -186,10 +188,12 @@ Unless otherwise noted these projects are open-source with permissive licensing: - [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML - [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption +- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage **Infrastructure:** - [Paddler](https://github.com/distantmagic/paddler) - Stateful load balancer custom-tailored for llama.cpp +- [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs **Games:** - [Lucy's Labyrinth](https://github.com/MorganRO8/Lucys_Labyrinth) - A simple maze game where agents controlled by an AI model will try to trick you. @@ -422,6 +426,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md) | [CUDA](./docs/build.md#cuda) | Nvidia GPU | | [hipBLAS](./docs/build.md#hipblas) | AMD GPU | | [Vulkan](./docs/build.md#vulkan) | GPU | +| [CANN](./docs/build.md#cann) | Ascend NPU | ## Tools diff --git a/ci/run.sh b/ci/run.sh index 58022c7dc37e0..751bb0a021dce 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -13,6 +13,9 @@ # # with SYCL support # GG_BUILD_SYCL=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt # +# # with VULKAN support +# GG_BUILD_VULKAN=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# if [ -z "$2" ]; then echo "usage: $0 " @@ -40,7 +43,7 @@ if [ ! -z ${GG_BUILD_METAL} ]; then fi if [ ! -z ${GG_BUILD_CUDA} ]; then - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=1" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native" fi if [ ! -z ${GG_BUILD_SYCL} ]; then @@ -52,6 +55,10 @@ if [ ! -z ${GG_BUILD_SYCL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON" fi + +if [ ! -z ${GG_BUILD_VULKAN} ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_VULKAN=1" +fi ## helpers # download a file if it does not exist or if it is outdated @@ -107,7 +114,7 @@ function gg_run_ctest_debug { gg_check_build_requirements (time cmake -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log (time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log @@ -138,7 +145,7 @@ function gg_run_ctest_release { gg_check_build_requirements (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log if [ -z ${GG_BUILD_LOW_PERF} ]; then (time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log @@ -266,7 +273,6 @@ function gg_sum_ctest_with_model_release { } # open_llama_7b_v2 -# requires: GG_BUILD_CUDA function gg_run_open_llama_7b_v2 { cd ${SRC} @@ -290,8 +296,8 @@ function gg_run_open_llama_7b_v2 { set -e - (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} -DGGML_CUDA=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../examples/convert_legacy_llama.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf @@ -425,7 +431,7 @@ function gg_run_pythia_1_4b { set -e (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf @@ -535,7 +541,6 @@ function gg_sum_pythia_1_4b { } # pythia_2_8b -# requires: GG_BUILD_CUDA function gg_run_pythia_2_8b { cd ${SRC} @@ -556,8 +561,8 @@ function gg_run_pythia_2_8b { set -e - (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} -DGGML_CUDA=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf @@ -692,7 +697,7 @@ function gg_run_embd_bge_small { set -e (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf @@ -761,7 +766,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then fi if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then - if [ -z ${GG_BUILD_CUDA} ]; then + if [ -z ${GG_BUILD_CUDA} ] && [ -z ${GG_BUILD_VULKAN} ]; then test $ret -eq 0 && gg_run pythia_1_4b else test $ret -eq 0 && gg_run pythia_2_8b diff --git a/common/common.cpp b/common/common.cpp index 7c93f6ba89ba5..29f49293666c5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -77,6 +77,41 @@ using json = nlohmann::ordered_json; +// +// Environment variable utils +// + +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::string(value) : target; +} + +template +static typename std::enable_if::value && std::is_integral::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::stoi(value) : target; +} + +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::stof(value) : target; +} + +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + if (value) { + std::string val(value); + target = val == "1" || val == "true"; + } +} + // // CPU utils // @@ -110,8 +145,34 @@ int32_t cpu_get_num_physical_cores() { if (result == 0) { return num_physical_cores; } -#elif defined(_WIN32) - //TODO: Implement +#elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later + // TODO: windows + arm64 + mingw64 + unsigned int n_threads_win = std::thread::hardware_concurrency(); + unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4; + + DWORD buffer_size = 0; + if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) { + if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) { + return default_threads; + } + } + + std::vector buffer(buffer_size); + if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast(buffer.data()), &buffer_size)) { + return default_threads; + } + + int32_t num_physical_cores = 0; + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast(buffer.data()); + while (buffer_size > 0) { + if (info->Relationship == RelationProcessorCore) { + num_physical_cores += info->Processor.GroupCount; + } + buffer_size -= info->Size; + info = reinterpret_cast(reinterpret_cast(info) + info->Size); + } + + return num_physical_cores > 0 ? num_physical_cores : default_threads; #endif unsigned int n_threads = std::thread::hardware_concurrency(); return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; @@ -194,12 +255,6 @@ int32_t cpu_get_num_math() { // CLI argument parsing // -void gpt_params_handle_hf_token(gpt_params & params) { - if (params.hf_token.empty() && std::getenv("HF_TOKEN")) { - params.hf_token = std::getenv("HF_TOKEN"); - } -} - void gpt_params_handle_model_default(gpt_params & params) { if (!params.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model @@ -247,7 +302,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { gpt_params_handle_model_default(params); - gpt_params_handle_hf_token(params); + if (params.hf_token.empty()) { + get_env("HF_TOKEN", params.hf_token); + } if (params.escape) { string_process_escapes(params.prompt); @@ -267,6 +324,32 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { return true; } +void gpt_params_parse_from_env(gpt_params & params) { + // we only care about server-related params for now + get_env("LLAMA_ARG_MODEL", params.model); + get_env("LLAMA_ARG_MODEL_URL", params.model_url); + get_env("LLAMA_ARG_MODEL_ALIAS", params.model_alias); + get_env("LLAMA_ARG_HF_REPO", params.hf_repo); + get_env("LLAMA_ARG_HF_FILE", params.hf_file); + get_env("LLAMA_ARG_THREADS", params.n_threads); + get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx); + get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel); + get_env("LLAMA_ARG_BATCH", params.n_batch); + get_env("LLAMA_ARG_UBATCH", params.n_ubatch); + get_env("LLAMA_ARG_N_GPU_LAYERS", params.n_gpu_layers); + get_env("LLAMA_ARG_THREADS_HTTP", params.n_threads_http); + get_env("LLAMA_ARG_CHAT_TEMPLATE", params.chat_template); + get_env("LLAMA_ARG_N_PREDICT", params.n_predict); + get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics); + get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots); + get_env("LLAMA_ARG_EMBEDDINGS", params.embedding); + get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn); + get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold); + get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching); + get_env("LLAMA_ARG_HOST", params.hostname); + get_env("LLAMA_ARG_PORT", params.port); +} + bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { const auto params_org = params; // the example can modify the default params @@ -849,7 +932,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } - if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") { + if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--n-gpu-layers-draft") { CHECK_ARG params.n_gpu_layers_draft = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { @@ -1754,7 +1837,13 @@ std::string gpt_params_get_system_info(const gpt_params & params) { if (params.n_threads_batch != -1) { os << " (n_threads_batch = " << params.n_threads_batch << ")"; } +#if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later + // TODO: windows + arm64 + mingw64 + DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS); + os << " / " << logicalProcessorCount << " | " << llama_print_system_info(); +#else os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); +#endif return os.str(); } @@ -1806,13 +1895,19 @@ std::string string_get_sortable_timestamp() { void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); - } + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); } void string_process_escapes(std::string & input) { @@ -2729,12 +2824,6 @@ std::string llama_detokenize(llama_context * ctx, const std::vector return text; } -bool llama_should_add_bos_token(const llama_model * model) { - const int add_bos = llama_add_bos_token(model); - - return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); -} - // // Chat template utils // diff --git a/common/common.h b/common/common.h index bbc33a499afcd..f603ba2be1d35 100644 --- a/common/common.h +++ b/common/common.h @@ -267,7 +267,7 @@ struct gpt_params { std::string lora_outfile = "ggml-lora-merged-f16.gguf"; }; -void gpt_params_handle_hf_token(gpt_params & params); +void gpt_params_parse_from_env(gpt_params & params); void gpt_params_handle_model_default(gpt_params & params); bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); @@ -380,10 +380,6 @@ std::string llama_detokenize( const std::vector & tokens, bool special = true); -// Uses the value from the model metadata if possible, otherwise -// defaults to true when model type is SPM, otherwise false. -bool llama_should_add_bos_token(const llama_model * model); - // // Chat template utils // diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index a518b766dc33e..438452eab570f 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -369,6 +369,9 @@ namespace grammar_parser { } // Validate the state to ensure that all rules are defined for (const auto & rule : state.rules) { + if (rule.empty()) { + throw std::runtime_error("Undefined rule"); + } for (const auto & elem : rule) { if (elem.type == LLAMA_GRETYPE_RULE_REF) { // Ensure that the rule at that location exists diff --git a/common/stb_image.h b/common/stb_image.h index 4766d7e6754e5..9eedabedc45b3 100644 --- a/common/stb_image.h +++ b/common/stb_image.h @@ -1,4 +1,4 @@ -/* stb_image - v2.28 - public domain image loader - http://nothings.org/stb +/* stb_image - v2.30 - public domain image loader - http://nothings.org/stb no warranty implied; use at your own risk Do this: @@ -48,6 +48,8 @@ LICENSE RECENT REVISION HISTORY: + 2.30 (2024-05-31) avoid erroneous gcc warning + 2.29 (2023-05-xx) optimizations 2.28 (2023-01-29) many error fixes, security errors, just tons of stuff 2.27 (2021-07-11) document stbi_info better, 16-bit PNM support, bug fixes 2.26 (2020-07-13) many minor fixes @@ -371,13 +373,14 @@ RECENT REVISION HISTORY: #define STBI_VERSION 1 -enum { - STBI_default = 0, // only used for desired_channels +enum +{ + STBI_default = 0, // only used for desired_channels - STBI_grey = 1, - STBI_grey_alpha = 2, - STBI_rgb = 3, - STBI_rgb_alpha = 4 + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 }; #include @@ -405,11 +408,11 @@ extern "C" { // load image by filename, open file, or memory buffer // -typedef struct { - int (*read)(void * user, char * data, - int size); // fill 'data' with 'size' bytes. return number of bytes actually read - void (*skip)(void * user, int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative - int (*eof)(void * user); // returns nonzero if we are at end of file/data +typedef struct +{ + int (*read) (void *user,char *data,int size); // fill 'data' with 'size' bytes. return number of bytes actually read + void (*skip) (void *user,int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative + int (*eof) (void *user); // returns nonzero if we are at end of file/data } stbi_io_callbacks; //////////////////////////////////// @@ -417,24 +420,21 @@ typedef struct { // 8-bits-per-channel interface // -STBIDEF stbi_uc * stbi_load_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels); -STBIDEF stbi_uc * stbi_load_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, - int * channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_memory (stbi_uc const *buffer, int len , int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk , void *user, int *x, int *y, int *channels_in_file, int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_uc * stbi_load(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); -STBIDEF stbi_uc * stbi_load_from_file(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); // for stbi_load_from_file, file pointer is left pointing immediately after image #endif #ifndef STBI_NO_GIF -STBIDEF stbi_uc * stbi_load_gif_from_memory(stbi_uc const * buffer, int len, int ** delays, int * x, int * y, int * z, - int * comp, int req_comp); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp); #endif #ifdef STBI_WINDOWS_UTF8 -STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wchar_t * input); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); #endif //////////////////////////////////// @@ -442,14 +442,12 @@ STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wc // 16-bits-per-channel interface // -STBIDEF stbi_us * stbi_load_16_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels); -STBIDEF stbi_us * stbi_load_16_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, - int * channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_us * stbi_load_16(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); -STBIDEF stbi_us * stbi_load_from_file_16(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16 (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); #endif //////////////////////////////////// @@ -457,55 +455,56 @@ STBIDEF stbi_us * stbi_load_from_file_16(FILE * f, int * x, int * y, int * chann // float-per-channel interface // #ifndef STBI_NO_LINEAR -STBIDEF float * stbi_loadf_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels); -STBIDEF float * stbi_loadf_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * channels_in_file, - int desired_channels); + STBIDEF float *stbi_loadf_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); + STBIDEF float *stbi_loadf_from_callbacks (stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); -#ifndef STBI_NO_STDIO -STBIDEF float * stbi_loadf(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); -STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); -#endif + #ifndef STBI_NO_STDIO + STBIDEF float *stbi_loadf (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); + STBIDEF float *stbi_loadf_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); + #endif #endif #ifndef STBI_NO_HDR -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); -STBIDEF void stbi_hdr_to_ldr_scale(float scale); + STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); + STBIDEF void stbi_hdr_to_ldr_scale(float scale); #endif // STBI_NO_HDR #ifndef STBI_NO_LINEAR -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); -STBIDEF void stbi_ldr_to_hdr_scale(float scale); + STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); + STBIDEF void stbi_ldr_to_hdr_scale(float scale); #endif // STBI_NO_LINEAR // stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const * clbk, void * user); -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const * buffer, int len); +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr(char const * filename); -STBIDEF int stbi_is_hdr_from_file(FILE * f); +STBIDEF int stbi_is_hdr (char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); #endif // STBI_NO_STDIO + // get a VERY brief reason for failure // on most compilers (and ALL modern mainstream compilers) this is threadsafe -STBIDEF const char * stbi_failure_reason(void); +STBIDEF const char *stbi_failure_reason (void); // free the loaded image -- this is just free() -STBIDEF void stbi_image_free(void * retval_from_stbi_load); +STBIDEF void stbi_image_free (void *retval_from_stbi_load); // get image dimensions & components without fully decoding -STBIDEF int stbi_info_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp); -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp); -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const * buffer, int len); -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * clbk, void * user); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const * filename, int * x, int * y, int * comp); -STBIDEF int stbi_info_from_file(FILE * f, int * x, int * y, int * comp); -STBIDEF int stbi_is_16_bit(char const * filename); -STBIDEF int stbi_is_16_bit_from_file(FILE * f); +STBIDEF int stbi_info (char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file (FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit (char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); #endif + + // for image formats that explicitly notate that they have premultiplied alpha, // we just return the colors as stored in the file. set this flag to force // unpremultiplication. results are undefined if the unpremultiply overflow. @@ -527,14 +526,14 @@ STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_fli // ZLIB client - used by PNG, available for other purposes -STBIDEF char * stbi_zlib_decode_malloc_guesssize(const char * buffer, int len, int initial_size, int * outlen); -STBIDEF char * stbi_zlib_decode_malloc_guesssize_headerflag(const char * buffer, int len, int initial_size, int * outlen, - int parse_header); -STBIDEF char * stbi_zlib_decode_malloc(const char * buffer, int len, int * outlen); -STBIDEF int stbi_zlib_decode_buffer(char * obuffer, int olen, const char * ibuffer, int ilen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header); +STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); + +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); -STBIDEF char * stbi_zlib_decode_noheader_malloc(const char * buffer, int len, int * outlen); -STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const char * ibuffer, int ilen); #ifdef __cplusplus } @@ -547,50 +546,52 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const cha #ifdef STB_IMAGE_IMPLEMENTATION -#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ - defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ - defined(STBI_ONLY_PNM) || defined(STBI_ONLY_ZLIB) -#ifndef STBI_ONLY_JPEG -#define STBI_NO_JPEG -#endif -#ifndef STBI_ONLY_PNG -#define STBI_NO_PNG -#endif -#ifndef STBI_ONLY_BMP -#define STBI_NO_BMP -#endif -#ifndef STBI_ONLY_PSD -#define STBI_NO_PSD -#endif -#ifndef STBI_ONLY_TGA -#define STBI_NO_TGA -#endif -#ifndef STBI_ONLY_GIF -#define STBI_NO_GIF -#endif -#ifndef STBI_ONLY_HDR -#define STBI_NO_HDR -#endif -#ifndef STBI_ONLY_PIC -#define STBI_NO_PIC -#endif -#ifndef STBI_ONLY_PNM -#define STBI_NO_PNM -#endif +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \ + || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \ + || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \ + || defined(STBI_ONLY_ZLIB) + #ifndef STBI_ONLY_JPEG + #define STBI_NO_JPEG + #endif + #ifndef STBI_ONLY_PNG + #define STBI_NO_PNG + #endif + #ifndef STBI_ONLY_BMP + #define STBI_NO_BMP + #endif + #ifndef STBI_ONLY_PSD + #define STBI_NO_PSD + #endif + #ifndef STBI_ONLY_TGA + #define STBI_NO_TGA + #endif + #ifndef STBI_ONLY_GIF + #define STBI_NO_GIF + #endif + #ifndef STBI_ONLY_HDR + #define STBI_NO_HDR + #endif + #ifndef STBI_ONLY_PIC + #define STBI_NO_PIC + #endif + #ifndef STBI_ONLY_PNM + #define STBI_NO_PNM + #endif #endif #if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) #define STBI_NO_ZLIB #endif -#include + #include #include // ptrdiff_t on osx #include #include +#include #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -#include // ldexp, pow +#include // ldexp, pow #endif #ifndef STBI_NO_STDIO @@ -608,54 +609,55 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const cha #define STBI_EXTERN extern #endif + #ifndef _MSC_VER -#ifdef __cplusplus -#define stbi_inline inline -#else -#define stbi_inline -#endif + #ifdef __cplusplus + #define stbi_inline inline + #else + #define stbi_inline + #endif #else -#define stbi_inline __forceinline + #define stbi_inline __forceinline #endif #ifndef STBI_NO_THREAD_LOCALS -#if defined(__cplusplus) && __cplusplus >= 201103L -#define STBI_THREAD_LOCAL thread_local -#elif defined(__GNUC__) && __GNUC__ < 5 -#define STBI_THREAD_LOCAL __thread -#elif defined(_MSC_VER) -#define STBI_THREAD_LOCAL __declspec(thread) -#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) -#define STBI_THREAD_LOCAL _Thread_local -#endif - -#ifndef STBI_THREAD_LOCAL -#if defined(__GNUC__) -#define STBI_THREAD_LOCAL __thread -#endif -#endif + #if defined(__cplusplus) && __cplusplus >= 201103L + #define STBI_THREAD_LOCAL thread_local + #elif defined(__GNUC__) && __GNUC__ < 5 + #define STBI_THREAD_LOCAL __thread + #elif defined(_MSC_VER) + #define STBI_THREAD_LOCAL __declspec(thread) + #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) + #define STBI_THREAD_LOCAL _Thread_local + #endif + + #ifndef STBI_THREAD_LOCAL + #if defined(__GNUC__) + #define STBI_THREAD_LOCAL __thread + #endif + #endif #endif #if defined(_MSC_VER) || defined(__SYMBIAN32__) typedef unsigned short stbi__uint16; -typedef signed short stbi__int16; -typedef unsigned int stbi__uint32; -typedef signed int stbi__int32; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; #else #include typedef uint16_t stbi__uint16; -typedef int16_t stbi__int16; +typedef int16_t stbi__int16; typedef uint32_t stbi__uint32; -typedef int32_t stbi__int32; +typedef int32_t stbi__int32; #endif // should produce compiler error if size is wrong -typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; +typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #ifdef _MSC_VER -#define STBI_NOTUSED(v) (void)(v) +#define STBI_NOTUSED(v) (void)(v) #else -#define STBI_NOTUSED(v) (void)sizeof(v) +#define STBI_NOTUSED(v) (void)sizeof(v) #endif #ifdef _MSC_VER @@ -663,9 +665,9 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #endif #ifdef STBI_HAS_LROTL -#define stbi_lrot(x, y) _lrotl(x, y) + #define stbi_lrot(x,y) _lrotl(x,y) #else -#define stbi_lrot(x, y) (((x) << (y)) | ((x) >> (-(y)&31))) + #define stbi_lrot(x,y) (((x) << (y)) | ((x) >> (-(y) & 31))) #endif #if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) @@ -677,13 +679,13 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #endif #ifndef STBI_MALLOC -#define STBI_MALLOC(sz) malloc(sz) -#define STBI_REALLOC(p, newsz) realloc(p, newsz) -#define STBI_FREE(p) free(p) +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p,newsz) realloc(p,newsz) +#define STBI_FREE(p) free(p) #endif #ifndef STBI_REALLOC_SIZED -#define STBI_REALLOC_SIZED(p, oldsz, newsz) STBI_REALLOC(p, newsz) +#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz) #endif // x86/x64 detection @@ -725,31 +727,34 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #ifdef _MSC_VER -#if _MSC_VER >= 1400 // not VC6 -#include // __cpuid -static int stbi__cpuid3(void) { - int info[4]; - __cpuid(info, 1); - return info[3]; +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) +{ + int info[4]; + __cpuid(info,1); + return info[3]; } #else -static int stbi__cpuid3(void) { - int res; - __asm { +static int stbi__cpuid3(void) +{ + int res; + __asm { mov eax,1 cpuid mov res,edx - } - return res; + } + return res; } #endif #define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) { - int info3 = stbi__cpuid3(); - return ((info3 >> 26) & 1) != 0; +static int stbi__sse2_available(void) +{ + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; } #endif @@ -757,11 +762,12 @@ static int stbi__sse2_available(void) { #define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) { - // If we're even attempting to compile this on GCC/Clang, that means - // -msse2 is on, which means the compiler is allowed to use SSE2 - // instructions at will, and so are we. - return 1; +static int stbi__sse2_available(void) +{ + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; } #endif @@ -796,162 +802,190 @@ static int stbi__sse2_available(void) { // stbi__context structure is our basic context used by all images, so it // contains all the IO context, plus some basic image information -typedef struct { - stbi__uint32 img_x, img_y; - int img_n, img_out_n; +typedef struct +{ + stbi__uint32 img_x, img_y; + int img_n, img_out_n; - stbi_io_callbacks io; - void * io_user_data; + stbi_io_callbacks io; + void *io_user_data; - int read_from_callbacks; - int buflen; - stbi_uc buffer_start[128]; - int callback_already_read; + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; - stbi_uc *img_buffer, *img_buffer_end; - stbi_uc *img_buffer_original, *img_buffer_original_end; + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; } stbi__context; -static void stbi__refill_buffer(stbi__context * s); + +static void stbi__refill_buffer(stbi__context *s); // initialize a memory-decode context -static void stbi__start_mem(stbi__context * s, stbi_uc const * buffer, int len) { - s->io.read = NULL; - s->read_from_callbacks = 0; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = (stbi_uc *)buffer; - s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *)buffer + len; +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) +{ + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len; } // initialize a callback-based context -static void stbi__start_callbacks(stbi__context * s, stbi_io_callbacks * c, void * user) { - s->io = *c; - s->io_user_data = user; - s->buflen = sizeof(s->buffer_start); - s->read_from_callbacks = 1; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = s->buffer_start; - stbi__refill_buffer(s); - s->img_buffer_original_end = s->img_buffer_end; +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) +{ + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; } #ifndef STBI_NO_STDIO -static int stbi__stdio_read(void * user, char * data, int size) { return (int)fread(data, 1, size, (FILE *)user); } +static int stbi__stdio_read(void *user, char *data, int size) +{ + return (int) fread(data,1,size,(FILE*) user); +} -static void stbi__stdio_skip(void * user, int n) { - int ch; - fseek((FILE *)user, n, SEEK_CUR); - ch = fgetc((FILE *)user); /* have to read a byte to reset feof()'s flag */ - if (ch != EOF) { - ungetc(ch, (FILE *)user); /* push byte back onto stream if valid. */ - } +static void stbi__stdio_skip(void *user, int n) +{ + int ch; + fseek((FILE*) user, n, SEEK_CUR); + ch = fgetc((FILE*) user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *) user); /* push byte back onto stream if valid. */ + } } -static int stbi__stdio_eof(void * user) { return feof((FILE *)user) || ferror((FILE *)user); } +static int stbi__stdio_eof(void *user) +{ + return feof((FILE*) user) || ferror((FILE *) user); +} -static stbi_io_callbacks stbi__stdio_callbacks = { - stbi__stdio_read, - stbi__stdio_skip, - stbi__stdio_eof, +static stbi_io_callbacks stbi__stdio_callbacks = +{ + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, }; -static void stbi__start_file(stbi__context * s, FILE * f) { stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); } +static void stbi__start_file(stbi__context *s, FILE *f) +{ + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f); +} -// static void stop_file(stbi__context *s) { } +//static void stop_file(stbi__context *s) { } #endif // !STBI_NO_STDIO -static void stbi__rewind(stbi__context * s) { - // conceptually rewind SHOULD rewind to the beginning of the stream, - // but we just rewind to the beginning of the initial buffer, because - // we only use it after doing 'test', which only ever looks at at most 92 bytes - s->img_buffer = s->img_buffer_original; - s->img_buffer_end = s->img_buffer_original_end; +static void stbi__rewind(stbi__context *s) +{ + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; } -enum { STBI_ORDER_RGB, STBI_ORDER_BGR }; +enum +{ + STBI_ORDER_RGB, + STBI_ORDER_BGR +}; -typedef struct { - int bits_per_channel; - int num_channels; - int channel_order; +typedef struct +{ + int bits_per_channel; + int num_channels; + int channel_order; } stbi__result_info; #ifndef STBI_NO_JPEG -static int stbi__jpeg_test(stbi__context * s); -static void * stbi__jpeg_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNG -static int stbi__png_test(stbi__context * s); -static void * stbi__png_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__png_info(stbi__context * s, int * x, int * y, int * comp); -static int stbi__png_is16(stbi__context * s); +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); #endif #ifndef STBI_NO_BMP -static int stbi__bmp_test(stbi__context * s); -static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__bmp_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_TGA -static int stbi__tga_test(stbi__context * s); -static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context * s); -static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc); -static int stbi__psd_info(stbi__context * s, int * x, int * y, int * comp); -static int stbi__psd_is16(stbi__context * s); +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); #endif #ifndef STBI_NO_HDR -static int stbi__hdr_test(stbi__context * s); -static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__hdr_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PIC -static int stbi__pic_test(stbi__context * s); -static void * stbi__pic_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_GIF -static int stbi__gif_test(stbi__context * s); -static void * stbi__gif_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int * y, int * z, int * comp, int req_comp); -static int stbi__gif_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context * s); -static void * stbi__pnm_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp); -static int stbi__pnm_is16(stbi__context * s); +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_is16(stbi__context *s); #endif static #ifdef STBI_THREAD_LOCAL - STBI_THREAD_LOCAL +STBI_THREAD_LOCAL #endif - const char * stbi__g_failure_reason; +const char *stbi__g_failure_reason; -STBIDEF const char * stbi_failure_reason(void) { return stbi__g_failure_reason; } +STBIDEF const char *stbi_failure_reason(void) +{ + return stbi__g_failure_reason; +} #ifndef STBI_NO_FAILURE_STRINGS -static int stbi__err(const char * str) { - stbi__g_failure_reason = str; - return 0; +static int stbi__err(const char *str) +{ + stbi__g_failure_reason = str; + return 0; } #endif -static void * stbi__malloc(size_t size) { return STBI_MALLOC(size); } +static void *stbi__malloc(size_t size) +{ + return STBI_MALLOC(size); +} // stb_image uses ints pervasively, including for offset calculations. // therefore the largest decoded image size we can support with the @@ -965,88 +999,88 @@ static void * stbi__malloc(size_t size) { return STBI_MALLOC(size); } // return 1 if the sum is valid, 0 on overflow. // negative terms are considered invalid. -static int stbi__addsizes_valid(int a, int b) { - if (b < 0) - return 0; - // now 0 <= b <= INT_MAX, hence also - // 0 <= INT_MAX - b <= INTMAX. - // And "a + b <= INT_MAX" (which might overflow) is the - // same as a <= INT_MAX - b (no overflow) - return a <= INT_MAX - b; +static int stbi__addsizes_valid(int a, int b) +{ + if (b < 0) return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; } // returns 1 if the product is valid, 0 on overflow. // negative factors are considered invalid. -static int stbi__mul2sizes_valid(int a, int b) { - if (a < 0 || b < 0) - return 0; - if (b == 0) - return 1; // mul-by-0 is always safe - // portable way to check for no overflows in a*b - return a <= INT_MAX / b; +static int stbi__mul2sizes_valid(int a, int b) +{ + if (a < 0 || b < 0) return 0; + if (b == 0) return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX/b; } #if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow -static int stbi__mad2sizes_valid(int a, int b, int add) { - return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a * b, add); +static int stbi__mad2sizes_valid(int a, int b, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add); } #endif // returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow -static int stbi__mad3sizes_valid(int a, int b, int c, int add) { - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && stbi__addsizes_valid(a * b * c, add); +static int stbi__mad3sizes_valid(int a, int b, int c, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && + stbi__addsizes_valid(a*b*c, add); } // returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) -static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) { - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && stbi__mul2sizes_valid(a * b * c, d) && - stbi__addsizes_valid(a * b * c * d, add); +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && + stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add); } #endif #if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // mallocs with size overflow checking -static void * stbi__malloc_mad2(int a, int b, int add) { - if (!stbi__mad2sizes_valid(a, b, add)) - return NULL; - return stbi__malloc(a * b + add); +static void *stbi__malloc_mad2(int a, int b, int add) +{ + if (!stbi__mad2sizes_valid(a, b, add)) return NULL; + return stbi__malloc(a*b + add); } #endif -static void * stbi__malloc_mad3(int a, int b, int c, int add) { - if (!stbi__mad3sizes_valid(a, b, c, add)) - return NULL; - return stbi__malloc(a * b * c + add); +static void *stbi__malloc_mad3(int a, int b, int c, int add) +{ + if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; + return stbi__malloc(a*b*c + add); } #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) -static void * stbi__malloc_mad4(int a, int b, int c, int d, int add) { - if (!stbi__mad4sizes_valid(a, b, c, d, add)) - return NULL; - return stbi__malloc(a * b * c * d + add); +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) +{ + if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; + return stbi__malloc(a*b*c*d + add); } #endif // returns 1 if the sum of two signed ints is valid (between -2^31 and 2^31-1 inclusive), 0 on overflow. -static int stbi__addints_valid(int a, int b) { - if ((a >= 0) != (b >= 0)) - return 1; // a and b have different signs, so no overflow - if (a < 0 && b < 0) - return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0. - return a <= INT_MAX - b; +static int stbi__addints_valid(int a, int b) +{ + if ((a >= 0) != (b >= 0)) return 1; // a and b have different signs, so no overflow + if (a < 0 && b < 0) return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0. + return a <= INT_MAX - b; } -// returns 1 if the product of two signed shorts is valid, 0 on overflow. -static int stbi__mul2shorts_valid(short a, short b) { - if (b == 0 || b == -1) - return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow - if ((a >= 0) == (b >= 0)) - return a <= SHRT_MAX / b; // product is positive, so similar to mul2sizes_valid - if (b < 0) - return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN - return a >= SHRT_MIN / b; +// returns 1 if the product of two ints fits in a signed short, 0 on overflow. +static int stbi__mul2shorts_valid(int a, int b) +{ + if (b == 0 || b == -1) return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow + if ((a >= 0) == (b >= 0)) return a <= SHRT_MAX/b; // product is positive, so similar to mul2sizes_valid + if (b < 0) return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN + return a >= SHRT_MIN / b; } // stbi__err - error @@ -1054,411 +1088,423 @@ static int stbi__mul2shorts_valid(short a, short b) { // stbi__errpuc - error returning pointer to unsigned char #ifdef STBI_NO_FAILURE_STRINGS -#define stbi__err(x, y) 0 + #define stbi__err(x,y) 0 #elif defined(STBI_FAILURE_USERMSG) -#define stbi__err(x, y) stbi__err(y) + #define stbi__err(x,y) stbi__err(y) #else -#define stbi__err(x, y) stbi__err(x) + #define stbi__err(x,y) stbi__err(x) #endif -#define stbi__errpf(x, y) ((float *)(size_t)(stbi__err(x, y) ? NULL : NULL)) -#define stbi__errpuc(x, y) ((unsigned char *)(size_t)(stbi__err(x, y) ? NULL : NULL)) +#define stbi__errpf(x,y) ((float *)(size_t) (stbi__err(x,y)?NULL:NULL)) +#define stbi__errpuc(x,y) ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL)) -STBIDEF void stbi_image_free(void * retval_from_stbi_load) { STBI_FREE(retval_from_stbi_load); } +STBIDEF void stbi_image_free(void *retval_from_stbi_load) +{ + STBI_FREE(retval_from_stbi_load); +} #ifndef STBI_NO_LINEAR -static float * stbi__ldr_to_hdr(stbi_uc * data, int x, int y, int comp); +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); #endif #ifndef STBI_NO_HDR -static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp); +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); #endif static int stbi__vertically_flip_on_load_global = 0; -STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) { - stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) +{ + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; } #ifndef STBI_THREAD_LOCAL -#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global #else static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) { - stbi__vertically_flip_on_load_local = flag_true_if_should_flip; - stbi__vertically_flip_on_load_set = 1; +STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) +{ + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; } -#define stbi__vertically_flip_on_load \ - (stbi__vertically_flip_on_load_set ? stbi__vertically_flip_on_load_local : stbi__vertically_flip_on_load_global) +#define stbi__vertically_flip_on_load (stbi__vertically_flip_on_load_set \ + ? stbi__vertically_flip_on_load_local \ + : stbi__vertically_flip_on_load_global) #endif // STBI_THREAD_LOCAL -static void * stbi__load_main(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc) { - memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields - ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed - ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order - ri->num_channels = 0; - -// test the formats with a very explicit header first (at least a FOURCC -// or distinctive magic number first) -#ifndef STBI_NO_PNG - if (stbi__png_test(s)) - return stbi__png_load(s, x, y, comp, req_comp, ri); -#endif -#ifndef STBI_NO_BMP - if (stbi__bmp_test(s)) - return stbi__bmp_load(s, x, y, comp, req_comp, ri); -#endif -#ifndef STBI_NO_GIF - if (stbi__gif_test(s)) - return stbi__gif_load(s, x, y, comp, req_comp, ri); -#endif -#ifndef STBI_NO_PSD - if (stbi__psd_test(s)) - return stbi__psd_load(s, x, y, comp, req_comp, ri, bpc); -#else - STBI_NOTUSED(bpc); -#endif -#ifndef STBI_NO_PIC - if (stbi__pic_test(s)) - return stbi__pic_load(s, x, y, comp, req_comp, ri); -#endif - -// then the formats that can end up attempting to load with just 1 or 2 -// bytes matching expectations; these are prone to false positives, so -// try them later -#ifndef STBI_NO_JPEG - if (stbi__jpeg_test(s)) - return stbi__jpeg_load(s, x, y, comp, req_comp, ri); -#endif -#ifndef STBI_NO_PNM - if (stbi__pnm_test(s)) - return stbi__pnm_load(s, x, y, comp, req_comp, ri); -#endif - -#ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - float * hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); - return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); - } -#endif - -#ifndef STBI_NO_TGA - // test tga last because it's a crappy test! - if (stbi__tga_test(s)) - return stbi__tga_load(s, x, y, comp, req_comp, ri); -#endif - - return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); -} - -static stbi_uc * stbi__convert_16_to_8(stbi__uint16 * orig, int w, int h, int channels) { - int i; - int img_len = w * h * channels; - stbi_uc * reduced; +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) +{ + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed + ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order + ri->num_channels = 0; + + // test the formats with a very explicit header first (at least a FOURCC + // or distinctive magic number first) + #ifndef STBI_NO_PNG + if (stbi__png_test(s)) return stbi__png_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) return stbi__bmp_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_GIF + if (stbi__gif_test(s)) return stbi__gif_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_PSD + if (stbi__psd_test(s)) return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc); + #else + STBI_NOTUSED(bpc); + #endif + #ifndef STBI_NO_PIC + if (stbi__pic_test(s)) return stbi__pic_load(s,x,y,comp,req_comp, ri); + #endif + + // then the formats that can end up attempting to load with just 1 or 2 + // bytes matching expectations; these are prone to false positives, so + // try them later + #ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) return stbi__pnm_load(s,x,y,comp,req_comp, ri); + #endif + + #ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } + #endif + + #ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s,x,y,comp,req_comp, ri); + #endif + + return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); +} + +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) +{ + int i; + int img_len = w * h * channels; + stbi_uc *reduced; - reduced = (stbi_uc *)stbi__malloc(img_len); - if (reduced == NULL) - return stbi__errpuc("outofmem", "Out of memory"); + reduced = (stbi_uc *) stbi__malloc(img_len); + if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); - for (i = 0; i < img_len; ++i) - reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling + for (i = 0; i < img_len; ++i) + reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling - STBI_FREE(orig); - return reduced; + STBI_FREE(orig); + return reduced; } -static stbi__uint16 * stbi__convert_8_to_16(stbi_uc * orig, int w, int h, int channels) { - int i; - int img_len = w * h * channels; - stbi__uint16 * enlarged; +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) +{ + int i; + int img_len = w * h * channels; + stbi__uint16 *enlarged; - enlarged = (stbi__uint16 *)stbi__malloc(img_len * 2); - if (enlarged == NULL) - return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + enlarged = (stbi__uint16 *) stbi__malloc(img_len*2); + if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); - for (i = 0; i < img_len; ++i) - enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff - STBI_FREE(orig); - return enlarged; + STBI_FREE(orig); + return enlarged; } -static void stbi__vertical_flip(void * image, int w, int h, int bytes_per_pixel) { - int row; - size_t bytes_per_row = (size_t)w * bytes_per_pixel; - stbi_uc temp[2048]; - stbi_uc * bytes = (stbi_uc *)image; - - for (row = 0; row < (h >> 1); row++) { - stbi_uc * row0 = bytes + row * bytes_per_row; - stbi_uc * row1 = bytes + (h - row - 1) * bytes_per_row; - // swap row0 with row1 - size_t bytes_left = bytes_per_row; - while (bytes_left) { - size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); - memcpy(temp, row0, bytes_copy); - memcpy(row0, row1, bytes_copy); - memcpy(row1, temp, bytes_copy); - row0 += bytes_copy; - row1 += bytes_copy; - bytes_left -= bytes_copy; - } - } +static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) +{ + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc *bytes = (stbi_uc *)image; + + for (row = 0; row < (h>>1); row++) { + stbi_uc *row0 = bytes + row*bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } } #ifndef STBI_NO_GIF -static void stbi__vertical_flip_slices(void * image, int w, int h, int z, int bytes_per_pixel) { - int slice; - int slice_size = w * h * bytes_per_pixel; - - stbi_uc * bytes = (stbi_uc *)image; - for (slice = 0; slice < z; ++slice) { - stbi__vertical_flip(bytes, w, h, bytes_per_pixel); - bytes += slice_size; - } +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) +{ + int slice; + int slice_size = w * h * bytes_per_pixel; + + stbi_uc *bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } } #endif -static unsigned char * stbi__load_and_postprocess_8bit(stbi__context * s, int * x, int * y, int * comp, int req_comp) { - stbi__result_info ri; - void * result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 8) { - result = stbi__convert_16_to_8((stbi__uint16 *)result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 8; - } + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } - // @TODO: move stbi__convert_format to here + // @TODO: move stbi__convert_format to here - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } - return (unsigned char *)result; + return (unsigned char *) result; } -static stbi__uint16 * stbi__load_and_postprocess_16bit(stbi__context * s, int * x, int * y, int * comp, int req_comp) { - stbi__result_info ri; - void * result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 16) { - result = stbi__convert_8_to_16((stbi_uc *)result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 16; - } + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } - // @TODO: move stbi__convert_format16 to here - // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } - return (stbi__uint16 *)result; + return (stbi__uint16 *) result; } #if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) -static void stbi__float_postprocess(float * result, int * x, int * y, int * comp, int req_comp) { - if (stbi__vertically_flip_on_load && result != NULL) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); - } +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) +{ + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } } #endif #ifndef STBI_NO_STDIO #if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) -STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char * str, - int cbmb, wchar_t * widestr, int cchwide); -STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, - const wchar_t * widestr, int cchwide, char * str, int cbmb, - const char * defchar, int * used_default); +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); #endif #if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) -STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wchar_t * input) { - return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int)bufferlen, NULL, NULL); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) +{ + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); } #endif -static FILE * stbi__fopen(char const * filename, char const * mode) { - FILE * f; +static FILE *stbi__fopen(char const *filename, char const *mode) +{ + FILE *f; #if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) - wchar_t wMode[64]; - wchar_t wFilename[1024]; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename) / sizeof(*wFilename))) - return 0; + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename))) + return 0; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode) / sizeof(*wMode))) - return 0; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode))) + return 0; #if defined(_MSC_VER) && _MSC_VER >= 1400 - if (0 != _wfopen_s(&f, wFilename, wMode)) - f = 0; + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; #else - f = _wfopen(wFilename, wMode); + f = _wfopen(wFilename, wMode); #endif #elif defined(_MSC_VER) && _MSC_VER >= 1400 - if (0 != fopen_s(&f, filename, mode)) - f = 0; + if (0 != fopen_s(&f, filename, mode)) + f=0; #else - f = fopen(filename, mode); + f = fopen(filename, mode); #endif - return f; + return f; } -STBIDEF stbi_uc * stbi_load(char const * filename, int * x, int * y, int * comp, int req_comp) { - FILE * f = stbi__fopen(filename, "rb"); - unsigned char * result; - if (!f) - return stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file(f, x, y, comp, req_comp); - fclose(f); - return result; -} -STBIDEF stbi_uc * stbi_load_from_file(FILE * f, int * x, int * y, int * comp, int req_comp) { - unsigned char * result; - stbi__context s; - stbi__start_file(&s, f); - result = stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; + if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f,x,y,comp,req_comp); + fclose(f); + return result; } -STBIDEF stbi__uint16 * stbi_load_from_file_16(FILE * f, int * x, int * y, int * comp, int req_comp) { - stbi__uint16 * result; - stbi__context s; - stbi__start_file(&s, f); - result = stbi__load_and_postprocess_16bit(&s, x, y, comp, req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + unsigned char *result; + stbi__context s; + stbi__start_file(&s,f); + result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + stbi__uint16 *result; + stbi__context s; + stbi__start_file(&s,f); + result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; + if (!f) return (stbi_us *) stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f,x,y,comp,req_comp); + fclose(f); + return result; } -STBIDEF stbi_us * stbi_load_16(char const * filename, int * x, int * y, int * comp, int req_comp) { - FILE * f = stbi__fopen(filename, "rb"); - stbi__uint16 * result; - if (!f) - return (stbi_us *)stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file_16(f, x, y, comp, req_comp); - fclose(f); - return result; -} -#endif //! STBI_NO_STDIO +#endif //!STBI_NO_STDIO -STBIDEF stbi_us * stbi_load_16_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, desired_channels); +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); } -STBIDEF stbi_us * stbi_load_16_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, - int * channels_in_file, int desired_channels) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); } -STBIDEF stbi_uc * stbi_load_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp, int req_comp) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); } -STBIDEF stbi_uc * stbi_load_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp, - int req_comp) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); } #ifndef STBI_NO_GIF -STBIDEF stbi_uc * stbi_load_gif_from_memory(stbi_uc const * buffer, int len, int ** delays, int * x, int * y, int * z, - int * comp, int req_comp) { - unsigned char * result; - stbi__context s; - stbi__start_mem(&s, buffer, len); - - result = (unsigned char *)stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); - if (stbi__vertically_flip_on_load) { - stbi__vertical_flip_slices(result, *x, *y, *z, *comp); - } +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp) +{ + unsigned char *result; + stbi__context s; + stbi__start_mem(&s,buffer,len); - return result; + result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices( result, *x, *y, *z, *comp ); + } + + return result; } #endif #ifndef STBI_NO_LINEAR -static float * stbi__loadf_main(stbi__context * s, int * x, int * y, int * comp, int req_comp) { - unsigned char * data; -#ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - stbi__result_info ri; - float * hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); - if (hdr_data) - stbi__float_postprocess(hdr_data, x, y, comp, req_comp); - return hdr_data; - } -#endif - data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); - if (data) - return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); - return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); -} - -STBIDEF float * stbi_loadf_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp, int req_comp) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__loadf_main(&s, x, y, comp, req_comp); +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + unsigned char *data; + #ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data,x,y,comp,req_comp); + return hdr_data; + } + #endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); +} + +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__loadf_main(&s,x,y,comp,req_comp); } -STBIDEF float * stbi_loadf_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp, - int req_comp) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__loadf_main(&s, x, y, comp, req_comp); +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__loadf_main(&s,x,y,comp,req_comp); } #ifndef STBI_NO_STDIO -STBIDEF float * stbi_loadf(char const * filename, int * x, int * y, int * comp, int req_comp) { - float * result; - FILE * f = stbi__fopen(filename, "rb"); - if (!f) - return stbi__errpf("can't fopen", "Unable to open file"); - result = stbi_loadf_from_file(f, x, y, comp, req_comp); - fclose(f); - return result; +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + float *result; + FILE *f = stbi__fopen(filename, "rb"); + if (!f) return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f,x,y,comp,req_comp); + fclose(f); + return result; } -STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * comp, int req_comp) { - stbi__context s; - stbi__start_file(&s, f); - return stbi__loadf_main(&s, x, y, comp, req_comp); +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_file(&s,f); + return stbi__loadf_main(&s,x,y,comp,req_comp); } #endif // !STBI_NO_STDIO @@ -1468,208 +1514,222 @@ STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * comp, int // defined, for API simplicity; if STBI_NO_LINEAR is defined, it always // reports false! -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const * buffer, int len) { -#ifndef STBI_NO_HDR - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__hdr_test(&s); -#else - STBI_NOTUSED(buffer); - STBI_NOTUSED(len); - return 0; -#endif +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) +{ + #ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__hdr_test(&s); + #else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; + #endif } #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr(char const * filename) { - FILE * f = stbi__fopen(filename, "rb"); - int result = 0; - if (f) { - result = stbi_is_hdr_from_file(f); - fclose(f); - } - return result; +STBIDEF int stbi_is_hdr (char const *filename) +{ + FILE *f = stbi__fopen(filename, "rb"); + int result=0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; } -STBIDEF int stbi_is_hdr_from_file(FILE * f) { -#ifndef STBI_NO_HDR - long pos = ftell(f); - int res; - stbi__context s; - stbi__start_file(&s, f); - res = stbi__hdr_test(&s); - fseek(f, pos, SEEK_SET); - return res; -#else - STBI_NOTUSED(f); - return 0; -#endif +STBIDEF int stbi_is_hdr_from_file(FILE *f) +{ + #ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s,f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; + #else + STBI_NOTUSED(f); + return 0; + #endif } #endif // !STBI_NO_STDIO -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const * clbk, void * user) { -#ifndef STBI_NO_HDR - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__hdr_test(&s); -#else - STBI_NOTUSED(clbk); - STBI_NOTUSED(user); - return 0; -#endif +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) +{ + #ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__hdr_test(&s); + #else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; + #endif } #ifndef STBI_NO_LINEAR -static float stbi__l2h_gamma = 2.2f, stbi__l2h_scale = 1.0f; +static float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f; -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } -STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } #endif -static float stbi__h2l_gamma_i = 1.0f / 2.2f, stbi__h2l_scale_i = 1.0f; +static float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f; + +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; } +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; } -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1 / gamma; } -STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1 / scale; } ////////////////////////////////////////////////////////////////////////////// // // Common code used by all image loaders // -enum { STBI__SCAN_load = 0, STBI__SCAN_type, STBI__SCAN_header }; - -static void stbi__refill_buffer(stbi__context * s) { - int n = (s->io.read)(s->io_user_data, (char *)s->buffer_start, s->buflen); - s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); - if (n == 0) { - // at end of file, treat same as if from memory, but need to handle case - // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file - s->read_from_callbacks = 0; - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start + 1; - *s->img_buffer = 0; - } else { - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start + n; - } -} +enum +{ + STBI__SCAN_load=0, + STBI__SCAN_type, + STBI__SCAN_header +}; -stbi_inline static stbi_uc stbi__get8(stbi__context * s) { - if (s->img_buffer < s->img_buffer_end) - return *s->img_buffer++; - if (s->read_from_callbacks) { - stbi__refill_buffer(s); - return *s->img_buffer++; - } - return 0; +static void stbi__refill_buffer(stbi__context *s) +{ + int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen); + s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start+1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } +} + +stbi_inline static stbi_uc stbi__get8(stbi__context *s) +{ + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; } #if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else -stbi_inline static int stbi__at_eof(stbi__context * s) { - if (s->io.read) { - if (!(s->io.eof)(s->io_user_data)) - return 0; - // if feof() is true, check if buffer = end - // special case: we've only got the special 0 character at the end - if (s->read_from_callbacks == 0) - return 1; - } +stbi_inline static int stbi__at_eof(stbi__context *s) +{ + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) return 1; + } - return s->img_buffer >= s->img_buffer_end; + return s->img_buffer >= s->img_buffer_end; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ - defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) // nothing #else -static void stbi__skip(stbi__context * s, int n) { - if (n == 0) - return; // already there! - if (n < 0) { - s->img_buffer = s->img_buffer_end; - return; - } - if (s->io.read) { - int blen = (int)(s->img_buffer_end - s->img_buffer); - if (blen < n) { - s->img_buffer = s->img_buffer_end; - (s->io.skip)(s->io_user_data, n - blen); - return; - } - } - s->img_buffer += n; +static void stbi__skip(stbi__context *s, int n) +{ + if (n == 0) return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int) (s->img_buffer_end - s->img_buffer); + if (blen < n) { + s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); + return; + } + } + s->img_buffer += n; } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) // nothing #else -static int stbi__getn(stbi__context * s, stbi_uc * buffer, int n) { - if (s->io.read) { - int blen = (int)(s->img_buffer_end - s->img_buffer); - if (blen < n) { - int res, count; - - memcpy(buffer, s->img_buffer, blen); - - count = (s->io.read)(s->io_user_data, (char *)buffer + blen, n - blen); - res = (count == (n - blen)); - s->img_buffer = s->img_buffer_end; - return res; - } - } +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) +{ + if (s->io.read) { + int blen = (int) (s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; + + memcpy(buffer, s->img_buffer, blen); - if (s->img_buffer + n <= s->img_buffer_end) { - memcpy(buffer, s->img_buffer, n); - s->img_buffer += n; - return 1; - } else - return 0; + count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen); + res = (count == (n-blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } + + if (s->img_buffer+n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; } #endif #if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static int stbi__get16be(stbi__context * s) { - int z = stbi__get8(s); - return (z << 8) + stbi__get8(s); +static int stbi__get16be(stbi__context *s) +{ + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static stbi__uint32 stbi__get32be(stbi__context * s) { - stbi__uint32 z = stbi__get16be(s); - return (z << 16) + stbi__get16be(s); +static stbi__uint32 stbi__get32be(stbi__context *s) +{ + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); } #endif #if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) // nothing #else -static int stbi__get16le(stbi__context * s) { - int z = stbi__get8(s); - return z + (stbi__get8(s) << 8); +static int stbi__get16le(stbi__context *s) +{ + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); } #endif #ifndef STBI_NO_BMP -static stbi__uint32 stbi__get32le(stbi__context * s) { - stbi__uint32 z = stbi__get16le(s); - z += (stbi__uint32)stbi__get16le(s) << 16; - return z; +static stbi__uint32 stbi__get32le(stbi__context *s) +{ + stbi__uint32 z = stbi__get16le(s); + z += (stbi__uint32)stbi__get16le(s) << 16; + return z; } #endif -#define STBI__BYTECAST(x) ((stbi_uc)((x)&255)) // truncate int to byte without warnings +#define STBI__BYTECAST(x) ((stbi_uc) ((x) & 255)) // truncate int to byte without warnings -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ - defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else ////////////////////////////////////////////////////////////////////////////// @@ -1683,264 +1743,169 @@ static stbi__uint32 stbi__get32le(stbi__context * s) { // assume data buffer is malloced, so malloc a new one and free that one // only failure mode is malloc failing -static stbi_uc stbi__compute_y(int r, int g, int b) { return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); } +static stbi_uc stbi__compute_y(int r, int g, int b) +{ + return (stbi_uc) (((r*77) + (g*150) + (29*b)) >> 8); +} #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ - defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else -static unsigned char * stbi__convert_format(unsigned char * data, int img_n, int req_comp, unsigned int x, unsigned int y) { - int i, j; - unsigned char * good; - - if (req_comp == img_n) - return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (unsigned char *)stbi__malloc_mad3(req_comp, x, y, 0); - if (good == NULL) { - STBI_FREE(data); - return stbi__errpuc("outofmem", "Out of memory"); - } - - for (j = 0; j < (int)y; ++j) { - unsigned char * src = data + j * x * img_n; - unsigned char * dest = good + j * x * req_comp; - -#define STBI__COMBO(a, b) ((a)*8 + (b)) -#define STBI__CASE(a, b) \ - case STBI__COMBO(a, b): \ - for (i = x - 1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1, 2) { - dest[0] = src[0]; - dest[1] = 255; - } - break; - STBI__CASE(1, 3) { dest[0] = dest[1] = dest[2] = src[0]; } - break; - STBI__CASE(1, 4) { - dest[0] = dest[1] = dest[2] = src[0]; - dest[3] = 255; - } - break; - STBI__CASE(2, 1) { dest[0] = src[0]; } - break; - STBI__CASE(2, 3) { dest[0] = dest[1] = dest[2] = src[0]; } - break; - STBI__CASE(2, 4) { - dest[0] = dest[1] = dest[2] = src[0]; - dest[3] = src[1]; - } - break; - STBI__CASE(3, 4) { - dest[0] = src[0]; - dest[1] = src[1]; - dest[2] = src[2]; - dest[3] = 255; - } - break; - STBI__CASE(3, 1) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); } - break; - STBI__CASE(3, 2) { - dest[0] = stbi__compute_y(src[0], src[1], src[2]); - dest[1] = 255; - } - break; - STBI__CASE(4, 1) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); } - break; - STBI__CASE(4, 2) { - dest[0] = stbi__compute_y(src[0], src[1], src[2]); - dest[1] = src[3]; - } - break; - STBI__CASE(4, 3) { - dest[0] = src[0]; - dest[1] = src[1]; - dest[2] = src[2]; - } - break; - default: - STBI_ASSERT(0); - STBI_FREE(data); - STBI_FREE(good); - return stbi__errpuc("unsupported", "Unsupported format conversion"); - } -#undef STBI__CASE - } - - STBI_FREE(data); - return good; +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) +{ + int i,j; + unsigned char *good; + + if (req_comp == img_n) return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + + for (j=0; j < (int) y; ++j) { + unsigned char *src = data + j * x * img_n ; + unsigned char *dest = good + j * x * req_comp; + + #define STBI__COMBO(a,b) ((a)*8+(b)) + #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255; } break; + STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255; } break; + STBI__CASE(2,1) { dest[0]=src[0]; } break; + STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; + STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255; } break; + STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; + STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255; } break; + STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; + STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break; + STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; + default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc("unsupported", "Unsupported format conversion"); + } + #undef STBI__CASE + } + + STBI_FREE(data); + return good; } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); } +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) +{ + return (stbi__uint16) (((r*77) + (g*150) + (29*b)) >> 8); +} #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 * stbi__convert_format16(stbi__uint16 * data, int img_n, int req_comp, unsigned int x, unsigned int y) { - int i, j; - stbi__uint16 * good; - - if (req_comp == img_n) - return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (stbi__uint16 *)stbi__malloc(req_comp * x * y * 2); - if (good == NULL) { - STBI_FREE(data); - return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); - } - - for (j = 0; j < (int)y; ++j) { - stbi__uint16 * src = data + j * x * img_n; - stbi__uint16 * dest = good + j * x * req_comp; - -#define STBI__COMBO(a, b) ((a)*8 + (b)) -#define STBI__CASE(a, b) \ - case STBI__COMBO(a, b): \ - for (i = x - 1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1, 2) { - dest[0] = src[0]; - dest[1] = 0xffff; - } - break; - STBI__CASE(1, 3) { dest[0] = dest[1] = dest[2] = src[0]; } - break; - STBI__CASE(1, 4) { - dest[0] = dest[1] = dest[2] = src[0]; - dest[3] = 0xffff; - } - break; - STBI__CASE(2, 1) { dest[0] = src[0]; } - break; - STBI__CASE(2, 3) { dest[0] = dest[1] = dest[2] = src[0]; } - break; - STBI__CASE(2, 4) { - dest[0] = dest[1] = dest[2] = src[0]; - dest[3] = src[1]; - } - break; - STBI__CASE(3, 4) { - dest[0] = src[0]; - dest[1] = src[1]; - dest[2] = src[2]; - dest[3] = 0xffff; - } - break; - STBI__CASE(3, 1) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } - break; - STBI__CASE(3, 2) { - dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); - dest[1] = 0xffff; - } - break; - STBI__CASE(4, 1) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } - break; - STBI__CASE(4, 2) { - dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); - dest[1] = src[3]; - } - break; - STBI__CASE(4, 3) { - dest[0] = src[0]; - dest[1] = src[1]; - dest[2] = src[2]; - } - break; - default: - STBI_ASSERT(0); - STBI_FREE(data); - STBI_FREE(good); - return (stbi__uint16 *)stbi__errpuc("unsupported", "Unsupported format conversion"); - } -#undef STBI__CASE - } - - STBI_FREE(data); - return good; +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) +{ + int i,j; + stbi__uint16 *good; + + if (req_comp == img_n) return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); + } + + for (j=0; j < (int) y; ++j) { + stbi__uint16 *src = data + j * x * img_n ; + stbi__uint16 *dest = good + j * x * req_comp; + + #define STBI__COMBO(a,b) ((a)*8+(b)) + #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff; } break; + STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff; } break; + STBI__CASE(2,1) { dest[0]=src[0]; } break; + STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; + STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff; } break; + STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; + STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break; + STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; + STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break; + STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; + default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc("unsupported", "Unsupported format conversion"); + } + #undef STBI__CASE + } + + STBI_FREE(data); + return good; } #endif #ifndef STBI_NO_LINEAR -static float * stbi__ldr_to_hdr(stbi_uc * data, int x, int y, int comp) { - int i, k, n; - float * output; - if (!data) - return NULL; - output = (float *)stbi__malloc_mad4(x, y, comp, sizeof(float), 0); - if (output == NULL) { - STBI_FREE(data); - return stbi__errpf("outofmem", "Out of memory"); - } - // compute number of non-alpha components - if (comp & 1) - n = comp; - else - n = comp - 1; - for (i = 0; i < x * y; ++i) { - for (k = 0; k < n; ++k) { - output[i * comp + k] = (float)(pow(data[i * comp + k] / 255.0f, stbi__l2h_gamma) * stbi__l2h_scale); - } - } - if (n < comp) { - for (i = 0; i < x * y; ++i) { - output[i * comp + n] = data[i * comp + n] / 255.0f; - } - } - STBI_FREE(data); - return output; +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) +{ + int i,k,n; + float *output; + if (!data) return NULL; + output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { STBI_FREE(data); return stbi__errpf("outofmem", "Out of memory"); } + // compute number of non-alpha components + if (comp & 1) n = comp; else n = comp-1; + for (i=0; i < x*y; ++i) { + for (k=0; k < n; ++k) { + output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale); + } + } + if (n < comp) { + for (i=0; i < x*y; ++i) { + output[i*comp + n] = data[i*comp + n]/255.0f; + } + } + STBI_FREE(data); + return output; } #endif #ifndef STBI_NO_HDR -#define stbi__float2int(x) ((int)(x)) -static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp) { - int i, k, n; - stbi_uc * output; - if (!data) - return NULL; - output = (stbi_uc *)stbi__malloc_mad3(x, y, comp, 0); - if (output == NULL) { - STBI_FREE(data); - return stbi__errpuc("outofmem", "Out of memory"); - } - // compute number of non-alpha components - if (comp & 1) - n = comp; - else - n = comp - 1; - for (i = 0; i < x * y; ++i) { - for (k = 0; k < n; ++k) { - float z = (float)pow(data[i * comp + k] * stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; - if (z < 0) - z = 0; - if (z > 255) - z = 255; - output[i * comp + k] = (stbi_uc)stbi__float2int(z); - } - if (k < comp) { - float z = data[i * comp + k] * 255 + 0.5f; - if (z < 0) - z = 0; - if (z > 255) - z = 255; - output[i * comp + k] = (stbi_uc)stbi__float2int(z); - } - } - STBI_FREE(data); - return output; +#define stbi__float2int(x) ((int) (x)) +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) +{ + int i,k,n; + stbi_uc *output; + if (!data) return NULL; + output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { STBI_FREE(data); return stbi__errpuc("outofmem", "Out of memory"); } + // compute number of non-alpha components + if (comp & 1) n = comp; else n = comp-1; + for (i=0; i < x*y; ++i) { + for (k=0; k < n; ++k) { + float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; + if (z < 0) z = 0; + if (z > 255) z = 255; + output[i*comp + k] = (stbi_uc) stbi__float2int(z); + } + if (k < comp) { + float z = data[i*comp+k] * 255 + 0.5f; + if (z < 0) z = 0; + if (z > 255) z = 255; + output[i*comp + k] = (stbi_uc) stbi__float2int(z); + } + } + STBI_FREE(data); + return output; } #endif @@ -1968,783 +1933,763 @@ static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp) { #ifndef STBI_NO_JPEG // huffman decoding acceleration -#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache - -typedef struct { - stbi_uc fast[1 << FAST_BITS]; - // weirdly, repacking this into AoS is a 10% speed loss, instead of a win - stbi__uint16 code[256]; - stbi_uc values[256]; - stbi_uc size[257]; - unsigned int maxcode[18]; - int delta[17]; // old 'firstsymbol' - old 'firstcode' +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + +typedef struct +{ + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' } stbi__huffman; -typedef struct { - stbi__context * s; - stbi__huffman huff_dc[4]; - stbi__huffman huff_ac[4]; - stbi__uint16 dequant[4][64]; - stbi__int16 fast_ac[4][1 << FAST_BITS]; - - // sizes for components, interleaved MCUs - int img_h_max, img_v_max; - int img_mcu_x, img_mcu_y; - int img_mcu_w, img_mcu_h; - - // definition of jpeg image component - struct { - int id; - int h, v; - int tq; - int hd, ha; - int dc_pred; - - int x, y, w2, h2; - stbi_uc * data; - void *raw_data, *raw_coeff; - stbi_uc * linebuf; - short * coeff; // progressive only - int coeff_w, coeff_h; // number of 8x8 coefficient blocks - } img_comp[4]; - - stbi__uint32 code_buffer; // jpeg entropy-coded buffer - int code_bits; // number of valid bits - unsigned char marker; // marker seen while filling entropy buffer - int nomore; // flag if we saw a marker so must stop - - int progressive; - int spec_start; - int spec_end; - int succ_high; - int succ_low; - int eob_run; - int jfif; - int app14_color_transform; // Adobe APP14 tag - int rgb; - - int scan_n, order[4]; - int restart_interval, todo; - - // kernels - void (*idct_block_kernel)(stbi_uc * out, int out_stride, short data[64]); - void (*YCbCr_to_RGB_kernel)(stbi_uc * out, const stbi_uc * y, const stbi_uc * pcb, const stbi_uc * pcr, int count, - int step); - stbi_uc * (*resample_row_hv_2_kernel)(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs); +typedef struct +{ + stbi__context *s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; + +// sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + +// definition of jpeg image component + struct + { + int id; + int h,v; + int tq; + int hd,ha; + int dc_pred; + + int x,y,w2,h2; + stbi_uc *data; + void *raw_data, *raw_coeff; + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + +// kernels + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); } stbi__jpeg; -static int stbi__build_huffman(stbi__huffman * h, int * count) { - int i, j, k = 0; - unsigned int code; - // build size list for each symbol (from JPEG spec) - for (i = 0; i < 16; ++i) { - for (j = 0; j < count[i]; ++j) { - h->size[k++] = (stbi_uc)(i + 1); - if (k >= 257) - return stbi__err("bad size list", "Corrupt JPEG"); - } - } - h->size[k] = 0; - - // compute actual symbols (from jpeg spec) - code = 0; - k = 0; - for (j = 1; j <= 16; ++j) { - // compute delta to add to code to compute symbol id - h->delta[j] = k - code; - if (h->size[k] == j) { - while (h->size[k] == j) - h->code[k++] = (stbi__uint16)(code++); - if (code - 1 >= (1u << j)) - return stbi__err("bad code lengths", "Corrupt JPEG"); - } - // compute largest code + 1 for this size, preshifted as needed later - h->maxcode[j] = code << (16 - j); - code <<= 1; - } - h->maxcode[j] = 0xffffffff; - - // build non-spec acceleration table; 255 is flag for not-accelerated - memset(h->fast, 255, 1 << FAST_BITS); - for (i = 0; i < k; ++i) { - int s = h->size[i]; - if (s <= FAST_BITS) { - int c = h->code[i] << (FAST_BITS - s); - int m = 1 << (FAST_BITS - s); - for (j = 0; j < m; ++j) { - h->fast[c + j] = (stbi_uc)i; - } - } - } - return 1; +static int stbi__build_huffman(stbi__huffman *h, int *count) +{ + int i,j,k=0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i=0; i < 16; ++i) { + for (j=0; j < count[i]; ++j) { + h->size[k++] = (stbi_uc) (i+1); + if(k >= 257) return stbi__err("bad size list","Corrupt JPEG"); + } + } + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for(j=1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16) (code++); + if (code-1 >= (1u << j)) return stbi__err("bad code lengths","Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16-j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i=0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS-s); + int m = 1 << (FAST_BITS-s); + for (j=0; j < m; ++j) { + h->fast[c+j] = (stbi_uc) i; + } + } + } + return 1; } // build a table that decodes both magnitude and value of small ACs in // one go. -static void stbi__build_fast_ac(stbi__int16 * fast_ac, stbi__huffman * h) { - int i; - for (i = 0; i < (1 << FAST_BITS); ++i) { - stbi_uc fast = h->fast[i]; - fast_ac[i] = 0; - if (fast < 255) { - int rs = h->values[fast]; - int run = (rs >> 4) & 15; - int magbits = rs & 15; - int len = h->size[fast]; - - if (magbits && len + magbits <= FAST_BITS) { - // magnitude code followed by receive_extend code - int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); - int m = 1 << (magbits - 1); - if (k < m) - k += (~0U << magbits) + 1; - // if the result is small enough, we can fit it in fast_ac table - if (k >= -128 && k <= 127) - fast_ac[i] = (stbi__int16)((k * 256) + (run * 16) + (len + magbits)); - } - } - } -} - -static void stbi__grow_buffer_unsafe(stbi__jpeg * j) { - do { - unsigned int b = j->nomore ? 0 : stbi__get8(j->s); - if (b == 0xff) { - int c = stbi__get8(j->s); - while (c == 0xff) - c = stbi__get8(j->s); // consume fill bytes - if (c != 0) { - j->marker = (unsigned char)c; - j->nomore = 1; - return; - } - } - j->code_buffer |= b << (24 - j->code_bits); - j->code_bits += 8; - } while (j->code_bits <= 24); +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) +{ + int i; + for (i=0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits)); + } + } + } +} + +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) +{ + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char) c; + j->nomore = 1; + return; + } + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); } // (1 << n) - 1 -static const stbi__uint32 stbi__bmask[17] = {0, 1, 3, 7, 15, 31, 63, 127, 255, - 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535}; +static const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535}; // decode a jpeg huffman value from the bitstream -stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg * j, stbi__huffman * h) { - unsigned int temp; - int c, k; - - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); - - // look at the top FAST_BITS and determine what symbol ID it is, - // if the code is <= FAST_BITS - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); - k = h->fast[c]; - if (k < 255) { - int s = h->size[k]; - if (s > j->code_bits) - return -1; - j->code_buffer <<= s; - j->code_bits -= s; - return h->values[k]; - } - - // naive test is to shift the code_buffer down so k bits are - // valid, then test against maxcode. To speed this up, we've - // preshifted maxcode left so that it has (16-k) 0s at the - // end; in other words, regardless of the number of bits, it - // wants to be compared against something shifted to have 16; - // that way we don't need to shift inside the loop. - temp = j->code_buffer >> 16; - for (k = FAST_BITS + 1;; ++k) - if (temp < h->maxcode[k]) - break; - if (k == 17) { - // error! code not found - j->code_bits -= 16; - return -1; - } - - if (k > j->code_bits) - return -1; - - // convert the huffman code to the symbol id - c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; - if (c < 0 || c >= 256) // symbol id out of bounds! - return -1; - STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); - - // convert the id to a symbol - j->code_bits -= k; - j->code_buffer <<= k; - return h->values[c]; +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) +{ + unsigned int temp; + int c,k; + + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) + return -1; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k=FAST_BITS+1 ; ; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + if(c < 0 || c >= 256) // symbol id out of bounds! + return -1; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; } // bias[n] = (-1<code_bits < n) - stbi__grow_buffer_unsafe(j); - if (j->code_bits < n) - return 0; // ran out of bits from stream, return 0s intead of continuing - - sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative) - k = stbi_lrot(j->code_buffer, n); - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k + (stbi__jbias[n] & (sgn - 1)); +stbi_inline static int stbi__extend_receive(stbi__jpeg *j, int n) +{ + unsigned int k; + int sgn; + if (j->code_bits < n) stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing + + sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative) + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & (sgn - 1)); } // get some unsigned bits -stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg * j, int n) { - unsigned int k; - if (j->code_bits < n) - stbi__grow_buffer_unsafe(j); - if (j->code_bits < n) - return 0; // ran out of bits from stream, return 0s intead of continuing - k = stbi_lrot(j->code_buffer, n); - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k; -} - -stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg * j) { - unsigned int k; - if (j->code_bits < 1) - stbi__grow_buffer_unsafe(j); - if (j->code_bits < 1) - return 0; // ran out of bits from stream, return 0s intead of continuing - k = j->code_buffer; - j->code_buffer <<= 1; - --j->code_bits; - return k & 0x80000000; +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) +{ + unsigned int k; + if (j->code_bits < n) stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; +} + +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) +{ + unsigned int k; + if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); + if (j->code_bits < 1) return 0; // ran out of bits from stream, return 0s intead of continuing + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; } // given a value that's at position X in the zigzag stream, // where does it appear in the 8x8 matrix coded as row-major? -static const stbi_uc stbi__jpeg_dezigzag[64 + 15] = { - 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28, 35, - 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51, 58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63, - // let corrupt input sample past end - 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63}; +static const stbi_uc stbi__jpeg_dezigzag[64+15] = +{ + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, + 63, 63, 63, 63, 63, 63, 63 +}; // decode one 64-entry block-- -static int stbi__jpeg_decode_block(stbi__jpeg * j, short data[64], stbi__huffman * hdc, stbi__huffman * hac, stbi__int16 * fac, - int b, stbi__uint16 * dequant) { - int diff, dc, k; - int t; - - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); - t = stbi__jpeg_huff_decode(j, hdc); - if (t < 0 || t > 15) - return stbi__err("bad huffman code", "Corrupt JPEG"); - - // 0 all the ac values now so we can do it 32-bits at a time - memset(data, 0, 64 * sizeof(data[0])); - - diff = t ? stbi__extend_receive(j, t) : 0; - if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) - return stbi__err("bad delta", "Corrupt JPEG"); - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - if (!stbi__mul2shorts_valid(dc, dequant[0])) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - data[0] = (short)(dc * dequant[0]); - - // decode AC components, see JPEG spec - k = 1; - do { - unsigned int zig; - int c, r, s; - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); - r = fac[c]; - if (r) { // fast-AC path +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant) +{ + int diff,dc,k; + int t; + + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) return stbi__err("bad huffman code","Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data,0,64*sizeof(data[0])); + + diff = t ? stbi__extend_receive(j, t) : 0; + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err("bad delta","Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, dequant[0])) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short) (dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c,r,s; + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) ((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) break; // end block + k += 16; + } else { + k += r; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]); + } + } + } while (k < 64); + return 1; +} + +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) +{ + int diff,dc; + int t; + if (j->spec_end != 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data,0,64*sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; + + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err("bad delta", "Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, 1 << j->succ_low)) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short) (dc * (1 << j->succ_low)); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short) (1 << j->succ_low); + } + return 1; +} + +// @OPTIMIZE: store non-zigzagged during the decode passes, +// and only de-zigzag when dequantizing +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) +{ + int k; + if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } + + k = j->spec_start; + do { + unsigned int zig; + int c,r,s; + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + r = fac[c]; + if (r) { // fast-AC path k += (r >> 4) & 15; // run - s = r & 15; // combined length - if (s > j->code_bits) - return stbi__err("bad huffman code", "Combined length longer than code bits available"); + s = r & 15; // combined length + if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); j->code_buffer <<= s; j->code_bits -= s; - // decode into unzigzag'd location zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short)((r >> 8) * dequant[zig]); - } else { + data[zig] = (short) ((r >> 8) * (1 << shift)); + } else { int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) - return stbi__err("bad huffman code", "Corrupt JPEG"); + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); s = rs & 15; r = rs >> 4; if (s == 0) { - if (rs != 0xf0) - break; // end block - k += 16; + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; } else { - k += r; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short)(stbi__extend_receive(j, s) * dequant[zig]); + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) (stbi__extend_receive(j,s) * (1 << shift)); } - } - } while (k < 64); - return 1; -} - -static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg * j, short data[64], stbi__huffman * hdc, int b) { - int diff, dc; - int t; - if (j->spec_end != 0) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); - - if (j->succ_high == 0) { - // first scan for DC coefficient, must be first - memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now - t = stbi__jpeg_huff_decode(j, hdc); - if (t < 0 || t > 15) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - diff = t ? stbi__extend_receive(j, t) : 0; - - if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) - return stbi__err("bad delta", "Corrupt JPEG"); - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - if (!stbi__mul2shorts_valid(dc, 1 << j->succ_low)) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - data[0] = (short)(dc * (1 << j->succ_low)); - } else { - // refinement scan for DC coefficient - if (stbi__jpeg_get_bit(j)) - data[0] += (short)(1 << j->succ_low); - } - return 1; -} - -// @OPTIMIZE: store non-zigzagged during the decode passes, -// and only de-zigzag when dequantizing -static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg * j, short data[64], stbi__huffman * hac, stbi__int16 * fac) { - int k; - if (j->spec_start == 0) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->succ_high == 0) { - int shift = j->succ_low; - - if (j->eob_run) { - --j->eob_run; - return 1; - } - - k = j->spec_start; - do { - unsigned int zig; - int c, r, s; - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - if (s > j->code_bits) - return stbi__err("bad huffman code", "Combined length longer than code bits available"); - j->code_buffer <<= s; - j->code_bits -= s; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short)((r >> 8) * (1 << shift)); + } + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients + + short bit = (short) (1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short *p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit)==0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { + k = j->spec_start; + do { + int r,s; + int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) - return stbi__err("bad huffman code", "Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r); - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - --j->eob_run; - break; - } - k += 16; - } else { - k += r; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short)(stbi__extend_receive(j, s) * (1 << shift)); - } + if (s != 1) return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; } - } while (k <= j->spec_end); - } else { - // refinement scan for these AC coefficients - - short bit = (short)(1 << j->succ_low); - - if (j->eob_run) { - --j->eob_run; - for (k = j->spec_start; k <= j->spec_end; ++k) { - short * p = &data[stbi__jpeg_dezigzag[k]]; - if (*p != 0) - if (stbi__jpeg_get_bit(j)) - if ((*p & bit) == 0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } + + // advance by r + while (k <= j->spec_end) { + short *p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit)==0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short) s; + break; + } + --r; + } } - } else { - k = j->spec_start; - do { - int r, s; - int rs = stbi__jpeg_huff_decode( - j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh - if (rs < 0) - return stbi__err("bad huffman code", "Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r) - 1; - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - r = 64; // force end of block - } else { - // r=15 s=0 should write 16 0s, so we just do - // a run of 15 0s and then write s (which is 0), - // so we don't have to do anything special here - } - } else { - if (s != 1) - return stbi__err("bad huffman code", "Corrupt JPEG"); - // sign bit - if (stbi__jpeg_get_bit(j)) - s = bit; - else - s = -bit; - } - - // advance by r - while (k <= j->spec_end) { - short * p = &data[stbi__jpeg_dezigzag[k++]]; - if (*p != 0) { - if (stbi__jpeg_get_bit(j)) - if ((*p & bit) == 0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } else { - if (r == 0) { - *p = (short)s; - break; - } - --r; - } - } - } while (k <= j->spec_end); - } - } - return 1; + } while (k <= j->spec_end); + } + } + return 1; } // take a -128..127 value and stbi__clamp it and convert to 0..255 -stbi_inline static stbi_uc stbi__clamp(int x) { - // trick to use a single test to catch both cases - if ((unsigned int)x > 255) { - if (x < 0) - return 0; - if (x > 255) - return 255; - } - return (stbi_uc)x; +stbi_inline static stbi_uc stbi__clamp(int x) +{ + // trick to use a single test to catch both cases + if ((unsigned int) x > 255) { + if (x < 0) return 0; + if (x > 255) return 255; + } + return (stbi_uc) x; } -#define stbi__f2f(x) ((int)(((x)*4096 + 0.5))) -#define stbi__fsh(x) ((x)*4096) +#define stbi__f2f(x) ((int) (((x) * 4096 + 0.5))) +#define stbi__fsh(x) ((x) * 4096) // derived from jidctint -- DCT_ISLOW -#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ - int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ - p2 = s2; \ - p3 = s6; \ - p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ - t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ - t3 = p1 + p2 * stbi__f2f(0.765366865f); \ - p2 = s0; \ - p3 = s4; \ - t0 = stbi__fsh(p2 + p3); \ - t1 = stbi__fsh(p2 - p3); \ - x0 = t0 + t3; \ - x3 = t0 - t3; \ - x1 = t1 + t2; \ - x2 = t1 - t2; \ - t0 = s7; \ - t1 = s5; \ - t2 = s3; \ - t3 = s1; \ - p3 = t0 + t2; \ - p4 = t1 + t3; \ - p1 = t0 + t3; \ - p2 = t1 + t2; \ - p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ - t0 = t0 * stbi__f2f(0.298631336f); \ - t1 = t1 * stbi__f2f(2.053119869f); \ - t2 = t2 * stbi__f2f(3.072711026f); \ - t3 = t3 * stbi__f2f(1.501321110f); \ - p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ - p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ - p3 = p3 * stbi__f2f(-1.961570560f); \ - p4 = p4 * stbi__f2f(-0.390180644f); \ - t3 += p1 + p4; \ - t2 += p2 + p3; \ - t1 += p2 + p4; \ - t0 += p1 + p3; - -static void stbi__idct_block(stbi_uc * out, int out_stride, short data[64]) { - int i, val[64], *v = val; - stbi_uc * o; - short * d = data; - - // columns - for (i = 0; i < 8; ++i, ++d, ++v) { - // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing - if (d[8] == 0 && d[16] == 0 && d[24] == 0 && d[32] == 0 && d[40] == 0 && d[48] == 0 && d[56] == 0) { - // no shortcut 0 seconds - // (1|2|3|4|5|6|7)==0 0 seconds - // all separate -0.047 seconds - // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds - int dcterm = d[0] * 4; - v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; - } else { - STBI__IDCT_1D(d[0], d[8], d[16], d[24], d[32], d[40], d[48], d[56]) - // constants scaled things up by 1<<12; let's bring them back - // down, but keep 2 extra bits of precision - x0 += 512; - x1 += 512; - x2 += 512; - x3 += 512; - v[0] = (x0 + t3) >> 10; - v[56] = (x0 - t3) >> 10; - v[8] = (x1 + t2) >> 10; - v[48] = (x1 - t2) >> 10; - v[16] = (x2 + t1) >> 10; - v[40] = (x2 - t1) >> 10; - v[24] = (x3 + t0) >> 10; - v[32] = (x3 - t0) >> 10; - } - } - - for (i = 0, v = val, o = out; i < 8; ++i, v += 8, o += out_stride) { - // no fast case since the first 1D IDCT spread components out - STBI__IDCT_1D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]) - // constants scaled things up by 1<<12, plus we had 1<<2 from first - // loop, plus horizontal and vertical each scale by sqrt(8) so together - // we've got an extra 1<<3, so 1<<17 total we need to remove. - // so we want to round that, which means adding 0.5 * 1<<17, - // aka 65536. Also, we'll end up with -128 to 127 that we want - // to encode as 0..255 by adding 128, so we'll add that before the shift - x0 += 65536 + (128 << 17); - x1 += 65536 + (128 << 17); - x2 += 65536 + (128 << 17); - x3 += 65536 + (128 << 17); - // tried computing the shifts into temps, or'ing the temps to see - // if any were out of range, but that was slower - o[0] = stbi__clamp((x0 + t3) >> 17); - o[7] = stbi__clamp((x0 - t3) >> 17); - o[1] = stbi__clamp((x1 + t2) >> 17); - o[6] = stbi__clamp((x1 - t2) >> 17); - o[2] = stbi__clamp((x2 + t1) >> 17); - o[5] = stbi__clamp((x2 - t1) >> 17); - o[3] = stbi__clamp((x3 + t0) >> 17); - o[4] = stbi__clamp((x3 - t0) >> 17); - } +#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ + int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2+p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3*stbi__f2f(-1.847759065f); \ + t3 = p1 + p2*stbi__f2f( 0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2+p3); \ + t1 = stbi__fsh(p2-p3); \ + x0 = t0+t3; \ + x3 = t0-t3; \ + x1 = t1+t2; \ + x2 = t1-t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0+t2; \ + p4 = t1+t3; \ + p1 = t0+t3; \ + p2 = t1+t2; \ + p5 = (p3+p4)*stbi__f2f( 1.175875602f); \ + t0 = t0*stbi__f2f( 0.298631336f); \ + t1 = t1*stbi__f2f( 2.053119869f); \ + t2 = t2*stbi__f2f( 3.072711026f); \ + t3 = t3*stbi__f2f( 1.501321110f); \ + p1 = p5 + p1*stbi__f2f(-0.899976223f); \ + p2 = p5 + p2*stbi__f2f(-2.562915447f); \ + p3 = p3*stbi__f2f(-1.961570560f); \ + p4 = p4*stbi__f2f(-0.390180644f); \ + t3 += p1+p4; \ + t2 += p2+p3; \ + t1 += p2+p4; \ + t0 += p1+p3; + +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) +{ + int i,val[64],*v=val; + stbi_uc *o; + short *d = data; + + // columns + for (i=0; i < 8; ++i,++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0 + && d[40]==0 && d[48]==0 && d[56]==0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0]*4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; x1 += 512; x2 += 512; x3 += 512; + v[ 0] = (x0+t3) >> 10; + v[56] = (x0-t3) >> 10; + v[ 8] = (x1+t2) >> 10; + v[48] = (x1-t2) >> 10; + v[16] = (x2+t1) >> 10; + v[40] = (x2-t1) >> 10; + v[24] = (x3+t0) >> 10; + v[32] = (x3-t0) >> 10; + } + } + + for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128<<17); + x1 += 65536 + (128<<17); + x2 += 65536 + (128<<17); + x3 += 65536 + (128<<17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0+t3) >> 17); + o[7] = stbi__clamp((x0-t3) >> 17); + o[1] = stbi__clamp((x1+t2) >> 17); + o[6] = stbi__clamp((x1-t2) >> 17); + o[2] = stbi__clamp((x2+t1) >> 17); + o[5] = stbi__clamp((x2-t1) >> 17); + o[3] = stbi__clamp((x3+t0) >> 17); + o[4] = stbi__clamp((x3-t0) >> 17); + } } #ifdef STBI_SSE2 // sse2 integer IDCT. not the fastest possible implementation but it // produces bit-identical results to the generic C version so it's // fully "transparent". -static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { - // This is constructed to match our regular (generic) integer IDCT exactly. - __m128i row0, row1, row2, row3, row4, row5, row6, row7; - __m128i tmp; - -// dot product constant: even elems=x, odd elems=y -#define dct_const(x, y) _mm_setr_epi16((x), (y), (x), (y), (x), (y), (x), (y)) - -// out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) -// out(1) = c1[even]*x + c1[odd]*y -#define dct_rot(out0, out1, x, y, c0, c1) \ - __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ - __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ - __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ - __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ - __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ - __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) - -// out = in << 12 (in 16-bit, out 32-bit) -#define dct_widen(out, in) \ - __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ - __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - -// wide add -#define dct_wadd(out, a, b) \ - __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - -// wide sub -#define dct_wsub(out, a, b) \ - __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) - -// butterfly a/b, add bias, then shift by "s" and pack -#define dct_bfly32o(out0, out1, a, b, bias, s) \ - { \ - __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ - __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ - dct_wadd(sum, abiased, b); \ - dct_wsub(dif, abiased, b); \ - out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ - out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ - } - -// 8-bit interleave step (for transposes) -#define dct_interleave8(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi8(a, b); \ - b = _mm_unpackhi_epi8(tmp, b) - -// 16-bit interleave step (for transposes) -#define dct_interleave16(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi16(a, b); \ - b = _mm_unpackhi_epi16(tmp, b) - -#define dct_pass(bias, shift) \ - { \ - /* even part */ \ - dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ - __m128i sum04 = _mm_add_epi16(row0, row4); \ - __m128i dif04 = _mm_sub_epi16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ - dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ - __m128i sum17 = _mm_add_epi16(row1, row7); \ - __m128i sum35 = _mm_add_epi16(row3, row5); \ - dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ - dct_wadd(x4, y0o, y4o); \ - dct_wadd(x5, y1o, y5o); \ - dct_wadd(x6, y2o, y5o); \ - dct_wadd(x7, y3o, y4o); \ - dct_bfly32o(row0, row7, x0, x7, bias, shift); \ - dct_bfly32o(row1, row6, x1, x6, bias, shift); \ - dct_bfly32o(row2, row5, x2, x5, bias, shift); \ - dct_bfly32o(row3, row4, x3, x4, bias, shift); \ - } - - __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); - __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f(0.765366865f), stbi__f2f(0.5411961f)); - __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); - __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); - __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f(0.298631336f), stbi__f2f(-1.961570560f)); - __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f(3.072711026f)); - __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f(2.053119869f), stbi__f2f(-0.390180644f)); - __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f(1.501321110f)); - - // rounding biases in column/row passes, see stbi__idct_block for explanation. - __m128i bias_0 = _mm_set1_epi32(512); - __m128i bias_1 = _mm_set1_epi32(65536 + (128 << 17)); - - // load - row0 = _mm_load_si128((const __m128i *)(data + 0 * 8)); - row1 = _mm_load_si128((const __m128i *)(data + 1 * 8)); - row2 = _mm_load_si128((const __m128i *)(data + 2 * 8)); - row3 = _mm_load_si128((const __m128i *)(data + 3 * 8)); - row4 = _mm_load_si128((const __m128i *)(data + 4 * 8)); - row5 = _mm_load_si128((const __m128i *)(data + 5 * 8)); - row6 = _mm_load_si128((const __m128i *)(data + 6 * 8)); - row7 = _mm_load_si128((const __m128i *)(data + 7 * 8)); - - // column pass - dct_pass(bias_0, 10); - - { - // 16bit 8x8 transpose pass 1 - dct_interleave16(row0, row4); - dct_interleave16(row1, row5); - dct_interleave16(row2, row6); - dct_interleave16(row3, row7); - - // transpose pass 2 - dct_interleave16(row0, row2); - dct_interleave16(row1, row3); - dct_interleave16(row4, row6); - dct_interleave16(row5, row7); - - // transpose pass 3 - dct_interleave16(row0, row1); - dct_interleave16(row2, row3); - dct_interleave16(row4, row5); - dct_interleave16(row6, row7); - } - - // row pass - dct_pass(bias_1, 17); - - { - // pack - __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 - __m128i p1 = _mm_packus_epi16(row2, row3); - __m128i p2 = _mm_packus_epi16(row4, row5); - __m128i p3 = _mm_packus_epi16(row6, row7); - - // 8bit 8x8 transpose pass 1 - dct_interleave8(p0, p2); // a0e0a1e1... - dct_interleave8(p1, p3); // c0g0c1g1... - - // transpose pass 2 - dct_interleave8(p0, p1); // a0c0e0g0... - dct_interleave8(p2, p3); // b0d0f0h0... - - // transpose pass 3 - dct_interleave8(p0, p2); // a0b0c0d0... - dct_interleave8(p1, p3); // a4b4c4d4... - - // store - _mm_storel_epi64((__m128i *)out, p0); - out += out_stride; - _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p0, 0x4e)); - out += out_stride; - _mm_storel_epi64((__m128i *)out, p2); - out += out_stride; - _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p2, 0x4e)); - out += out_stride; - _mm_storel_epi64((__m128i *)out, p1); - out += out_stride; - _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p1, 0x4e)); - out += out_stride; - _mm_storel_epi64((__m128i *)out, p3); - out += out_stride; - _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p3, 0x4e)); - } +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) +{ + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + + // dot product constant: even elems=x, odd elems=y + #define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) + + // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) + // out(1) = c1[even]*x + c1[odd]*y + #define dct_rot(out0,out1, x,y,c0,c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + + // out = in << 12 (in 16-bit, out 32-bit) + #define dct_widen(out, in) \ + __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) + + // wide add + #define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) + + // wide sub + #define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + + // butterfly a/b, add bias, then shift by "s" and pack + #define dct_bfly32o(out0, out1, a,b,bias,s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + + // 8-bit interleave step (for transposes) + #define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + + // 16-bit interleave step (for transposes) + #define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + + #define dct_pass(bias,shift) \ + { \ + /* even part */ \ + dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ + dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0,row7, x0,x7,bias,shift); \ + dct_bfly32o(row1,row6, x1,x6,bias,shift); \ + dct_bfly32o(row2,row5, x2,x5,bias,shift); \ + dct_bfly32o(row3,row4, x3,x4,bias,shift); \ + } + + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); + __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f)); + __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f)); + __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f)); + + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17)); + + // load + row0 = _mm_load_si128((const __m128i *) (data + 0*8)); + row1 = _mm_load_si128((const __m128i *) (data + 1*8)); + row2 = _mm_load_si128((const __m128i *) (data + 2*8)); + row3 = _mm_load_si128((const __m128i *) (data + 3*8)); + row4 = _mm_load_si128((const __m128i *) (data + 4*8)); + row5 = _mm_load_si128((const __m128i *) (data + 5*8)); + row6 = _mm_load_si128((const __m128i *) (data + 6*8)); + row7 = _mm_load_si128((const __m128i *) (data + 7*8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i *) out, p0); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p2); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p1); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p3); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e)); + } #undef dct_const #undef dct_rot @@ -2763,235 +2708,198 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { // NEON integer IDCT. should produce bit-identical // results to the generic C version. -static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { - int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; - - int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); - int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); - int16x4_t rot0_2 = vdup_n_s16(stbi__f2f(0.765366865f)); - int16x4_t rot1_0 = vdup_n_s16(stbi__f2f(1.175875602f)); - int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); - int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); - int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); - int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); - int16x4_t rot3_0 = vdup_n_s16(stbi__f2f(0.298631336f)); - int16x4_t rot3_1 = vdup_n_s16(stbi__f2f(2.053119869f)); - int16x4_t rot3_2 = vdup_n_s16(stbi__f2f(3.072711026f)); - int16x4_t rot3_3 = vdup_n_s16(stbi__f2f(1.501321110f)); - -#define dct_long_mul(out, inq, coeff) \ - int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) - -#define dct_long_mac(out, acc, inq, coeff) \ - int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) - -#define dct_widen(out, inq) \ - int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ - int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) +{ + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; + + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f)); + +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) + +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) + +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) // wide add -#define dct_wadd(out, a, b) \ - int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vaddq_s32(a##_h, b##_h) +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) // wide sub -#define dct_wsub(out, a, b) \ - int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vsubq_s32(a##_h, b##_h) +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) // butterfly a/b, then shift using "shiftop" by "s" and pack -#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ - { \ - dct_wadd(sum, a, b); \ - dct_wsub(dif, a, b); \ - out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ - out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ - } - -#define dct_pass(shiftop, shift) \ - { \ - /* even part */ \ - int16x8_t sum26 = vaddq_s16(row2, row6); \ - dct_long_mul(p1e, sum26, rot0_0); \ - dct_long_mac(t2e, p1e, row6, rot0_1); \ - dct_long_mac(t3e, p1e, row2, rot0_2); \ - int16x8_t sum04 = vaddq_s16(row0, row4); \ - int16x8_t dif04 = vsubq_s16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - int16x8_t sum15 = vaddq_s16(row1, row5); \ - int16x8_t sum17 = vaddq_s16(row1, row7); \ - int16x8_t sum35 = vaddq_s16(row3, row5); \ - int16x8_t sum37 = vaddq_s16(row3, row7); \ - int16x8_t sumodd = vaddq_s16(sum17, sum35); \ - dct_long_mul(p5o, sumodd, rot1_0); \ - dct_long_mac(p1o, p5o, sum17, rot1_1); \ - dct_long_mac(p2o, p5o, sum35, rot1_2); \ - dct_long_mul(p3o, sum37, rot2_0); \ - dct_long_mul(p4o, sum15, rot2_1); \ - dct_wadd(sump13o, p1o, p3o); \ - dct_wadd(sump24o, p2o, p4o); \ - dct_wadd(sump23o, p2o, p3o); \ - dct_wadd(sump14o, p1o, p4o); \ - dct_long_mac(x4, sump13o, row7, rot3_0); \ - dct_long_mac(x5, sump24o, row5, rot3_1); \ - dct_long_mac(x6, sump23o, row3, rot3_2); \ - dct_long_mac(x7, sump14o, row1, rot3_3); \ - dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ - dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ - dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ - dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ - } - - // load - row0 = vld1q_s16(data + 0 * 8); - row1 = vld1q_s16(data + 1 * 8); - row2 = vld1q_s16(data + 2 * 8); - row3 = vld1q_s16(data + 3 * 8); - row4 = vld1q_s16(data + 4 * 8); - row5 = vld1q_s16(data + 5 * 8); - row6 = vld1q_s16(data + 6 * 8); - row7 = vld1q_s16(data + 7 * 8); - - // add DC bias - row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); - - // column pass - dct_pass(vrshrn_n_s32, 10); - - // 16bit 8x8 transpose - { +#define dct_bfly32o(out0,out1, a,b,shiftop,s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \ + dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \ + dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \ + dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \ + } + + // load + row0 = vld1q_s16(data + 0*8); + row1 = vld1q_s16(data + 1*8); + row2 = vld1q_s16(data + 2*8); + row3 = vld1q_s16(data + 3*8); + row4 = vld1q_s16(data + 4*8); + row5 = vld1q_s16(data + 5*8); + row6 = vld1q_s16(data + 6*8); + row7 = vld1q_s16(data + 7*8); + + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + + // column pass + dct_pass(vrshrn_n_s32, 10); + + // 16bit 8x8 transpose + { // these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. // whether compilers actually get this is another story, sadly. -#define dct_trn16(x, y) \ - { \ - int16x8x2_t t = vtrnq_s16(x, y); \ - x = t.val[0]; \ - y = t.val[1]; \ - } -#define dct_trn32(x, y) \ - { \ - int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ - x = vreinterpretq_s16_s32(t.val[0]); \ - y = vreinterpretq_s16_s32(t.val[1]); \ - } -#define dct_trn64(x, y) \ - { \ - int16x8_t x0 = x; \ - int16x8_t y0 = y; \ - x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ - y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ - } - - // pass 1 - dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 - dct_trn16(row2, row3); - dct_trn16(row4, row5); - dct_trn16(row6, row7); - - // pass 2 - dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 - dct_trn32(row1, row3); - dct_trn32(row4, row6); - dct_trn32(row5, row7); - - // pass 3 - dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 - dct_trn64(row1, row5); - dct_trn64(row2, row6); - dct_trn64(row3, row7); +#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; } +#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); } +#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); } + + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); + + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); + + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); #undef dct_trn16 #undef dct_trn32 #undef dct_trn64 - } - - // row pass - // vrshrn_n_s32 only supports shifts up to 16, we need - // 17. so do a non-rounding shift of 16 first then follow - // up with a rounding shift by 1. - dct_pass(vshrn_n_s32, 16); - - { - // pack and round - uint8x8_t p0 = vqrshrun_n_s16(row0, 1); - uint8x8_t p1 = vqrshrun_n_s16(row1, 1); - uint8x8_t p2 = vqrshrun_n_s16(row2, 1); - uint8x8_t p3 = vqrshrun_n_s16(row3, 1); - uint8x8_t p4 = vqrshrun_n_s16(row4, 1); - uint8x8_t p5 = vqrshrun_n_s16(row5, 1); - uint8x8_t p6 = vqrshrun_n_s16(row6, 1); - uint8x8_t p7 = vqrshrun_n_s16(row7, 1); - - // again, these can translate into one instruction, but often don't. -#define dct_trn8_8(x, y) \ - { \ - uint8x8x2_t t = vtrn_u8(x, y); \ - x = t.val[0]; \ - y = t.val[1]; \ - } -#define dct_trn8_16(x, y) \ - { \ - uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ - x = vreinterpret_u8_u16(t.val[0]); \ - y = vreinterpret_u8_u16(t.val[1]); \ - } -#define dct_trn8_32(x, y) \ - { \ - uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ - x = vreinterpret_u8_u32(t.val[0]); \ - y = vreinterpret_u8_u32(t.val[1]); \ - } - - // sadly can't use interleaved stores here since we only write - // 8 bytes to each scan line! - - // 8x8 8-bit transpose pass 1 - dct_trn8_8(p0, p1); - dct_trn8_8(p2, p3); - dct_trn8_8(p4, p5); - dct_trn8_8(p6, p7); - - // pass 2 - dct_trn8_16(p0, p2); - dct_trn8_16(p1, p3); - dct_trn8_16(p4, p6); - dct_trn8_16(p5, p7); - - // pass 3 - dct_trn8_32(p0, p4); - dct_trn8_32(p1, p5); - dct_trn8_32(p2, p6); - dct_trn8_32(p3, p7); - - // store - vst1_u8(out, p0); - out += out_stride; - vst1_u8(out, p1); - out += out_stride; - vst1_u8(out, p2); - out += out_stride; - vst1_u8(out, p3); - out += out_stride; - vst1_u8(out, p4); - out += out_stride; - vst1_u8(out, p5); - out += out_stride; - vst1_u8(out, p6); - out += out_stride; - vst1_u8(out, p7); + } + + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); + + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; } +#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); } +#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); } + + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! + + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); + + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); + + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); + + // store + vst1_u8(out, p0); out += out_stride; + vst1_u8(out, p1); out += out_stride; + vst1_u8(out, p2); out += out_stride; + vst1_u8(out, p3); out += out_stride; + vst1_u8(out, p4); out += out_stride; + vst1_u8(out, p5); out += out_stride; + vst1_u8(out, p6); out += out_stride; + vst1_u8(out, p7); #undef dct_trn8_8 #undef dct_trn8_16 #undef dct_trn8_32 - } + } #undef dct_long_mul #undef dct_long_mac @@ -3004,1267 +2912,1169 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { #endif // STBI_NEON -#define STBI__MARKER_none 0xff +#define STBI__MARKER_none 0xff // if there's a pending marker from the entropy stream, return that // otherwise, fetch from the stream and get a marker. if there's no // marker, return 0xff, which is never a valid marker value -static stbi_uc stbi__get_marker(stbi__jpeg * j) { - stbi_uc x; - if (j->marker != STBI__MARKER_none) { - x = j->marker; - j->marker = STBI__MARKER_none; - return x; - } - x = stbi__get8(j->s); - if (x != 0xff) - return STBI__MARKER_none; - while (x == 0xff) - x = stbi__get8(j->s); // consume repeated 0xff fill bytes - return x; +static stbi_uc stbi__get_marker(stbi__jpeg *j) +{ + stbi_uc x; + if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; } + x = stbi__get8(j->s); + if (x != 0xff) return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; } // in each scan, we'll have scan_n components, and the order // of the components is specified by order[] -#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) // after a restart interval, stbi__jpeg_reset the entropy decoder and // the dc prediction -static void stbi__jpeg_reset(stbi__jpeg * j) { - j->code_bits = 0; - j->code_buffer = 0; - j->nomore = 0; - j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; - j->marker = STBI__MARKER_none; - j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; - j->eob_run = 0; - // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, - // since we don't even allow 1<<30 pixels -} - -static int stbi__parse_entropy_coded_data(stbi__jpeg * z) { - stbi__jpeg_reset(z); - if (!z->progressive) { - if (z->scan_n == 1) { - int i, j; - STBI_SIMD_ALIGN(short, data[64]); - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x + 7) >> 3; - int h = (z->img_comp[n].y + 7) >> 3; - for (j = 0; j < h; ++j) { - for (i = 0; i < w; ++i) { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, z->huff_ac + ha, z->fast_ac[ha], n, - z->dequant[z->img_comp[n].tq])) - return 0; - z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + i * 8, z->img_comp[n].w2, data); - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) - stbi__grow_buffer_unsafe(z); - // if it's NOT a restart, then just bail, so we get corrupt data - // rather than no data - if (!STBI__RESTART(z->marker)) - return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i, j, k, x, y; - STBI_SIMD_ALIGN(short, data[64]); - for (j = 0; j < z->img_mcu_y; ++j) { - for (i = 0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k = 0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y = 0; y < z->img_comp[n].v; ++y) { - for (x = 0; x < z->img_comp[n].h; ++x) { - int x2 = (i * z->img_comp[n].h + x) * 8; - int y2 = (j * z->img_comp[n].v + y) * 8; - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, z->huff_ac + ha, - z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) - return 0; - z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * y2 + x2, z->img_comp[n].w2, - data); - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) - stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) - return 1; - stbi__jpeg_reset(z); - } - } +static void stbi__jpeg_reset(stbi__jpeg *j) +{ + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels +} + +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) +{ + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i,j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } } - return 1; - } - } else { - if (z->scan_n == 1) { - int i, j; - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x + 7) >> 3; - int h = (z->img_comp[n].y + 7) >> 3; - for (j = 0; j < h; ++j) { - for (i = 0; i < w; ++i) { - short * data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - if (z->spec_start == 0) { - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } else { + } + return 1; + } else { // interleaved + int i,j,k,x,y; + STBI_SIMD_ALIGN(short, data[64]); + for (j=0; j < z->img_mcu_y; ++j) { + for (i=0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k=0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y=0; y < z->img_comp[n].v; ++y) { + for (x=0; x < z->img_comp[n].h; ++x) { + int x2 = (i*z->img_comp[n].h + x)*8; + int y2 = (j*z->img_comp[n].v + y)*8; int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) - return 0; - } - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) - stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) - return 1; - stbi__jpeg_reset(z); - } - } + if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data); + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } } - return 1; - } else { // interleaved - int i, j, k, x, y; - for (j = 0; j < z->img_mcu_y; ++j) { - for (i = 0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k = 0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y = 0; y < z->img_comp[n].v; ++y) { - for (x = 0; x < z->img_comp[n].h; ++x) { - int x2 = (i * z->img_comp[n].h + x); - int y2 = (j * z->img_comp[n].v + y); - short * data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) - stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) - return 1; - stbi__jpeg_reset(z); - } - } + } + return 1; + } + } else { + if (z->scan_n == 1) { + int i,j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } } - return 1; - } - } -} - -static void stbi__jpeg_dequantize(short * data, stbi__uint16 * dequant) { - int i; - for (i = 0; i < 64; ++i) - data[i] *= dequant[i]; -} - -static void stbi__jpeg_finish(stbi__jpeg * z) { - if (z->progressive) { - // dequantize and idct the data - int i, j, n; - for (n = 0; n < z->s->img_n; ++n) { - int w = (z->img_comp[n].x + 7) >> 3; - int h = (z->img_comp[n].y + 7) >> 3; - for (j = 0; j < h; ++j) { - for (i = 0; i < w; ++i) { - short * data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); - z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + i * 8, z->img_comp[n].w2, data); - } + } + return 1; + } else { // interleaved + int i,j,k,x,y; + for (j=0; j < z->img_mcu_y; ++j) { + for (i=0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k=0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y=0; y < z->img_comp[n].v; ++y) { + for (x=0; x < z->img_comp[n].h; ++x) { + int x2 = (i*z->img_comp[n].h + x); + int y2 = (j*z->img_comp[n].v + y); + short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } } - } - } + } + return 1; + } + } } -static int stbi__process_marker(stbi__jpeg * z, int m) { - int L; - switch (m) { - case STBI__MARKER_none: // no marker found - return stbi__err("expected marker", "Corrupt JPEG"); +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) +{ + int i; + for (i=0; i < 64; ++i) + data[i] *= dequant[i]; +} - case 0xDD: // DRI - specify restart interval - if (stbi__get16be(z->s) != 4) - return stbi__err("bad DRI len", "Corrupt JPEG"); - z->restart_interval = stbi__get16be(z->s); - return 1; +static void stbi__jpeg_finish(stbi__jpeg *z) +{ + if (z->progressive) { + // dequantize and idct the data + int i,j,n; + for (n=0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); + } + } + } + } +} - case 0xDB: // DQT - define quantization table - L = stbi__get16be(z->s) - 2; - while (L > 0) { +static int stbi__process_marker(stbi__jpeg *z, int m) +{ + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker","Corrupt JPEG"); + + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len","Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; + + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s)-2; + while (L > 0) { int q = stbi__get8(z->s); int p = q >> 4, sixteen = (p != 0); - int t = q & 15, i; - if (p != 0 && p != 1) - return stbi__err("bad DQT type", "Corrupt JPEG"); - if (t > 3) - return stbi__err("bad DQT table", "Corrupt JPEG"); - - for (i = 0; i < 64; ++i) - z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); - L -= (sixteen ? 129 : 65); - } - return L == 0; + int t = q & 15,i; + if (p != 0 && p != 1) return stbi__err("bad DQT type","Corrupt JPEG"); + if (t > 3) return stbi__err("bad DQT table","Corrupt JPEG"); - case 0xC4: // DHT - define huffman table - L = stbi__get16be(z->s) - 2; - while (L > 0) { - stbi_uc * v; - int sizes[16], i, n = 0; + for (i=0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L==0; + + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s)-2; + while (L > 0) { + stbi_uc *v; + int sizes[16],i,n=0; int q = stbi__get8(z->s); int tc = q >> 4; int th = q & 15; - if (tc > 1 || th > 3) - return stbi__err("bad DHT header", "Corrupt JPEG"); - for (i = 0; i < 16; ++i) { - sizes[i] = stbi__get8(z->s); - n += sizes[i]; + if (tc > 1 || th > 3) return stbi__err("bad DHT header","Corrupt JPEG"); + for (i=0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; } - if (n > 256) - return stbi__err("bad DHT header", "Corrupt JPEG"); // Loop over i < n would write past end of values! + if(n > 256) return stbi__err("bad DHT header","Corrupt JPEG"); // Loop over i < n would write past end of values! L -= 17; if (tc == 0) { - if (!stbi__build_huffman(z->huff_dc + th, sizes)) - return 0; - v = z->huff_dc[th].values; + if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0; + v = z->huff_dc[th].values; } else { - if (!stbi__build_huffman(z->huff_ac + th, sizes)) - return 0; - v = z->huff_ac[th].values; + if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0; + v = z->huff_ac[th].values; } - for (i = 0; i < n; ++i) - v[i] = stbi__get8(z->s); + for (i=0; i < n; ++i) + v[i] = stbi__get8(z->s); if (tc != 0) - stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); L -= n; - } - return L == 0; - } - - // check for comment block or APP blocks - if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { - L = stbi__get16be(z->s); - if (L < 2) { - if (m == 0xFE) - return stbi__err("bad COM len", "Corrupt JPEG"); - else - return stbi__err("bad APP len", "Corrupt JPEG"); - } - L -= 2; - - if (m == 0xE0 && L >= 5) { // JFIF APP0 segment - static const unsigned char tag[5] = {'J', 'F', 'I', 'F', '\0'}; - int ok = 1; - int i; - for (i = 0; i < 5; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 5; - if (ok) - z->jfif = 1; - } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment - static const unsigned char tag[6] = {'A', 'd', 'o', 'b', 'e', '\0'}; - int ok = 1; - int i; - for (i = 0; i < 6; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; + } + return L==0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len","Corrupt JPEG"); + else + return stbi__err("bad APP len","Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J','F','I','F','\0'}; + int ok = 1; + int i; + for (i=0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A','d','o','b','e','\0'}; + int ok = 1; + int i; + for (i=0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform L -= 6; - if (ok) { - stbi__get8(z->s); // version - stbi__get16be(z->s); // flags0 - stbi__get16be(z->s); // flags1 - z->app14_color_transform = stbi__get8(z->s); // color transform - L -= 6; - } - } + } + } - stbi__skip(z->s, L); - return 1; - } + stbi__skip(z->s, L); + return 1; + } - return stbi__err("unknown marker", "Corrupt JPEG"); + return stbi__err("unknown marker","Corrupt JPEG"); } // after we see SOS -static int stbi__process_scan_header(stbi__jpeg * z) { - int i; - int Ls = stbi__get16be(z->s); - z->scan_n = stbi__get8(z->s); - if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) - return stbi__err("bad SOS component count", "Corrupt JPEG"); - if (Ls != 6 + 2 * z->scan_n) - return stbi__err("bad SOS len", "Corrupt JPEG"); - for (i = 0; i < z->scan_n; ++i) { - int id = stbi__get8(z->s), which; - int q = stbi__get8(z->s); - for (which = 0; which < z->s->img_n; ++which) - if (z->img_comp[which].id == id) - break; - if (which == z->s->img_n) - return 0; // no match - z->img_comp[which].hd = q >> 4; - if (z->img_comp[which].hd > 3) - return stbi__err("bad DC huff", "Corrupt JPEG"); - z->img_comp[which].ha = q & 15; - if (z->img_comp[which].ha > 3) - return stbi__err("bad AC huff", "Corrupt JPEG"); - z->order[i] = which; - } - - { - int aa; - z->spec_start = stbi__get8(z->s); - z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 - aa = stbi__get8(z->s); - z->succ_high = (aa >> 4); - z->succ_low = (aa & 15); - if (z->progressive) { - if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) - return stbi__err("bad SOS", "Corrupt JPEG"); - } else { - if (z->spec_start != 0) - return stbi__err("bad SOS", "Corrupt JPEG"); - if (z->succ_high != 0 || z->succ_low != 0) - return stbi__err("bad SOS", "Corrupt JPEG"); - z->spec_end = 63; - } - } - - return 1; +static int stbi__process_scan_header(stbi__jpeg *z) +{ + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err("bad SOS component count","Corrupt JPEG"); + if (Ls != 6+2*z->scan_n) return stbi__err("bad SOS len","Corrupt JPEG"); + for (i=0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) return 0; // no match + z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff","Corrupt JPEG"); + z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return stbi__err("bad AC huff","Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) return stbi__err("bad SOS","Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) return stbi__err("bad SOS","Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; +} + +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) +{ + int i; + for (i=0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; + z->img_comp[i].data = NULL; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); + z->img_comp[i].linebuf = NULL; + } + } + return why; +} + +static int stbi__process_frame_header(stbi__jpeg *z, int scan) +{ + stbi__context *s = z->s; + int Lf,p,i,q, h_max=1,v_max=1,c; + Lf = stbi__get16be(s); if (Lf < 11) return stbi__err("bad SOF len","Corrupt JPEG"); // JPEG + p = stbi__get8(s); if (p != 8) return stbi__err("only 8-bit","JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG + s->img_x = stbi__get16be(s); if (s->img_x == 0) return stbi__err("0 width","Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) return stbi__err("bad component count","Corrupt JPEG"); + s->img_n = c; + for (i=0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8+3*s->img_n) return stbi__err("bad SOF len","Corrupt JPEG"); + + z->rgb = 0; + for (i=0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = { 'R', 'G', 'B' }; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err("bad H","Corrupt JPEG"); + z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err("bad V","Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); if (z->img_comp[i].tq > 3) return stbi__err("bad TQ","Corrupt JPEG"); + } + + if (scan != STBI__SCAN_load) return 1; + + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err("too large", "Image too large to decode"); + + for (i=0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; + } + + // check that plane subsampling factors are integer ratios; our resamplers can't deal with fractional ratios + // and I've never seen a non-corrupted JPEG file actually use them + for (i=0; i < s->img_n; ++i) { + if (h_max % z->img_comp[i].h != 0) return stbi__err("bad H","Corrupt JPEG"); + if (v_max % z->img_comp[i].v != 0) return stbi__err("bad V","Corrupt JPEG"); + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h; + + for (i=0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) + // so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15); + } + } + + return 1; } -static int stbi__free_jpeg_components(stbi__jpeg * z, int ncomp, int why) { - int i; - for (i = 0; i < ncomp; ++i) { - if (z->img_comp[i].raw_data) { - STBI_FREE(z->img_comp[i].raw_data); - z->img_comp[i].raw_data = NULL; - z->img_comp[i].data = NULL; - } - if (z->img_comp[i].raw_coeff) { - STBI_FREE(z->img_comp[i].raw_coeff); - z->img_comp[i].raw_coeff = 0; - z->img_comp[i].coeff = 0; - } - if (z->img_comp[i].linebuf) { - STBI_FREE(z->img_comp[i].linebuf); - z->img_comp[i].linebuf = NULL; - } - } - return why; -} - -static int stbi__process_frame_header(stbi__jpeg * z, int scan) { - stbi__context * s = z->s; - int Lf, p, i, q, h_max = 1, v_max = 1, c; - Lf = stbi__get16be(s); - if (Lf < 11) - return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG - p = stbi__get8(s); - if (p != 8) - return stbi__err("only 8-bit", "JPEG format not supported: 8-bit only"); // JPEG baseline - s->img_y = stbi__get16be(s); - if (s->img_y == 0) - return stbi__err("no header height", - "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG - s->img_x = stbi__get16be(s); - if (s->img_x == 0) - return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires - if (s->img_y > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - c = stbi__get8(s); - if (c != 3 && c != 1 && c != 4) - return stbi__err("bad component count", "Corrupt JPEG"); - s->img_n = c; - for (i = 0; i < c; ++i) { - z->img_comp[i].data = NULL; - z->img_comp[i].linebuf = NULL; - } - - if (Lf != 8 + 3 * s->img_n) - return stbi__err("bad SOF len", "Corrupt JPEG"); - - z->rgb = 0; - for (i = 0; i < s->img_n; ++i) { - static const unsigned char rgb[3] = {'R', 'G', 'B'}; - z->img_comp[i].id = stbi__get8(s); - if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) - ++z->rgb; - q = stbi__get8(s); - z->img_comp[i].h = (q >> 4); - if (!z->img_comp[i].h || z->img_comp[i].h > 4) - return stbi__err("bad H", "Corrupt JPEG"); - z->img_comp[i].v = q & 15; - if (!z->img_comp[i].v || z->img_comp[i].v > 4) - return stbi__err("bad V", "Corrupt JPEG"); - z->img_comp[i].tq = stbi__get8(s); - if (z->img_comp[i].tq > 3) - return stbi__err("bad TQ", "Corrupt JPEG"); - } - - if (scan != STBI__SCAN_load) - return 1; - - if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) - return stbi__err("too large", "Image too large to decode"); - - for (i = 0; i < s->img_n; ++i) { - if (z->img_comp[i].h > h_max) - h_max = z->img_comp[i].h; - if (z->img_comp[i].v > v_max) - v_max = z->img_comp[i].v; - } - - // check that plane subsampling factors are integer ratios; our resamplers can't deal with fractional ratios - // and I've never seen a non-corrupted JPEG file actually use them - for (i = 0; i < s->img_n; ++i) { - if (h_max % z->img_comp[i].h != 0) - return stbi__err("bad H", "Corrupt JPEG"); - if (v_max % z->img_comp[i].v != 0) - return stbi__err("bad V", "Corrupt JPEG"); - } - - // compute interleaved mcu info - z->img_h_max = h_max; - z->img_v_max = v_max; - z->img_mcu_w = h_max * 8; - z->img_mcu_h = v_max * 8; - // these sizes can't be more than 17 bits - z->img_mcu_x = (s->img_x + z->img_mcu_w - 1) / z->img_mcu_w; - z->img_mcu_y = (s->img_y + z->img_mcu_h - 1) / z->img_mcu_h; - - for (i = 0; i < s->img_n; ++i) { - // number of effective pixels (e.g. for non-interleaved MCU) - z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max - 1) / h_max; - z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max - 1) / v_max; - // to simplify generation, we'll allocate enough memory to decode - // the bogus oversized data from using interleaved MCUs and their - // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't - // discard the extra data until colorspace conversion - // - // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) - // so these muls can't overflow with 32-bit ints (which we require) - z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; - z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; - z->img_comp[i].coeff = 0; - z->img_comp[i].raw_coeff = 0; - z->img_comp[i].linebuf = NULL; - z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); - if (z->img_comp[i].raw_data == NULL) - return stbi__free_jpeg_components(z, i + 1, stbi__err("outofmem", "Out of memory")); - // align blocks for idct using mmx/sse - z->img_comp[i].data = (stbi_uc *)(((size_t)z->img_comp[i].raw_data + 15) & ~15); - if (z->progressive) { - // w2, h2 are multiples of 8 (see above) - z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; - z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; - z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); - if (z->img_comp[i].raw_coeff == NULL) - return stbi__free_jpeg_components(z, i + 1, stbi__err("outofmem", "Out of memory")); - z->img_comp[i].coeff = (short *)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); - } - } +// use comparisons since in some cases we handle more than one case (e.g. SOF) +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) - return 1; -} +#define stbi__SOF_progressive(x) ((x) == 0xc2) -// use comparisons since in some cases we handle more than one case (e.g. SOF) -#define stbi__DNL(x) ((x) == 0xdc) -#define stbi__SOI(x) ((x) == 0xd8) -#define stbi__EOI(x) ((x) == 0xd9) -#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) -#define stbi__SOS(x) ((x) == 0xda) - -#define stbi__SOF_progressive(x) ((x) == 0xc2) - -static int stbi__decode_jpeg_header(stbi__jpeg * z, int scan) { - int m; - z->jfif = 0; - z->app14_color_transform = -1; // valid values are 0,1,2 - z->marker = STBI__MARKER_none; // initialize cached marker to empty - m = stbi__get_marker(z); - if (!stbi__SOI(m)) - return stbi__err("no SOI", "Corrupt JPEG"); - if (scan == STBI__SCAN_type) - return 1; - m = stbi__get_marker(z); - while (!stbi__SOF(m)) { - if (!stbi__process_marker(z, m)) - return 0; - m = stbi__get_marker(z); - while (m == STBI__MARKER_none) { - // some files have extra padding after their blocks, so ok, we'll scan - if (stbi__at_eof(z->s)) - return stbi__err("no SOF", "Corrupt JPEG"); - m = stbi__get_marker(z); - } - } - z->progressive = stbi__SOF_progressive(m); - if (!stbi__process_frame_header(z, scan)) - return 0; - return 1; -} - -static int stbi__skip_jpeg_junk_at_end(stbi__jpeg * j) { - // some JPEGs have junk at end, skip over it but if we find what looks - // like a valid marker, resume there - while (!stbi__at_eof(j->s)) { - int x = stbi__get8(j->s); - while (x == 255) { // might be a marker - if (stbi__at_eof(j->s)) - return STBI__MARKER_none; - x = stbi__get8(j->s); - if (x != 0x00 && x != 0xff) { - // not a stuffed zero or lead-in to another marker, looks - // like an actual marker, return it - return x; - } - // stuffed zero has x=0 now which ends the loop, meaning we go - // back to regular scan loop. - // repeated 0xff keeps trying to read the next byte of the marker. - } - } - return STBI__MARKER_none; +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) +{ + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) return stbi__err("no SOI","Corrupt JPEG"); + if (scan == STBI__SCAN_type) return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z,m)) return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) return stbi__err("no SOF", "Corrupt JPEG"); + m = stbi__get_marker(z); + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) return 0; + return 1; +} + +static stbi_uc stbi__skip_jpeg_junk_at_end(stbi__jpeg *j) +{ + // some JPEGs have junk at end, skip over it but if we find what looks + // like a valid marker, resume there + while (!stbi__at_eof(j->s)) { + stbi_uc x = stbi__get8(j->s); + while (x == 0xff) { // might be a marker + if (stbi__at_eof(j->s)) return STBI__MARKER_none; + x = stbi__get8(j->s); + if (x != 0x00 && x != 0xff) { + // not a stuffed zero or lead-in to another marker, looks + // like an actual marker, return it + return x; + } + // stuffed zero has x=0 now which ends the loop, meaning we go + // back to regular scan loop. + // repeated 0xff keeps trying to read the next byte of the marker. + } + } + return STBI__MARKER_none; } // decode image to YCbCr format -static int stbi__decode_jpeg_image(stbi__jpeg * j) { - int m; - for (m = 0; m < 4; m++) { - j->img_comp[m].raw_data = NULL; - j->img_comp[m].raw_coeff = NULL; - } - j->restart_interval = 0; - if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) - return 0; - m = stbi__get_marker(j); - while (!stbi__EOI(m)) { - if (stbi__SOS(m)) { - if (!stbi__process_scan_header(j)) - return 0; - if (!stbi__parse_entropy_coded_data(j)) - return 0; - if (j->marker == STBI__MARKER_none) { - j->marker = stbi__skip_jpeg_junk_at_end(j); - // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 - } - m = stbi__get_marker(j); - if (STBI__RESTART(m)) - m = stbi__get_marker(j); - } else if (stbi__DNL(m)) { - int Ld = stbi__get16be(j->s); - stbi__uint32 NL = stbi__get16be(j->s); - if (Ld != 4) - return stbi__err("bad DNL len", "Corrupt JPEG"); - if (NL != j->s->img_y) - return stbi__err("bad DNL height", "Corrupt JPEG"); - m = stbi__get_marker(j); - } else { - if (!stbi__process_marker(j, m)) - return 1; +static int stbi__decode_jpeg_image(stbi__jpeg *j) +{ + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) return 0; + if (!stbi__parse_entropy_coded_data(j)) return 0; + if (j->marker == STBI__MARKER_none ) { + j->marker = stbi__skip_jpeg_junk_at_end(j); + // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 + } + m = stbi__get_marker(j); + if (STBI__RESTART(m)) m = stbi__get_marker(j); - } - } - if (j->progressive) - stbi__jpeg_finish(j); - return 1; + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) return stbi__err("bad DNL height", "Corrupt JPEG"); + m = stbi__get_marker(j); + } else { + if (!stbi__process_marker(j, m)) return 1; + m = stbi__get_marker(j); + } + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; } // static jfif-centered resampling (across block boundaries) -typedef stbi_uc * (*resample_row_func)(stbi_uc * out, stbi_uc * in0, stbi_uc * in1, int w, int hs); +typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, + int w, int hs); -#define stbi__div4(x) ((stbi_uc)((x) >> 2)) +#define stbi__div4(x) ((stbi_uc) ((x) >> 2)) -static stbi_uc * resample_row_1(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - STBI_NOTUSED(out); - STBI_NOTUSED(in_far); - STBI_NOTUSED(w); - STBI_NOTUSED(hs); - return in_near; +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; } -static stbi_uc * stbi__resample_row_v_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // need to generate two samples vertically for every one in input - int i; - STBI_NOTUSED(hs); - for (i = 0; i < w; ++i) - out[i] = stbi__div4(3 * in_near[i] + in_far[i] + 2); - return out; +static stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i=0; i < w; ++i) + out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2); + return out; } -static stbi_uc * stbi__resample_row_h_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // need to generate two samples horizontally for every one in input - int i; - stbi_uc * input = in_near; +static stbi_uc* stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate two samples horizontally for every one in input + int i; + stbi_uc *input = in_near; - if (w == 1) { - // if only one sample, can't do any interpolation - out[0] = out[1] = input[0]; - return out; - } + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } - out[0] = input[0]; - out[1] = stbi__div4(input[0] * 3 + input[1] + 2); - for (i = 1; i < w - 1; ++i) { - int n = 3 * input[i] + 2; - out[i * 2 + 0] = stbi__div4(n + input[i - 1]); - out[i * 2 + 1] = stbi__div4(n + input[i + 1]); - } - out[i * 2 + 0] = stbi__div4(input[w - 2] * 3 + input[w - 1] + 2); - out[i * 2 + 1] = input[w - 1]; + out[0] = input[0]; + out[1] = stbi__div4(input[0]*3 + input[1] + 2); + for (i=1; i < w-1; ++i) { + int n = 3*input[i]+2; + out[i*2+0] = stbi__div4(n+input[i-1]); + out[i*2+1] = stbi__div4(n+input[i+1]); + } + out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2); + out[i*2+1] = input[w-1]; - STBI_NOTUSED(in_far); - STBI_NOTUSED(hs); + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); - return out; + return out; } -#define stbi__div16(x) ((stbi_uc)((x) >> 4)) +#define stbi__div16(x) ((stbi_uc) ((x) >> 4)) -static stbi_uc * stbi__resample_row_hv_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // need to generate 2x2 samples for every one in input - int i, t0, t1; - if (w == 1) { - out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); - return out; - } +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate 2x2 samples for every one in input + int i,t0,t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); + return out; + } - t1 = 3 * in_near[0] + in_far[0]; - out[0] = stbi__div4(t1 + 2); - for (i = 1; i < w; ++i) { - t0 = t1; - t1 = 3 * in_near[i] + in_far[i]; - out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); - out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - } - out[w * 2 - 1] = stbi__div4(t1 + 2); + t1 = 3*in_near[0] + in_far[0]; + out[0] = stbi__div4(t1+2); + for (i=1; i < w; ++i) { + t0 = t1; + t1 = 3*in_near[i]+in_far[i]; + out[i*2-1] = stbi__div16(3*t0 + t1 + 8); + out[i*2 ] = stbi__div16(3*t1 + t0 + 8); + } + out[w*2-1] = stbi__div4(t1+2); - STBI_NOTUSED(hs); + STBI_NOTUSED(hs); - return out; + return out; } #if defined(STBI_SSE2) || defined(STBI_NEON) -static stbi_uc * stbi__resample_row_hv_2_simd(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // need to generate 2x2 samples for every one in input - int i = 0, t0, t1; - - if (w == 1) { - out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); - return out; - } - - t1 = 3 * in_near[0] + in_far[0]; - // process groups of 8 pixels for as long as we can. - // note we can't handle the last pixel in a row in this loop - // because we need to handle the filter boundary conditions. - for (; i < ((w - 1) & ~7); i += 8) { +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate 2x2 samples for every one in input + int i=0,t0,t1; + + if (w == 1) { + out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3*in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w-1) & ~7); i += 8) { #if defined(STBI_SSE2) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - __m128i zero = _mm_setzero_si128(); - __m128i farb = _mm_loadl_epi64((__m128i *)(in_far + i)); - __m128i nearb = _mm_loadl_epi64((__m128i *)(in_near + i)); - __m128i farw = _mm_unpacklo_epi8(farb, zero); - __m128i nearw = _mm_unpacklo_epi8(nearb, zero); - __m128i diff = _mm_sub_epi16(farw, nearw); - __m128i nears = _mm_slli_epi16(nearw, 2); - __m128i curr = _mm_add_epi16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - __m128i prv0 = _mm_slli_si128(curr, 2); - __m128i nxt0 = _mm_srli_si128(curr, 2); - __m128i prev = _mm_insert_epi16(prv0, t1, 0); - __m128i next = _mm_insert_epi16(nxt0, 3 * in_near[i + 8] + in_far[i + 8], 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - __m128i bias = _mm_set1_epi16(8); - __m128i curs = _mm_slli_epi16(curr, 2); - __m128i prvd = _mm_sub_epi16(prev, curr); - __m128i nxtd = _mm_sub_epi16(next, curr); - __m128i curb = _mm_add_epi16(curs, bias); - __m128i even = _mm_add_epi16(prvd, curb); - __m128i odd = _mm_add_epi16(nxtd, curb); - - // interleave even and odd pixels, then undo scaling. - __m128i int0 = _mm_unpacklo_epi16(even, odd); - __m128i int1 = _mm_unpackhi_epi16(even, odd); - __m128i de0 = _mm_srli_epi16(int0, 4); - __m128i de1 = _mm_srli_epi16(int1, 4); - - // pack and write output - __m128i outv = _mm_packus_epi16(de0, de1); - _mm_storeu_si128((__m128i *)(out + i * 2), outv); + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *) (in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *) (out + i*2), outv); #elif defined(STBI_NEON) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - uint8x8_t farb = vld1_u8(in_far + i); - uint8x8_t nearb = vld1_u8(in_near + i); - int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); - int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); - int16x8_t curr = vaddq_s16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - int16x8_t prv0 = vextq_s16(curr, curr, 7); - int16x8_t nxt0 = vextq_s16(curr, curr, 1); - int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); - int16x8_t next = vsetq_lane_s16(3 * in_near[i + 8] + in_far[i + 8], nxt0, 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - int16x8_t curs = vshlq_n_s16(curr, 2); - int16x8_t prvd = vsubq_s16(prev, curr); - int16x8_t nxtd = vsubq_s16(next, curr); - int16x8_t even = vaddq_s16(curs, prvd); - int16x8_t odd = vaddq_s16(curs, nxtd); - - // undo scaling and round, then store with even/odd phases interleaved - uint8x8x2_t o; - o.val[0] = vqrshrun_n_s16(even, 4); - o.val[1] = vqrshrun_n_s16(odd, 4); - vst2_u8(out + i * 2, o); -#endif - - // "previous" value for next iter - t1 = 3 * in_near[i + 7] + in_far[i + 7]; - } - - t0 = t1; - t1 = 3 * in_near[i] + in_far[i]; - out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - - for (++i; i < w; ++i) { - t0 = t1; - t1 = 3 * in_near[i] + in_far[i]; - out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); - out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - } - out[w * 2 - 1] = stbi__div4(t1 + 2); - - STBI_NOTUSED(hs); - - return out; -} -#endif - -static stbi_uc * stbi__resample_row_generic(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // resample with nearest-neighbor - int i, j; - STBI_NOTUSED(in_far); - for (i = 0; i < w; ++i) - for (j = 0; j < hs; ++j) - out[i * hs + j] = in_near[i]; - return out; + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i*2, o); +#endif + + // "previous" value for next iter + t1 = 3*in_near[i+7] + in_far[i+7]; + } + + t0 = t1; + t1 = 3*in_near[i] + in_far[i]; + out[i*2] = stbi__div16(3*t1 + t0 + 8); + + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3*in_near[i]+in_far[i]; + out[i*2-1] = stbi__div16(3*t0 + t1 + 8); + out[i*2 ] = stbi__div16(3*t1 + t0 + 8); + } + out[w*2-1] = stbi__div4(t1+2); + + STBI_NOTUSED(hs); + + return out; +} +#endif + +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // resample with nearest-neighbor + int i,j; + STBI_NOTUSED(in_far); + for (i=0; i < w; ++i) + for (j=0; j < hs; ++j) + out[i*hs+j] = in_near[i]; + return out; } // this is a reduced-precision calculation of YCbCr-to-RGB introduced // to make sure the code produces the same results in both SIMD and scalar -#define stbi__float2fixed(x) (((int)((x)*4096.0f + 0.5f)) << 8) -static void stbi__YCbCr_to_RGB_row(stbi_uc * out, const stbi_uc * y, const stbi_uc * pcb, const stbi_uc * pcr, int count, - int step) { - int i; - for (i = 0; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1 << 19); // rounding - int r, g, b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr * stbi__float2fixed(1.40200f); - g = y_fixed + (cr * -stbi__float2fixed(0.71414f)) + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb * stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned)r > 255) { - if (r < 0) - r = 0; - else - r = 255; - } - if ((unsigned)g > 255) { - if (g < 0) - g = 0; - else - g = 255; - } - if ((unsigned)b > 255) { - if (b < 0) - b = 0; - else - b = 255; - } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } +#define stbi__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) +{ + int i; + for (i=0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1<<19); // rounding + int r,g,b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr* stbi__float2fixed(1.40200f); + g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb* stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #if defined(STBI_SSE2) || defined(STBI_NEON) -static void stbi__YCbCr_to_RGB_simd(stbi_uc * out, stbi_uc const * y, stbi_uc const * pcb, stbi_uc const * pcr, int count, - int step) { - int i = 0; +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) +{ + int i = 0; #ifdef STBI_SSE2 - // step == 3 is pretty ugly on the final interleave, and i'm not convinced - // it's useful in practice (you wouldn't use it for textures, for example). - // so just accelerate step == 4 case. - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - __m128i signflip = _mm_set1_epi8(-0x80); - __m128i cr_const0 = _mm_set1_epi16((short)(1.40200f * 4096.0f + 0.5f)); - __m128i cr_const1 = _mm_set1_epi16(-(short)(0.71414f * 4096.0f + 0.5f)); - __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); - __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); - __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); - __m128i xw = _mm_set1_epi16(255); // alpha channel - - for (; i + 7 < count; i += 8) { - // load - __m128i y_bytes = _mm_loadl_epi64((__m128i *)(y + i)); - __m128i cr_bytes = _mm_loadl_epi64((__m128i *)(pcr + i)); - __m128i cb_bytes = _mm_loadl_epi64((__m128i *)(pcb + i)); - __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 - __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 - - // unpack to short (and left-shift cr, cb by 8) - __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); - __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); - __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); - - // color transform - __m128i yws = _mm_srli_epi16(yw, 4); - __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); - __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); - __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); - __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); - __m128i rws = _mm_add_epi16(cr0, yws); - __m128i gwt = _mm_add_epi16(cb0, yws); - __m128i bws = _mm_add_epi16(yws, cb1); - __m128i gws = _mm_add_epi16(gwt, cr1); - - // descale - __m128i rw = _mm_srai_epi16(rws, 4); - __m128i bw = _mm_srai_epi16(bws, 4); - __m128i gw = _mm_srai_epi16(gws, 4); - - // back to byte, set up for transpose - __m128i brb = _mm_packus_epi16(rw, bw); - __m128i gxb = _mm_packus_epi16(gw, xw); - - // transpose to interleave channels - __m128i t0 = _mm_unpacklo_epi8(brb, gxb); - __m128i t1 = _mm_unpackhi_epi8(brb, gxb); - __m128i o0 = _mm_unpacklo_epi16(t0, t1); - __m128i o1 = _mm_unpackhi_epi16(t0, t1); - - // store - _mm_storeu_si128((__m128i *)(out + 0), o0); - _mm_storeu_si128((__m128i *)(out + 16), o1); - out += 32; - } - } + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16( (short) ( 1.40200f*4096.0f+0.5f)); + __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f)); + __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f)); + __m128i cb_const1 = _mm_set1_epi16( (short) ( 1.77200f*4096.0f+0.5f)); + __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i+7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i *) (out + 0), o0); + _mm_storeu_si128((__m128i *) (out + 16), o1); + out += 32; + } + } #endif #ifdef STBI_NEON - // in this version, step=3 support would be easy to add. but is there demand? - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - uint8x8_t signflip = vdup_n_u8(0x80); - int16x8_t cr_const0 = vdupq_n_s16((short)(1.40200f * 4096.0f + 0.5f)); - int16x8_t cr_const1 = vdupq_n_s16(-(short)(0.71414f * 4096.0f + 0.5f)); - int16x8_t cb_const0 = vdupq_n_s16(-(short)(0.34414f * 4096.0f + 0.5f)); - int16x8_t cb_const1 = vdupq_n_s16((short)(1.77200f * 4096.0f + 0.5f)); - - for (; i + 7 < count; i += 8) { - // load - uint8x8_t y_bytes = vld1_u8(y + i); - uint8x8_t cr_bytes = vld1_u8(pcr + i); - uint8x8_t cb_bytes = vld1_u8(pcb + i); - int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); - int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); - - // expand to s16 - int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); - int16x8_t crw = vshll_n_s8(cr_biased, 7); - int16x8_t cbw = vshll_n_s8(cb_biased, 7); - - // color transform - int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); - int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); - int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); - int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); - int16x8_t rws = vaddq_s16(yws, cr0); - int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); - int16x8_t bws = vaddq_s16(yws, cb1); - - // undo scaling, round, convert to byte - uint8x8x4_t o; - o.val[0] = vqrshrun_n_s16(rws, 4); - o.val[1] = vqrshrun_n_s16(gws, 4); - o.val[2] = vqrshrun_n_s16(bws, 4); - o.val[3] = vdup_n_u8(255); - - // store, interleaving r/g/b/a - vst4_u8(out, o); - out += 8 * 4; - } - } -#endif - - for (; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1 << 19); // rounding - int r, g, b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr * stbi__float2fixed(1.40200f); - g = y_fixed + cr * -stbi__float2fixed(0.71414f) + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb * stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned)r > 255) { - if (r < 0) - r = 0; - else - r = 255; - } - if ((unsigned)g > 255) { - if (g < 0) - g = 0; - else - g = 255; - } - if ((unsigned)b > 255) { - if (b < 0) - b = 0; - else - b = 255; - } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16( (short) ( 1.40200f*4096.0f+0.5f)); + int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f)); + int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f)); + int16x8_t cb_const1 = vdupq_n_s16( (short) ( 1.77200f*4096.0f+0.5f)); + + for (; i+7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8*4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1<<19); // rounding + int r,g,b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr* stbi__float2fixed(1.40200f); + g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb* stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #endif // set up the kernels -static void stbi__setup_jpeg(stbi__jpeg * j) { - j->idct_block_kernel = stbi__idct_block; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; +static void stbi__setup_jpeg(stbi__jpeg *j) +{ + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; #ifdef STBI_SSE2 - if (stbi__sse2_available()) { - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; - } + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } #endif #ifdef STBI_NEON - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; #endif } // clean up the temporary component buffers -static void stbi__cleanup_jpeg(stbi__jpeg * j) { stbi__free_jpeg_components(j, j->s->img_n, 0); } - -typedef struct { - resample_row_func resample; - stbi_uc *line0, *line1; - int hs, vs; // expansion factor in each axis - int w_lores; // horizontal pixels pre-expansion - int ystep; // how far through vertical expansion we are - int ypos; // which pre-expansion row we're on +static void stbi__cleanup_jpeg(stbi__jpeg *j) +{ + stbi__free_jpeg_components(j, j->s->img_n, 0); +} + +typedef struct +{ + resample_row_func resample; + stbi_uc *line0,*line1; + int hs,vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on } stbi__resample; // fast 0..255 * 0..255 => 0..255 rounded multiplication -static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) { - unsigned int t = x * y + 128; - return (stbi_uc)((t + (t >> 8)) >> 8); +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) +{ + unsigned int t = x*y + 128; + return (stbi_uc) ((t + (t >>8)) >> 8); } -static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * comp, int req_comp) { - int n, decode_n, is_rgb; - z->s->img_n = 0; // make stbi__cleanup_jpeg safe - - // validate req_comp - if (req_comp < 0 || req_comp > 4) - return stbi__errpuc("bad req_comp", "Internal error"); - - // load a jpeg image from whichever source, but leave in YCbCr format - if (!stbi__decode_jpeg_image(z)) { - stbi__cleanup_jpeg(z); - return NULL; - } - - // determine actual number of components to generate - n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; - - is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); - - if (z->s->img_n == 3 && n < 3 && !is_rgb) - decode_n = 1; - else - decode_n = z->s->img_n; - - // nothing to do if no components requested; check this now to avoid - // accessing uninitialized coutput[0] later - if (decode_n <= 0) { - stbi__cleanup_jpeg(z); - return NULL; - } - - // resample and color-convert - { - int k; - unsigned int i, j; - stbi_uc * output; - stbi_uc * coutput[4] = {NULL, NULL, NULL, NULL}; - - stbi__resample res_comp[4]; - - for (k = 0; k < decode_n; ++k) { - stbi__resample * r = &res_comp[k]; - - // allocate line buffer big enough for upsampling off the edges - // with upsample factor of 4 - z->img_comp[k].linebuf = (stbi_uc *)stbi__malloc(z->s->img_x + 3); - if (!z->img_comp[k].linebuf) { - stbi__cleanup_jpeg(z); - return stbi__errpuc("outofmem", "Out of memory"); - } - - r->hs = z->img_h_max / z->img_comp[k].h; - r->vs = z->img_v_max / z->img_comp[k].v; - r->ystep = r->vs >> 1; - r->w_lores = (z->s->img_x + r->hs - 1) / r->hs; - r->ypos = 0; - r->line0 = r->line1 = z->img_comp[k].data; - - if (r->hs == 1 && r->vs == 1) - r->resample = resample_row_1; - else if (r->hs == 1 && r->vs == 2) - r->resample = stbi__resample_row_v_2; - else if (r->hs == 2 && r->vs == 1) - r->resample = stbi__resample_row_h_2; - else if (r->hs == 2 && r->vs == 2) - r->resample = z->resample_row_hv_2_kernel; - else - r->resample = stbi__resample_row_generic; - } - - // can't error after this so, this is safe - output = (stbi_uc *)stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); - if (!output) { - stbi__cleanup_jpeg(z); - return stbi__errpuc("outofmem", "Out of memory"); - } - - // now go ahead and resample - for (j = 0; j < z->s->img_y; ++j) { - stbi_uc * out = output + n * z->s->img_x * j; - for (k = 0; k < decode_n; ++k) { - stbi__resample * r = &res_comp[k]; - int y_bot = r->ystep >= (r->vs >> 1); - coutput[k] = r->resample(z->img_comp[k].linebuf, y_bot ? r->line1 : r->line0, y_bot ? r->line0 : r->line1, - r->w_lores, r->hs); - if (++r->ystep >= r->vs) { - r->ystep = 0; - r->line0 = r->line1; - if (++r->ypos < z->img_comp[k].y) - r->line1 += z->img_comp[k].w2; - } +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) +{ + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // nothing to do if no components requested; check this now to avoid + // accessing uninitialized coutput[0] later + if (decode_n <= 0) { stbi__cleanup_jpeg(z); return NULL; } + + // resample and color-convert + { + int k; + unsigned int i,j; + stbi_uc *output; + stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL }; + + stbi__resample res_comp[4]; + + for (k=0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } + + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs-1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; + else r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } + + // now go ahead and resample + for (j=0; j < z->s->img_y; ++j) { + stbi_uc *out = output + n * z->s->img_x * j; + for (k=0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = r->resample(z->img_comp[k].linebuf, + y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, + r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; } - if (n >= 3) { - stbi_uc * y = coutput[0]; - if (z->s->img_n == 3) { - if (is_rgb) { - for (i = 0; i < z->s->img_x; ++i) { - out[0] = y[i]; - out[1] = coutput[1][i]; - out[2] = coutput[2][i]; - out[3] = 255; - out += n; - } - } else { - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else if (z->s->img_n == 4) { - if (z->app14_color_transform == 0) { // CMYK - for (i = 0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(coutput[0][i], m); - out[1] = stbi__blinn_8x8(coutput[1][i], m); - out[2] = stbi__blinn_8x8(coutput[2][i], m); - out[3] = 255; - out += n; - } - } else if (z->app14_color_transform == 2) { // YCCK - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - for (i = 0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(255 - out[0], m); - out[1] = stbi__blinn_8x8(255 - out[1], m); - out[2] = stbi__blinn_8x8(255 - out[2], m); - out += n; - } - } else { // YCbCr + alpha? Ignore the fourth channel for now - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else - for (i = 0; i < z->s->img_x; ++i) { - out[0] = out[1] = out[2] = y[i]; - out[3] = 255; // not used if n==3 - out += n; - } + } + if (n >= 3) { + stbi_uc *y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i=0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; + } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; + } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else + for (i=0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i=0; i < z->s->img_x; ++i) + *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i=0; i < z->s->img_x; ++i, out += 2) { + out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i=0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } } else { - if (is_rgb) { - if (n == 1) - for (i = 0; i < z->s->img_x; ++i) - *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - else { - for (i = 0; i < z->s->img_x; ++i, out += 2) { - out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - out[1] = 255; - } - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { - for (i = 0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); - stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); - stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); - out[0] = stbi__compute_y(r, g, b); - out[1] = 255; - out += n; - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { - for (i = 0; i < z->s->img_x; ++i) { - out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); - out[1] = 255; - out += n; - } - } else { - stbi_uc * y = coutput[0]; - if (n == 1) - for (i = 0; i < z->s->img_x; ++i) - out[i] = y[i]; - else - for (i = 0; i < z->s->img_x; ++i) { - *out++ = y[i]; - *out++ = 255; - } - } + stbi_uc *y = coutput[0]; + if (n == 1) + for (i=0; i < z->s->img_x; ++i) out[i] = y[i]; + else + for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } } - } - stbi__cleanup_jpeg(z); - *out_x = z->s->img_x; - *out_y = z->s->img_y; - if (comp) - *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output - return output; - } + } + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } } -static void * stbi__jpeg_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - unsigned char * result; - stbi__jpeg * j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); - if (!j) - return stbi__errpuc("outofmem", "Out of memory"); - memset(j, 0, sizeof(stbi__jpeg)); - STBI_NOTUSED(ri); - j->s = s; - stbi__setup_jpeg(j); - result = load_jpeg_image(j, x, y, comp, req_comp); - STBI_FREE(j); - return result; -} - -static int stbi__jpeg_test(stbi__context * s) { - int r; - stbi__jpeg * j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); - if (!j) - return stbi__err("outofmem", "Out of memory"); - memset(j, 0, sizeof(stbi__jpeg)); - j->s = s; - stbi__setup_jpeg(j); - r = stbi__decode_jpeg_header(j, STBI__SCAN_type); - stbi__rewind(s); - STBI_FREE(j); - return r; -} - -static int stbi__jpeg_info_raw(stbi__jpeg * j, int * x, int * y, int * comp) { - if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { - stbi__rewind(j->s); - return 0; - } - if (x) - *x = j->s->img_x; - if (y) - *y = j->s->img_y; - if (comp) - *comp = j->s->img_n >= 3 ? 3 : 1; - return 1; +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + unsigned char* result; + stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg)); + if (!j) return stbi__errpuc("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x,y,comp,req_comp); + STBI_FREE(j); + return result; +} + +static int stbi__jpeg_test(stbi__context *s) +{ + int r; + stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg)); + if (!j) return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; +} + +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) +{ + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind( j->s ); + return 0; + } + if (x) *x = j->s->img_x; + if (y) *y = j->s->img_y; + if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; } -static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp) { - int result; - stbi__jpeg * j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); - if (!j) - return stbi__err("outofmem", "Out of memory"); - memset(j, 0, sizeof(stbi__jpeg)); - j->s = s; - result = stbi__jpeg_info_raw(j, x, y, comp); - STBI_FREE(j); - return result; +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) +{ + int result; + stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg))); + if (!j) return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; } #endif @@ -4278,381 +4088,383 @@ static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp) { #ifndef STBI_NO_ZLIB // fast-way is faster to check than jpeg huffman, but slow way is slower -#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables -#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) #define STBI__ZNSYMS 288 // number of symbols in literal/length alphabet // zlib-style huffman encoding // (jpegs packs from left, zlib from right, so can't share code) -typedef struct { - stbi__uint16 fast[1 << STBI__ZFAST_BITS]; - stbi__uint16 firstcode[16]; - int maxcode[17]; - stbi__uint16 firstsymbol[16]; - stbi_uc size[STBI__ZNSYMS]; - stbi__uint16 value[STBI__ZNSYMS]; +typedef struct +{ + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[STBI__ZNSYMS]; + stbi__uint16 value[STBI__ZNSYMS]; } stbi__zhuffman; -stbi_inline static int stbi__bitreverse16(int n) { - n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); - n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); - n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); - n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); - return n; -} - -stbi_inline static int stbi__bit_reverse(int v, int bits) { - STBI_ASSERT(bits <= 16); - // to bit reverse n bits, reverse 16 and shift - // e.g. 11 bits, bit reverse and shift away 5 - return stbi__bitreverse16(v) >> (16 - bits); -} - -static int stbi__zbuild_huffman(stbi__zhuffman * z, const stbi_uc * sizelist, int num) { - int i, k = 0; - int code, next_code[16], sizes[17]; - - // DEFLATE spec for generating codes - memset(sizes, 0, sizeof(sizes)); - memset(z->fast, 0, sizeof(z->fast)); - for (i = 0; i < num; ++i) - ++sizes[sizelist[i]]; - sizes[0] = 0; - for (i = 1; i < 16; ++i) - if (sizes[i] > (1 << i)) - return stbi__err("bad sizes", "Corrupt PNG"); - code = 0; - for (i = 1; i < 16; ++i) { - next_code[i] = code; - z->firstcode[i] = (stbi__uint16)code; - z->firstsymbol[i] = (stbi__uint16)k; - code = (code + sizes[i]); - if (sizes[i]) - if (code - 1 >= (1 << i)) - return stbi__err("bad codelengths", "Corrupt PNG"); - z->maxcode[i] = code << (16 - i); // preshift for inner loop - code <<= 1; - k += sizes[i]; - } - z->maxcode[16] = 0x10000; // sentinel - for (i = 0; i < num; ++i) { - int s = sizelist[i]; - if (s) { - int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; - stbi__uint16 fastv = (stbi__uint16)((s << 9) | i); - z->size[c] = (stbi_uc)s; - z->value[c] = (stbi__uint16)i; - if (s <= STBI__ZFAST_BITS) { - int j = stbi__bit_reverse(next_code[s], s); - while (j < (1 << STBI__ZFAST_BITS)) { - z->fast[j] = fastv; - j += (1 << s); - } - } - ++next_code[s]; - } - } - return 1; +stbi_inline static int stbi__bitreverse16(int n) +{ + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); + return n; } -// zlib-from-memory implementation for PNG reading -// because PNG allows splitting the zlib stream arbitrarily, -// and it's annoying structurally to have PNG call ZLIB call PNG, -// we require PNG read all the IDATs and combine them into a single -// memory buffer - -typedef struct { - stbi_uc *zbuffer, *zbuffer_end; - int num_bits; - stbi__uint32 code_buffer; - - char * zout; - char * zout_start; - char * zout_end; - int z_expandable; - - stbi__zhuffman z_length, z_distance; -} stbi__zbuf; - -stbi_inline static int stbi__zeof(stbi__zbuf * z) { return (z->zbuffer >= z->zbuffer_end); } - -stbi_inline static stbi_uc stbi__zget8(stbi__zbuf * z) { return stbi__zeof(z) ? 0 : *z->zbuffer++; } +stbi_inline static int stbi__bit_reverse(int v, int bits) +{ + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16-bits); +} -static void stbi__fill_bits(stbi__zbuf * z) { - do { - if (z->code_buffer >= (1U << z->num_bits)) { - z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ - return; - } - z->code_buffer |= (unsigned int)stbi__zget8(z) << z->num_bits; - z->num_bits += 8; - } while (z->num_bits <= 24); -} - -stbi_inline static unsigned int stbi__zreceive(stbi__zbuf * z, int n) { - unsigned int k; - if (z->num_bits < n) - stbi__fill_bits(z); - k = z->code_buffer & ((1 << n) - 1); - z->code_buffer >>= n; - z->num_bits -= n; - return k; -} - -static int stbi__zhuffman_decode_slowpath(stbi__zbuf * a, stbi__zhuffman * z) { - int b, s, k; - // not resolved by fast table, so compute it the slow way - // use jpeg approach, which requires MSbits at top - k = stbi__bit_reverse(a->code_buffer, 16); - for (s = STBI__ZFAST_BITS + 1;; ++s) - if (k < z->maxcode[s]) - break; - if (s >= 16) - return -1; // invalid code! - // code size is s, so: - b = (k >> (16 - s)) - z->firstcode[s] + z->firstsymbol[s]; - if (b >= STBI__ZNSYMS) - return -1; // some data was corrupt somewhere! - if (z->size[b] != s) - return -1; // was originally an assert, but report failure instead. - a->code_buffer >>= s; - a->num_bits -= s; - return z->value[b]; -} - -stbi_inline static int stbi__zhuffman_decode(stbi__zbuf * a, stbi__zhuffman * z) { - int b, s; - if (a->num_bits < 16) { - if (stbi__zeof(a)) { - return -1; /* report error for unexpected end of data. */ - } - stbi__fill_bits(a); - } - b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; - if (b) { - s = b >> 9; - a->code_buffer >>= s; - a->num_bits -= s; - return b & 511; - } - return stbi__zhuffman_decode_slowpath(a, z); -} - -static int stbi__zexpand(stbi__zbuf * z, char * zout, int n) // need to make room for n bytes -{ - char * q; - unsigned int cur, limit, old_limit; - z->zout = zout; - if (!z->z_expandable) - return stbi__err("output buffer limit", "Corrupt PNG"); - cur = (unsigned int)(z->zout - z->zout_start); - limit = old_limit = (unsigned)(z->zout_end - z->zout_start); - if (UINT_MAX - cur < (unsigned)n) - return stbi__err("outofmem", "Out of memory"); - while (cur + n > limit) { - if (limit > UINT_MAX / 2) - return stbi__err("outofmem", "Out of memory"); - limit *= 2; - } - q = (char *)STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); - STBI_NOTUSED(old_limit); - if (q == NULL) - return stbi__err("outofmem", "Out of memory"); - z->zout_start = q; - z->zout = q + cur; - z->zout_end = q + limit; - return 1; -} - -static const int stbi__zlength_base[31] = {3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, - 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; - -static const int stbi__zlength_extra[31] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, - 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0}; - -static const int stbi__zdist_base[32] = {1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, - 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, - 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; - -static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, - 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; - -static int stbi__parse_huffman_block(stbi__zbuf * a) { - char * zout = a->zout; - for (;;) { - int z = stbi__zhuffman_decode(a, &a->z_length); - if (z < 256) { - if (z < 0) - return stbi__err("bad huffman code", "Corrupt PNG"); // error in huffman codes - if (zout >= a->zout_end) { - if (!stbi__zexpand(a, zout, 1)) - return 0; - zout = a->zout; - } - *zout++ = (char)z; - } else { - stbi_uc * p; - int len, dist; - if (z == 256) { - a->zout = zout; - return 1; - } - if (z >= 286) - return stbi__err("bad huffman code", - "Corrupt PNG"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data - z -= 257; - len = stbi__zlength_base[z]; - if (stbi__zlength_extra[z]) - len += stbi__zreceive(a, stbi__zlength_extra[z]); - z = stbi__zhuffman_decode(a, &a->z_distance); - if (z < 0 || z >= 30) - return stbi__err("bad huffman code", - "Corrupt PNG"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data - dist = stbi__zdist_base[z]; - if (stbi__zdist_extra[z]) - dist += stbi__zreceive(a, stbi__zdist_extra[z]); - if (zout - a->zout_start < dist) - return stbi__err("bad dist", "Corrupt PNG"); - if (zout + len > a->zout_end) { - if (!stbi__zexpand(a, zout, len)) - return 0; - zout = a->zout; - } - p = (stbi_uc *)(zout - dist); - if (dist == 1) { // run of one byte; common in images. - stbi_uc v = *p; - if (len) { - do - *zout++ = v; - while (--len); - } - } else { - if (len) { - do - *zout++ = *p++; - while (--len); - } +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) +{ + int i,k=0; + int code, next_code[16], sizes[17]; + + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i=0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i=1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i=1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16) code; + z->firstsymbol[i] = (stbi__uint16) k; + code = (code + sizes[i]); + if (sizes[i]) + if (code-1 >= (1 << i)) return stbi__err("bad codelengths","Corrupt PNG"); + z->maxcode[i] = code << (16-i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i=0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i); + z->size [c] = (stbi_uc ) s; + z->value[c] = (stbi__uint16) i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s],s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); } - } - } + } + ++next_code[s]; + } + } + return 1; +} + +// zlib-from-memory implementation for PNG reading +// because PNG allows splitting the zlib stream arbitrarily, +// and it's annoying structurally to have PNG call ZLIB call PNG, +// we require PNG read all the IDATs and combine them into a single +// memory buffer + +typedef struct +{ + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + int hit_zeof_once; + stbi__uint32 code_buffer; + + char *zout; + char *zout_start; + char *zout_end; + int z_expandable; + + stbi__zhuffman z_length, z_distance; +} stbi__zbuf; + +stbi_inline static int stbi__zeof(stbi__zbuf *z) +{ + return (z->zbuffer >= z->zbuffer_end); } -static int stbi__compute_huffman_codes(stbi__zbuf * a) { - static const stbi_uc length_dezigzag[19] = {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; - stbi__zhuffman z_codelength; - stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op - stbi_uc codelength_sizes[19]; - int i, n; +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) +{ + return stbi__zeof(z) ? 0 : *z->zbuffer++; +} - int hlit = stbi__zreceive(a, 5) + 257; - int hdist = stbi__zreceive(a, 5) + 1; - int hclen = stbi__zreceive(a, 4) + 4; - int ntot = hlit + hdist; +static void stbi__fill_bits(stbi__zbuf *z) +{ + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} - memset(codelength_sizes, 0, sizeof(codelength_sizes)); - for (i = 0; i < hclen; ++i) { - int s = stbi__zreceive(a, 3); - codelength_sizes[length_dezigzag[i]] = (stbi_uc)s; - } - if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) - return 0; +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) +{ + unsigned int k; + if (z->num_bits < n) stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} - n = 0; - while (n < ntot) { - int c = stbi__zhuffman_decode(a, &z_codelength); - if (c < 0 || c >= 19) - return stbi__err("bad codelengths", "Corrupt PNG"); - if (c < 16) - lencodes[n++] = (stbi_uc)c; - else { - stbi_uc fill = 0; - if (c == 16) { - c = stbi__zreceive(a, 2) + 3; - if (n == 0) - return stbi__err("bad codelengths", "Corrupt PNG"); - fill = lencodes[n - 1]; - } else if (c == 17) { - c = stbi__zreceive(a, 3) + 3; - } else if (c == 18) { - c = stbi__zreceive(a, 7) + 11; - } else { - return stbi__err("bad codelengths", "Corrupt PNG"); +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) +{ + int b,s,k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s=STBI__ZFAST_BITS+1; ; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) return -1; // invalid code! + // code size is s, so: + b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= STBI__ZNSYMS) return -1; // some data was corrupt somewhere! + if (z->size[b] != s) return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} + +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) +{ + int b,s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + if (!a->hit_zeof_once) { + // This is the first time we hit eof, insert 16 extra padding btis + // to allow us to keep going; if we actually consume any of them + // though, that is invalid data. This is caught later. + a->hit_zeof_once = 1; + a->num_bits += 16; // add 16 implicit zero bits + } else { + // We already inserted our extra 16 padding bits and are again + // out, this stream is actually prematurely terminated. + return -1; + } + } else { + stbi__fill_bits(a); + } + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes +{ + char *q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) return stbi__err("output buffer limit","Corrupt PNG"); + cur = (unsigned int) (z->zout - z->zout_start); + limit = old_limit = (unsigned) (z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned) n) return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if(limit > UINT_MAX / 2) return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; +} + +static const int stbi__zlength_base[31] = { + 3,4,5,6,7,8,9,10,11,13, + 15,17,19,23,27,31,35,43,51,59, + 67,83,99,115,131,163,195,227,258,0,0 }; + +static const int stbi__zlength_extra[31]= +{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 }; + +static const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193, +257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0}; + +static const int stbi__zdist_extra[32] = +{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; + +static int stbi__parse_huffman_block(stbi__zbuf *a) +{ + char *zout = a->zout; + for(;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) return 0; + zout = a->zout; + } + *zout++ = (char) z; + } else { + stbi_uc *p; + int len,dist; + if (z == 256) { + a->zout = zout; + if (a->hit_zeof_once && a->num_bits < 16) { + // The first time we hit zeof, we inserted 16 extra zero bits into our bit + // buffer so the decoder can just do its speculative decoding. But if we + // actually consumed any of those bits (which is the case when num_bits < 16), + // the stream actually read past the end so it is malformed. + return stbi__err("unexpected end","Corrupt PNG"); } - if (ntot - n < c) - return stbi__err("bad codelengths", "Corrupt PNG"); - memset(lencodes + n, fill, c); - n += c; - } - } - if (n != ntot) - return stbi__err("bad codelengths", "Corrupt PNG"); - if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) - return 0; - if (!stbi__zbuild_huffman(&a->z_distance, lencodes + hlit, hdist)) - return 0; - return 1; -} - -static int stbi__parse_uncompressed_block(stbi__zbuf * a) { - stbi_uc header[4]; - int len, nlen, k; - if (a->num_bits & 7) - stbi__zreceive(a, a->num_bits & 7); // discard - // drain the bit-packed data into header - k = 0; - while (a->num_bits > 0) { - header[k++] = (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check - a->code_buffer >>= 8; - a->num_bits -= 8; - } - if (a->num_bits < 0) - return stbi__err("zlib corrupt", "Corrupt PNG"); - // now fill header the normal way - while (k < 4) - header[k++] = stbi__zget8(a); - len = header[1] * 256 + header[0]; - nlen = header[3] * 256 + header[2]; - if (nlen != (len ^ 0xffff)) - return stbi__err("zlib corrupt", "Corrupt PNG"); - if (a->zbuffer + len > a->zbuffer_end) - return stbi__err("read past buffer", "Corrupt PNG"); - if (a->zout + len > a->zout_end) - if (!stbi__zexpand(a, a->zout, len)) - return 0; - memcpy(a->zout, a->zbuffer, len); - a->zbuffer += len; - a->zout += len; - return 1; -} - -static int stbi__parse_zlib_header(stbi__zbuf * a) { - int cmf = stbi__zget8(a); - int cm = cmf & 15; - /* int cinfo = cmf >> 4; */ - int flg = stbi__zget8(a); - if (stbi__zeof(a)) - return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec - if ((cmf * 256 + flg) % 31 != 0) - return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec - if (flg & 32) - return stbi__err("no preset dict", "Corrupt PNG"); // preset dictionary not allowed in png - if (cm != 8) - return stbi__err("bad compression", "Corrupt PNG"); // DEFLATE required for png - // window = 1 << (8 + cinfo)... but who cares, we fully buffer output - return 1; -} - -static const stbi_uc stbi__zdefault_length[STBI__ZNSYMS] = { - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8}; -static const stbi_uc stbi__zdefault_distance[32] = {5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, - 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}; + return 1; + } + if (z >= 286) return stbi__err("bad huffman code","Corrupt PNG"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0 || z >= 30) return stbi__err("bad huffman code","Corrupt PNG"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) return stbi__err("bad dist","Corrupt PNG"); + if (len > a->zout_end - zout) { + if (!stbi__zexpand(a, zout, len)) return 0; + zout = a->zout; + } + p = (stbi_uc *) (zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { do *zout++ = v; while (--len); } + } else { + if (len) { do *zout++ = *p++; while (--len); } + } + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf *a) +{ + static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 }; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286+32+137];//padding for maximum single op + stbi_uc codelength_sizes[19]; + int i,n; + + int hlit = stbi__zreceive(a,5) + 257; + int hdist = stbi__zreceive(a,5) + 1; + int hclen = stbi__zreceive(a,4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i=0; i < hclen; ++i) { + int s = stbi__zreceive(a,3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc) s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc) c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a,2)+3; + if (n == 0) return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n-1]; + } else if (c == 17) { + c = stbi__zreceive(a,3)+3; + } else if (c == 18) { + c = stbi__zreceive(a,7)+11; + } else { + return stbi__err("bad codelengths", "Corrupt PNG"); + } + if (ntot - n < c) return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes+n, fill, c); + n += c; + } + } + if (n != ntot) return stbi__err("bad codelengths","Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0; + return 1; +} + +static int stbi__parse_uncompressed_block(stbi__zbuf *a) +{ + stbi_uc header[4]; + int len,nlen,k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) return stbi__err("zlib corrupt","Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) return stbi__err("zlib corrupt","Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) return stbi__err("read past buffer","Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; +} + +static int stbi__parse_zlib_header(stbi__zbuf *a) +{ + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec + if ((cmf*256+flg) % 31 != 0) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec + if (flg & 32) return stbi__err("no preset dict","Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) return stbi__err("bad compression","Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; +} + +static const stbi_uc stbi__zdefault_length[STBI__ZNSYMS] = +{ + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8 +}; +static const stbi_uc stbi__zdefault_distance[32] = +{ + 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 +}; /* Init algorithm: { @@ -4666,122 +4478,118 @@ Init algorithm: } */ -static int stbi__parse_zlib(stbi__zbuf * a, int parse_header) { - int final, type; - if (parse_header) - if (!stbi__parse_zlib_header(a)) - return 0; - a->num_bits = 0; - a->code_buffer = 0; - do { - final = stbi__zreceive(a, 1); - type = stbi__zreceive(a, 2); - if (type == 0) { - if (!stbi__parse_uncompressed_block(a)) - return 0; - } else if (type == 3) { - return 0; - } else { - if (type == 1) { - // use fixed code lengths - if (!stbi__zbuild_huffman(&a->z_length, stbi__zdefault_length, STBI__ZNSYMS)) - return 0; - if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) - return 0; - } else { - if (!stbi__compute_huffman_codes(a)) - return 0; - } - if (!stbi__parse_huffman_block(a)) - return 0; - } - } while (!final); - return 1; -} - -static int stbi__do_zlib(stbi__zbuf * a, char * obuf, int olen, int exp, int parse_header) { - a->zout_start = obuf; - a->zout = obuf; - a->zout_end = obuf + olen; - a->z_expandable = exp; - - return stbi__parse_zlib(a, parse_header); -} - -STBIDEF char * stbi_zlib_decode_malloc_guesssize(const char * buffer, int len, int initial_size, int * outlen) { - stbi__zbuf a; - char * p = (char *)stbi__malloc(initial_size); - if (p == NULL) - return NULL; - a.zbuffer = (stbi_uc *)buffer; - a.zbuffer_end = (stbi_uc *)buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { - if (outlen) - *outlen = (int)(a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) +{ + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) return 0; + a->num_bits = 0; + a->code_buffer = 0; + a->hit_zeof_once = 0; + do { + final = stbi__zreceive(a,1); + type = stbi__zreceive(a,2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length , stbi__zdefault_length , STBI__ZNSYMS)) return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) return 0; + } else { + if (!stbi__compute_huffman_codes(a)) return 0; + } + if (!stbi__parse_huffman_block(a)) return 0; + } + } while (!final); + return 1; +} + +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) +{ + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; -STBIDEF char * stbi_zlib_decode_malloc(char const * buffer, int len, int * outlen) { - return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); + return stbi__parse_zlib(a, parse_header); } -STBIDEF char * stbi_zlib_decode_malloc_guesssize_headerflag(const char * buffer, int len, int initial_size, int * outlen, - int parse_header) { - stbi__zbuf a; - char * p = (char *)stbi__malloc(initial_size); - if (p == NULL) - return NULL; - a.zbuffer = (stbi_uc *)buffer; - a.zbuffer_end = (stbi_uc *)buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { - if (outlen) - *outlen = (int)(a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(initial_size); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) +{ + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); } -STBIDEF int stbi_zlib_decode_buffer(char * obuffer, int olen, char const * ibuffer, int ilen) { - stbi__zbuf a; - a.zbuffer = (stbi_uc *)ibuffer; - a.zbuffer_end = (stbi_uc *)ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) - return (int)(a.zout - a.zout_start); - else - return -1; -} - -STBIDEF char * stbi_zlib_decode_noheader_malloc(char const * buffer, int len, int * outlen) { - stbi__zbuf a; - char * p = (char *)stbi__malloc(16384); - if (p == NULL) - return NULL; - a.zbuffer = (stbi_uc *)buffer; - a.zbuffer_end = (stbi_uc *)buffer + len; - if (stbi__do_zlib(&a, p, 16384, 1, 0)) { - if (outlen) - *outlen = (int)(a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(initial_size); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) +{ + stbi__zbuf a; + a.zbuffer = (stbi_uc *) ibuffer; + a.zbuffer_end = (stbi_uc *) ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int) (a.zout - a.zout_start); + else + return -1; } -STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const char * ibuffer, int ilen) { - stbi__zbuf a; - a.zbuffer = (stbi_uc *)ibuffer; - a.zbuffer_end = (stbi_uc *)ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) - return (int)(a.zout - a.zout_start); - else - return -1; +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(16384); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer+len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) +{ + stbi__zbuf a; + a.zbuffer = (stbi_uc *) ibuffer; + a.zbuffer_end = (stbi_uc *) ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int) (a.zout - a.zout_start); + else + return -1; } #endif @@ -4796,1303 +4604,1131 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const cha // - uses stb_zlib, a PD zlib implementation with fast huffman decoding #ifndef STBI_NO_PNG -typedef struct { - stbi__uint32 length; - stbi__uint32 type; +typedef struct +{ + stbi__uint32 length; + stbi__uint32 type; } stbi__pngchunk; -static stbi__pngchunk stbi__get_chunk_header(stbi__context * s) { - stbi__pngchunk c; - c.length = stbi__get32be(s); - c.type = stbi__get32be(s); - return c; +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) +{ + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; } -static int stbi__check_png_header(stbi__context * s) { - static const stbi_uc png_sig[8] = {137, 80, 78, 71, 13, 10, 26, 10}; - int i; - for (i = 0; i < 8; ++i) - if (stbi__get8(s) != png_sig[i]) - return stbi__err("bad png sig", "Not a PNG"); - return 1; +static int stbi__check_png_header(stbi__context *s) +{ + static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 }; + int i; + for (i=0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) return stbi__err("bad png sig","Not a PNG"); + return 1; } -typedef struct { - stbi__context * s; - stbi_uc *idata, *expanded, *out; - int depth; +typedef struct +{ + stbi__context *s; + stbi_uc *idata, *expanded, *out; + int depth; } stbi__png; + enum { - STBI__F_none = 0, - STBI__F_sub = 1, - STBI__F_up = 2, - STBI__F_avg = 3, - STBI__F_paeth = 4, - // synthetic filters used for first scanline to avoid needing a dummy row of 0s - STBI__F_avg_first, - STBI__F_paeth_first + STBI__F_none=0, + STBI__F_sub=1, + STBI__F_up=2, + STBI__F_avg=3, + STBI__F_paeth=4, + // synthetic filter used for first scanline to avoid needing a dummy row of 0s + STBI__F_avg_first }; -static stbi_uc first_row_filter[5] = {STBI__F_none, STBI__F_sub, STBI__F_none, STBI__F_avg_first, STBI__F_paeth_first}; +static stbi_uc first_row_filter[5] = +{ + STBI__F_none, + STBI__F_sub, + STBI__F_none, + STBI__F_avg_first, + STBI__F_sub // Paeth with b=c=0 turns out to be equivalent to sub +}; -static int stbi__paeth(int a, int b, int c) { - int p = a + b - c; - int pa = abs(p - a); - int pb = abs(p - b); - int pc = abs(p - c); - if (pa <= pb && pa <= pc) - return a; - if (pb <= pc) - return b; - return c; +static int stbi__paeth(int a, int b, int c) +{ + // This formulation looks very different from the reference in the PNG spec, but is + // actually equivalent and has favorable data dependencies and admits straightforward + // generation of branch-free code, which helps performance significantly. + int thresh = c*3 - (a + b); + int lo = a < b ? a : b; + int hi = a < b ? b : a; + int t0 = (hi <= thresh) ? lo : c; + int t1 = (thresh <= lo) ? hi : t0; + return t1; +} + +static const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; + +// adds an extra all-255 alpha channel +// dest == src is legal +// img_n must be 1 or 3 +static void stbi__create_png_alpha_expand8(stbi_uc *dest, stbi_uc *src, stbi__uint32 x, int img_n) +{ + int i; + // must process data backwards since we allow dest==src + if (img_n == 1) { + for (i=x-1; i >= 0; --i) { + dest[i*2+1] = 255; + dest[i*2+0] = src[i]; + } + } else { + STBI_ASSERT(img_n == 3); + for (i=x-1; i >= 0; --i) { + dest[i*4+3] = 255; + dest[i*4+2] = src[i*3+2]; + dest[i*4+1] = src[i*3+1]; + dest[i*4+0] = src[i*3+0]; + } + } } -static const stbi_uc stbi__depth_scale_table[9] = {0, 0xff, 0x55, 0, 0x11, 0, 0, 0, 0x01}; - // create the png data from post-deflated data -static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, - stbi__uint32 y, int depth, int color) { - int bytes = (depth == 16 ? 2 : 1); - stbi__context * s = a->s; - stbi__uint32 i, j, stride = x * out_n * bytes; - stbi__uint32 img_len, img_width_bytes; - int k; - int img_n = s->img_n; // copy it into a local for later - - int output_bytes = out_n * bytes; - int filter_bytes = img_n * bytes; - int width = x; - - STBI_ASSERT(out_n == s->img_n || out_n == s->img_n + 1); - a->out = (stbi_uc *)stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into - if (!a->out) - return stbi__err("outofmem", "Out of memory"); - - if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) - return stbi__err("too large", "Corrupt PNG"); - img_width_bytes = (((img_n * x * depth) + 7) >> 3); - img_len = (img_width_bytes + 1) * y; - - // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, - // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), - // so just check for raw_len < img_len always. - if (raw_len < img_len) - return stbi__err("not enough pixels", "Corrupt PNG"); - - for (j = 0; j < y; ++j) { - stbi_uc * cur = a->out + stride * j; - stbi_uc * prior; - int filter = *raw++; - - if (filter > 4) - return stbi__err("invalid filter", "Corrupt PNG"); - - if (depth < 8) { - if (img_width_bytes > x) - return stbi__err("invalid width", "Corrupt PNG"); - cur += x * out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place - filter_bytes = 1; - width = img_width_bytes; - } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above - - // if first row, use special filter that doesn't sample previous row - if (j == 0) - filter = first_row_filter[filter]; - - // handle first byte explicitly - for (k = 0; k < filter_bytes; ++k) { - switch (filter) { - case STBI__F_none: - cur[k] = raw[k]; - break; - case STBI__F_sub: - cur[k] = raw[k]; - break; - case STBI__F_up: - cur[k] = STBI__BYTECAST(raw[k] + prior[k]); - break; - case STBI__F_avg: - cur[k] = STBI__BYTECAST(raw[k] + (prior[k] >> 1)); - break; - case STBI__F_paeth: - cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0, prior[k], 0)); - break; - case STBI__F_avg_first: - cur[k] = raw[k]; - break; - case STBI__F_paeth_first: - cur[k] = raw[k]; - break; - } - } - - if (depth == 8) { - if (img_n != out_n) - cur[img_n] = 255; // first pixel - raw += img_n; - cur += out_n; - prior += out_n; - } else if (depth == 16) { - if (img_n != out_n) { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes + 1] = 255; // first pixel bottom byte - } - raw += filter_bytes; - cur += output_bytes; - prior += output_bytes; - } else { - raw += 1; - cur += 1; - prior += 1; - } - - // this is a little gross, so that we don't switch per-pixel or per-component - if (depth < 8 || img_n == out_n) { - int nk = (width - 1) * filter_bytes; -#define STBI__CASE(f) \ - case f: \ - for (k = 0; k < nk; ++k) - switch (filter) { - // "none" filter turns into a memcpy here; make that explicit. - case STBI__F_none: - memcpy(cur, raw, nk); - break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); } - break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } - break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - filter_bytes]) >> 1)); } - break; - STBI__CASE(STBI__F_paeth) { - cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], prior[k], prior[k - filter_bytes])); - } - break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); } - break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); } - break; - } -#undef STBI__CASE - raw += nk; - } else { - STBI_ASSERT(img_n + 1 == out_n); -#define STBI__CASE(f) \ - case f: \ - for (i = x - 1; i >= 1; --i, cur[filter_bytes] = 255, raw += filter_bytes, cur += output_bytes, prior += output_bytes) \ - for (k = 0; k < filter_bytes; ++k) - switch (filter) { - STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } - break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); } - break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } - break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - output_bytes]) >> 1)); } - break; - STBI__CASE(STBI__F_paeth) { - cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], prior[k], prior[k - output_bytes])); - } - break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); } - break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); } - break; - } -#undef STBI__CASE - - // the loop above sets the high byte of the pixels' alpha, but for - // 16 bit png files we also need the low byte set. we'll do that here. - if (depth == 16) { - cur = a->out + stride * j; // start at the beginning of the row again - for (i = 0; i < x; ++i, cur += output_bytes) { - cur[filter_bytes + 1] = 255; - } - } - } - } - - // we make a separate pass to expand bits to pixels; for performance, - // this could run two scanlines behind the above code, so it won't - // intefere with filtering but will still be in the cache. - if (depth < 8) { - for (j = 0; j < y; ++j) { - stbi_uc * cur = a->out + stride * j; - stbi_uc * in = a->out + stride * j + x * out_n - img_width_bytes; - // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for - // 1/2/4-bit png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that - // will be skipped in the later loop - stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range - - // note that the final byte might overshoot and write more data than desired. - // we can allocate enough data that this never writes out of memory, but it - // could also overwrite the next scanline. can it overwrite non-empty data - // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel. - // so we need to explicitly clamp the final ones - - if (depth == 4) { - for (k = x * img_n; k >= 2; k -= 2, ++in) { - *cur++ = scale * ((*in >> 4)); - *cur++ = scale * ((*in) & 0x0f); - } - if (k > 0) - *cur++ = scale * ((*in >> 4)); - } else if (depth == 2) { - for (k = x * img_n; k >= 4; k -= 4, ++in) { - *cur++ = scale * ((*in >> 6)); - *cur++ = scale * ((*in >> 4) & 0x03); - *cur++ = scale * ((*in >> 2) & 0x03); - *cur++ = scale * ((*in) & 0x03); - } - if (k > 0) - *cur++ = scale * ((*in >> 6)); - if (k > 1) - *cur++ = scale * ((*in >> 4) & 0x03); - if (k > 2) - *cur++ = scale * ((*in >> 2) & 0x03); - } else if (depth == 1) { - for (k = x * img_n; k >= 8; k -= 8, ++in) { - *cur++ = scale * ((*in >> 7)); - *cur++ = scale * ((*in >> 6) & 0x01); - *cur++ = scale * ((*in >> 5) & 0x01); - *cur++ = scale * ((*in >> 4) & 0x01); - *cur++ = scale * ((*in >> 3) & 0x01); - *cur++ = scale * ((*in >> 2) & 0x01); - *cur++ = scale * ((*in >> 1) & 0x01); - *cur++ = scale * ((*in) & 0x01); - } - if (k > 0) - *cur++ = scale * ((*in >> 7)); - if (k > 1) - *cur++ = scale * ((*in >> 6) & 0x01); - if (k > 2) - *cur++ = scale * ((*in >> 5) & 0x01); - if (k > 3) - *cur++ = scale * ((*in >> 4) & 0x01); - if (k > 4) - *cur++ = scale * ((*in >> 3) & 0x01); - if (k > 5) - *cur++ = scale * ((*in >> 2) & 0x01); - if (k > 6) - *cur++ = scale * ((*in >> 1) & 0x01); +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) +{ + int bytes = (depth == 16 ? 2 : 1); + stbi__context *s = a->s; + stbi__uint32 i,j,stride = x*out_n*bytes; + stbi__uint32 img_len, img_width_bytes; + stbi_uc *filter_buf; + int all_ok = 1; + int k; + int img_n = s->img_n; // copy it into a local for later + + int output_bytes = out_n*bytes; + int filter_bytes = img_n*bytes; + int width = x; + + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1); + a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) return stbi__err("outofmem", "Out of memory"); + + // note: error exits here don't need to clean up a->out individually, + // stbi__do_png always does on error. + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + if (!stbi__mad2sizes_valid(img_width_bytes, y, img_width_bytes)) return stbi__err("too large", "Corrupt PNG"); + img_len = (img_width_bytes + 1) * y; + + // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, + // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), + // so just check for raw_len < img_len always. + if (raw_len < img_len) return stbi__err("not enough pixels","Corrupt PNG"); + + // Allocate two scan lines worth of filter workspace buffer. + filter_buf = (stbi_uc *) stbi__malloc_mad2(img_width_bytes, 2, 0); + if (!filter_buf) return stbi__err("outofmem", "Out of memory"); + + // Filtering for low-bit-depth images + if (depth < 8) { + filter_bytes = 1; + width = img_width_bytes; + } + + for (j=0; j < y; ++j) { + // cur/prior filter buffers alternate + stbi_uc *cur = filter_buf + (j & 1)*img_width_bytes; + stbi_uc *prior = filter_buf + (~j & 1)*img_width_bytes; + stbi_uc *dest = a->out + stride*j; + int nk = width * filter_bytes; + int filter = *raw++; + + // check filter type + if (filter > 4) { + all_ok = stbi__err("invalid filter","Corrupt PNG"); + break; + } + + // if first row, use special filter that doesn't sample previous row + if (j == 0) filter = first_row_filter[filter]; + + // perform actual filtering + switch (filter) { + case STBI__F_none: + memcpy(cur, raw, nk); + break; + case STBI__F_sub: + memcpy(cur, raw, filter_bytes); + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); + break; + case STBI__F_up: + for (k = 0; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + break; + case STBI__F_avg: + for (k = 0; k < filter_bytes; ++k) + cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); + break; + case STBI__F_paeth: + for (k = 0; k < filter_bytes; ++k) + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); // prior[k] == stbi__paeth(0,prior[k],0) + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes], prior[k], prior[k-filter_bytes])); + break; + case STBI__F_avg_first: + memcpy(cur, raw, filter_bytes); + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); + break; + } + + raw += nk; + + // expand decoded bits in cur to dest, also adding an extra alpha channel if desired + if (depth < 8) { + stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range + stbi_uc *in = cur; + stbi_uc *out = dest; + stbi_uc inb = 0; + stbi__uint32 nsmp = x*img_n; + + // expand bits to bytes first + if (depth == 4) { + for (i=0; i < nsmp; ++i) { + if ((i & 1) == 0) inb = *in++; + *out++ = scale * (inb >> 4); + inb <<= 4; } - if (img_n != out_n) { - int q; - // insert alpha = 255 - cur = a->out + stride * j; - if (img_n == 1) { - for (q = x - 1; q >= 0; --q) { - cur[q * 2 + 1] = 255; - cur[q * 2 + 0] = cur[q]; - } - } else { - STBI_ASSERT(img_n == 3); - for (q = x - 1; q >= 0; --q) { - cur[q * 4 + 3] = 255; - cur[q * 4 + 2] = cur[q * 3 + 2]; - cur[q * 4 + 1] = cur[q * 3 + 1]; - cur[q * 4 + 0] = cur[q * 3 + 0]; - } - } + } else if (depth == 2) { + for (i=0; i < nsmp; ++i) { + if ((i & 3) == 0) inb = *in++; + *out++ = scale * (inb >> 6); + inb <<= 2; } - } - } else if (depth == 16) { - // force the image data from big-endian to platform-native. - // this is done in a separate pass due to the decoding relying - // on the data being untouched, but could probably be done - // per-line during decode if care is taken. - stbi_uc * cur = a->out; - stbi__uint16 * cur16 = (stbi__uint16 *)cur; - - for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) { - *cur16 = (cur[0] << 8) | cur[1]; - } - } - - return 1; -} - -static int stbi__create_png_image(stbi__png * a, stbi_uc * image_data, stbi__uint32 image_data_len, int out_n, int depth, - int color, int interlaced) { - int bytes = (depth == 16 ? 2 : 1); - int out_bytes = out_n * bytes; - stbi_uc * final; - int p; - if (!interlaced) - return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); - - // de-interlacing - final = (stbi_uc *)stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); - if (!final) - return stbi__err("outofmem", "Out of memory"); - for (p = 0; p < 7; ++p) { - int xorig[] = {0, 4, 0, 2, 0, 1, 0}; - int yorig[] = {0, 0, 4, 0, 2, 0, 1}; - int xspc[] = {8, 8, 4, 4, 2, 2, 1}; - int yspc[] = {8, 8, 8, 4, 4, 2, 2}; - int i, j, x, y; - // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 - x = (a->s->img_x - xorig[p] + xspc[p] - 1) / xspc[p]; - y = (a->s->img_y - yorig[p] + yspc[p] - 1) / yspc[p]; - if (x && y) { - stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; - if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { - STBI_FREE(final); - return 0; + } else { + STBI_ASSERT(depth == 1); + for (i=0; i < nsmp; ++i) { + if ((i & 7) == 0) inb = *in++; + *out++ = scale * (inb >> 7); + inb <<= 1; } - for (j = 0; j < y; ++j) { - for (i = 0; i < x; ++i) { - int out_y = j * yspc[p] + yorig[p]; - int out_x = i * xspc[p] + xorig[p]; - memcpy(final + out_y * a->s->img_x * out_bytes + out_x * out_bytes, a->out + (j * x + i) * out_bytes, - out_bytes); - } + } + + // insert alpha=255 values if desired + if (img_n != out_n) + stbi__create_png_alpha_expand8(dest, dest, x, img_n); + } else if (depth == 8) { + if (img_n == out_n) + memcpy(dest, cur, x*img_n); + else + stbi__create_png_alpha_expand8(dest, cur, x, img_n); + } else if (depth == 16) { + // convert the image data from big-endian to platform-native + stbi__uint16 *dest16 = (stbi__uint16*)dest; + stbi__uint32 nsmp = x*img_n; + + if (img_n == out_n) { + for (i = 0; i < nsmp; ++i, ++dest16, cur += 2) + *dest16 = (cur[0] << 8) | cur[1]; + } else { + STBI_ASSERT(img_n+1 == out_n); + if (img_n == 1) { + for (i = 0; i < x; ++i, dest16 += 2, cur += 2) { + dest16[0] = (cur[0] << 8) | cur[1]; + dest16[1] = 0xffff; + } + } else { + STBI_ASSERT(img_n == 3); + for (i = 0; i < x; ++i, dest16 += 4, cur += 6) { + dest16[0] = (cur[0] << 8) | cur[1]; + dest16[1] = (cur[2] << 8) | cur[3]; + dest16[2] = (cur[4] << 8) | cur[5]; + dest16[3] = 0xffff; + } } - STBI_FREE(a->out); - image_data += img_len; - image_data_len -= img_len; - } - } - a->out = final; + } + } + } - return 1; -} - -static int stbi__compute_transparency(stbi__png * z, stbi_uc tc[3], int out_n) { - stbi__context * s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc * p = z->out; + STBI_FREE(filter_buf); + if (!all_ok) return 0; - // compute color-based transparency, assuming we've - // already got 255 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); - - if (out_n == 2) { - for (i = 0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 255); - p += 2; - } - } else { - for (i = 0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + return 1; } -static int stbi__compute_transparency16(stbi__png * z, stbi__uint16 tc[3], int out_n) { - stbi__context * s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi__uint16 * p = (stbi__uint16 *)z->out; - - // compute color-based transparency, assuming we've - // already got 65535 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) +{ + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc *final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); + + // de-interlacing + final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + if (!final) return stbi__err("outofmem", "Out of memory"); + for (p=0; p < 7; ++p) { + int xorig[] = { 0,4,0,2,0,1,0 }; + int yorig[] = { 0,0,4,0,2,0,1 }; + int xspc[] = { 8,8,4,4,2,2,1 }; + int yspc[] = { 8,8,8,4,4,2,2 }; + int i,j,x,y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j=0; j < y; ++j) { + for (i=0; i < x; ++i) { + int out_y = j*yspc[p]+yorig[p]; + int out_x = i*xspc[p]+xorig[p]; + memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes, + a->out + (j*x+i)*out_bytes, out_bytes); + } + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; - if (out_n == 2) { - for (i = 0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 65535); - p += 2; - } - } else { - for (i = 0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + return 1; } -static int stbi__expand_png_palette(stbi__png * a, stbi_uc * palette, int len, int pal_img_n) { - stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; - stbi_uc *p, *temp_out, *orig = a->out; - - p = (stbi_uc *)stbi__malloc_mad2(pixel_count, pal_img_n, 0); - if (p == NULL) - return stbi__err("outofmem", "Out of memory"); - - // between here and free(out) below, exitting would leak - temp_out = p; - - if (pal_img_n == 3) { - for (i = 0; i < pixel_count; ++i) { - int n = orig[i] * 4; - p[0] = palette[n]; - p[1] = palette[n + 1]; - p[2] = palette[n + 2]; - p += 3; - } - } else { - for (i = 0; i < pixel_count; ++i) { - int n = orig[i] * 4; - p[0] = palette[n]; - p[1] = palette[n + 1]; - p[2] = palette[n + 2]; - p[3] = palette[n + 3]; - p += 4; - } - } - STBI_FREE(a->out); - a->out = temp_out; - - STBI_NOTUSED(len); - - return 1; +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); + + if (out_n == 2) { + for (i=0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i=0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; +} + +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 *p = (stbi__uint16*) z->out; + + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); + + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; +} + +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) +{ + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; + + p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) return stbi__err("outofmem", "Out of memory"); + + // between here and free(out) below, exitting would leak + temp_out = p; + + if (pal_img_n == 3) { + for (i=0; i < pixel_count; ++i) { + int n = orig[i]*4; + p[0] = palette[n ]; + p[1] = palette[n+1]; + p[2] = palette[n+2]; + p += 3; + } + } else { + for (i=0; i < pixel_count; ++i) { + int n = orig[i]*4; + p[0] = palette[n ]; + p[1] = palette[n+1]; + p[2] = palette[n+2]; + p[3] = palette[n+3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; + + STBI_NOTUSED(len); + + return 1; } static int stbi__unpremultiply_on_load_global = 0; static int stbi__de_iphone_flag_global = 0; -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) { - stbi__unpremultiply_on_load_global = flag_true_if_should_unpremultiply; +STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) +{ + stbi__unpremultiply_on_load_global = flag_true_if_should_unpremultiply; } -STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) { - stbi__de_iphone_flag_global = flag_true_if_should_convert; +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) +{ + stbi__de_iphone_flag_global = flag_true_if_should_convert; } #ifndef STBI_THREAD_LOCAL -#define stbi__unpremultiply_on_load stbi__unpremultiply_on_load_global -#define stbi__de_iphone_flag stbi__de_iphone_flag_global +#define stbi__unpremultiply_on_load stbi__unpremultiply_on_load_global +#define stbi__de_iphone_flag stbi__de_iphone_flag_global #else static STBI_THREAD_LOCAL int stbi__unpremultiply_on_load_local, stbi__unpremultiply_on_load_set; static STBI_THREAD_LOCAL int stbi__de_iphone_flag_local, stbi__de_iphone_flag_set; -STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply) { - stbi__unpremultiply_on_load_local = flag_true_if_should_unpremultiply; - stbi__unpremultiply_on_load_set = 1; +STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply) +{ + stbi__unpremultiply_on_load_local = flag_true_if_should_unpremultiply; + stbi__unpremultiply_on_load_set = 1; } -STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert) { - stbi__de_iphone_flag_local = flag_true_if_should_convert; - stbi__de_iphone_flag_set = 1; +STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert) +{ + stbi__de_iphone_flag_local = flag_true_if_should_convert; + stbi__de_iphone_flag_set = 1; } -#define stbi__unpremultiply_on_load \ - (stbi__unpremultiply_on_load_set ? stbi__unpremultiply_on_load_local : stbi__unpremultiply_on_load_global) -#define stbi__de_iphone_flag (stbi__de_iphone_flag_set ? stbi__de_iphone_flag_local : stbi__de_iphone_flag_global) +#define stbi__unpremultiply_on_load (stbi__unpremultiply_on_load_set \ + ? stbi__unpremultiply_on_load_local \ + : stbi__unpremultiply_on_load_global) +#define stbi__de_iphone_flag (stbi__de_iphone_flag_set \ + ? stbi__de_iphone_flag_local \ + : stbi__de_iphone_flag_global) #endif // STBI_THREAD_LOCAL -static void stbi__de_iphone(stbi__png * z) { - stbi__context * s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc * p = z->out; - - if (s->img_out_n == 3) { // convert bgr to rgb - for (i = 0; i < pixel_count; ++i) { +static void stbi__de_iphone(stbi__png *z) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + if (s->img_out_n == 3) { // convert bgr to rgb + for (i=0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i=0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = ( t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; + } + } else { + // convert bgr to rgb + for (i=0; i < pixel_count; ++i) { stbi_uc t = p[0]; p[0] = p[2]; p[2] = t; - p += 3; - } - } else { - STBI_ASSERT(s->img_out_n == 4); - if (stbi__unpremultiply_on_load) { - // convert bgr to rgb and unpremultiply - for (i = 0; i < pixel_count; ++i) { - stbi_uc a = p[3]; - stbi_uc t = p[0]; - if (a) { - stbi_uc half = a / 2; - p[0] = (p[2] * 255 + half) / a; - p[1] = (p[1] * 255 + half) / a; - p[2] = (t * 255 + half) / a; - } else { - p[0] = p[2]; - p[2] = t; - } - p += 4; - } - } else { - // convert bgr to rgb - for (i = 0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 4; - } - } - } + p += 4; + } + } + } } -#define STBI__PNG_TYPE(a, b, c, d) (((unsigned)(a) << 24) + ((unsigned)(b) << 16) + ((unsigned)(c) << 8) + (unsigned)(d)) +#define STBI__PNG_TYPE(a,b,c,d) (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d)) -static int stbi__parse_png_file(stbi__png * z, int scan, int req_comp) { - stbi_uc palette[1024], pal_img_n = 0; - stbi_uc has_trans = 0, tc[3] = {0}; - stbi__uint16 tc16[3]; - stbi__uint32 ioff = 0, idata_limit = 0, i, pal_len = 0; - int first = 1, k, interlace = 0, color = 0, is_iphone = 0; - stbi__context * s = z->s; +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) +{ + stbi_uc palette[1024], pal_img_n=0; + stbi_uc has_trans=0, tc[3]={0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0; + int first=1,k,interlace=0, color=0, is_iphone=0; + stbi__context *s = z->s; - z->expanded = NULL; - z->idata = NULL; - z->out = NULL; + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; - if (!stbi__check_png_header(s)) - return 0; + if (!stbi__check_png_header(s)) return 0; - if (scan == STBI__SCAN_type) - return 1; + if (scan == STBI__SCAN_type) return 1; - for (;;) { - stbi__pngchunk c = stbi__get_chunk_header(s); - switch (c.type) { - case STBI__PNG_TYPE('C', 'g', 'B', 'I'): + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C','g','B','I'): is_iphone = 1; stbi__skip(s, c.length); break; - case STBI__PNG_TYPE('I', 'H', 'D', 'R'): { - int comp, filter; - if (!first) - return stbi__err("multiple IHDR", "Corrupt PNG"); + case STBI__PNG_TYPE('I','H','D','R'): { + int comp,filter; + if (!first) return stbi__err("multiple IHDR","Corrupt PNG"); first = 0; - if (c.length != 13) - return stbi__err("bad IHDR len", "Corrupt PNG"); + if (c.length != 13) return stbi__err("bad IHDR len","Corrupt PNG"); s->img_x = stbi__get32be(s); s->img_y = stbi__get32be(s); - if (s->img_y > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - z->depth = stbi__get8(s); - if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) - return stbi__err("1/2/4/8/16-bit only", "PNG not supported: 1/2/4/8/16-bit only"); - color = stbi__get8(s); - if (color > 6) - return stbi__err("bad ctype", "Corrupt PNG"); - if (color == 3 && z->depth == 16) - return stbi__err("bad ctype", "Corrupt PNG"); - if (color == 3) - pal_img_n = 3; - else if (color & 1) - return stbi__err("bad ctype", "Corrupt PNG"); - comp = stbi__get8(s); - if (comp) - return stbi__err("bad comp method", "Corrupt PNG"); - filter = stbi__get8(s); - if (filter) - return stbi__err("bad filter method", "Corrupt PNG"); - interlace = stbi__get8(s); - if (interlace > 1) - return stbi__err("bad interlace method", "Corrupt PNG"); - if (!s->img_x || !s->img_y) - return stbi__err("0-pixel image", "Corrupt PNG"); + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + z->depth = stbi__get8(s); if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) return stbi__err("1/2/4/8/16-bit only","PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); if (color > 6) return stbi__err("bad ctype","Corrupt PNG"); + if (color == 3 && z->depth == 16) return stbi__err("bad ctype","Corrupt PNG"); + if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err("bad ctype","Corrupt PNG"); + comp = stbi__get8(s); if (comp) return stbi__err("bad comp method","Corrupt PNG"); + filter= stbi__get8(s); if (filter) return stbi__err("bad filter method","Corrupt PNG"); + interlace = stbi__get8(s); if (interlace>1) return stbi__err("bad interlace method","Corrupt PNG"); + if (!s->img_x || !s->img_y) return stbi__err("0-pixel image","Corrupt PNG"); if (!pal_img_n) { - s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); - if ((1 << 30) / s->img_x / s->img_n < s->img_y) - return stbi__err("too large", "Image too large to decode"); + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err("too large", "Image too large to decode"); } else { - // if paletted, then pal_n is our final components, and - // img_n is # components to decompress/filter. - s->img_n = 1; - if ((1 << 30) / s->img_x / 4 < s->img_y) - return stbi__err("too large", "Corrupt PNG"); + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err("too large","Corrupt PNG"); } // even with SCAN_header, have to scan to see if we have a tRNS break; - } + } - case STBI__PNG_TYPE('P', 'L', 'T', 'E'): { - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); - if (c.length > 256 * 3) - return stbi__err("invalid PLTE", "Corrupt PNG"); + case STBI__PNG_TYPE('P','L','T','E'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256*3) return stbi__err("invalid PLTE","Corrupt PNG"); pal_len = c.length / 3; - if (pal_len * 3 != c.length) - return stbi__err("invalid PLTE", "Corrupt PNG"); - for (i = 0; i < pal_len; ++i) { - palette[i * 4 + 0] = stbi__get8(s); - palette[i * 4 + 1] = stbi__get8(s); - palette[i * 4 + 2] = stbi__get8(s); - palette[i * 4 + 3] = 255; + if (pal_len * 3 != c.length) return stbi__err("invalid PLTE","Corrupt PNG"); + for (i=0; i < pal_len; ++i) { + palette[i*4+0] = stbi__get8(s); + palette[i*4+1] = stbi__get8(s); + palette[i*4+2] = stbi__get8(s); + palette[i*4+3] = 255; } break; - } + } - case STBI__PNG_TYPE('t', 'R', 'N', 'S'): { - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); - if (z->idata) - return stbi__err("tRNS after IDAT", "Corrupt PNG"); + case STBI__PNG_TYPE('t','R','N','S'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) return stbi__err("tRNS after IDAT","Corrupt PNG"); if (pal_img_n) { - if (scan == STBI__SCAN_header) { - s->img_n = 4; - return 1; - } - if (pal_len == 0) - return stbi__err("tRNS before PLTE", "Corrupt PNG"); - if (c.length > pal_len) - return stbi__err("bad tRNS len", "Corrupt PNG"); - pal_img_n = 4; - for (i = 0; i < c.length; ++i) - palette[i * 4 + 3] = stbi__get8(s); + if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; } + if (pal_len == 0) return stbi__err("tRNS before PLTE","Corrupt PNG"); + if (c.length > pal_len) return stbi__err("bad tRNS len","Corrupt PNG"); + pal_img_n = 4; + for (i=0; i < c.length; ++i) + palette[i*4+3] = stbi__get8(s); } else { - if (!(s->img_n & 1)) - return stbi__err("tRNS with alpha", "Corrupt PNG"); - if (c.length != (stbi__uint32)s->img_n * 2) - return stbi__err("bad tRNS len", "Corrupt PNG"); - has_trans = 1; - // non-paletted with tRNS = constant alpha. if header-scanning, we can stop now. - if (scan == STBI__SCAN_header) { - ++s->img_n; - return 1; - } - if (z->depth == 16) { - for (k = 0; k < s->img_n; ++k) - tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is - } else { - for (k = 0; k < s->img_n; ++k) - tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * - stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger - } + if (!(s->img_n & 1)) return stbi__err("tRNS with alpha","Corrupt PNG"); + if (c.length != (stbi__uint32) s->img_n*2) return stbi__err("bad tRNS len","Corrupt PNG"); + has_trans = 1; + // non-paletted with tRNS = constant alpha. if header-scanning, we can stop now. + if (scan == STBI__SCAN_header) { ++s->img_n; return 1; } + if (z->depth == 16) { + for (k = 0; k < s->img_n && k < 3; ++k) // extra loop test to suppress false GCC warning + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n && k < 3; ++k) + tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger + } } break; - } + } - case STBI__PNG_TYPE('I', 'D', 'A', 'T'): { - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); - if (pal_img_n && !pal_len) - return stbi__err("no PLTE", "Corrupt PNG"); + case STBI__PNG_TYPE('I','D','A','T'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) return stbi__err("no PLTE","Corrupt PNG"); if (scan == STBI__SCAN_header) { - // header scan definitely stops at first IDAT - if (pal_img_n) - s->img_n = pal_img_n; - return 1; + // header scan definitely stops at first IDAT + if (pal_img_n) + s->img_n = pal_img_n; + return 1; } - if (c.length > (1u << 30)) - return stbi__err("IDAT size limit", "IDAT section larger than 2^30 bytes"); - if ((int)(ioff + c.length) < (int)ioff) - return 0; + if (c.length > (1u << 30)) return stbi__err("IDAT size limit", "IDAT section larger than 2^30 bytes"); + if ((int)(ioff + c.length) < (int)ioff) return 0; if (ioff + c.length > idata_limit) { - stbi__uint32 idata_limit_old = idata_limit; - stbi_uc * p; - if (idata_limit == 0) - idata_limit = c.length > 4096 ? c.length : 4096; - while (ioff + c.length > idata_limit) - idata_limit *= 2; - STBI_NOTUSED(idata_limit_old); - p = (stbi_uc *)STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); - if (p == NULL) - return stbi__err("outofmem", "Out of memory"); - z->idata = p; + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc *p; + if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err("outofmem", "Out of memory"); + z->idata = p; } - if (!stbi__getn(s, z->idata + ioff, c.length)) - return stbi__err("outofdata", "Corrupt PNG"); + if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err("outofdata","Corrupt PNG"); ioff += c.length; break; - } + } - case STBI__PNG_TYPE('I', 'E', 'N', 'D'): { + case STBI__PNG_TYPE('I','E','N','D'): { stbi__uint32 raw_len, bpl; - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); - if (scan != STBI__SCAN_load) - return 1; - if (z->idata == NULL) - return stbi__err("no IDAT", "Corrupt PNG"); + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) return 1; + if (z->idata == NULL) return stbi__err("no IDAT","Corrupt PNG"); // initial guess for decoded data size to avoid unnecessary reallocs bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; - z->expanded = (stbi_uc *)stbi_zlib_decode_malloc_guesssize_headerflag((char *)z->idata, ioff, raw_len, - (int *)&raw_len, !is_iphone); - if (z->expanded == NULL) - return 0; // zlib should set error - STBI_FREE(z->idata); - z->idata = NULL; - if ((req_comp == s->img_n + 1 && req_comp != 3 && !pal_img_n) || has_trans) - s->img_out_n = s->img_n + 1; + z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone); + if (z->expanded == NULL) return 0; // zlib should set error + STBI_FREE(z->idata); z->idata = NULL; + if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans) + s->img_out_n = s->img_n+1; else - s->img_out_n = s->img_n; - if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) - return 0; + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0; if (has_trans) { - if (z->depth == 16) { - if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) - return 0; - } else { - if (!stbi__compute_transparency(z, tc, s->img_out_n)) - return 0; - } + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0; + } } if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) - stbi__de_iphone(z); + stbi__de_iphone(z); if (pal_img_n) { - // pal_img_n == 3 or 4 - s->img_n = pal_img_n; // record the actual colors we had - s->img_out_n = pal_img_n; - if (req_comp >= 3) - s->img_out_n = req_comp; - if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) - return 0; + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; } else if (has_trans) { - // non-paletted image with tRNS -> source image has (constant) alpha - ++s->img_n; + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; } - STBI_FREE(z->expanded); - z->expanded = NULL; + STBI_FREE(z->expanded); z->expanded = NULL; // end of PNG chunk, read and skip CRC stbi__get32be(s); return 1; - } + } - default: + default: // if critical, fail - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); if ((c.type & (1 << 29)) == 0) { -#ifndef STBI_NO_FAILURE_STRINGS - // not threadsafe - static char invalid_chunk[] = "XXXX PNG chunk not known"; - invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); - invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); - invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); - invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); -#endif - return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); + #ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); + #endif + return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); } stbi__skip(s, c.length); break; - } - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - } + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } } -static void * stbi__do_png(stbi__png * p, int * x, int * y, int * n, int req_comp, stbi__result_info * ri) { - void * result = NULL; - if (req_comp < 0 || req_comp > 4) - return stbi__errpuc("bad req_comp", "Internal error"); - if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { - if (p->depth <= 8) - ri->bits_per_channel = 8; - else if (p->depth == 16) - ri->bits_per_channel = 16; - else - return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); - result = p->out; - p->out = NULL; - if (req_comp && req_comp != p->s->img_out_n) { - if (ri->bits_per_channel == 8) - result = stbi__convert_format((unsigned char *)result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - else - result = stbi__convert_format16((stbi__uint16 *)result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - p->s->img_out_n = req_comp; - if (result == NULL) - return result; - } - *x = p->s->img_x; - *y = p->s->img_y; - if (n) - *n = p->s->img_n; - } - STBI_FREE(p->out); - p->out = NULL; - STBI_FREE(p->expanded); - p->expanded = NULL; - STBI_FREE(p->idata); - p->idata = NULL; +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) +{ + void *result=NULL; + if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) *n = p->s->img_n; + } + STBI_FREE(p->out); p->out = NULL; + STBI_FREE(p->expanded); p->expanded = NULL; + STBI_FREE(p->idata); p->idata = NULL; + + return result; +} + +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi__png p; + p.s = s; + return stbi__do_png(&p, x,y,comp,req_comp, ri); +} - return result; +static int stbi__png_test(stbi__context *s) +{ + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; } -static void * stbi__png_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi__png p; - p.s = s; - return stbi__do_png(&p, x, y, comp, req_comp, ri); +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) +{ + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind( p->s ); + return 0; + } + if (x) *x = p->s->img_x; + if (y) *y = p->s->img_y; + if (comp) *comp = p->s->img_n; + return 1; } -static int stbi__png_test(stbi__context * s) { - int r; - r = stbi__check_png_header(s); - stbi__rewind(s); - return r; +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) +{ + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); } -static int stbi__png_info_raw(stbi__png * p, int * x, int * y, int * comp) { - if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { - stbi__rewind(p->s); - return 0; - } - if (x) - *x = p->s->img_x; - if (y) - *y = p->s->img_y; - if (comp) - *comp = p->s->img_n; - return 1; -} - -static int stbi__png_info(stbi__context * s, int * x, int * y, int * comp) { - stbi__png p; - p.s = s; - return stbi__png_info_raw(&p, x, y, comp); -} - -static int stbi__png_is16(stbi__context * s) { - stbi__png p; - p.s = s; - if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) - return 0; - if (p.depth != 16) { - stbi__rewind(p.s); - return 0; - } - return 1; +static int stbi__png_is16(stbi__context *s) +{ + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; } #endif // Microsoft/Windows BMP image #ifndef STBI_NO_BMP -static int stbi__bmp_test_raw(stbi__context * s) { - int r; - int sz; - if (stbi__get8(s) != 'B') - return 0; - if (stbi__get8(s) != 'M') - return 0; - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - stbi__get32le(s); // discard data offset - sz = stbi__get32le(s); - r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); - return r; -} - -static int stbi__bmp_test(stbi__context * s) { - int r = stbi__bmp_test_raw(s); - stbi__rewind(s); - return r; +static int stbi__bmp_test_raw(stbi__context *s) +{ + int r; + int sz; + if (stbi__get8(s) != 'B') return 0; + if (stbi__get8(s) != 'M') return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; +} + +static int stbi__bmp_test(stbi__context *s) +{ + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; } + // returns 0..31 for the highest set bit -static int stbi__high_bit(unsigned int z) { - int n = 0; - if (z == 0) - return -1; - if (z >= 0x10000) { - n += 16; - z >>= 16; - } - if (z >= 0x00100) { - n += 8; - z >>= 8; - } - if (z >= 0x00010) { - n += 4; - z >>= 4; - } - if (z >= 0x00004) { - n += 2; - z >>= 2; - } - if (z >= 0x00002) { - n += 1; /* >>= 1;*/ - } - return n; +static int stbi__high_bit(unsigned int z) +{ + int n=0; + if (z == 0) return -1; + if (z >= 0x10000) { n += 16; z >>= 16; } + if (z >= 0x00100) { n += 8; z >>= 8; } + if (z >= 0x00010) { n += 4; z >>= 4; } + if (z >= 0x00004) { n += 2; z >>= 2; } + if (z >= 0x00002) { n += 1;/* >>= 1;*/ } + return n; } -static int stbi__bitcount(unsigned int a) { - a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 - a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 - a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits - a = (a + (a >> 8)); // max 16 per 8 bits - a = (a + (a >> 16)); // max 32 per 8 bits - return a & 0xff; +static int stbi__bitcount(unsigned int a) +{ + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; } // extract an arbitrarily-aligned N-bit value (N=bits) // from v, and then make it 8-bits long and fractionally // extend it to full full range. -static int stbi__shiftsigned(unsigned int v, int shift, int bits) { - static unsigned int mul_table[9] = { - 0, - 0xff /*0b11111111*/, - 0x55 /*0b01010101*/, - 0x49 /*0b01001001*/, - 0x11 /*0b00010001*/, - 0x21 /*0b00100001*/, - 0x41 /*0b01000001*/, - 0x81 /*0b10000001*/, - 0x01 /*0b00000001*/, - }; - static unsigned int shift_table[9] = { - 0, 0, 0, 1, 0, 2, 4, 6, 0, - }; - if (shift < 0) - v <<= -shift; - else - v >>= shift; - STBI_ASSERT(v < 256); - v >>= (8 - bits); - STBI_ASSERT(bits >= 0 && bits <= 8); - return (int)((unsigned)v * mul_table[bits]) >> shift_table[bits]; -} - -typedef struct { - int bpp, offset, hsz; - unsigned int mr, mg, mb, ma, all_a; - int extra_read; +static int stbi__shiftsigned(unsigned int v, int shift, int bits) +{ + static unsigned int mul_table[9] = { + 0, + 0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/, + 0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0,0,1,0,2,4,6,0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8-bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits]; +} + +typedef struct +{ + int bpp, offset, hsz; + unsigned int mr,mg,mb,ma, all_a; + int extra_read; } stbi__bmp_data; -static int stbi__bmp_set_mask_defaults(stbi__bmp_data * info, int compress) { - // BI_BITFIELDS specifies masks explicitly, don't override - if (compress == 3) - return 1; - - if (compress == 0) { - if (info->bpp == 16) { - info->mr = 31u << 10; - info->mg = 31u << 5; - info->mb = 31u << 0; - } else if (info->bpp == 32) { - info->mr = 0xffu << 16; - info->mg = 0xffu << 8; - info->mb = 0xffu << 0; - info->ma = 0xffu << 24; - info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 - } else { - // otherwise, use defaults, which is all-0 - info->mr = info->mg = info->mb = info->ma = 0; - } - return 1; - } - return 0; // error -} - -static void * stbi__bmp_parse_header(stbi__context * s, stbi__bmp_data * info) { - int hsz; - if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') - return stbi__errpuc("not BMP", "Corrupt BMP"); - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - info->offset = stbi__get32le(s); - info->hsz = hsz = stbi__get32le(s); - info->mr = info->mg = info->mb = info->ma = 0; - info->extra_read = 14; - - if (info->offset < 0) - return stbi__errpuc("bad BMP", "bad BMP"); - - if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) - return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); - if (hsz == 12) { - s->img_x = stbi__get16le(s); - s->img_y = stbi__get16le(s); - } else { - s->img_x = stbi__get32le(s); - s->img_y = stbi__get32le(s); - } - if (stbi__get16le(s) != 1) - return stbi__errpuc("bad BMP", "bad BMP"); - info->bpp = stbi__get16le(s); - if (hsz != 12) { - int compress = stbi__get32le(s); - if (compress == 1 || compress == 2) - return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); - if (compress >= 4) - return stbi__errpuc("BMP JPEG/PNG", - "BMP type not supported: unsupported compression"); // this includes PNG/JPEG modes - if (compress == 3 && info->bpp != 16 && info->bpp != 32) - return stbi__errpuc("bad BMP", "bad BMP"); // bitfields requires 16 or 32 bits/pixel - stbi__get32le(s); // discard sizeof - stbi__get32le(s); // discard hres - stbi__get32le(s); // discard vres - stbi__get32le(s); // discard colorsused - stbi__get32le(s); // discard max important - if (hsz == 40 || hsz == 56) { - if (hsz == 56) { - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - } - if (info->bpp == 16 || info->bpp == 32) { - if (compress == 0) { - stbi__bmp_set_mask_defaults(info, compress); - } else if (compress == 3) { - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->extra_read += 12; - // not documented, but generated by photoshop and handled by mspaint - if (info->mr == info->mg && info->mg == info->mb) { - // ?!?!? - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else { - // V4/V5 header - int i; - if (hsz != 108 && hsz != 124) - return stbi__errpuc("bad BMP", "bad BMP"); - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->ma = stbi__get32le(s); - if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs - stbi__bmp_set_mask_defaults(info, compress); - stbi__get32le(s); // discard color space - for (i = 0; i < 12; ++i) - stbi__get32le(s); // discard color space parameters - if (hsz == 124) { - stbi__get32le(s); // discard rendering intent - stbi__get32le(s); // discard offset of profile data - stbi__get32le(s); // discard size of profile data - stbi__get32le(s); // discard reserved - } - } - } - return (void *)1; -} - -static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * out; - unsigned int mr = 0, mg = 0, mb = 0, ma = 0, all_a; - stbi_uc pal[256][4]; - int psize = 0, i, j, width; - int flip_vertically, pad, target; - stbi__bmp_data info; - STBI_NOTUSED(ri); - - info.all_a = 255; - if (stbi__bmp_parse_header(s, &info) == NULL) - return NULL; // error code already set - - flip_vertically = ((int)s->img_y) > 0; - s->img_y = abs((int)s->img_y); - - if (s->img_y > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - - mr = info.mr; - mg = info.mg; - mb = info.mb; - ma = info.ma; - all_a = info.all_a; - - if (info.hsz == 12) { - if (info.bpp < 24) - psize = (info.offset - info.extra_read - 24) / 3; - } else { - if (info.bpp < 16) - psize = (info.offset - info.extra_read - info.hsz) >> 2; - } - if (psize == 0) { - // accept some number of extra bytes after the header, but if the offset points either to before - // the header ends or implies a large amount of extra data, reject the file as malformed - int bytes_read_so_far = s->callback_already_read + (int)(s->img_buffer - s->img_buffer_original); - int header_limit = 1024; // max we actually read is below 256 bytes currently. - int extra_data_limit = 256 * 4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size. - if (bytes_read_so_far <= 0 || bytes_read_so_far > header_limit) { - return stbi__errpuc("bad header", "Corrupt BMP"); - } - // we established that bytes_read_so_far is positive and sensible. - // the first half of this test rejects offsets that are either too small positives, or - // negative, and guarantees that info.offset >= bytes_read_so_far > 0. this in turn - // ensures the number computed in the second half of the test can't overflow. - if (info.offset < bytes_read_so_far || info.offset - bytes_read_so_far > extra_data_limit) { - return stbi__errpuc("bad offset", "Corrupt BMP"); - } else { - stbi__skip(s, info.offset - bytes_read_so_far); - } - } - - if (info.bpp == 24 && ma == 0xff000000) - s->img_n = 3; - else - s->img_n = ma ? 4 : 3; - if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 - target = req_comp; - else - target = s->img_n; // if they want monochrome, we'll post-convert - - // sanity-check size - if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "Corrupt BMP"); - - out = (stbi_uc *)stbi__malloc_mad3(target, s->img_x, s->img_y, 0); - if (!out) - return stbi__errpuc("outofmem", "Out of memory"); - if (info.bpp < 16) { - int z = 0; - if (psize == 0 || psize > 256) { - STBI_FREE(out); - return stbi__errpuc("invalid", "Corrupt BMP"); - } - for (i = 0; i < psize; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - if (info.hsz != 12) - stbi__get8(s); - pal[i][3] = 255; - } - stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); - if (info.bpp == 1) - width = (s->img_x + 7) >> 3; - else if (info.bpp == 4) - width = (s->img_x + 1) >> 1; - else if (info.bpp == 8) - width = s->img_x; - else { - STBI_FREE(out); - return stbi__errpuc("bad bpp", "Corrupt BMP"); - } - pad = (-width) & 3; - if (info.bpp == 1) { - for (j = 0; j < (int)s->img_y; ++j) { - int bit_offset = 7, v = stbi__get8(s); - for (i = 0; i < (int)s->img_x; ++i) { - int color = (v >> bit_offset) & 0x1; - out[z++] = pal[color][0]; - out[z++] = pal[color][1]; - out[z++] = pal[color][2]; - if (target == 4) - out[z++] = 255; - if (i + 1 == (int)s->img_x) - break; - if ((--bit_offset) < 0) { - bit_offset = 7; - v = stbi__get8(s); - } - } - stbi__skip(s, pad); - } - } else { - for (j = 0; j < (int)s->img_y; ++j) { - for (i = 0; i < (int)s->img_x; i += 2) { - int v = stbi__get8(s), v2 = 0; - if (info.bpp == 4) { - v2 = v & 15; - v >>= 4; - } - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) - out[z++] = 255; - if (i + 1 == (int)s->img_x) - break; - v = (info.bpp == 8) ? stbi__get8(s) : v2; - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) - out[z++] = 255; - } - stbi__skip(s, pad); - } - } - } else { - int rshift = 0, gshift = 0, bshift = 0, ashift = 0, rcount = 0, gcount = 0, bcount = 0, acount = 0; - int z = 0; - int easy = 0; - stbi__skip(s, info.offset - info.extra_read - info.hsz); - if (info.bpp == 24) - width = 3 * s->img_x; - else if (info.bpp == 16) - width = 2 * s->img_x; - else /* bpp = 32 and pad = 0 */ - width = 0; - pad = (-width) & 3; - if (info.bpp == 24) { - easy = 1; - } else if (info.bpp == 32) { - if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) - easy = 2; - } - if (!easy) { - if (!mr || !mg || !mb) { - STBI_FREE(out); - return stbi__errpuc("bad masks", "Corrupt BMP"); - } - // right shift amt to put high bit in position #7 - rshift = stbi__high_bit(mr) - 7; - rcount = stbi__bitcount(mr); - gshift = stbi__high_bit(mg) - 7; - gcount = stbi__bitcount(mg); - bshift = stbi__high_bit(mb) - 7; - bcount = stbi__bitcount(mb); - ashift = stbi__high_bit(ma) - 7; - acount = stbi__bitcount(ma); - if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { - STBI_FREE(out); - return stbi__errpuc("bad masks", "Corrupt BMP"); +static int stbi__bmp_set_mask_defaults(stbi__bmp_data *info, int compress) +{ + // BI_BITFIELDS specifies masks explicitly, don't override + if (compress == 3) + return 1; + + if (compress == 0) { + if (info->bpp == 16) { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } else if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 + } else { + // otherwise, use defaults, which is all-0 + info->mr = info->mg = info->mb = info->ma = 0; + } + return 1; + } + return 0; // error +} + +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) +{ + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; + + if (info->offset < 0) return stbi__errpuc("bad BMP", "bad BMP"); + + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + if (compress >= 4) return stbi__errpuc("BMP JPEG/PNG", "BMP type not supported: unsupported compression"); // this includes PNG/JPEG modes + if (compress == 3 && info->bpp != 16 && info->bpp != 32) return stbi__errpuc("bad BMP", "bad BMP"); // bitfields requires 16 or 32 bits/pixel + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + stbi__bmp_set_mask_defaults(info, compress); + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else { + // V4/V5 header + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs + stbi__bmp_set_mask_defaults(info, compress); + stbi__get32le(s); // discard color space + for (i=0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved + } + } + } + return (void *) 1; +} + + +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *out; + unsigned int mr=0,mg=0,mb=0,ma=0, all_a; + stbi_uc pal[256][4]; + int psize=0,i,j,width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); + + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set + + flip_vertically = ((int) s->img_y) > 0; + s->img_y = abs((int) s->img_y); + + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; + + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + // accept some number of extra bytes after the header, but if the offset points either to before + // the header ends or implies a large amount of extra data, reject the file as malformed + int bytes_read_so_far = s->callback_already_read + (int)(s->img_buffer - s->img_buffer_original); + int header_limit = 1024; // max we actually read is below 256 bytes currently. + int extra_data_limit = 256*4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size. + if (bytes_read_so_far <= 0 || bytes_read_so_far > header_limit) { + return stbi__errpuc("bad header", "Corrupt BMP"); + } + // we established that bytes_read_so_far is positive and sensible. + // the first half of this test rejects offsets that are either too small positives, or + // negative, and guarantees that info.offset >= bytes_read_so_far > 0. this in turn + // ensures the number computed in the second half of the test can't overflow. + if (info.offset < bytes_read_so_far || info.offset - bytes_read_so_far > extra_data_limit) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + stbi__skip(s, info.offset - bytes_read_so_far); + } + } + + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert + + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z=0; + if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc("invalid", "Corrupt BMP"); } + for (i=0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) width = s->img_x; + else { STBI_FREE(out); return stbi__errpuc("bad bpp", "Corrupt BMP"); } + pad = (-width)&3; + if (info.bpp == 1) { + for (j=0; j < (int) s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i=0; i < (int) s->img_x; ++i) { + int color = (v>>bit_offset)&0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) out[z++] = 255; + if (i+1 == (int) s->img_x) break; + if((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } } - } - for (j = 0; j < (int)s->img_y; ++j) { - if (easy) { - for (i = 0; i < (int)s->img_x; ++i) { - unsigned char a; - out[z + 2] = stbi__get8(s); - out[z + 1] = stbi__get8(s); - out[z + 0] = stbi__get8(s); - z += 3; - a = (easy == 2 ? stbi__get8(s) : 255); - all_a |= a; - if (target == 4) - out[z++] = a; - } - } else { - int bpp = info.bpp; - for (i = 0; i < (int)s->img_x; ++i) { - stbi__uint32 v = (bpp == 16 ? (stbi__uint32)stbi__get16le(s) : stbi__get32le(s)); - unsigned int a; - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); - a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); - all_a |= a; - if (target == 4) - out[z++] = STBI__BYTECAST(a); - } + stbi__skip(s, pad); + } + } else { + for (j=0; j < (int) s->img_y; ++j) { + for (i=0; i < (int) s->img_x; i += 2) { + int v=stbi__get8(s),v2=0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) out[z++] = 255; + if (i+1 == (int) s->img_x) break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) out[z++] = 255; } stbi__skip(s, pad); - } - } - - // if alpha channel is all 0s, replace with all 255s - if (target == 4 && all_a == 0) - for (i = 4 * s->img_x * s->img_y - 1; i >= 0; i -= 4) - out[i] = 255; - - if (flip_vertically) { - stbi_uc t; - for (j = 0; j < (int)s->img_y >> 1; ++j) { - stbi_uc * p1 = out + j * s->img_x * target; - stbi_uc * p2 = out + (s->img_y - 1 - j) * s->img_x * target; - for (i = 0; i < (int)s->img_x * target; ++i) { - t = p1[i]; - p1[i] = p2[i]; - p2[i] = t; + } + } + } else { + int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0; + int z = 0; + int easy=0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) width = 3 * s->img_x; + else if (info.bpp == 16) width = 2*s->img_x; + else /* bpp = 32 and pad = 0 */ width=0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + } + for (j=0; j < (int) s->img_y; ++j) { + if (easy) { + for (i=0; i < (int) s->img_x; ++i) { + unsigned char a; + out[z+2] = stbi__get8(s); + out[z+1] = stbi__get8(s); + out[z+0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) out[z++] = a; } - } - } - - if (req_comp && req_comp != target) { - out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); - if (out == NULL) - return out; // stbi__convert_format frees input on failure - } - - *x = s->img_x; - *y = s->img_y; - if (comp) - *comp = s->img_n; - return out; + } else { + int bpp = info.bpp; + for (i=0; i < (int) s->img_x; ++i) { + stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) out[z++] = STBI__BYTECAST(a); + } + } + stbi__skip(s, pad); + } + } + + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4) + out[i] = 255; + + if (flip_vertically) { + stbi_uc t; + for (j=0; j < (int) s->img_y>>1; ++j) { + stbi_uc *p1 = out + j *s->img_x*target; + stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target; + for (i=0; i < (int) s->img_x*target; ++i) { + t = p1[i]; p1[i] = p2[i]; p2[i] = t; + } + } + } + + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + + *x = s->img_x; + *y = s->img_y; + if (comp) *comp = s->img_n; + return out; } #endif @@ -6100,74 +5736,68 @@ static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, in // by Jonathan Dummer #ifndef STBI_NO_TGA // returns STBI_rgb or whatever, 0 on error -static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int * is_rgb16) { - // only RGB or RGBA (incl. 16bit) or grey allowed - if (is_rgb16) - *is_rgb16 = 0; - switch (bits_per_pixel) { - case 8: - return STBI_grey; - case 16: - if (is_grey) - return STBI_grey_alpha; - // fallthrough - case 15: - if (is_rgb16) - *is_rgb16 = 1; - return STBI_rgb; - case 24: // fallthrough - case 32: - return bits_per_pixel / 8; - default: - return 0; - } -} - -static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16) +{ + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) *is_rgb16 = 0; + switch(bits_per_pixel) { + case 8: return STBI_grey; + case 16: if(is_grey) return STBI_grey_alpha; + // fallthrough + case 15: if(is_rgb16) *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: return bits_per_pixel/8; + default: return 0; + } +} + +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) +{ int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; int sz, tga_colormap_type; - stbi__get8(s); // discard Offset + stbi__get8(s); // discard Offset tga_colormap_type = stbi__get8(s); // colormap type - if (tga_colormap_type > 1) { + if( tga_colormap_type > 1 ) { stbi__rewind(s); - return 0; // only RGB or indexed allowed + return 0; // only RGB or indexed allowed } tga_image_type = stbi__get8(s); // image type - if (tga_colormap_type == 1) { // colormapped (paletted) image + if ( tga_colormap_type == 1 ) { // colormapped (paletted) image if (tga_image_type != 1 && tga_image_type != 9) { stbi__rewind(s); return 0; } - stbi__skip(s, 4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) { + stbi__skip(s,4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) { stbi__rewind(s); return 0; } - stbi__skip(s, 4); // skip image x and y origin + stbi__skip(s,4); // skip image x and y origin tga_colormap_bpp = sz; } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE - if ((tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11)) { + if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) { stbi__rewind(s); return 0; // only RGB or grey allowed, +/- RLE } - stbi__skip(s, 9); // skip colormap specification and image x/y origin + stbi__skip(s,9); // skip colormap specification and image x/y origin tga_colormap_bpp = 0; } tga_w = stbi__get16le(s); - if (tga_w < 1) { + if( tga_w < 1 ) { stbi__rewind(s); - return 0; // test width + return 0; // test width } tga_h = stbi__get16le(s); - if (tga_h < 1) { + if( tga_h < 1 ) { stbi__rewind(s); - return 0; // test height + return 0; // test height } tga_bits_per_pixel = stbi__get8(s); // bits per pixel - stbi__get8(s); // ignore alpha bits + stbi__get8(s); // ignore alpha bits if (tga_colormap_bpp != 0) { - if ((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { // when using a colormap, tga_bits_per_pixel is the size of the indexes // I don't think anything but 8 or 16bit indexes makes sense stbi__rewind(s); @@ -6177,522 +5807,521 @@ static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp) { } else { tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); } - if (!tga_comp) { - stbi__rewind(s); - return 0; - } - if (x) - *x = tga_w; - if (y) - *y = tga_h; - if (comp) - *comp = tga_comp; - return 1; // seems to have passed everything -} - -static int stbi__tga_test(stbi__context * s) { - int res = 0; - int sz, tga_color_type; - stbi__get8(s); // discard Offset - tga_color_type = stbi__get8(s); // color type - if (tga_color_type > 1) - goto errorEnd; // only RGB or indexed allowed - sz = stbi__get8(s); // image type - if (tga_color_type == 1) { // colormapped (paletted) image - if (sz != 1 && sz != 9) - goto errorEnd; // colortype 1 demands image type 1 or 9 - stbi__skip(s, 4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) - goto errorEnd; - stbi__skip(s, 4); // skip image x and y origin - } else { // "normal" image w/o colormap - if ((sz != 2) && (sz != 3) && (sz != 10) && (sz != 11)) - goto errorEnd; // only RGB or grey allowed, +/- RLE - stbi__skip(s, 9); // skip colormap specification and image x/y origin + if(!tga_comp) { + stbi__rewind(s); + return 0; } - if (stbi__get16le(s) < 1) - goto errorEnd; // test width - if (stbi__get16le(s) < 1) - goto errorEnd; // test height - sz = stbi__get8(s); // bits per pixel - if ((tga_color_type == 1) && (sz != 8) && (sz != 16)) - goto errorEnd; // for colormapped images, bpp is size of an index - if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) - goto errorEnd; - - res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + if (x) *x = tga_w; + if (y) *y = tga_h; + if (comp) *comp = tga_comp; + return 1; // seems to have passed everything +} + +static int stbi__tga_test(stbi__context *s) +{ + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if ( tga_color_type > 1 ) goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if ( tga_color_type == 1 ) { // colormapped (paletted) image + if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s,4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; + stbi__skip(s,4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s,9); // skip colormap specification and image x/y origin + } + if ( stbi__get16le(s) < 1 ) goto errorEnd; // test width + if ( stbi__get16le(s) < 1 ) goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; + + res = 1; // if we got this far, everything's good and we can return 1 instead of 0 errorEnd: - stbi__rewind(s); - return res; + stbi__rewind(s); + return res; } // read 16bit value and convert to 24bit RGB -static void stbi__tga_read_rgb16(stbi__context * s, stbi_uc * out) { - stbi__uint16 px = (stbi__uint16)stbi__get16le(s); - stbi__uint16 fiveBitMask = 31; - // we have 3 channels with 5bits each - int r = (px >> 10) & fiveBitMask; - int g = (px >> 5) & fiveBitMask; - int b = px & fiveBitMask; - // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later - out[0] = (stbi_uc)((r * 255) / 31); - out[1] = (stbi_uc)((g * 255) / 31); - out[2] = (stbi_uc)((b * 255) / 31); - - // some people claim that the most significant bit might be used for alpha - // (possibly if an alpha-bit is set in the "image descriptor byte") - // but that only made 16bit test images completely translucent.. - // so let's treat all 15 and 16bit TGAs as RGB with no alpha. -} - -static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - // read in the TGA header stuff - int tga_offset = stbi__get8(s); - int tga_indexed = stbi__get8(s); - int tga_image_type = stbi__get8(s); - int tga_is_RLE = 0; - int tga_palette_start = stbi__get16le(s); - int tga_palette_len = stbi__get16le(s); - int tga_palette_bits = stbi__get8(s); - int tga_x_origin = stbi__get16le(s); - int tga_y_origin = stbi__get16le(s); - int tga_width = stbi__get16le(s); - int tga_height = stbi__get16le(s); - int tga_bits_per_pixel = stbi__get8(s); - int tga_comp, tga_rgb16 = 0; - int tga_inverted = stbi__get8(s); - // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) - // image data - unsigned char * tga_data; - unsigned char * tga_palette = NULL; - int i, j; - unsigned char raw_data[4] = {0}; - int RLE_count = 0; - int RLE_repeating = 0; - int read_next_pixel = 1; - STBI_NOTUSED(ri); - STBI_NOTUSED(tga_x_origin); // @TODO - STBI_NOTUSED(tga_y_origin); // @TODO - - if (tga_height > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (tga_width > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - - // do a tiny bit of precessing - if (tga_image_type >= 8) { - tga_image_type -= 8; - tga_is_RLE = 1; - } - tga_inverted = 1 - ((tga_inverted >> 5) & 1); - - // If I'm paletted, then I'll use the number of bits from the palette - if (tga_indexed) - tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); - else - tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); - - if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency - return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); - - // tga info - *x = tga_width; - *y = tga_height; - if (comp) - *comp = tga_comp; - - if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) - return stbi__errpuc("too large", "Corrupt TGA"); - - tga_data = (unsigned char *)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); - if (!tga_data) - return stbi__errpuc("outofmem", "Out of memory"); - - // skip to the data's starting position (offset usually = 0) - stbi__skip(s, tga_offset); - - if (!tga_indexed && !tga_is_RLE && !tga_rgb16) { - for (i = 0; i < tga_height; ++i) { - int row = tga_inverted ? tga_height - i - 1 : i; - stbi_uc * tga_row = tga_data + row * tga_width * tga_comp; - stbi__getn(s, tga_row, tga_width * tga_comp); - } - } else { - // do I need to load a palette? - if (tga_indexed) { - if (tga_palette_len == 0) { /* you have to have at least one entry! */ - STBI_FREE(tga_data); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } - - // any data to skip? (offset usually = 0) - stbi__skip(s, tga_palette_start); - // load the palette - tga_palette = (unsigned char *)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); - if (!tga_palette) { - STBI_FREE(tga_data); - return stbi__errpuc("outofmem", "Out of memory"); +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out) +{ + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later + out[0] = (stbi_uc)((r * 255)/31); + out[1] = (stbi_uc)((g * 255)/31); + out[2] = (stbi_uc)((b * 255)/31); + + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. +} + +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16=0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) + // image data + unsigned char *tga_data; + unsigned char *tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO + + if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + // do a tiny bit of precessing + if ( tga_image_type >= 8 ) + { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); + + // If I'm paletted, then I'll use the number of bits from the palette + if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); + + if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + + // tga info + *x = tga_width; + *y = tga_height; + if (comp) *comp = tga_comp; + + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); + + tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) return stbi__errpuc("outofmem", "Out of memory"); + + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset ); + + if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) { + for (i=0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height -i - 1 : i; + stbi_uc *tga_row = tga_data + row*tga_width*tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if ( tga_indexed) + { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start ); + // load the palette + tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); + } + if (tga_rgb16) { + stbi_uc *pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i=0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; } - if (tga_rgb16) { - stbi_uc * pal_entry = tga_palette; - STBI_ASSERT(tga_comp == STBI_rgb); - for (i = 0; i < tga_palette_len; ++i) { - stbi__tga_read_rgb16(s, pal_entry); - pal_entry += tga_comp; - } - } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { - STBI_FREE(tga_data); - STBI_FREE(tga_palette); - return stbi__errpuc("bad palette", "Corrupt TGA"); + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + } + // load the data + for (i=0; i < tga_width * tga_height; ++i) + { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if ( tga_is_RLE ) + { + if ( RLE_count == 0 ) + { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if ( !RLE_repeating ) + { + read_next_pixel = 1; } - } - // load the data - for (i = 0; i < tga_width * tga_height; ++i) { - // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? - if (tga_is_RLE) { - if (RLE_count == 0) { - // yep, get the next byte as a RLE command - int RLE_cmd = stbi__get8(s); - RLE_count = 1 + (RLE_cmd & 127); - RLE_repeating = RLE_cmd >> 7; - read_next_pixel = 1; - } else if (!RLE_repeating) { - read_next_pixel = 1; - } + } else + { + read_next_pixel = 1; + } + // OK, if I need to read a pixel, do it now + if ( read_next_pixel ) + { + // load however much data we did have + if ( tga_indexed ) + { + // read in index, then perform the lookup + int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if ( pal_idx >= tga_palette_len ) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx+j]; + } + } else if(tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); } else { - read_next_pixel = 1; + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } } - // OK, if I need to read a pixel, do it now - if (read_next_pixel) { - // load however much data we did have - if (tga_indexed) { - // read in index, then perform the lookup - int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); - if (pal_idx >= tga_palette_len) { - // invalid index - pal_idx = 0; - } - pal_idx *= tga_comp; - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = tga_palette[pal_idx + j]; - } - } else if (tga_rgb16) { - STBI_ASSERT(tga_comp == STBI_rgb); - stbi__tga_read_rgb16(s, raw_data); - } else { - // read in the data raw - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = stbi__get8(s); - } - } - // clear the reading flag for the next pixel - read_next_pixel = 0; - } // end of reading a pixel - - // copy data - for (j = 0; j < tga_comp; ++j) - tga_data[i * tga_comp + j] = raw_data[j]; - - // in case we're in RLE mode, keep counting down - --RLE_count; - } - // do I need to invert the image? - if (tga_inverted) { - for (j = 0; j * 2 < tga_height; ++j) { - int index1 = j * tga_width * tga_comp; - int index2 = (tga_height - 1 - j) * tga_width * tga_comp; - for (i = tga_width * tga_comp; i > 0; --i) { - unsigned char temp = tga_data[index1]; - tga_data[index1] = tga_data[index2]; - tga_data[index2] = temp; - ++index1; - ++index2; - } + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel + + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i*tga_comp+j] = raw_data[j]; + + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if ( tga_inverted ) + { + for (j = 0; j*2 < tga_height; ++j) + { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) + { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; } - } - // clear my palette, if I had one - if (tga_palette != NULL) { - STBI_FREE(tga_palette); - } - } + } + } + // clear my palette, if I had one + if ( tga_palette != NULL ) + { + STBI_FREE( tga_palette ); + } + } + + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) + { + unsigned char* tga_pixel = tga_data; + for (i=0; i < tga_width * tga_height; ++i) + { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } + + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = + tga_x_origin = tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; +} +#endif - // swap RGB - if the source data was RGB16, it already is in the right order - if (tga_comp >= 3 && !tga_rgb16) { - unsigned char * tga_pixel = tga_data; - for (i = 0; i < tga_width * tga_height; ++i) { - unsigned char temp = tga_pixel[0]; - tga_pixel[0] = tga_pixel[2]; - tga_pixel[2] = temp; - tga_pixel += tga_comp; - } - } +// ************************************************************************************************* +// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB + +#ifndef STBI_NO_PSD +static int stbi__psd_test(stbi__context *s) +{ + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; +} - // convert to target component count - if (req_comp && req_comp != tga_comp) - tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) +{ + int count, nleft, len; + + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; + } + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } - // the things I do to get rid of an error message, and yet keep - // Microsoft's C compilers happy... [8^( - tga_palette_start = tga_palette_len = tga_palette_bits = tga_x_origin = tga_y_origin = 0; - STBI_NOTUSED(tga_palette_start); - // OK, done - return tga_data; + return 1; } -#endif - -// ************************************************************************************************* -// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB -#ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context * s) { - int r = (stbi__get32be(s) == 0x38425053); - stbi__rewind(s); - return r; -} - -static int stbi__psd_decode_rle(stbi__context * s, stbi_uc * p, int pixelCount) { - int count, nleft, len; - - count = 0; - while ((nleft = pixelCount - count) > 0) { - len = stbi__get8(s); - if (len == 128) { - // No-op. - } else if (len < 128) { - // Copy next len+1 bytes literally. - len++; - if (len > nleft) - return 0; // corrupt data - count += len; - while (len) { - *p = stbi__get8(s); - p += 4; - len--; - } - } else if (len > 128) { - stbi_uc val; - // Next -len+1 bytes in the dest are replicated from next source byte. - // (Interpret len as a negative 8-bit int.) - len = 257 - len; - if (len > nleft) - return 0; // corrupt data - val = stbi__get8(s); - count += len; - while (len) { - *p = val; - p += 4; - len--; +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) +{ + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w,h; + stbi_uc *out; + STBI_NOTUSED(ri); + + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); + + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + + // Skip 6 reserved bytes. + stbi__skip(s, 6 ); + + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); + + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); + + if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); + + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + + // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) + stbi__skip(s,stbi__get32be(s) ); + + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s) ); + + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s) ); + + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", "PSD has an unknown compression format"); + + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); + + // Create the destination image. + + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *) stbi__malloc(4 * w*h); + + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w*h; + + // Initialize the data to zero. + //memset( out, 0, pixelCount * 4 ); + + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. + // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. + // Else if n is 128, noop. + // Endloop + + // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, + // which we're going to just skip. + stbi__skip(s, h * channelCount * 2 ); + + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc *p; + + p = out+channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); } - } - } - - return 1; -} - -static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc) { - int pixelCount; - int channelCount, compression; - int channel, i; - int bitdepth; - int w, h; - stbi_uc * out; - STBI_NOTUSED(ri); - - // Check identifier - if (stbi__get32be(s) != 0x38425053) // "8BPS" - return stbi__errpuc("not PSD", "Corrupt PSD image"); - - // Check file type version. - if (stbi__get16be(s) != 1) - return stbi__errpuc("wrong version", "Unsupported version of PSD image"); - - // Skip 6 reserved bytes. - stbi__skip(s, 6); - - // Read the number of channels (R, G, B, A, etc). - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) - return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); - - // Read the rows and columns of the image. - h = stbi__get32be(s); - w = stbi__get32be(s); - - if (h > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (w > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - - // Make sure the depth is 8 bits. - bitdepth = stbi__get16be(s); - if (bitdepth != 8 && bitdepth != 16) - return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); - - // Make sure the color mode is RGB. - // Valid options are: - // 0: Bitmap - // 1: Grayscale - // 2: Indexed color - // 3: RGB color - // 4: CMYK color - // 7: Multichannel - // 8: Duotone - // 9: Lab color - if (stbi__get16be(s) != 3) - return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); - - // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) - stbi__skip(s, stbi__get32be(s)); - - // Skip the image resources. (resolution, pen tool paths, etc) - stbi__skip(s, stbi__get32be(s)); - - // Skip the reserved data. - stbi__skip(s, stbi__get32be(s)); - - // Find out if the data is compressed. - // Known values: - // 0: no compression - // 1: RLE compressed - compression = stbi__get16be(s); - if (compression > 1) - return stbi__errpuc("bad compression", "PSD has an unknown compression format"); - - // Check size - if (!stbi__mad3sizes_valid(4, w, h, 0)) - return stbi__errpuc("too large", "Corrupt PSD"); - - // Create the destination image. - - if (!compression && bitdepth == 16 && bpc == 16) { - out = (stbi_uc *)stbi__malloc_mad3(8, w, h, 0); - ri->bits_per_channel = 16; - } else - out = (stbi_uc *)stbi__malloc(4 * w * h); - - if (!out) - return stbi__errpuc("outofmem", "Out of memory"); - pixelCount = w * h; - - // Initialize the data to zero. - // memset( out, 0, pixelCount * 4 ); - - // Finally, the image data. - if (compression) { - // RLE as used by .PSD and .TIFF - // Loop until you get the number of unpacked bytes you are expecting: - // Read the next source byte into n. - // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. - // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. - // Else if n is 128, noop. - // Endloop - - // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, - // which we're going to just skip. - stbi__skip(s, h * channelCount * 2); - - // Read the RLE data by channel. - for (channel = 0; channel < 4; channel++) { - stbi_uc * p; - - p = out + channel; - if (channel >= channelCount) { - // Fill this channel with default data. - for (i = 0; i < pixelCount; i++, p += 4) - *p = (channel == 3 ? 255 : 0); + } + } + + } else { + // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) + // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 *q = ((stbi__uint16 *) out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; } else { - // Read the RLE data. - if (!stbi__psd_decode_rle(s, p, pixelCount)) { - STBI_FREE(out); - return stbi__errpuc("corrupt", "bad RLE data"); - } + stbi_uc *p = out+channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; } - } - } else { - // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) - // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. - - // Read the data by channel. - for (channel = 0; channel < 4; channel++) { - if (channel >= channelCount) { - // Fill this channel with default data. - if (bitdepth == 16 && bpc == 16) { - stbi__uint16 * q = ((stbi__uint16 *)out) + channel; - stbi__uint16 val = channel == 3 ? 65535 : 0; - for (i = 0; i < pixelCount; i++, q += 4) - *q = val; - } else { - stbi_uc * p = out + channel; - stbi_uc val = channel == 3 ? 255 : 0; - for (i = 0; i < pixelCount; i++, p += 4) - *p = val; - } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *) out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16) stbi__get16be(s); } else { - if (ri->bits_per_channel == 16) { // output bpc - stbi__uint16 * q = ((stbi__uint16 *)out) + channel; - for (i = 0; i < pixelCount; i++, q += 4) - *q = (stbi__uint16)stbi__get16be(s); - } else { - stbi_uc * p = out + channel; - if (bitdepth == 16) { // input bpc - for (i = 0; i < pixelCount; i++, p += 4) - *p = (stbi_uc)(stbi__get16be(s) >> 8); - } else { - for (i = 0; i < pixelCount; i++, p += 4) - *p = stbi__get8(s); - } - } + stbi_uc *p = out+channel; + if (bitdepth == 16) { // input bpc + for (i = 0; i < pixelCount; i++, p += 4) + *p = (stbi_uc) (stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } } - } - } - - // remove weird white matte from PSD - if (channelCount >= 4) { - if (ri->bits_per_channel == 16) { - for (i = 0; i < w * h; ++i) { - stbi__uint16 * pixel = (stbi__uint16 *)out + 4 * i; - if (pixel[3] != 0 && pixel[3] != 65535) { - float a = pixel[3] / 65535.0f; - float ra = 1.0f / a; - float inv_a = 65535.0f * (1 - ra); - pixel[0] = (stbi__uint16)(pixel[0] * ra + inv_a); - pixel[1] = (stbi__uint16)(pixel[1] * ra + inv_a); - pixel[2] = (stbi__uint16)(pixel[2] * ra + inv_a); - } + } + } + } + + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i=0; i < w*h; ++i) { + stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a); + pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a); + pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a); } - } else { - for (i = 0; i < w * h; ++i) { - unsigned char * pixel = out + 4 * i; - if (pixel[3] != 0 && pixel[3] != 255) { - float a = pixel[3] / 255.0f; - float ra = 1.0f / a; - float inv_a = 255.0f * (1 - ra); - pixel[0] = (unsigned char)(pixel[0] * ra + inv_a); - pixel[1] = (unsigned char)(pixel[1] * ra + inv_a); - pixel[2] = (unsigned char)(pixel[2] * ra + inv_a); - } + } + } else { + for (i=0; i < w*h; ++i) { + unsigned char *pixel = out + 4*i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char) (pixel[0]*ra + inv_a); + pixel[1] = (unsigned char) (pixel[1]*ra + inv_a); + pixel[2] = (unsigned char) (pixel[2]*ra + inv_a); } - } - } + } + } + } - // convert to desired output format - if (req_comp && req_comp != 4) { - if (ri->bits_per_channel == 16) - out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, 4, req_comp, w, h); - else - out = stbi__convert_format(out, 4, req_comp, w, h); - if (out == NULL) - return out; // stbi__convert_format frees input on failure - } + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) return out; // stbi__convert_format frees input on failure + } - if (comp) - *comp = 4; - *y = h; - *x = w; + if (comp) *comp = 4; + *y = h; + *x = w; - return out; + return out; } #endif @@ -6704,221 +6333,216 @@ static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, in // See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ #ifndef STBI_NO_PIC -static int stbi__pic_is4(stbi__context * s, const char * str) { - int i; - for (i = 0; i < 4; ++i) - if (stbi__get8(s) != (stbi_uc)str[i]) - return 0; +static int stbi__pic_is4(stbi__context *s,const char *str) +{ + int i; + for (i=0; i<4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; - return 1; + return 1; } -static int stbi__pic_test_core(stbi__context * s) { - int i; +static int stbi__pic_test_core(stbi__context *s) +{ + int i; - if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) - return 0; + if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) + return 0; - for (i = 0; i < 84; ++i) - stbi__get8(s); + for(i=0;i<84;++i) + stbi__get8(s); - if (!stbi__pic_is4(s, "PICT")) - return 0; + if (!stbi__pic_is4(s,"PICT")) + return 0; - return 1; + return 1; } -typedef struct { - stbi_uc size, type, channel; +typedef struct +{ + stbi_uc size,type,channel; } stbi__pic_packet; -static stbi_uc * stbi__readval(stbi__context * s, int channel, stbi_uc * dest) { - int mask = 0x80, i; +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) +{ + int mask=0x80, i; - for (i = 0; i < 4; ++i, mask >>= 1) { - if (channel & mask) { - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "PIC file too short"); - dest[i] = stbi__get8(s); - } - } + for (i=0; i<4; ++i, mask>>=1) { + if (channel & mask) { + if (stbi__at_eof(s)) return stbi__errpuc("bad file","PIC file too short"); + dest[i]=stbi__get8(s); + } + } - return dest; + return dest; } -static void stbi__copyval(int channel, stbi_uc * dest, const stbi_uc * src) { - int mask = 0x80, i; +static void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src) +{ + int mask=0x80,i; - for (i = 0; i < 4; ++i, mask >>= 1) - if (channel & mask) - dest[i] = src[i]; + for (i=0;i<4; ++i, mask>>=1) + if (channel&mask) + dest[i]=src[i]; } -static stbi_uc * stbi__pic_load_core(stbi__context * s, int width, int height, int * comp, stbi_uc * result) { - int act_comp = 0, num_packets = 0, y, chained; - stbi__pic_packet packets[10]; +static stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result) +{ + int act_comp=0,num_packets=0,y,chained; + stbi__pic_packet packets[10]; - // this will (should...) cater for even some bizarre stuff like having data + // this will (should...) cater for even some bizarre stuff like having data // for the same channel in multiple packets. - do { - stbi__pic_packet * packet; + do { + stbi__pic_packet *packet; - if (num_packets == sizeof(packets) / sizeof(packets[0])) - return stbi__errpuc("bad format", "too many packets"); + if (num_packets==sizeof(packets)/sizeof(packets[0])) + return stbi__errpuc("bad format","too many packets"); - packet = &packets[num_packets++]; + packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); - act_comp |= packet->channel; + act_comp |= packet->channel; - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "file too short (reading packets)"); - if (packet->size != 8) - return stbi__errpuc("bad format", "packet isn't 8bpp"); - } while (chained); + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (reading packets)"); + if (packet->size != 8) return stbi__errpuc("bad format","packet isn't 8bpp"); + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? - for (y = 0; y < height; ++y) { - int packet_idx; + for(y=0; ytype) { + switch (packet->type) { default: - return stbi__errpuc("bad format", "packet has bad compression type"); + return stbi__errpuc("bad format","packet has bad compression type"); - case 0: { // uncompressed - int x; + case 0: {//uncompressed + int x; - for (x = 0; x < width; ++x, dest += 4) - if (!stbi__readval(s, packet->channel, dest)) - return 0; - break; + for(x=0;xchannel,dest)) + return 0; + break; } - case 1: // Pure RLE - { - int left = width, i; + case 1://Pure RLE + { + int left=width, i; - while (left > 0) { - stbi_uc count, value[4]; + while (left>0) { + stbi_uc count,value[4]; - count = stbi__get8(s); - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "file too short (pure read count)"); + count=stbi__get8(s); + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pure read count)"); - if (count > left) - count = (stbi_uc)left; + if (count > left) + count = (stbi_uc) left; - if (!stbi__readval(s, packet->channel, value)) - return 0; + if (!stbi__readval(s,packet->channel,value)) return 0; - for (i = 0; i < count; ++i, dest += 4) - stbi__copyval(packet->channel, dest, value); - left -= count; - } - } break; + for(i=0; ichannel,dest,value); + left -= count; + } + } + break; - case 2: { // Mixed RLE - int left = width; - while (left > 0) { - int count = stbi__get8(s), i; - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "file too short (mixed read count)"); + case 2: {//Mixed RLE + int left=width; + while (left>0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (mixed read count)"); - if (count >= 128) { // Repeated - stbi_uc value[4]; + if (count >= 128) { // Repeated + stbi_uc value[4]; - if (count == 128) - count = stbi__get16be(s); - else - count -= 127; - if (count > left) - return stbi__errpuc("bad file", "scanline overrun"); - - if (!stbi__readval(s, packet->channel, value)) - return 0; - - for (i = 0; i < count; ++i, dest += 4) - stbi__copyval(packet->channel, dest, value); - } else { // Raw - ++count; - if (count > left) - return stbi__errpuc("bad file", "scanline overrun"); - - for (i = 0; i < count; ++i, dest += 4) - if (!stbi__readval(s, packet->channel, dest)) - return 0; - } - left -= count; - } - break; - } + if (count==128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file","scanline overrun"); + + if (!stbi__readval(s,packet->channel,value)) + return 0; + + for(i=0;ichannel,dest,value); + } else { // Raw + ++count; + if (count>left) return stbi__errpuc("bad file","scanline overrun"); + + for(i=0;ichannel,dest)) + return 0; + } + left-=count; + } + break; } - } - } + } + } + } - return result; + return result; } -static void * stbi__pic_load(stbi__context * s, int * px, int * py, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * result; - int i, x, y, internal_comp; - STBI_NOTUSED(ri); +static void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri) +{ + stbi_uc *result; + int i, x,y, internal_comp; + STBI_NOTUSED(ri); - if (!comp) - comp = &internal_comp; + if (!comp) comp = &internal_comp; - for (i = 0; i < 92; ++i) - stbi__get8(s); + for (i=0; i<92; ++i) + stbi__get8(s); - x = stbi__get16be(s); - y = stbi__get16be(s); + x = stbi__get16be(s); + y = stbi__get16be(s); - if (y > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (x > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "file too short (pic header)"); - if (!stbi__mad3sizes_valid(x, y, 4, 0)) - return stbi__errpuc("too large", "PIC image too large to decode"); + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); - stbi__get32be(s); // skip `ratio' - stbi__get16be(s); // skip `fields' - stbi__get16be(s); // skip `pad' + stbi__get32be(s); //skip `ratio' + stbi__get16be(s); //skip `fields' + stbi__get16be(s); //skip `pad' - // intermediate buffer is RGBA - result = (stbi_uc *)stbi__malloc_mad3(x, y, 4, 0); - if (!result) - return stbi__errpuc("outofmem", "Out of memory"); - memset(result, 0xff, x * y * 4); + // intermediate buffer is RGBA + result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0); + if (!result) return stbi__errpuc("outofmem", "Out of memory"); + memset(result, 0xff, x*y*4); - if (!stbi__pic_load_core(s, x, y, comp, result)) { - STBI_FREE(result); - result = 0; - } - *px = x; - *py = y; - if (req_comp == 0) - req_comp = *comp; - result = stbi__convert_format(result, 4, req_comp, x, y); + if (!stbi__pic_load_core(s,x,y,comp, result)) { + STBI_FREE(result); + result=0; + } + *px = x; + *py = y; + if (req_comp == 0) req_comp = *comp; + result=stbi__convert_format(result,4,req_comp,x,y); - return result; + return result; } -static int stbi__pic_test(stbi__context * s) { - int r = stbi__pic_test_core(s); - stbi__rewind(s); - return r; +static int stbi__pic_test(stbi__context *s) +{ + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; } #endif @@ -6926,968 +6550,931 @@ static int stbi__pic_test(stbi__context * s) { // GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb #ifndef STBI_NO_GIF -typedef struct { - stbi__int16 prefix; - stbi_uc first; - stbi_uc suffix; +typedef struct +{ + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; } stbi__gif_lzw; -typedef struct { - int w, h; - stbi_uc * out; // output buffer (always 4 components) - stbi_uc * background; // The current "background" as far as a gif is concerned - stbi_uc * history; - int flags, bgindex, ratio, transparent, eflags; - stbi_uc pal[256][4]; - stbi_uc lpal[256][4]; - stbi__gif_lzw codes[8192]; - stbi_uc * color_table; - int parse, step; - int lflags; - int start_x, start_y; - int max_x, max_y; - int cur_x, cur_y; - int line_size; - int delay; +typedef struct +{ + int w,h; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc *color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; } stbi__gif; -static int stbi__gif_test_raw(stbi__context * s) { - int sz; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') - return 0; - sz = stbi__get8(s); - if (sz != '9' && sz != '7') - return 0; - if (stbi__get8(s) != 'a') - return 0; - return 1; -} - -static int stbi__gif_test(stbi__context * s) { - int r = stbi__gif_test_raw(s); - stbi__rewind(s); - return r; -} - -static void stbi__gif_parse_colortable(stbi__context * s, stbi_uc pal[256][4], int num_entries, int transp) { - int i; - for (i = 0; i < num_entries; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - pal[i][3] = transp == i ? 0 : 255; - } +static int stbi__gif_test_raw(stbi__context *s) +{ + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') return 0; + if (stbi__get8(s) != 'a') return 0; + return 1; } -static int stbi__gif_header(stbi__context * s, stbi__gif * g, int * comp, int is_info) { - stbi_uc version; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') - return stbi__err("not GIF", "Corrupt GIF"); - - version = stbi__get8(s); - if (version != '7' && version != '9') - return stbi__err("not GIF", "Corrupt GIF"); - if (stbi__get8(s) != 'a') - return stbi__err("not GIF", "Corrupt GIF"); - - stbi__g_failure_reason = ""; - g->w = stbi__get16le(s); - g->h = stbi__get16le(s); - g->flags = stbi__get8(s); - g->bgindex = stbi__get8(s); - g->ratio = stbi__get8(s); - g->transparent = -1; - - if (g->w > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - if (g->h > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); +static int stbi__gif_test(stbi__context *s) +{ + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; +} - if (comp != 0) - *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) +{ + int i; + for (i=0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } +} - if (is_info) - return 1; +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) +{ + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); - if (g->flags & 0x80) - stbi__gif_parse_colortable(s, g->pal, 2 << (g->flags & 7), -1); + version = stbi__get8(s); + if (version != '7' && version != '9') return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') return stbi__err("not GIF", "Corrupt GIF"); - return 1; -} + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; -static int stbi__gif_info_raw(stbi__context * s, int * x, int * y, int * comp) { - stbi__gif * g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); - if (!g) - return stbi__err("outofmem", "Out of memory"); - if (!stbi__gif_header(s, g, comp, 1)) { - STBI_FREE(g); - stbi__rewind(s); - return 0; - } - if (x) - *x = g->w; - if (y) - *y = g->h; - STBI_FREE(g); - return 1; -} + if (g->w > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); -static void stbi__out_gif_code(stbi__gif * g, stbi__uint16 code) { - stbi_uc *p, *c; - int idx; + if (comp != 0) *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments - // recurse to decode the prefixes, since the linked-list is backwards, - // and working backwards through an interleaved image would be nasty - if (g->codes[code].prefix >= 0) - stbi__out_gif_code(g, g->codes[code].prefix); + if (is_info) return 1; - if (g->cur_y >= g->max_y) - return; + if (g->flags & 0x80) + stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1); - idx = g->cur_x + g->cur_y; - p = &g->out[idx]; - g->history[idx / 4] = 1; + return 1; +} - c = &g->color_table[g->codes[code].suffix * 4]; - if (c[3] > 128) { // don't render transparent pixels; - p[0] = c[2]; - p[1] = c[1]; - p[2] = c[0]; - p[3] = c[3]; - } - g->cur_x += 4; +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) +{ + stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif)); + if (!g) return stbi__err("outofmem", "Out of memory"); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind( s ); + return 0; + } + if (x) *x = g->w; + if (y) *y = g->h; + STBI_FREE(g); + return 1; +} + +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) +{ + stbi_uc *p, *c; + int idx; + + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); + + if (g->cur_y >= g->max_y) return; + + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; + + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; + + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; + + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } +} + +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) +{ + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw *p; + + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc) init_code; + g->codes[init_code].suffix = (stbi_uc) init_code; + } + + // support no starting clear code + avail = clear+2; + oldcode = -1; + + len = 0; + for(;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32) stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s,len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); + } - if (g->cur_x >= g->max_x) { - g->cur_x = g->start_x; - g->cur_y += g->step; + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } - while (g->cur_y >= g->max_y && g->parse > 0) { - g->step = (1 << g->parse) * g->line_size; - g->cur_y = g->start_y + (g->step >> 1); - --g->parse; - } - } -} + p->prefix = (stbi__int16) oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); -static stbi_uc * stbi__process_gif_raster(stbi__context * s, stbi__gif * g) { - stbi_uc lzw_cs; - stbi__int32 len, init_code; - stbi__uint32 first; - stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; - stbi__gif_lzw * p; - - lzw_cs = stbi__get8(s); - if (lzw_cs > 12) - return NULL; - clear = 1 << lzw_cs; - first = 1; - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - bits = 0; - valid_bits = 0; - for (init_code = 0; init_code < clear; init_code++) { - g->codes[init_code].prefix = -1; - g->codes[init_code].first = (stbi_uc)init_code; - g->codes[init_code].suffix = (stbi_uc)init_code; - } + stbi__out_gif_code(g, (stbi__uint16) code); - // support no starting clear code - avail = clear + 2; - oldcode = -1; - - len = 0; - for (;;) { - if (valid_bits < codesize) { - if (len == 0) { - len = stbi__get8(s); // start new block - if (len == 0) - return g->out; - } - --len; - bits |= (stbi__int32)stbi__get8(s) << valid_bits; - valid_bits += 8; - } else { - stbi__int32 code = bits & codemask; - bits >>= codesize; - valid_bits -= codesize; - // @OPTIMIZE: is there some way we can accelerate the non-clear path? - if (code == clear) { // clear code - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - avail = clear + 2; - oldcode = -1; - first = 0; - } else if (code == clear + 1) { // end of stream code - stbi__skip(s, len); - while ((len = stbi__get8(s)) > 0) - stbi__skip(s, len); - return g->out; - } else if (code <= avail) { - if (first) { - return stbi__errpuc("no clear code", "Corrupt GIF"); - } - - if (oldcode >= 0) { - p = &g->codes[avail++]; - if (avail > 8192) { - return stbi__errpuc("too many codes", "Corrupt GIF"); - } - - p->prefix = (stbi__int16)oldcode; - p->first = g->codes[oldcode].first; - p->suffix = (code == avail) ? p->first : g->codes[code].first; - } else if (code == avail) - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - - stbi__out_gif_code(g, (stbi__uint16)code); - - if ((avail & codemask) == 0 && avail <= 0x0FFF) { - codesize++; - codemask = (1 << codesize) - 1; - } - - oldcode = code; - } else { - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; } - } - } + + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } } // this function is designed to support animated gifs, although stb_image doesn't support it // two back is the image from two frames ago, used for a very specific disposal format -static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * comp, int req_comp, stbi_uc * two_back) { - int dispose; - int first_frame; - int pi; - int pcount; - STBI_NOTUSED(req_comp); - - // on first frame, any non-written pixels get the background colour (non-transparent) - first_frame = 0; - if (g->out == 0) { - if (!stbi__gif_header(s, g, comp, 0)) - return 0; // stbi__g_failure_reason set by stbi__gif_header - if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) - return stbi__errpuc("too large", "GIF image is too large"); - pcount = g->w * g->h; - g->out = (stbi_uc *)stbi__malloc(4 * pcount); - g->background = (stbi_uc *)stbi__malloc(4 * pcount); - g->history = (stbi_uc *)stbi__malloc(pcount); - if (!g->out || !g->background || !g->history) - return stbi__errpuc("outofmem", "Out of memory"); - - // image is treated as "transparent" at the start - ie, nothing overwrites the current background; - // background colour is only used for pixels that are not rendered first frame, after that "background" - // color refers to the color that was there the previous frame. - memset(g->out, 0x00, 4 * pcount); - memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) - memset(g->history, 0x00, pcount); // pixels that were affected previous frame - first_frame = 1; - } else { - // second frame - how do we dispose of the previous one? - dispose = (g->eflags & 0x1C) >> 2; - pcount = g->w * g->h; - - if ((dispose == 3) && (two_back == 0)) { - dispose = 2; // if I don't have an image to revert back to, default to the old background - } - - if (dispose == 3) { // use previous graphic - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy(&g->out[pi * 4], &two_back[pi * 4], 4); - } +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) +{ + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); + + // on first frame, any non-written pixels get the background colour (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *) stbi__malloc(4 * pcount); + g->background = (stbi_uc *) stbi__malloc(4 * pcount); + g->history = (stbi_uc *) stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); + + // image is treated as "transparent" at the start - ie, nothing overwrites the current background; + // background colour is only used for pixels that are not rendered first frame, after that "background" + // color refers to the color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; + + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the old background + } + + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 ); } - } else if (dispose == 2) { - // restore what was changed last frame to background before that frame; - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy(&g->out[pi * 4], &g->background[pi * 4], 4); - } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 ); } - } else { - // This is a non-disposal case eithe way, so just - // leave the pixels as is, and they will become the new background - // 1: do not dispose - // 0: not specified. - } - - // background is what out is after the undoing of the previou frame; - memcpy(g->background, g->out, 4 * g->w * g->h); - } - - // clear my history; - memset(g->history, 0x00, g->w * g->h); // pixels that were affected previous frame - - for (;;) { - int tag = stbi__get8(s); - switch (tag) { - case 0x2C: /* Image Descriptor */ - { + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } + + // background is what out is after the undoing of the previou frame; + memcpy( g->background, g->out, 4 * g->w * g->h ); + } + + // clear my history; + memset( g->history, 0x00, g->w * g->h ); // pixels that were affected previous frame + + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { stbi__int32 x, y, w, h; - stbi_uc * o; + stbi_uc *o; x = stbi__get16le(s); y = stbi__get16le(s); w = stbi__get16le(s); h = stbi__get16le(s); if (((x + w) > (g->w)) || ((y + h) > (g->h))) - return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); g->line_size = g->w * 4; g->start_x = x * 4; g->start_y = y * g->line_size; - g->max_x = g->start_x + w * 4; - g->max_y = g->start_y + h * g->line_size; - g->cur_x = g->start_x; - g->cur_y = g->start_y; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; // if the width of the specified rectangle is 0, that means // we may not see *any* pixels or the image is malformed; // to make sure this is caught, move the current y down to // max_y (which is what out_gif_code checks). if (w == 0) - g->cur_y = g->max_y; + g->cur_y = g->max_y; g->lflags = stbi__get8(s); if (g->lflags & 0x40) { - g->step = 8 * g->line_size; // first interlaced spacing - g->parse = 3; + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; } else { - g->step = g->line_size; - g->parse = 0; + g->step = g->line_size; + g->parse = 0; } if (g->lflags & 0x80) { - stbi__gif_parse_colortable(s, g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); - g->color_table = (stbi_uc *)g->lpal; + stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *) g->lpal; } else if (g->flags & 0x80) { - g->color_table = (stbi_uc *)g->pal; + g->color_table = (stbi_uc *) g->pal; } else - return stbi__errpuc("missing color table", "Corrupt GIF"); + return stbi__errpuc("missing color table", "Corrupt GIF"); o = stbi__process_gif_raster(s, g); - if (!o) - return NULL; + if (!o) return NULL; // if this was the first frame, pcount = g->w * g->h; if (first_frame && (g->bgindex > 0)) { - // if first frame, any pixel not drawn to gets the background color - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi] == 0) { - g->pal[g->bgindex][3] = - 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; - memcpy(&g->out[pi * 4], &g->pal[g->bgindex], 4); - } - } + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; + memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 ); + } + } } return o; - } + } - case 0x21: // Comment Extension. - { + case 0x21: // Comment Extension. + { int len; int ext = stbi__get8(s); if (ext == 0xF9) { // Graphic Control Extension. - len = stbi__get8(s); - if (len == 4) { - g->eflags = stbi__get8(s); - g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. - - // unset old transparent - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 255; - } - if (g->eflags & 0x01) { - g->transparent = stbi__get8(s); - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 0; - } - } else { - // don't need transparent - stbi__skip(s, 1); - g->transparent = -1; - } - } else { - stbi__skip(s, len); - break; - } + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. + + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; + } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } } while ((len = stbi__get8(s)) != 0) { - stbi__skip(s, len); + stbi__skip(s, len); } break; - } + } - case 0x3B: // gif stream termination code - return (stbi_uc *)s; // using '1' causes warning on some compilers + case 0x3B: // gif stream termination code + return (stbi_uc *) s; // using '1' causes warning on some compilers - default: + default: return stbi__errpuc("unknown code", "Corrupt GIF"); - } - } + } + } } -static void * stbi__load_gif_main_outofmem(stbi__gif * g, stbi_uc * out, int ** delays) { - STBI_FREE(g->out); - STBI_FREE(g->history); - STBI_FREE(g->background); +static void *stbi__load_gif_main_outofmem(stbi__gif *g, stbi_uc *out, int **delays) +{ + STBI_FREE(g->out); + STBI_FREE(g->history); + STBI_FREE(g->background); - if (out) - STBI_FREE(out); - if (delays && *delays) - STBI_FREE(*delays); - return stbi__errpuc("outofmem", "Out of memory"); + if (out) STBI_FREE(out); + if (delays && *delays) STBI_FREE(*delays); + return stbi__errpuc("outofmem", "Out of memory"); } -static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int * y, int * z, int * comp, int req_comp) { - if (stbi__gif_test(s)) { - int layers = 0; - stbi_uc * u = 0; - stbi_uc * out = 0; - stbi_uc * two_back = 0; - stbi__gif g; - int stride; - int out_size = 0; - int delays_size = 0; - - STBI_NOTUSED(out_size); - STBI_NOTUSED(delays_size); - - memset(&g, 0, sizeof(g)); - if (delays) { - *delays = 0; - } +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) +{ + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; + + STBI_NOTUSED(out_size); + STBI_NOTUSED(delays_size); + + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } + + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *) s) u = 0; // end of animated gif marker + + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; + + if (out) { + void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride ); + if (!tmp) + return stbi__load_gif_main_outofmem(&g, out, delays); + else { + out = (stbi_uc*) tmp; + out_size = layers * stride; + } + + if (delays) { + int *new_delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers ); + if (!new_delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + *delays = new_delays; + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc*)stbi__malloc( layers * stride ); + if (!out) + return stbi__load_gif_main_outofmem(&g, out, delays); + out_size = layers * stride; + if (delays) { + *delays = (int*) stbi__malloc( layers * sizeof(int) ); + if (!*delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + delays_size = layers * sizeof(int); + } + } + memcpy( out + ((layers - 1) * stride), u, stride ); + if (layers >= 2) { + two_back = out - 2 * stride; + } - do { - u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); - if (u == (stbi_uc *)s) - u = 0; // end of animated gif marker - - if (u) { - *x = g.w; - *y = g.h; - ++layers; - stride = g.w * g.h * 4; - - if (out) { - void * tmp = (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); - if (!tmp) - return stbi__load_gif_main_outofmem(&g, out, delays); - else { - out = (stbi_uc *)tmp; - out_size = layers * stride; - } - - if (delays) { - int * new_delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, sizeof(int) * layers); - if (!new_delays) - return stbi__load_gif_main_outofmem(&g, out, delays); - *delays = new_delays; - delays_size = layers * sizeof(int); - } - } else { - out = (stbi_uc *)stbi__malloc(layers * stride); - if (!out) - return stbi__load_gif_main_outofmem(&g, out, delays); - out_size = layers * stride; - if (delays) { - *delays = (int *)stbi__malloc(layers * sizeof(int)); - if (!*delays) - return stbi__load_gif_main_outofmem(&g, out, delays); - delays_size = layers * sizeof(int); - } - } - memcpy(out + ((layers - 1) * stride), u, stride); - if (layers >= 2) { - two_back = out - 2 * stride; - } - - if (delays) { - (*delays)[layers - 1U] = g.delay; - } + if (delays) { + (*delays)[layers - 1U] = g.delay; } - } while (u != 0); + } + } while (u != 0); - // free temp buffer; - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); - // do the final conversion after loading everything; - if (req_comp && req_comp != 4) - out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); - *z = layers; - return out; - } else { - return stbi__errpuc("not GIF", "Image was not as a gif type."); - } + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } } -static void * stbi__gif_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * u = 0; - stbi__gif g; - memset(&g, 0, sizeof(g)); - STBI_NOTUSED(ri); - - u = stbi__gif_load_next(s, &g, comp, req_comp, 0); - if (u == (stbi_uc *)s) - u = 0; // end of animated gif marker - if (u) { - *x = g.w; - *y = g.h; - - // moved conversion to after successful load so that the same - // can be done for multiple frames. - if (req_comp && req_comp != 4) - u = stbi__convert_format(u, 4, req_comp, g.w, g.h); - } else if (g.out) { - // if there was an error and we allocated an image buffer, free it! - STBI_FREE(g.out); - } - - // free buffers needed for multiple frame loading; - STBI_FREE(g.history); - STBI_FREE(g.background); - - return u; +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); + + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *) s) u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; + + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } + + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); + + return u; +} + +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) +{ + return stbi__gif_info_raw(s,x,y,comp); } - -static int stbi__gif_info(stbi__context * s, int * x, int * y, int * comp) { return stbi__gif_info_raw(s, x, y, comp); } #endif // ************************************************************************************************* // Radiance RGBE HDR loader // originally by Nicolas Schulz #ifndef STBI_NO_HDR -static int stbi__hdr_test_core(stbi__context * s, const char * signature) { - int i; - for (i = 0; signature[i]; ++i) - if (stbi__get8(s) != signature[i]) - return 0; - stbi__rewind(s); - return 1; +static int stbi__hdr_test_core(stbi__context *s, const char *signature) +{ + int i; + for (i=0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; } -static int stbi__hdr_test(stbi__context * s) { - int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); - stbi__rewind(s); - if (!r) { - r = stbi__hdr_test_core(s, "#?RGBE\n"); - stbi__rewind(s); - } - return r; +static int stbi__hdr_test(stbi__context* s) +{ + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if(!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; } -#define STBI__HDR_BUFLEN 1024 -static char * stbi__hdr_gettoken(stbi__context * z, char * buffer) { - int len = 0; - char c = '\0'; +#define STBI__HDR_BUFLEN 1024 +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) +{ + int len=0; + char c = '\0'; - c = (char)stbi__get8(z); + c = (char) stbi__get8(z); - while (!stbi__at_eof(z) && c != '\n') { - buffer[len++] = c; - if (len == STBI__HDR_BUFLEN - 1) { - // flush to end of line - while (!stbi__at_eof(z) && stbi__get8(z) != '\n') - ; - break; - } - c = (char)stbi__get8(z); - } + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN-1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char) stbi__get8(z); + } - buffer[len] = 0; - return buffer; -} - -static void stbi__hdr_convert(float * output, stbi_uc * input, int req_comp) { - if (input[3] != 0) { - float f1; - // Exponent - f1 = (float)ldexp(1.0f, input[3] - (int)(128 + 8)); - if (req_comp <= 2) - output[0] = (input[0] + input[1] + input[2]) * f1 / 3; - else { - output[0] = input[0] * f1; - output[1] = input[1] * f1; - output[2] = input[2] * f1; - } - if (req_comp == 2) - output[1] = 1; - if (req_comp == 4) - output[3] = 1; - } else { - switch (req_comp) { - case 4: - output[3] = 1; /* fallthrough */ - case 3: - output[0] = output[1] = output[2] = 0; - break; - case 2: - output[1] = 1; /* fallthrough */ - case 1: - output[0] = 0; - break; - } - } + buffer[len] = 0; + return buffer; } -static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - char buffer[STBI__HDR_BUFLEN]; - char * token; - int valid = 0; - int width, height; - stbi_uc * scanline; - float * hdr_data; - int len; - unsigned char count, value; - int i, j, k, c1, c2, z; - const char * headerToken; - STBI_NOTUSED(ri); - - // Check identifier - headerToken = stbi__hdr_gettoken(s, buffer); - if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) - return stbi__errpf("not HDR", "Corrupt HDR image"); - - // Parse header - for (;;) { - token = stbi__hdr_gettoken(s, buffer); - if (token[0] == 0) - break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) - valid = 1; - } - - if (!valid) - return stbi__errpf("unsupported format", "Unsupported HDR format"); - - // Parse width and height - // can't use sscanf() if we're not using stdio! - token = stbi__hdr_gettoken(s, buffer); - if (strncmp(token, "-Y ", 3)) - return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - height = (int)strtol(token, &token, 10); - while (*token == ' ') - ++token; - if (strncmp(token, "+X ", 3)) - return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - width = (int)strtol(token, NULL, 10); - - if (height > STBI_MAX_DIMENSIONS) - return stbi__errpf("too large", "Very large image (corrupt?)"); - if (width > STBI_MAX_DIMENSIONS) - return stbi__errpf("too large", "Very large image (corrupt?)"); - - *x = width; - *y = height; - - if (comp) - *comp = 3; - if (req_comp == 0) - req_comp = 3; - - if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) - return stbi__errpf("too large", "HDR image is too large"); - - // Read data - hdr_data = (float *)stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); - if (!hdr_data) - return stbi__errpf("outofmem", "Out of memory"); - - // Load image data - // image data is stored as some number of sca - if (width < 8 || width >= 32768) { - // Read flat data - for (j = 0; j < height; ++j) { - for (i = 0; i < width; ++i) { - stbi_uc rgbe[4]; - main_decode_loop: - stbi__getn(s, rgbe, 4); - stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); - } - } - } else { - // Read RLE-encoded data - scanline = NULL; - - for (j = 0; j < height; ++j) { - c1 = stbi__get8(s); - c2 = stbi__get8(s); - len = stbi__get8(s); - if (c1 != 2 || c2 != 2 || (len & 0x80)) { - // not run-length encoded, so we have to actually use THIS data as a decoded - // pixel (note this can't be a valid pixel--one of RGB must be >= 128) - stbi_uc rgbe[4]; - rgbe[0] = (stbi_uc)c1; - rgbe[1] = (stbi_uc)c2; - rgbe[2] = (stbi_uc)len; - rgbe[3] = (stbi_uc)stbi__get8(s); - stbi__hdr_convert(hdr_data, rgbe, req_comp); - i = 1; - j = 0; - STBI_FREE(scanline); - goto main_decode_loop; // yes, this makes no sense - } - len <<= 8; - len |= stbi__get8(s); - if (len != width) { - STBI_FREE(hdr_data); - STBI_FREE(scanline); - return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); - } - if (scanline == NULL) { - scanline = (stbi_uc *)stbi__malloc_mad2(width, 4, 0); - if (!scanline) { - STBI_FREE(hdr_data); - return stbi__errpf("outofmem", "Out of memory"); - } +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) +{ + if ( input[3] != 0 ) { + float f1; + // Exponent + f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) output[1] = 1; + if (req_comp == 4) output[3] = 1; + } else { + switch (req_comp) { + case 4: output[3] = 1; /* fallthrough */ + case 3: output[0] = output[1] = output[2] = 0; + break; + case 2: output[1] = 1; /* fallthrough */ + case 1: output[0] = 0; + break; + } + } +} + +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int width, height; + stbi_uc *scanline; + float *hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1,c2, z; + const char *headerToken; + STBI_NOTUSED(ri); + + // Check identifier + headerToken = stbi__hdr_gettoken(s,buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); + + // Parse header + for(;;) { + token = stbi__hdr_gettoken(s,buffer); + if (token[0] == 0) break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; + } + + if (!valid) return stbi__errpf("unsupported format", "Unsupported HDR format"); + + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s,buffer); + if (strncmp(token, "-Y ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int) strtol(token, &token, 10); + while (*token == ' ') ++token; + if (strncmp(token, "+X ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int) strtol(token, NULL, 10); + + if (height > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); + + *x = width; + *y = height; + + if (comp) *comp = 3; + if (req_comp == 0) req_comp = 3; + + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); + + // Read data + hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); + + // Load image data + // image data is stored as some number of sca + if ( width < 8 || width >= 32768) { + // Read flat data + for (j=0; j < height; ++j) { + for (i=0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); + } + } + } else { + // Read RLE-encoded data + scanline = NULL; + + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a decoded + // pixel (note this can't be a valid pixel--one of RGB must be >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc) c1; + rgbe[1] = (stbi_uc) c2; + rgbe[2] = (stbi_uc) len; + rgbe[3] = (stbi_uc) stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; + STBI_FREE(scanline); + goto main_decode_loop; // yes, this makes no sense + } + len <<= 8; + len |= stbi__get8(s); + if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } + if (scanline == NULL) { + scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); } - - for (k = 0; k < 4; ++k) { - int nleft; - i = 0; - while ((nleft = width - i) > 0) { - count = stbi__get8(s); - if (count > 128) { - // Run - value = stbi__get8(s); - count -= 128; - if ((count == 0) || (count > nleft)) { - STBI_FREE(hdr_data); - STBI_FREE(scanline); - return stbi__errpf("corrupt", "bad RLE data in HDR"); - } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = value; - } else { - // Dump - if ((count == 0) || (count > nleft)) { - STBI_FREE(hdr_data); - STBI_FREE(scanline); - return stbi__errpf("corrupt", "bad RLE data in HDR"); - } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = stbi__get8(s); - } - } + } + + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } } - for (i = 0; i < width; ++i) - stbi__hdr_convert(hdr_data + (j * width + i) * req_comp, scanline + i * 4, req_comp); - } - if (scanline) - STBI_FREE(scanline); - } + } + for (i=0; i < width; ++i) + stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } - return hdr_data; + return hdr_data; } -static int stbi__hdr_info(stbi__context * s, int * x, int * y, int * comp) { - char buffer[STBI__HDR_BUFLEN]; - char * token; - int valid = 0; - int dummy; - - if (!x) - x = &dummy; - if (!y) - y = &dummy; - if (!comp) - comp = &dummy; - - if (stbi__hdr_test(s) == 0) { - stbi__rewind(s); - return 0; - } - - for (;;) { - token = stbi__hdr_gettoken(s, buffer); - if (token[0] == 0) - break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) - valid = 1; - } - - if (!valid) { - stbi__rewind(s); - return 0; - } - token = stbi__hdr_gettoken(s, buffer); - if (strncmp(token, "-Y ", 3)) { - stbi__rewind(s); - return 0; - } - token += 3; - *y = (int)strtol(token, &token, 10); - while (*token == ' ') - ++token; - if (strncmp(token, "+X ", 3)) { - stbi__rewind(s); - return 0; - } - token += 3; - *x = (int)strtol(token, NULL, 10); - *comp = 3; - return 1; +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) +{ + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int dummy; + + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + + if (stbi__hdr_test(s) == 0) { + stbi__rewind( s ); + return 0; + } + + for(;;) { + token = stbi__hdr_gettoken(s,buffer); + if (token[0] == 0) break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; + } + + if (!valid) { + stbi__rewind( s ); + return 0; + } + token = stbi__hdr_gettoken(s,buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind( s ); + return 0; + } + token += 3; + *y = (int) strtol(token, &token, 10); + while (*token == ' ') ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind( s ); + return 0; + } + token += 3; + *x = (int) strtol(token, NULL, 10); + *comp = 3; + return 1; } #endif // STBI_NO_HDR #ifndef STBI_NO_BMP -static int stbi__bmp_info(stbi__context * s, int * x, int * y, int * comp) { - void * p; - stbi__bmp_data info; - - info.all_a = 255; - p = stbi__bmp_parse_header(s, &info); - if (p == NULL) { - stbi__rewind(s); - return 0; - } - if (x) - *x = s->img_x; - if (y) - *y = s->img_y; - if (comp) { - if (info.bpp == 24 && info.ma == 0xff000000) - *comp = 3; - else - *comp = info.ma ? 4 : 3; - } - return 1; +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) +{ + void *p; + stbi__bmp_data info; + + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + if (p == NULL) { + stbi__rewind( s ); + return 0; + } + if (x) *x = s->img_x; + if (y) *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; } #endif #ifndef STBI_NO_PSD -static int stbi__psd_info(stbi__context * s, int * x, int * y, int * comp) { - int channelCount, dummy, depth; - if (!x) - x = &dummy; - if (!y) - y = &dummy; - if (!comp) - comp = &dummy; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind(s); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind(s); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind(s); - return 0; - } - *y = stbi__get32be(s); - *x = stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 8 && depth != 16) { - stbi__rewind(s); - return 0; - } - if (stbi__get16be(s) != 3) { - stbi__rewind(s); - return 0; - } - *comp = 4; - return 1; -} - -static int stbi__psd_is16(stbi__context * s) { - int channelCount, depth; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind(s); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind(s); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind(s); - return 0; - } - STBI_NOTUSED(stbi__get32be(s)); - STBI_NOTUSED(stbi__get32be(s)); - depth = stbi__get16be(s); - if (depth != 16) { - stbi__rewind(s); - return 0; - } - return 1; +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) +{ + int channelCount, dummy, depth; + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind( s ); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind( s ); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind( s ); + return 0; + } + *comp = 4; + return 1; +} + +static int stbi__psd_is16(stbi__context *s) +{ + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind( s ); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind( s ); + return 0; + } + STBI_NOTUSED(stbi__get32be(s)); + STBI_NOTUSED(stbi__get32be(s)); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind( s ); + return 0; + } + return 1; } #endif #ifndef STBI_NO_PIC -static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp) { - int act_comp = 0, num_packets = 0, chained, dummy; - stbi__pic_packet packets[10]; - - if (!x) - x = &dummy; - if (!y) - y = &dummy; - if (!comp) - comp = &dummy; - - if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) { - stbi__rewind(s); - return 0; - } +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) +{ + int act_comp=0,num_packets=0,chained,dummy; + stbi__pic_packet packets[10]; - stbi__skip(s, 88); + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; - *x = stbi__get16be(s); - *y = stbi__get16be(s); - if (stbi__at_eof(s)) { - stbi__rewind(s); - return 0; - } - if ((*x) != 0 && (1 << 28) / (*x) < (*y)) { - stbi__rewind(s); - return 0; - } + if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } - stbi__skip(s, 8); + stbi__skip(s, 88); - do { - stbi__pic_packet * packet; + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind( s); + return 0; + } + if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind( s ); + return 0; + } - if (num_packets == sizeof(packets) / sizeof(packets[0])) - return 0; + stbi__skip(s, 8); - packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); - act_comp |= packet->channel; + do { + stbi__pic_packet *packet; - if (stbi__at_eof(s)) { - stbi__rewind(s); - return 0; - } - if (packet->size != 8) { - stbi__rewind(s); - return 0; - } - } while (chained); + if (num_packets==sizeof(packets)/sizeof(packets[0])) + return 0; + + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; - *comp = (act_comp & 0x10 ? 4 : 3); + if (stbi__at_eof(s)) { + stbi__rewind( s ); + return 0; + } + if (packet->size != 8) { + stbi__rewind( s ); + return 0; + } + } while (chained); - return 1; + *comp = (act_comp & 0x10 ? 4 : 3); + + return 1; } #endif @@ -7904,271 +7491,272 @@ static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp) { #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context * s) { - char p, t; - p = (char)stbi__get8(s); - t = (char)stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind(s); - return 0; - } - return 1; +static int stbi__pnm_test(stbi__context *s) +{ + char p, t; + p = (char) stbi__get8(s); + t = (char) stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind( s ); + return 0; + } + return 1; } -static void * stbi__pnm_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * out; - STBI_NOTUSED(ri); - - ri->bits_per_channel = stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n); - if (ri->bits_per_channel == 0) - return 0; - - if (s->img_y > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - - *x = s->img_x; - *y = s->img_y; - if (comp) - *comp = s->img_n; - - if (!stbi__mad4sizes_valid(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0)) - return stbi__errpuc("too large", "PNM too large"); - - out = (stbi_uc *)stbi__malloc_mad4(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0); - if (!out) - return stbi__errpuc("outofmem", "Out of memory"); - if (!stbi__getn(s, out, s->img_n * s->img_x * s->img_y * (ri->bits_per_channel / 8))) { - STBI_FREE(out); - return stbi__errpuc("bad PNM", "PNM file truncated"); - } - - if (req_comp && req_comp != s->img_n) { - if (ri->bits_per_channel == 16) { - out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, s->img_n, req_comp, s->img_x, s->img_y); - } else { - out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); - } - if (out == NULL) - return out; // stbi__convert_format frees input on failure - } - return out; +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *out; + STBI_NOTUSED(ri); + + ri->bits_per_channel = stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n); + if (ri->bits_per_channel == 0) + return 0; + + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + *x = s->img_x; + *y = s->img_y; + if (comp) *comp = s->img_n; + + if (!stbi__mad4sizes_valid(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0)) + return stbi__errpuc("too large", "PNM too large"); + + out = (stbi_uc *) stbi__malloc_mad4(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0); + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + if (!stbi__getn(s, out, s->img_n * s->img_x * s->img_y * (ri->bits_per_channel / 8))) { + STBI_FREE(out); + return stbi__errpuc("bad PNM", "PNM file truncated"); + } + + if (req_comp && req_comp != s->img_n) { + if (ri->bits_per_channel == 16) { + out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, s->img_n, req_comp, s->img_x, s->img_y); + } else { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + } + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + return out; +} + +static int stbi__pnm_isspace(char c) +{ + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; } -static int stbi__pnm_isspace(char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; } - -static void stbi__pnm_skip_whitespace(stbi__context * s, char * c) { - for (;;) { - while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) - *c = (char)stbi__get8(s); +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) +{ + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char) stbi__get8(s); - if (stbi__at_eof(s) || *c != '#') - break; + if (stbi__at_eof(s) || *c != '#') + break; - while (!stbi__at_eof(s) && *c != '\n' && *c != '\r') - *c = (char)stbi__get8(s); - } + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r' ) + *c = (char) stbi__get8(s); + } } -static int stbi__pnm_isdigit(char c) { return c >= '0' && c <= '9'; } +static int stbi__pnm_isdigit(char c) +{ + return c >= '0' && c <= '9'; +} -static int stbi__pnm_getinteger(stbi__context * s, char * c) { - int value = 0; +static int stbi__pnm_getinteger(stbi__context *s, char *c) +{ + int value = 0; - while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { - value = value * 10 + (*c - '0'); - *c = (char)stbi__get8(s); - if ((value > 214748364) || (value == 214748364 && *c > '7')) - return stbi__err("integer parse overflow", "Parsing an integer in the PPM header overflowed a 32-bit int"); - } + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value*10 + (*c - '0'); + *c = (char) stbi__get8(s); + if((value > 214748364) || (value == 214748364 && *c > '7')) + return stbi__err("integer parse overflow", "Parsing an integer in the PPM header overflowed a 32-bit int"); + } - return value; + return value; } -static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp) { - int maxv, dummy; - char c, p, t; +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) +{ + int maxv, dummy; + char c, p, t; - if (!x) - x = &dummy; - if (!y) - y = &dummy; - if (!comp) - comp = &dummy; + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; - stbi__rewind(s); + stbi__rewind(s); - // Get identifier - p = (char)stbi__get8(s); - t = (char)stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind(s); - return 0; - } + // Get identifier + p = (char) stbi__get8(s); + t = (char) stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } - *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm - c = (char)stbi__get8(s); - stbi__pnm_skip_whitespace(s, &c); + c = (char) stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); - *x = stbi__pnm_getinteger(s, &c); // read width - if (*x == 0) - return stbi__err("invalid width", "PPM image header had zero or overflowing width"); - stbi__pnm_skip_whitespace(s, &c); + *x = stbi__pnm_getinteger(s, &c); // read width + if(*x == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); - *y = stbi__pnm_getinteger(s, &c); // read height - if (*y == 0) - return stbi__err("invalid width", "PPM image header had zero or overflowing width"); - stbi__pnm_skip_whitespace(s, &c); + *y = stbi__pnm_getinteger(s, &c); // read height + if (*y == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); - maxv = stbi__pnm_getinteger(s, &c); // read max value - if (maxv > 65535) - return stbi__err("max value > 65535", "PPM image supports only 8-bit and 16-bit images"); - else if (maxv > 255) - return 16; - else - return 8; + maxv = stbi__pnm_getinteger(s, &c); // read max value + if (maxv > 65535) + return stbi__err("max value > 65535", "PPM image supports only 8-bit and 16-bit images"); + else if (maxv > 255) + return 16; + else + return 8; } -static int stbi__pnm_is16(stbi__context * s) { - if (stbi__pnm_info(s, NULL, NULL, NULL) == 16) - return 1; - return 0; +static int stbi__pnm_is16(stbi__context *s) +{ + if (stbi__pnm_info(s, NULL, NULL, NULL) == 16) + return 1; + return 0; } #endif -static int stbi__info_main(stbi__context * s, int * x, int * y, int * comp) { -#ifndef STBI_NO_JPEG - if (stbi__jpeg_info(s, x, y, comp)) - return 1; -#endif +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) +{ + #ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_PNG - if (stbi__png_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_GIF - if (stbi__gif_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_BMP - if (stbi__bmp_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_PSD - if (stbi__psd_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_PIC - if (stbi__pic_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_PNM - if (stbi__pnm_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_HDR - if (stbi__hdr_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) return 1; + #endif -// test tga last because it's a crappy test! -#ifndef STBI_NO_TGA - if (stbi__tga_info(s, x, y, comp)) - return 1; -#endif - return stbi__err("unknown image type", "Image not of any known type, or corrupt"); + // test tga last because it's a crappy test! + #ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; + #endif + return stbi__err("unknown image type", "Image not of any known type, or corrupt"); } -static int stbi__is_16_main(stbi__context * s) { -#ifndef STBI_NO_PNG - if (stbi__png_is16(s)) - return 1; -#endif +static int stbi__is_16_main(stbi__context *s) +{ + #ifndef STBI_NO_PNG + if (stbi__png_is16(s)) return 1; + #endif -#ifndef STBI_NO_PSD - if (stbi__psd_is16(s)) - return 1; -#endif + #ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) return 1; + #endif -#ifndef STBI_NO_PNM - if (stbi__pnm_is16(s)) - return 1; -#endif - return 0; + #ifndef STBI_NO_PNM + if (stbi__pnm_is16(s)) return 1; + #endif + return 0; } #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const * filename, int * x, int * y, int * comp) { - FILE * f = stbi__fopen(filename, "rb"); +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) +{ + FILE *f = stbi__fopen(filename, "rb"); int result; - if (!f) - return stbi__err("can't fopen", "Unable to open file"); + if (!f) return stbi__err("can't fopen", "Unable to open file"); result = stbi_info_from_file(f, x, y, comp); fclose(f); return result; } -STBIDEF int stbi_info_from_file(FILE * f, int * x, int * y, int * comp) { - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__info_main(&s, x, y, comp); - fseek(f, pos, SEEK_SET); - return r; +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) +{ + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s,x,y,comp); + fseek(f,pos,SEEK_SET); + return r; } -STBIDEF int stbi_is_16_bit(char const * filename) { - FILE * f = stbi__fopen(filename, "rb"); +STBIDEF int stbi_is_16_bit(char const *filename) +{ + FILE *f = stbi__fopen(filename, "rb"); int result; - if (!f) - return stbi__err("can't fopen", "Unable to open file"); + if (!f) return stbi__err("can't fopen", "Unable to open file"); result = stbi_is_16_bit_from_file(f); fclose(f); return result; } -STBIDEF int stbi_is_16_bit_from_file(FILE * f) { - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__is_16_main(&s); - fseek(f, pos, SEEK_SET); - return r; +STBIDEF int stbi_is_16_bit_from_file(FILE *f) +{ + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f,pos,SEEK_SET); + return r; } #endif // !STBI_NO_STDIO -STBIDEF int stbi_info_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__info_main(&s, x, y, comp); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__info_main(&s,x,y,comp); } -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const * c, void * user, int * x, int * y, int * comp) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); - return stbi__info_main(&s, x, y, comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); + return stbi__info_main(&s,x,y,comp); } -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const * buffer, int len) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__is_16_main(&s); } -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * c, void * user) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); + return stbi__is_16_main(&s); } #endif // STB_IMAGE_IMPLEMENTATION @@ -8279,9 +7867,12 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * c, void * us 1.30 (2011-06-11) added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) removed deprecated format-specific test/load functions - removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks - anyway error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) fix inefficiency in - decoding 32-bit BMP (David Woo) 1.29 (2010-08-16) various warning fixes from Aurelien Pocheville 1.28 (2010-08-01) + removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway + error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) + fix inefficiency in decoding 32-bit BMP (David Woo) + 1.29 (2010-08-16) + various warning fixes from Aurelien Pocheville + 1.28 (2010-08-01) fix bug in GIF palette transparency (SpartanJ) 1.27 (2010-08-01) cast-to-stbi_uc to fix warnings @@ -8353,6 +7944,7 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * c, void * us first released version */ + /* ------------------------------------------------------------------------------ This software is available under 2 licenses -- choose whichever you prefer. diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a3d9c9491b02d..eb8d08acaebba 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -63,6 +63,7 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path + is_lora: bool # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -70,7 +71,7 @@ class Model: def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, is_lora: bool = False): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -92,6 +93,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py + self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -295,6 +297,7 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.FFN_GATE_INP, gguf.MODEL_TENSOR.POS_EMBD, gguf.MODEL_TENSOR.TOKEN_TYPES, + gguf.MODEL_TENSOR.SSM_CONV1D, ) ) or not name.endswith(".weight") @@ -587,6 +590,15 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249": # ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M res = "smollm" + if chkhsh == "3c30d3ad1d6b64202cd222813e7736c2db6e1bd6d67197090fc1211fbc612ae7": + # ref: https://huggingface.co/bigscience/bloom + res = "bloom" + if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21": + # ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small + res = "gpt3-finnish" + if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae": + # ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct + res = "exaone" if res is None: logger.warning("\n") @@ -890,7 +902,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return tensors -@Model.register("BloomForCausalLM") +@Model.register("BloomForCausalLM", "BloomModel") class BloomModel(Model): model_arch = gguf.MODEL_ARCH.BLOOM @@ -1557,7 +1569,7 @@ def prepare_tensors(self): if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) - dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) factor = rope_scaling.get("factor", 8.0) @@ -1580,7 +1592,8 @@ def prepare_tensors(self): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + if not self.is_lora: + self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) super().prepare_tensors() @@ -2127,8 +2140,9 @@ def set_gguf_parameters(self): if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + if not self.is_lora: + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) @Model.register("PlamoForCausalLM") @@ -2699,7 +2713,7 @@ class StarCoder2Model(Model): model_arch = gguf.MODEL_ARCH.STARCODER2 -@Model.register("MambaForCausalLM", "MambaLMHeadModel") +@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") class MambaModel(Model): model_arch = gguf.MODEL_ARCH.MAMBA @@ -2730,7 +2744,10 @@ def set_gguf_parameters(self): # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16) rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 - + use_dt_b_c_norm = False + # For falconmamba we do apply RMS norm on B / DT and C layers + if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",): + use_dt_b_c_norm = True # Fail early for models which don't have a block expansion factor of 2 assert d_inner == 2 * d_model @@ -2738,12 +2755,13 @@ def set_gguf_parameters(self): self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_ssm_conv_kernel(d_conv) self.gguf_writer.add_ssm_inner_size(d_inner) self.gguf_writer.add_ssm_state_size(d_state) self.gguf_writer.add_ssm_time_step_rank(dt_rank) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers self.gguf_writer.add_file_type(self.ftype) _tok_embd = None @@ -2770,23 +2788,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] - def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: - if bid is not None and new_name in ( - self.format_tensor_name( - n, bid, ".weight" if name.endswith(".weight") else "" - ) - for n in [ - gguf.MODEL_TENSOR.SSM_CONV1D, - gguf.MODEL_TENSOR.SSM_X, - gguf.MODEL_TENSOR.SSM_DT, - gguf.MODEL_TENSOR.SSM_A, - gguf.MODEL_TENSOR.SSM_D, - ] - ): - return gguf.GGMLQuantizationType.F32 - - return super().tensor_force_quant(name, new_name, bid, n_dims) - @Model.register("CohereForCausalLM") class CommandR2Model(Model): @@ -3731,8 +3732,121 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.removeprefix("transformer.") return [(self.map_tensor_name(name), data_torch)] -###### CONVERSION LOGIC ###### +@Model.register("NemotronForCausalLM") +class NemotronModel(Model): + model_arch = gguf.MODEL_ARCH.NEMOTRON + + def set_vocab(self): + self._set_vocab_sentencepiece() + self.gguf_writer.add_pad_token_id(0) + self.gguf_writer.add_unk_token_id(1) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + f_norm_eps = self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon", "norm_eps"]) + self.gguf_writer.add_layer_norm_eps(f_norm_eps) + + # * Partial RoPE + rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"]) + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + + # * RopeScaling for Nemotron + if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + else: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side + # model.layers.{l}.input_layernorm.weight + # model.layers.{l}.post_attention_layernorm.weight + # model.norm.weight + if name.endswith("norm.weight"): + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + +@Model.register("ExaoneForCausalLM") +class ExaoneModel(Model): + model_arch = gguf.MODEL_ARCH.EXAONE + + def set_gguf_parameters(self): + hparams = self.hparams + + assert (hparams["activation_function"] == "silu") + + max_position_embeddings = hparams["max_position_embeddings"] + embed_dim = hparams["hidden_size"] + num_heads = hparams["num_attention_heads"] + num_kv_heads = hparams.get("num_key_value_heads", num_heads) + layer_norm_eps = hparams["layer_norm_epsilon"] + intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim + num_layers = hparams["num_layers"] + # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0 + # attention_dropout_rate = hparams["attention_dropout"] + # ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0 + # embed_dropout_rate = hparams["embed_dropout"] + self.gguf_writer.add_embedding_length(embed_dim) + self.gguf_writer.add_head_count(num_heads) + self.gguf_writer.add_head_count_kv(num_kv_heads) + self.gguf_writer.add_context_length(max_position_embeddings) + self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_block_count(num_layers) + self.gguf_writer.add_file_type(self.ftype) + + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True) + rotary_factor = rotary_factor if rotary_factor is not None else 1.0 + self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) + if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]: + if hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) + + def prepare_tensors(self): + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + assert low_freq_wavelen != high_freq_wavelen + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + if not self.is_lora: + self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + + super().prepare_tensors() + + +###### CONVERSION LOGIC ###### # tree of lazy tensors class LazyTorchTensor(gguf.LazyBase): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index d5a2d925eaef5..ff4955f9c614e 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -94,6 +94,9 @@ class TOKENIZER_TYPE(IntEnum): {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, + {'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", }, + {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, + {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, ] diff --git a/convert_llama_ggml_to_gguf.py b/convert_llama_ggml_to_gguf.py index 7b00b4398178b..29b14e98dd237 100755 --- a/convert_llama_ggml_to_gguf.py +++ b/convert_llama_ggml_to_gguf.py @@ -116,7 +116,7 @@ def load(self, data, offset): assert quant is not None, 'Unknown tensor type' (blksize, tysize) = quant offset += 12 - self.dtype= dtype + self.dtype= gguf.GGMLQuantizationType(dtype) self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) offset += 4 * n_dims self.name = bytes(data[offset:offset + name_len]) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index a88d0d4a978a9..ddd347a2abd2a 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -386,6 +386,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter dry_run=args.dry_run, dir_lora_model=dir_lora, lora_alpha=alpha, + is_lora=True, ) logger.info("Exporting model...") diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md new file mode 100644 index 0000000000000..6bdd9d2daab90 --- /dev/null +++ b/docs/backend/CANN.md @@ -0,0 +1,259 @@ +# llama.cpp for CANN + + - [Background](#background) + - [News](#news) + - [OS](#os) + - [Hardware](#hardware) + - [Model Supports](#model-supports) + - [DataType Supports](#datatype-supports) + - [Docker](#docker) + - [Linux](#linux) + - [TODO](#todo) + + +## Background + +**Ascend NPU** is a range of AI processors using Neural Processing Unit. It will efficiently handle matrix-matrix multiplication, dot-product and scalars. + +**CANN** (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for AI scenarios, providing support for multiple AI frameworks on the top and serving AI processors and programming at the bottom. It plays a crucial role in bridging the gap between upper and lower layers, and is a key platform for improving the computing efficiency of Ascend AI processors. Meanwhile, it offers a highly efficient and easy-to-use programming interface for diverse application scenarios, allowing users to rapidly build AI applications and services based on the Ascend platform. + +**Llama.cpp + CANN** + +The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the ability of AscendC and ACLNN which are intergrated to CANN Toolkit and kernels to using Ascend NPU directly. + +## News + +- 2024.8 + - Support `Q4_0` and `Q8_0` data type for Ascend NPU. +- 2024.7 + - Create CANN backend for Ascend NPU. + +## OS + +| OS | Status | Verified | +|:-------:|:-------:|:----------------------------------------------:| +| Linux | Support | Ubuntu 22.04, OpenEuler22.03 | + + +## Hardware + +### Ascend NPU + +**Verified devices** +| Ascend NPU | Status | +|:-----------------------------:|:-------:| +| Atlas 300T A2 | Support | + +*Notes:* + +- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag. +- If you run successfully with your Ascend NPU device, please help update the upper table. + + +## Model Supports + +| Model Name | FP16 | Q8_0 | Q4_0 | +|:----------------------------|:-----:|:----:|:----:| +| AquilaChat2-7B | √ | √ | √ | +| Baichuan-7b | √ | √ | √ | +| Baichuan2-7B-Chat | √ | √ | √ | +| bitnet_b1_58-large | √ | √ | √ | +| bloom-560m | √ | x | √ | +| bloomz-alpaca-560m | √ | x | √ | +| c4ai-command-r-35B-v01 | x | x | x | +| chatglm3-6B | x | x | x | +| chinese-alpaca-2-1.3b | √ | √ | √ | +| CodeShell-7B | √ | √ | √ | +| deepseek-ai_deepseek-coder-1.3B-base | x | x | x | +| deepseek-ai_DeepSeek-V2-Lite | x | x | x | +| deepseek-coder-6.7B-instruct | x | x | x | +| DeepSeek-V2-Lite-64x1.5B | x | x | x | +| falcon-7b-instruct | √ | √ | √ | +| flan-t5-large | √ | √ | √ | +| gemma-2-9b-it | √ | √ | √ | +| glm-4-9B | x | x | x | +| gpt2 | √ | √ | √ | +| Gpt2-163M | √ | √ | √ | +| granite-3B-code-instruct | √ | √ | √ | +| GritLM-7B | √ | √ | √ | +| internlm2_5-7b-chat | √ | √ | √ | +| koala-7B-HF | √ | √ | √ | +| Llama-2-7b-chat-hf | √ | √ | √ | +| Llama-3-Smaug-8B | √ | √ | √ | +| Llama2-Chinese-7b-Chat | √ | √ | √ | +| Llama3-8B | √ | √ | √ | +| Llama3-8b-chinese | √ | √ | √ | +| mamba-130m-hf | √ | √ | √ | +| Mistral-7B-Instruct-v0.2 | √ | √ | √ | +| Mixtral-8x7B-Instruct-v0.1 | x | √ | √ | +| mpt-7B | √ | √ | √ | +| OLMo-1B-hf | √ | √ | √ | +| OpenELM-3B-Instruct | √ | √ | √ | +| Orion-14b-base | √ | √ | √ | +| phi1 | x | x | x | +| phi2 | x | x | x | +| Phi-3-mini-4k-instruct | √ | √ | √ | +| plamo-13b | √ | √ | √ | +| pythia-70M | x | x | x | +| Qwen-7B | √ | √ | √ | +| Qwen2-1.5B-Instruct | √ | x | √ | +| Refact-1_6B-fim | √ | √ | √ | +| SmolLM-135M | √ | √ | √ | +| stablelm-zephyr | x | x | x | +| stablelm-2-zephyr-1_6b | x | x | x | +| starcoderbase-1b | √ | √ | √ | +| starcoder2-3b | √ | √ | √ | +| vigogne-7b-chat | √ | √ | √ | +| xverse-7b-chat | √ | √ | √ | +| Yi-6b-Chat | √ | √ | √ | + + + +## DataType Supports + +| DataType | Status | +|:----------------------:|:-------:| +| FP16 | Support | +| Q8_0 | Support | +| Q4_0 | Support | + +## Docker + +### Build Images +You can get a image with llama.cpp in one command. +```sh +docker build -t llama-cpp-cann -f .devops/llama-cli-cann.Dockerfile . +``` + +### Run container + +```sh +# Find all cards. +npu-smi info + +# Select the cards that you want to use, make sure these cards are not used by someone. +# Following using cards of device0. +docker run --name llamacpp --device /dev/davinci0 --device /dev/davinci_manager --device /dev/devmm_svm --device /dev/hisi_hdc -v /usr/local/dcmi:/usr/local/dcmi -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info -v /PATH_TO_YOUR_MODELS/:/app/models -it llama-cpp-cann -m /app/models/MODEL_PATH -ngl 32 -p "Building a website can be done in 10 simple steps:" +``` + +*Notes:* + +- You may need to install Ascend Driver and firmware on the **host** machine *(Please refer to the [Linux configuration](#linux) for details)*. + +## Linux + +### I. Setup Environment + +1. **Install Ascend Driver and firmware** + + ```sh + # create driver running user. + sudo groupadd -g HwHiAiUser + sudo useradd -g HwHiAiUser -d /home/HwHiAiUser -m HwHiAiUser -s /bin/bash + sudo usermod -aG HwHiAiUser $USER + + # download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system + # and install driver. + sudo sh Ascend-hdk-910b-npu-driver_x.x.x_linux-{arch}.run --full --install-for-all + ``` + + Once installed, run `npu-smi info` to check whether driver is installed successfully. + ```sh + +-------------------------------------------------------------------------------------------+ + | npu-smi 24.1.rc2 Version: 24.1.rc2 | + +----------------------+---------------+----------------------------------------------------+ + | NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page)| + | Chip | Bus-Id | AICore(%) Memory-Usage(MB) HBM-Usage(MB) | + +======================+===============+====================================================+ + | 2 xxx | OK | 64.4 51 15 / 15 | + | 0 | 0000:01:00.0 | 0 1873 / 15077 0 / 32768 | + +======================+===============+====================================================+ + | 5 xxx | OK | 64.0 52 15 / 15 | + | 0 | 0000:81:00.0 | 0 1874 / 15077 0 / 32768 | + +======================+===============+====================================================+ + | No running processes found in NPU 2 | + +======================+===============+====================================================+ + | No running processes found in NPU 5 | + +======================+===============+====================================================+ + ``` + +2. **Install Ascend Firmware** + ```sh + # download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system + # and install driver. + sudo sh Ascend-hdk-910b-npu-firmware_x.x.x.x.X.run --full + ``` + If the following messaage appers, firmware is installed successfully. + ```sh + Firmware package installed successfully! + ``` + + +3. **Install CANN toolkit and kernels** + + CANN toolkit and kernels can be obtained from the official [CANN Toolkit](https://www.hiascend.com/zh/developer/download/community/result?module=cann) page. + + Please download the corresponding version that satified your system. The minimum version required is 8.0.RC2.alpha002 and here is the install command. + ```sh + pip3 install attrs numpy decorator sympy cffi pyyaml pathlib2 psutil protobuf scipy requests absl-py wheel typing_extensions + sh Ascend-cann-toolkit_8.0.RC2.alpha002_linux-aarch64.run --install + sh Ascend-cann-kernels-910b_8.0.RC2.alpha002_linux.run --install + ``` + + Set Ascend Variables: + ```sh + echo "source ~/Ascend/ascend-toolkit/set_env.sh" >> ~/.bashrc + source ~/.bashrc + ``` + +Upon a successful installation, CANN is enabled for the available ascend devices. + +### II. Build llama.cpp + +```sh +cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release +cmake --build build --config release +``` + +### III. Run the inference + +1. **Retrieve and prepare model** + + You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration. + + **Notes**: + + - CANN backend only supports FP16/Q4_0/Q8_0 models currently. + +2. **Launch inference** + + There are two device selection modes: + + - Single device: Use one device target specified by the user. + - Multiple devices: Automatically choose the devices with the same backend. + + | Device selection | Parameter | + |:----------------:|:--------------------------------------:| + | Single device | --split-mode none --main-gpu DEVICE_ID | + | Multiple devices | --split-mode layer (default) | + + Examples: + + - Use device 0: + + ```sh + ./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0 + ``` + + - Use multiple devices: + + ```sh + ./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer + ``` + +### **GitHub contribution**: +Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. + + +## TODO +- Support more models and data types. diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 59a39fbb67395..e838b2be6b11c 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -20,7 +20,7 @@ **oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include: - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. -- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL - Math Kernel Library)*. +- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL and oneDNN)*. - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs. - **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets. @@ -28,10 +28,6 @@ The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*). -When targeting **Intel CPU**, it is recommended to use llama.cpp for [Intel oneMKL](README.md#intel-onemkl) backend. - -It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS, cuBLAS, etc..*. In beginning work, the oneAPI's [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) open-source migration tool (Commercial release [Intel® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) was used for this purpose. - ## Recommended Release The SYCL backend would be broken by some PRs due to no online CI. @@ -45,6 +41,10 @@ The following release is verified with good quality: ## News + +- 2024.8 + - Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs. + - 2024.5 - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770. - Arch Linux is verified successfully. @@ -196,7 +196,7 @@ Please follow the instructions for downloading and installing the Toolkit for Li Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable. -Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI MKL for intel GPUs. +Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI oneDNN for Intel GPUs. - **Adding support to Nvidia GPUs** @@ -255,8 +255,6 @@ or # Export relevant ENV variables source /opt/intel/oneapi/setvars.sh -# Build LLAMA with MKL BLAS acceleration for intel GPU - # Option 1: Use FP32 (recommended for better performance in most cases) cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx diff --git a/docs/build.md b/docs/build.md index 8b16d1a358518..152d46d6f31af 100644 --- a/docs/build.md +++ b/docs/build.md @@ -352,6 +352,31 @@ cmake --build build --config Release # ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32 ``` +### CANN +This provides NPU acceleration using the AI cores of your Ascend NPU. And [CANN](https://www.hiascend.com/en/software/cann) is a hierarchical APIs to help you to quickly build AI applications and service based on Ascend NPU. + +For more information about Ascend NPU in [Ascend Community](https://www.hiascend.com/en/). + +Make sure to have the CANN toolkit installed. You can download it from here: [CANN Toolkit](https://www.hiascend.com/developer/download/community/result?module=cann) + +Go to `llama.cpp` directory and build using CMake. +```bash +cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release +cmake --build build --config release +``` + +You can test with: + +`./build/llama-cli -m PATH_TO_MODEL -p "Building a website can be done in 10 steps:" -ngl 32` + +If the fllowing info is output on screen, you are using `llama.cpp by CANN backend`: +```bash +llm_load_tensors: CANN buffer size = 13313.00 MiB +llama_new_context_with_model: CANN compute buffer size = 1260.81 MiB +``` + +For detailed info, such as model/device supports, CANN install, please refer to [llama.cpp for CANN](./backend/CANN.md). + ### Android To read documentation for how to build on Android, [click here](./android.md) diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index a12e90d828275..8fa492571aa44 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -271,7 +271,7 @@ struct tokenized_prompt { size_t max_seq_len; tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); tokens_pos = ::llama_tokenize(ctx, pos, add_bos, true); tokens_neg = ::llama_tokenize(ctx, neg, add_bos, true); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index ef35ba2c03942..5e89988e2beda 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -127,7 +127,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { } static bool run(llama_context * ctx, const gpt_params & params) { - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); diff --git a/examples/export-lora/README.md b/examples/export-lora/README.md index 91c33c34acaa9..7dce99c9a9e61 100644 --- a/examples/export-lora/README.md +++ b/examples/export-lora/README.md @@ -17,9 +17,9 @@ For example: ```bash ./bin/llama-export-lora \ - -m open-llama-3b-v2-q8_0.gguf \ - -o open-llama-3b-v2-q8_0-english2tokipona-chat.gguf \ - --lora lora-open-llama-3b-v2-q8_0-english2tokipona-chat-LATEST.gguf + -m open-llama-3b-v2.gguf \ + -o open-llama-3b-v2-english2tokipona-chat.gguf \ + --lora lora-open-llama-3b-v2-english2tokipona-chat-LATEST.gguf ``` Multiple LORA adapters can be applied by passing multiple `--lora FNAME` or `--lora-scaled FNAME S` command line parameters: diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index 3176d6e26ef8b..c7e5ca78845ee 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -10,6 +10,12 @@ static bool g_verbose = false; +struct tensor_transformation { + struct ggml_tensor * in; + struct ggml_tensor * out; + bool is_copy; +}; + static std::string get_kv_str(struct gguf_context * ctx_gguf, const std::string & key){ int id = gguf_find_key(ctx_gguf, key.c_str()); return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id)); @@ -198,8 +204,7 @@ struct lora_merge_ctx { } // mapping base tensor to out tensor (same shape with base, but different type) - // if out_tensor == nullptr, we only copy it - std::vector> base_to_out_tensors; + std::vector trans; for (auto & it : base_model.tensors) { bool t_a = true; bool t_b = true; @@ -212,14 +217,22 @@ struct lora_merge_ctx { // only copy struct ggml_tensor * cpy_tensor = ggml_dup_tensor(ctx_out_ggml, base_tensor); ggml_set_name(cpy_tensor, base_tensor->name); - base_to_out_tensors.push_back(std::make_pair(cpy_tensor, nullptr)); + trans.push_back({ + cpy_tensor, + cpy_tensor, + true, + }); gguf_add_tensor(ctx_out, cpy_tensor); } else if (t_a && t_b) { // need merging struct ggml_tensor * out_tensor = ggml_new_tensor( ctx_out_ggml, get_out_tensor_type(base_tensor), GGML_MAX_DIMS, base_tensor->ne); ggml_set_name(out_tensor, base_tensor->name); - base_to_out_tensors.push_back(std::make_pair(base_tensor, out_tensor)); + trans.push_back({ + base_tensor, + out_tensor, + false, + }); gguf_add_tensor(ctx_out, out_tensor); } else { throw std::runtime_error("tensor " + it.first + " missing either lora_a or lora_b"); @@ -234,12 +247,12 @@ struct lora_merge_ctx { // process base model tensors size_t n_merged = 0; - for (auto & it : base_to_out_tensors) { - if (it.second != nullptr) { - merge_tensor(it.first, it.second); + for (auto & it : trans) { + if (!it.is_copy) { + merge_tensor(it.in, it.out); n_merged++; } else { - copy_tensor(it.first); + copy_tensor(it.in); } } @@ -252,7 +265,7 @@ struct lora_merge_ctx { } printf("%s : merged %ld tensors with lora adapters\n", __func__, n_merged); - printf("%s : wrote %ld tensors to output file\n", __func__, base_to_out_tensors.size()); + printf("%s : wrote %ld tensors to output file\n", __func__, trans.size()); } void copy_tensor(struct ggml_tensor * base) { @@ -285,6 +298,10 @@ struct lora_merge_ctx { for (size_t i = 0; i < adapters.size(); ++i) { auto t_a = adapters[i]->get_tensor(name_lora_a); auto t_b = adapters[i]->get_tensor(name_lora_b); + // TODO: add support for quantized lora + if (ggml_is_quantized(t_a->type) || ggml_is_quantized(t_b->type)) { + throw std::runtime_error("quantized LoRA adapters is not supported, please retry with f16 or f32"); + } inp_a[i] = ggml_dup_tensor(ctx, t_a); inp_b[i] = ggml_dup_tensor(ctx, t_b); } diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 58814b96e7d49..83b85d72b043a 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -433,8 +433,8 @@ static void process_logits( } static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); const int n_ctx = llama_n_ctx(ctx); auto tim1 = std::chrono::high_resolution_clock::now(); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 92d630b15fdf1..05700c1d591d9 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -203,8 +203,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); LOG_TEE("%s\n", gpt_params_get_system_info(params).c_str()); } - const bool add_bos = llama_should_add_bos_token(model); - GGML_ASSERT(llama_add_eos_token(model) != 1); + const bool add_bos = llama_add_bos_token(model); + GGML_ASSERT(!llama_add_eos_token(model)); LOG("add_bos: %d\n", add_bos); std::vector embd_inp; diff --git a/examples/llava/README-minicpmv2.5.md b/examples/llava/README-minicpmv2.5.md index 4affc1d0f26ff..1c8498ff9e151 100644 --- a/examples/llava/README-minicpmv2.5.md +++ b/examples/llava/README-minicpmv2.5.md @@ -15,9 +15,9 @@ cd llama.cpp Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) by us) ```bash -python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 -python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 -python ./convert-hf-to-gguf.py ../MiniCPM-Llama3-V-2_5/model +python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 +python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 2 +python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model # quantize int4 version ./llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M diff --git a/examples/llava/README-minicpmv2.6.md b/examples/llava/README-minicpmv2.6.md new file mode 100644 index 0000000000000..c4be5e5dd6484 --- /dev/null +++ b/examples/llava/README-minicpmv2.6.md @@ -0,0 +1,107 @@ +## MiniCPM-V 2.6 + +### Prepare models and code + +Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch model from huggingface to "MiniCPM-V-2_6" folder. + +Clone llama.cpp: +```bash +git clone git@github.com:OpenBMB/llama.cpp.git +cd llama.cpp +git checkout minicpmv-main +``` + +### Usage of MiniCPM-V 2.6 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) by us) + +```bash +python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-V-2_6 +python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-2_6 --minicpmv-projector ../MiniCPM-V-2_6/minicpmv.projector --output-dir ../MiniCPM-V-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 3 +python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model + +# quantize int4 version +./llama-quantize ../MiniCPM-V-2_6/model/ggml-model-f16.gguf ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + +Build for Linux or Mac + +```bash +make +make llama-minicpmv-cli +``` + +Inference on Linux or Mac +``` +# run f16 version +./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run quantized int4 version +./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# or run in interactive mode +./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i +``` + +### Video +Install FFmpeg +``` +brew install ffmpeg +brew install pkg-config +``` + +### Android + +#### Build on Android device using Termux +We found that build on Android device would bring better runtime performance, so we recommend to build on device. + +[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required). + +Install tools in Termux: +``` +apt update && apt upgrade -y +apt install git make cmake +``` + +It's recommended to move your model inside the `~/` directory for best performance: +``` +cd storage/downloads +mv model.gguf ~/ +``` + +#### Building the Project using Android NDK +Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake. + +Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux: + +```bash +mkdir build-android +cd build-android +export NDK=/your_ndk_path +cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod .. +make +``` + +Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice). + +Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission: + +(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`) +``` +$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/ +$cd /data/data/com.termux/files/home/bin +$chmod +x ./* +``` + +Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/` + +``` +$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/ +$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/ +``` + +Now, you can start chatting: +``` +$cd /data/data/com.termux/files/home/bin +$./llama-minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" +``` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 54aa822c90d29..10e8765b4cd19 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -20,6 +20,10 @@ #include "ggml-cann.h" #endif +#ifdef GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif + #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -81,6 +85,7 @@ static std::string format(const char * fmt, ...) { #define KEY_HAS_VIS_ENC "clip.has_vision_encoder" #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_USE_GELU "clip.use_gelu" #define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_FF "clip.%s.feed_forward_length" @@ -211,13 +216,19 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int static void replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); - } + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); } static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { @@ -526,6 +537,7 @@ struct clip_ctx { bool has_vision_encoder = false; bool has_llava_projector = false; bool has_minicpmv_projector = false; + int minicpmv_version = 2; struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -641,7 +653,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 if (ctx->has_minicpmv_projector) { int pos_w = image_size_width/patch_size; int pos_h = image_size_height/patch_size; - pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); + if (ctx->minicpmv_version == 2) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); + } + else if (ctx->minicpmv_version == 3) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); + } ggml_set_name(pos_embed, "pos_embed"); ggml_set_input(pos_embed); } @@ -768,8 +785,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_gelu(ctx0, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); - - } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { + } + else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -949,10 +966,20 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } { // attention - const int hidden_size = 4096; + int hidden_size = 4096; const int d_head = 128; - const int n_head = hidden_size/d_head; - const int num_query = 96; + int n_head = hidden_size/d_head; + int num_query = 96; + if (ctx->minicpmv_version == 2) { + hidden_size = 4096; + n_head = hidden_size/d_head; + num_query = 96; + } + else if (ctx->minicpmv_version == 3) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); @@ -1091,7 +1118,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } } - clip_ctx * new_clip = new clip_ctx; + clip_ctx * new_clip = new clip_ctx{}; // update projector type { @@ -1125,6 +1152,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_TEE("%s: CLIP using CANN backend\n", __func__); #endif +#ifdef GGML_USE_VULKAN + new_clip->backend = ggml_backend_vk_init(0); + LOG_TEE("%s: CLIP using Vulkan backend\n", __func__); +#endif if (!new_clip->backend) { new_clip->backend = ggml_backend_cpu_init(); @@ -1149,6 +1180,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx); } + idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION); + if (idx != -1) { + new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); + } + // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search GGML_ASSERT(new_clip->has_vision_encoder); @@ -1910,10 +1946,12 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { - if (clip_is_minicpmv(ctx)) { - std::vector> imgs = uhd_slice_image(img); + + if(clip_is_minicpmv(ctx)){ + int max_slice_nums = 9; + std::vector> imgs = uhd_slice_image(img, max_slice_nums); res_imgs->size = 0; - for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t i = 0; i < imgs.size(); ++i){ res_imgs->size += imgs[i].size(); } res_imgs->data = new clip_image_f32[res_imgs->size]; @@ -2146,7 +2184,12 @@ int clip_n_patches(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { n_patches /= 4; } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { - n_patches = 96; + if (ctx->minicpmv_version == 2) { + n_patches = 96; + } + else if (ctx->minicpmv_version == 3) { + n_patches = 64; + } } return n_patches; @@ -2282,6 +2325,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + if(ctx->load_image_size==nullptr){ + ctx->load_image_size= clip_image_size_init(); + } + const int pos_w = ctx->load_image_size->width/patch_size; + const int pos_h = ctx->load_image_size->height/patch_size; { struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); @@ -2316,8 +2364,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); - for (int i = 0; i < num_positions; i++) { - positions_data[i] = std::floor(70.0*i/num_positions); + int bucket_coords_h[70]; + int bucket_coords_w[70]; + for (int i = 0; i < pos_h; i++){ + bucket_coords_h[i] = std::floor(70.0*i/pos_h); + } + for (int i = 0; i < pos_w; i++){ + bucket_coords_w[i] = std::floor(70.0*i/pos_w); + } + for (int i = 0, id = 0; i < pos_h; i++){ + for (int j = 0; j < pos_w; j++){ + positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + } } ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); @@ -2328,12 +2386,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/Qwen/Qwen-VL/tree/main // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); - if(ctx->load_image_size==nullptr){ - ctx->load_image_size= clip_image_size_init(); - } - int pos_w = ctx->load_image_size->width/patch_size; - int pos_h = ctx->load_image_size->height/patch_size; int embed_dim = 4096; + if (ctx->minicpmv_version == 2) { + embed_dim = 4096; + } + else if (ctx->minicpmv_version == 3) { + embed_dim = 3584; + } auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); @@ -2346,7 +2405,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed)); free(pos_embed_data); } - } else { + } + else{ { if (ctx->has_class_embedding) { struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); @@ -2548,13 +2608,21 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->vision_model.mm_3_b->ne[0]; } if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { - return 4096; + if (ctx->minicpmv_version == 2) { + return 4096; + } + else if (ctx->minicpmv_version == 3) { + return 3584; + } } std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); } -bool clip_is_minicpmv(const struct clip_ctx * ctx) { - return ctx->has_minicpmv_projector; +int clip_is_minicpmv(const struct clip_ctx * ctx) { + if (ctx->has_minicpmv_projector) { + return ctx->minicpmv_version; + } + return 0; } diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 2ff4d39929dc3..78588bdf1745c 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -85,7 +85,7 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); -CLIP_API bool clip_is_minicpmv(const struct clip_ctx * ctx); +CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); #ifdef __cplusplus } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 916d9dc401dc4..851af0f004a69 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -256,7 +256,14 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli load_image_size->width = img_res_v.data[i].nx; load_image_size->height = img_res_v.data[i].ny; clip_add_load_image_size(ctx_clip, load_image_size); - const bool encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); + bool encoded = false; + int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); + if (has_minicpmv_projector == 2) { + encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); + } + else if (has_minicpmv_projector == 3) { + encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); + } if (!encoded) { LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); return false; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index f951b57b29158..379fc295f1101 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -134,7 +134,13 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e std::string system_prompt; int idx = 0; int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip); - system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; + int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); + if (has_minicpmv_projector == 2) { + system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; + } + else if (has_minicpmv_projector == 3) { + system_prompt = "<|im_start|>user\n"; + } LOG_TEE("%s: image token past: %d\n", __func__, n_past); eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false); process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); @@ -210,10 +216,24 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ std::string user_prompt = prompt; - if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt; + int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); + if (!is_first) { + if (has_minicpmv_projector == 2) { + user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt; + } + else if (has_minicpmv_projector == 3) { + user_prompt = "<|im_start|>user\n" + prompt; + } + } eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); - eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); + if (has_minicpmv_projector == 2) { + eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); + } + else if (has_minicpmv_projector == 3) { + eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); + } + // generate the response LOG_TEE("\n"); diff --git a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py index 12cdd1281d2ff..ea773742a832b 100644 --- a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py +++ b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py @@ -1,9 +1,416 @@ -import argparse +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Siglip model. """ +# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes + + import os +import math +import warnings + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out + +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import ( + logging, +) +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + Example: + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +): + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + denom = fan_in + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.self_attn = ( + SiglipAttention(config) + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + + if isinstance(module, SiglipVisionEmbeddings): + width = self.config.hidden_size + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + +class SiglipVisionTransformer(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + _supports_flash_attn_2 = True + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + +import argparse import json import re -import torch import numpy as np from gguf import * from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig @@ -94,6 +501,7 @@ def bytes_to_unicode(): default_image_std = [0.26862954, 0.26130258, 0.27577711] ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) +ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2) # with proper args = ap.parse_args() @@ -135,6 +543,15 @@ def bytes_to_unicode(): # model = CLIPModel.from_pretrained(dir_model) # processor = CLIPProcessor.from_pretrained(dir_model) +minicpmv_version = args.minicpmv_version +emb_dim = 4096 +if minicpmv_version == 1: + emb_dim = 2304 +elif minicpmv_version == 2: + emb_dim = 4096 +elif minicpmv_version == 3: + emb_dim = 3584 + default_vision_config = { "hidden_size": 1152, "image_size": 980, @@ -144,8 +561,12 @@ def bytes_to_unicode(): "num_hidden_layers": 27, "patch_size": 14, } + vision_config = Idefics2VisionConfig(**default_vision_config) model = Idefics2VisionTransformer(vision_config) +if minicpmv_version == 3: + vision_config = SiglipVisionConfig(**default_vision_config) + model = SiglipVisionTransformer(vision_config) processor = None # if model.attn_pool is not None: @@ -158,6 +579,7 @@ def bytes_to_unicode(): has_text_encoder = True has_vision_encoder = True has_minicpmv_projector = False + if args.text_only: fname_middle = "text-" has_vision_encoder = False @@ -165,6 +587,7 @@ def bytes_to_unicode(): fname_middle = "mmproj-" has_text_encoder = False has_minicpmv_projector = True + minicpmv_version = 3 elif args.vision_only: fname_middle = "vision-" has_text_encoder = False @@ -189,6 +612,7 @@ def bytes_to_unicode(): fout.add_description("image encoder for MiniCPM-V") # add projector type fout.add_string("clip.projector_type", "resampler") + fout.add_int32("clip.minicpmv_version", minicpmv_version) else: fout.add_description("two-tower CLIP model") @@ -274,11 +698,11 @@ def _replace_name_resampler(s, v): if re.match("resampler.pos_embed", s): return { s: v, - re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), + re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))), } if re.match("resampler.proj", s): return { - re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), + re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))), re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), } if re.match("resampler.attn.in_proj_.*", s): diff --git a/examples/llava/minicpmv-surgery.py b/examples/llava/minicpmv-surgery.py index 2b6bce7cfebe9..748ff5c57824e 100644 --- a/examples/llava/minicpmv-surgery.py +++ b/examples/llava/minicpmv-surgery.py @@ -4,7 +4,7 @@ from transformers import AutoModel, AutoTokenizer ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.5 model") +ap.add_argument("-m", "--model", help="Path to MiniCPM-V model") args = ap.parse_args() # find the model part that includes the the multimodal projector weights @@ -29,7 +29,6 @@ f.write("{}\n") config = model.llm.config -config._name_or_path = "openbmb/MiniCPM-Llama3-V-2.5" config.auto_map = { "AutoConfig": "configuration_minicpm.MiniCPMConfig", "AutoModel": "modeling_minicpm.MiniCPMModel", @@ -40,7 +39,6 @@ model.llm.save_pretrained(f"{args.model}/model") tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) tok.save_pretrained(f"{args.model}/model") -# os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/MiniCPM_l3/modeling_minicpm.py") print("Done!") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") diff --git a/examples/llava/requirements.txt b/examples/llava/requirements.txt index dfe5fbe62cea6..cbcbf26c9b4e9 100644 --- a/examples/llava/requirements.txt +++ b/examples/llava/requirements.txt @@ -2,4 +2,4 @@ --extra-index-url https://download.pytorch.org/whl/cpu pillow~=10.2.0 torch~=2.2.1 -torchvision==0.17.1 +torchvision~=0.17.1 diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 82b22062dbfa0..96eb45660c55d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -269,9 +269,9 @@ int main(int argc, char ** argv) { } } - const bool add_bos = llama_should_add_bos_token(model); + const bool add_bos = llama_add_bos_token(model); if (!llama_model_has_encoder(model)) { - GGML_ASSERT(llama_add_eos_token(model) != 1); + GGML_ASSERT(!llama_add_eos_token(model)); } LOG("add_bos: %d\n", add_bos); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 372684f092de2..484dd589109c7 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -340,8 +340,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); @@ -480,8 +480,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); std::ofstream logits_stream; if (!params.logits_file.empty()) { @@ -1733,8 +1733,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { const int n_batch = params.n_batch; const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int nv = 2*((n_vocab + 1)/2) + 4; - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); diff --git a/examples/quantize/README.md b/examples/quantize/README.md index 553c2701bced3..5d1e11c67b13f 100644 --- a/examples/quantize/README.md +++ b/examples/quantize/README.md @@ -34,7 +34,7 @@ Run the quantized model: ```bash # start inference on a gguf model -./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -n 128 +./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -cnv -p "You are a helpful assistant" ``` When running the larger models, make sure you have enough disk space to store all the intermediate files. diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 7312309aeef98..2023463100c8b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -104,7 +104,7 @@ static void usage(const char * executable) { printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); - printf(" --keep-split: will generate quatized model in the same shards as input"); + printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); printf("Note: --include-weights and --exclude-weights cannot be used together\n"); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 65b19ce71cbe3..aab9d81058af9 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -253,6 +253,8 @@ int main(int argc, char ** argv) { chunks[i].tokens.clear(); } + struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); + // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; while (true) { @@ -260,7 +262,6 @@ int main(int argc, char ** argv) { std::getline(std::cin, query); std::vector query_tokens = llama_tokenize(ctx, query, true); - struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); batch_add_seq(query_batch, query_tokens, 0); std::vector query_emb(n_embd, 0); @@ -293,6 +294,7 @@ int main(int argc, char ** argv) { } // clean up + llama_batch_free(query_batch); llama_print_timings(ctx); llama_free(ctx); llama_free_model(model); diff --git a/examples/server/README.md b/examples/server/README.md index e17595fe87f25..805e05b4a5114 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -247,6 +247,51 @@ logging: --log-append Don't truncate the old log file. ``` +Available environment variables (if specified, these variables will override parameters specified in arguments): + +- `LLAMA_CACHE`: cache directory, used by `--hf-repo` +- `HF_TOKEN`: Hugging Face access token, used when accessing a gated model with `--hf-repo` +- `LLAMA_ARG_MODEL`: equivalent to `-m` +- `LLAMA_ARG_MODEL_URL`: equivalent to `-mu` +- `LLAMA_ARG_MODEL_ALIAS`: equivalent to `-a` +- `LLAMA_ARG_HF_REPO`: equivalent to `--hf-repo` +- `LLAMA_ARG_HF_FILE`: equivalent to `--hf-file` +- `LLAMA_ARG_THREADS`: equivalent to `-t` +- `LLAMA_ARG_CTX_SIZE`: equivalent to `-c` +- `LLAMA_ARG_N_PARALLEL`: equivalent to `-np` +- `LLAMA_ARG_BATCH`: equivalent to `-b` +- `LLAMA_ARG_UBATCH`: equivalent to `-ub` +- `LLAMA_ARG_N_GPU_LAYERS`: equivalent to `-ngl` +- `LLAMA_ARG_THREADS_HTTP`: equivalent to `--threads-http` +- `LLAMA_ARG_CHAT_TEMPLATE`: equivalent to `--chat-template` +- `LLAMA_ARG_N_PREDICT`: equivalent to `-n` +- `LLAMA_ARG_ENDPOINT_METRICS`: if set to `1`, it will enable metrics endpoint (equivalent to `--metrics`) +- `LLAMA_ARG_ENDPOINT_SLOTS`: if set to `0`, it will **disable** slots endpoint (equivalent to `--no-slots`). This feature is enabled by default. +- `LLAMA_ARG_EMBEDDINGS`: if set to `1`, it will enable embeddings endpoint (equivalent to `--embeddings`) +- `LLAMA_ARG_FLASH_ATTN`: if set to `1`, it will enable flash attention (equivalent to `-fa`) +- `LLAMA_ARG_CONT_BATCHING`: if set to `0`, it will **disable** continuous batching (equivalent to `--no-cont-batching`). This feature is enabled by default. +- `LLAMA_ARG_DEFRAG_THOLD`: equivalent to `-dt` +- `LLAMA_ARG_HOST`: equivalent to `--host` +- `LLAMA_ARG_PORT`: equivalent to `--port` + +Example usage of docker compose with environment variables: + +```yml +services: + llamacpp-server: + image: ghcr.io/ggerganov/llama.cpp:server + ports: + - 8080:8080 + volumes: + - ./models:/models + environment: + # alternatively, you can use "LLAMA_ARG_MODEL_URL" to download the model + LLAMA_ARG_MODEL: /models/my_model.gguf + LLAMA_ARG_CTX_SIZE: 4096 + LLAMA_ARG_N_PARALLEL: 2 + LLAMA_ARG_ENDPOINT_METRICS: 1 # to disable, either remove or set to 0 + LLAMA_ARG_PORT: 8080 +``` ## Build @@ -368,15 +413,16 @@ node index.js ## API Endpoints -### GET `/health`: Returns the current state of the server +### GET `/health`: Returns heath check result - - 503 -> `{"status": "loading model"}` if the model is still being loaded. - - 500 -> `{"status": "error"}` if the model failed to load. - - 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below. - - 200 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slots are currently available. - - 503 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if the query parameter `fail_on_no_slot` is provided and no slots are currently available. +**Response format** - If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set. +- HTTP status code 503 + - Body: `{"error": {"code": 503, "message": "Loading model", "type": "unavailable_error"}}` + - Explanation: the model is still being loaded. +- HTTP status code 200 + - Body: `{"status": "ok" }` + - Explanation: the model is successfully loaded and the server is ready. ### POST `/completion`: Given a `prompt`, it returns the predicted completion. @@ -639,10 +685,16 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte }' ``` -### GET `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`. +### GET `/slots`: Returns the current slots processing state + +This endpoint can be disabled with `--no-slots` + +If query param `?fail_on_no_slot=1` is set, this endpoint will respond with status code 503 if there is no available slots. **Response format** +Example: + ```json [ { @@ -702,7 +754,13 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte ] ``` -### GET `/metrics`: Prometheus compatible metrics exporter endpoint if `--metrics` is enabled: +Possible values for `slot[i].state` are: +- `0`: SLOT_STATE_IDLE +- `1`: SLOT_STATE_PROCESSING + +### GET `/metrics`: Prometheus compatible metrics exporter + +This endpoint is only accessible if `--metrics` is set. Available metrics: - `llamacpp:prompt_tokens_total`: Number of prompt tokens processed. @@ -767,6 +825,10 @@ Available metrics: ### GET `/lora-adapters`: Get list of all LoRA adapters +This endpoint returns the loaded LoRA adapters. You can add adapters using `--lora` when starting the server, for example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...` + +By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply` + If an adapter is disabled, the scale will be set to 0. **Response format** diff --git a/examples/server/public/index.js b/examples/server/public/index.js index 6709609399e4a..fe615ca25cd67 100644 --- a/examples/server/public/index.js +++ b/examples/server/public/index.js @@ -1 +1 @@ -const t=Symbol.for("preact-signals");function n(){if(r>1){r--;return}let t,n=!1;while(void 0!==i){let _=i;i=void 0;u++;while(void 0!==_){const i=_.o;_.o=void 0;_.f&=-3;if(!(8&_.f)&&h(_))try{_.c()}catch(e){if(!n){t=e;n=!0}}_=i}}u=0;r--;if(n)throw t}function e(t){if(r>0)return t();r++;try{return t()}finally{n()}}let _,i;function o(t){const n=_;_=void 0;try{return t()}finally{_=n}}let r=0,u=0,l=0;function s(t){if(void 0===_)return;let n=t.n;if(void 0===n||n.t!==_){n={i:0,S:t,p:_.s,n:void 0,t:_,e:void 0,x:void 0,r:n};if(void 0!==_.s)_.s.n=n;_.s=n;t.n=n;if(32&_.f)t.S(n);return n}else if(-1===n.i){n.i=0;if(void 0!==n.n){n.n.p=n.p;if(void 0!==n.p)n.p.n=n.n;n.p=_.s;n.n=void 0;_.s.n=n;_.s=n}return n}}function f(t){this.v=t;this.i=0;this.n=void 0;this.t=void 0}f.prototype.brand=t;f.prototype.h=function(){return!0};f.prototype.S=function(t){if(this.t!==t&&void 0===t.e){t.x=this.t;if(void 0!==this.t)this.t.e=t;this.t=t}};f.prototype.U=function(t){if(void 0!==this.t){const n=t.e,e=t.x;if(void 0!==n){n.x=e;t.e=void 0}if(void 0!==e){e.e=n;t.x=void 0}if(t===this.t)this.t=e}};f.prototype.subscribe=function(t){return k(()=>{const n=this.value,e=_;_=void 0;try{t(n)}finally{_=e}})};f.prototype.valueOf=function(){return this.value};f.prototype.toString=function(){return this.value+""};f.prototype.toJSON=function(){return this.value};f.prototype.peek=function(){const t=_;_=void 0;try{return this.value}finally{_=t}};Object.defineProperty(f.prototype,"value",{get(){const t=s(this);if(void 0!==t)t.i=this.i;return this.v},set(t){if(t!==this.v){if(u>100)throw new Error("Cycle detected");this.v=t;this.i++;l++;r++;try{for(let t=this.t;void 0!==t;t=t.x)t.t.N()}finally{n()}}}});function c(t){return new f(t)}function h(t){for(let n=t.s;void 0!==n;n=n.n)if(n.S.i!==n.i||!n.S.h()||n.S.i!==n.i)return!0;return!1}function a(t){for(let n=t.s;void 0!==n;n=n.n){const e=n.S.n;if(void 0!==e)n.r=e;n.S.n=n;n.i=-1;if(void 0===n.n){t.s=n;break}}}function p(t){let n,e=t.s;while(void 0!==e){const t=e.p;if(-1===e.i){e.S.U(e);if(void 0!==t)t.n=e.n;if(void 0!==e.n)e.n.p=t}else n=e;e.S.n=e.r;if(void 0!==e.r)e.r=void 0;e=t}t.s=n}function d(t){f.call(this,void 0);this.x=t;this.s=void 0;this.g=l-1;this.f=4}(d.prototype=new f).h=function(){this.f&=-3;if(1&this.f)return!1;if(32==(36&this.f))return!0;this.f&=-5;if(this.g===l)return!0;this.g=l;this.f|=1;if(this.i>0&&!h(this)){this.f&=-2;return!0}const t=_;try{a(this);_=this;const t=this.x();if(16&this.f||this.v!==t||0===this.i){this.v=t;this.f&=-17;this.i++}}catch(t){this.v=t;this.f|=16;this.i++}_=t;p(this);this.f&=-2;return!0};d.prototype.S=function(t){if(void 0===this.t){this.f|=36;for(let t=this.s;void 0!==t;t=t.n)t.S.S(t)}f.prototype.S.call(this,t)};d.prototype.U=function(t){if(void 0!==this.t){f.prototype.U.call(this,t);if(void 0===this.t){this.f&=-33;for(let t=this.s;void 0!==t;t=t.n)t.S.U(t)}}};d.prototype.N=function(){if(!(2&this.f)){this.f|=6;for(let t=this.t;void 0!==t;t=t.x)t.t.N()}};Object.defineProperty(d.prototype,"value",{get(){if(1&this.f)throw new Error("Cycle detected");const t=s(this);this.h();if(void 0!==t)t.i=this.i;if(16&this.f)throw this.v;return this.v}});function v(t){return new d(t)}function y(t){const e=t.u;t.u=void 0;if("function"==typeof e){r++;const i=_;_=void 0;try{e()}catch(n){t.f&=-2;t.f|=8;m(t);throw n}finally{_=i;n()}}}function m(t){for(let n=t.s;void 0!==n;n=n.n)n.S.U(n);t.x=void 0;t.s=void 0;y(t)}function g(t){if(_!==this)throw new Error("Out-of-order effect");p(this);_=t;this.f&=-2;if(8&this.f)m(this);n()}function b(t){this.x=t;this.u=void 0;this.s=void 0;this.o=void 0;this.f=32}b.prototype.c=function(){const t=this.S();try{if(8&this.f)return;if(void 0===this.x)return;const n=this.x();if("function"==typeof n)this.u=n}finally{t()}};b.prototype.S=function(){if(1&this.f)throw new Error("Cycle detected");this.f|=1;this.f&=-9;y(this);a(this);r++;const t=_;_=this;return g.bind(this,t)};b.prototype.N=function(){if(!(2&this.f)){this.f|=2;this.o=i;i=this}};b.prototype.d=function(){this.f|=8;if(!(1&this.f))m(this)};function k(t){const n=new b(t);try{n.c()}catch(t){n.d();throw t}return n.d.bind(n)}var w,S,x,C,U,E,H,P,N,$,D,T,M={},F=[],A=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i,V=Array.isArray;function W(t,n){for(var e in n)t[e]=n[e];return t}function L(t){var n=t.parentNode;n&&n.removeChild(t)}function O(t,n,e){var _,i,o,r={};for(o in n)"key"==o?_=n[o]:"ref"==o?i=n[o]:r[o]=n[o];if(arguments.length>2&&(r.children=arguments.length>3?w.call(arguments,2):e),"function"==typeof t&&null!=t.defaultProps)for(o in t.defaultProps)void 0===r[o]&&(r[o]=t.defaultProps[o]);return R(t,r,_,i,null)}function R(t,n,e,_,i){var o={type:t,props:n,key:e,ref:_,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,constructor:void 0,__v:null==i?++x:i,__i:-1,__u:0};return null==i&&null!=S.vnode&&S.vnode(o),o}function I(){return{current:null}}function j(t){return t.children}function q(t,n){this.props=t,this.context=n}function B(t,n){if(null==n)return t.__?B(t.__,t.__i+1):null;for(var e;nn&&U.sort(P));J.__r=0}function K(t,n,e,_,i,o,r,u,l,s,f){var c,h,a,p,d,v=_&&_.__k||F,y=n.length;for(e.__d=l,Q(e,n,v),l=e.__d,c=0;c0?R(i.type,i.props,i.key,i.ref?i.ref:null,i.__v):i)?(i.__=t,i.__b=t.__b+1,u=Z(i,e,r,f),i.__i=u,o=null,-1!==u&&(f--,(o=e[u])&&(o.__u|=131072)),null==o||null===o.__v?(-1==u&&c--,"function"!=typeof i.type&&(i.__u|=65536)):u!==r&&(u===r+1?c++:u>r?f>l-r?c+=u-r:c--:u(null!=l&&0==(131072&l.__u)?1:0))for(;r>=0||u=0){if((l=n[r])&&0==(131072&l.__u)&&i==l.key&&o===l.type)return r;r--}if(u2&&(u.children=arguments.length>3?w.call(arguments,2):e),R(t.type,u,_||t.key,i||t.ref,null)}function ht(t,n){var e={__c:n="__cC"+T++,__:t,Consumer:function(t,n){return t.children(n)},Provider:function(t){var e,_;return this.getChildContext||(e=[],(_={})[n]=this,this.getChildContext=function(){return _},this.shouldComponentUpdate=function(t){this.props.value!==t.value&&e.some((function(t){t.__e=!0,G(t)}))},this.sub=function(t){e.push(t);var n=t.componentWillUnmount;t.componentWillUnmount=function(){e.splice(e.indexOf(t),1),n&&n.call(t)}}),t.children}};return e.Provider.__=e.Consumer.contextType=e}w=F.slice,S={__e:function(t,n,e,_){for(var i,o,r;n=n.__;)if((i=n.__c)&&!i.__)try{if((o=i.constructor)&&null!=o.getDerivedStateFromError&&(i.setState(o.getDerivedStateFromError(t)),r=i.__d),null!=i.componentDidCatch&&(i.componentDidCatch(t,_||{}),r=i.__d),r)return i.__E=i}catch(n){t=n}throw t}},x=0,C=function(t){return null!=t&&null==t.constructor},q.prototype.setState=function(t,n){var e;e=null!=this.__s&&this.__s!==this.state?this.__s:this.__s=W({},this.state),"function"==typeof t&&(t=t(W({},e),this.props)),t&&W(e,t),null!=t&&this.__v&&(n&&this._sb.push(n),G(this))},q.prototype.forceUpdate=function(t){this.__v&&(this.__e=!0,t&&this.__h.push(t),G(this))},q.prototype.render=j,U=[],H="function"==typeof Promise?Promise.prototype.then.bind(Promise.resolve()):setTimeout,P=function(t,n){return t.__v.__b-n.__v.__b},J.__r=0,N=0,$=et(!1),D=et(!0),T=0;var at,pt,dt,vt,yt=0,mt=[],gt=[],bt=S,kt=bt.__b,wt=bt.__r,St=bt.diffed,xt=bt.__c,Ct=bt.unmount,Ut=bt.__;function Et(t,n){bt.__h&&bt.__h(pt,t,yt||n),yt=0;var e=pt.__H||(pt.__H={__:[],__h:[]});return t>=e.__.length&&e.__.push({__V:gt}),e.__[t]}function Ht(t){return yt=1,Pt(zt,t)}function Pt(t,n,e){var _=Et(at++,2);if(_.t=t,!_.__c&&(_.__=[e?e(n):zt(void 0,n),function(t){var n=_.__N?_.__N[0]:_.__[0],e=_.t(n,t);n!==e&&(_.__N=[e,_.__[1]],_.__c.setState({}))}],_.__c=pt,!pt.u)){var i=function(t,n,e){if(!_.__c.__H)return!0;var i=_.__c.__H.__.filter((function(t){return!!t.__c}));if(i.every((function(t){return!t.__N})))return!o||o.call(this,t,n,e);var r=!1;return i.forEach((function(t){if(t.__N){var n=t.__[0];t.__=t.__N,t.__N=void 0,n!==t.__[0]&&(r=!0)}})),!(!r&&_.__c.props===t)&&(!o||o.call(this,t,n,e))};pt.u=!0;var o=pt.shouldComponentUpdate,r=pt.componentWillUpdate;pt.componentWillUpdate=function(t,n,e){if(this.__e){var _=o;o=void 0,i(t,n,e),o=_}r&&r.call(this,t,n,e)},pt.shouldComponentUpdate=i}return _.__N||_.__}function Nt(t,n){var e=Et(at++,3);!bt.__s&&Bt(e.__H,n)&&(e.__=t,e.i=n,pt.__H.__h.push(e))}function $t(t,n){var e=Et(at++,4);!bt.__s&&Bt(e.__H,n)&&(e.__=t,e.i=n,pt.__h.push(e))}function Dt(t){return yt=5,Mt((function(){return{current:t}}),[])}function Tt(t,n,e){yt=6,$t((function(){return"function"==typeof t?(t(n()),function(){return t(null)}):t?(t.current=n(),function(){return t.current=null}):void 0}),null==e?e:e.concat(t))}function Mt(t,n){var e=Et(at++,7);return Bt(e.__H,n)?(e.__V=t(),e.i=n,e.__h=t,e.__V):e.__}function Ft(t,n){return yt=8,Mt((function(){return t}),n)}function At(t){var n=pt.context[t.__c],e=Et(at++,9);return e.c=t,n?(null==e.__&&(e.__=!0,n.sub(pt)),n.props.value):t.__}function Vt(t,n){bt.useDebugValue&&bt.useDebugValue(n?n(t):t)}function Wt(t){var n=Et(at++,10),e=Ht();return n.__=t,pt.componentDidCatch||(pt.componentDidCatch=function(t,_){n.__&&n.__(t,_),e[1](t)}),[e[0],function(){e[1](void 0)}]}function Lt(){var t=Et(at++,11);if(!t.__){for(var n=pt.__v;null!==n&&!n.__m&&null!==n.__;)n=n.__;var e=n.__m||(n.__m=[0,0]);t.__="P"+e[0]+"-"+e[1]++}return t.__}function Ot(){for(var t;t=mt.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(jt),t.__H.__h.forEach(qt),t.__H.__h=[]}catch(n){t.__H.__h=[],bt.__e(n,t.__v)}}bt.__b=function(t){pt=null,kt&&kt(t)},bt.__=function(t,n){t&&n.__k&&n.__k.__m&&(t.__m=n.__k.__m),Ut&&Ut(t,n)},bt.__r=function(t){wt&&wt(t),at=0;var n=(pt=t.__c).__H;n&&(dt===pt?(n.__h=[],pt.__h=[],n.__.forEach((function(t){t.__N&&(t.__=t.__N),t.__V=gt,t.__N=t.i=void 0}))):(n.__h.forEach(jt),n.__h.forEach(qt),n.__h=[],at=0)),dt=pt},bt.diffed=function(t){St&&St(t);var n=t.__c;n&&n.__H&&(n.__H.__h.length&&(1!==mt.push(n)&&vt===bt.requestAnimationFrame||((vt=bt.requestAnimationFrame)||It)(Ot)),n.__H.__.forEach((function(t){t.i&&(t.__H=t.i),t.__V!==gt&&(t.__=t.__V),t.i=void 0,t.__V=gt}))),dt=pt=null},bt.__c=function(t,n){n.some((function(t){try{t.__h.forEach(jt),t.__h=t.__h.filter((function(t){return!t.__||qt(t)}))}catch(r){n.some((function(t){t.__h&&(t.__h=[])})),n=[],bt.__e(r,t.__v)}})),xt&&xt(t,n)},bt.unmount=function(t){Ct&&Ct(t);var n,e=t.__c;e&&e.__H&&(e.__H.__.forEach((function(t){try{jt(t)}catch(t){n=t}})),e.__H=void 0,n&&bt.__e(n,e.__v))};var Rt="function"==typeof requestAnimationFrame;function It(t){var n,e=function(){clearTimeout(_),Rt&&cancelAnimationFrame(n),setTimeout(t)},_=setTimeout(e,100);Rt&&(n=requestAnimationFrame(e))}function jt(t){var n=pt,e=t.__c;"function"==typeof e&&(t.__c=void 0,e()),pt=n}function qt(t){var n=pt;t.__c=t.__(),pt=n}function Bt(t,n){return!t||t.length!==n.length||n.some((function(n,e){return n!==t[e]}))}function zt(t,n){return"function"==typeof n?n(t):n}function Gt(t,n){S[t]=n.bind(null,S[t]||(()=>{}))}let Jt,Kt;function Qt(t){if(Kt)Kt();Kt=t&&t.S()}function Xt({data:t}){const n=Zt(t);n.value=t;const e=Mt(()=>{let t=this.__v;while(t=t.__)if(t.__c){t.__c.__$f|=4;break}this.__$u.c=()=>{var t;if(!C(e.peek())&&3===(null==(t=this.base)?void 0:t.nodeType))this.base.data=e.peek();else{this.__$f|=1;this.setState({})}};return v(()=>{let t=n.value.value;return 0===t?0:!0===t?"":t||""})},[]);return e.value}Xt.displayName="_st";Object.defineProperties(f.prototype,{constructor:{configurable:!0,value:void 0},type:{configurable:!0,value:Xt},props:{configurable:!0,get(){return{data:this}}},__b:{configurable:!0,value:1}});Gt("__b",(t,n)=>{if("string"==typeof n.type){let t,e=n.props;for(let _ in e){if("children"===_)continue;let i=e[_];if(i instanceof f){if(!t)n.__np=t={};t[_]=i;e[_]=i.peek()}}}t(n)});Gt("__r",(t,n)=>{Qt();let e,_=n.__c;if(_){_.__$f&=-2;e=_.__$u;if(void 0===e)_.__$u=e=function(t){let n;k((function(){n=this}));n.c=()=>{_.__$f|=1;_.setState({})};return n}()}Jt=_;Qt(e);t(n)});Gt("__e",(t,n,e,_)=>{Qt();Jt=void 0;t(n,e,_)});Gt("diffed",(t,n)=>{Qt();Jt=void 0;let e;if("string"==typeof n.type&&(e=n.__e)){let t=n.__np,_=n.props;if(t){let n=e.U;if(n)for(let e in n){let _=n[e];if(void 0!==_&&!(e in t)){_.d();n[e]=void 0}}else{n={};e.U=n}for(let i in t){let o=n[i],r=t[i];if(void 0===o){o=Yt(e,i,r,_);n[i]=o}else o.o(r,_)}}}t(n)});function Yt(t,n,e,_){const i=n in t&&void 0===t.ownerSVGElement,o=c(e);return{o:(t,n)=>{o.value=t;_=n},d:k(()=>{const e=o.value.value;if(_[n]!==e){_[n]=e;if(i)t[n]=e;else if(e)t.setAttribute(n,e);else t.removeAttribute(n)}})}}Gt("unmount",(t,n)=>{if("string"==typeof n.type){let t=n.__e;if(t){const n=t.U;if(n){t.U=void 0;for(let t in n){let e=n[t];if(e)e.d()}}}}else{let t=n.__c;if(t){const n=t.__$u;if(n){t.__$u=void 0;n.d()}}}t(n)});Gt("__h",(t,n,e,_)=>{if(_<3||9===_)n.__$f|=2;t(n,e,_)});q.prototype.shouldComponentUpdate=function(t,n){const e=this.__$u;if(!(e&&void 0!==e.s||4&this.__$f))return!0;if(3&this.__$f)return!0;for(let _ in n)return!0;for(let _ in t)if("__source"!==_&&t[_]!==this.props[_])return!0;for(let _ in this.props)if(!(_ in t))return!0;return!1};function Zt(t){return Mt(()=>c(t),[])}function tn(t){const n=Dt(t);n.current=t;Jt.__$f|=4;return Mt(()=>v(()=>n.current()),[])}function nn(t){const n=Dt(t);n.current=t;Nt(()=>k(()=>n.current()),[])}var en=function(t,n,e,_){var i;n[0]=0;for(var o=1;o=5&&((i||!t&&5===_)&&(r.push(_,0,i,e),_=6),t&&(r.push(_,t,0,e),_=6)),i=""},l=0;l"===n?(_=1,i=""):i=n+i[0]:o?n===o?o="":i+=n:'"'===n||"'"===n?o=n:">"===n?(u(),_=1):_&&("="===n?(_=5,e=i,i=""):"/"===n&&(_<5||">"===t[l][s+1])?(u(),3===_&&(r=r[0]),_=r,(r=r[0]).push(2,0,_),_=0):" "===n||"\t"===n||"\n"===n||"\r"===n?(u(),_=2):i+=n),3===_&&"!--"===i&&(_=4,r=r[0])}return u(),r}(t)),n),arguments,[])).length>1?n:n[0]}var rn=on.bind(O);export{q as Component,j as Fragment,f as Signal,e as batch,ct as cloneElement,v as computed,ht as createContext,O as createElement,I as createRef,k as effect,O as h,rn as html,ft as hydrate,C as isValidElement,S as options,st as render,c as signal,Y as toChildArray,o as untracked,Ft as useCallback,tn as useComputed,At as useContext,Vt as useDebugValue,Nt as useEffect,Wt as useErrorBoundary,Lt as useId,Tt as useImperativeHandle,$t as useLayoutEffect,Mt as useMemo,Pt as useReducer,Dt as useRef,Zt as useSignal,nn as useSignalEffect,Ht as useState}; +const t=Symbol.for("preact-signals");function n(){if(r>1){r--;return}let t,n=!1;while(void 0!==i){let _=i;i=void 0;u++;while(void 0!==_){const i=_.o;_.o=void 0;_.f&=-3;if(!(8&_.f)&&h(_))try{_.c()}catch(e){if(!n){t=e;n=!0}}_=i}}u=0;r--;if(n)throw t}function e(t){if(r>0)return t();r++;try{return t()}finally{n()}}let _,i;function o(t){const n=_;_=void 0;try{return t()}finally{_=n}}let r=0,u=0,l=0;function f(t){if(void 0===_)return;let n=t.n;if(void 0===n||n.t!==_){n={i:0,S:t,p:_.s,n:void 0,t:_,e:void 0,x:void 0,r:n};if(void 0!==_.s)_.s.n=n;_.s=n;t.n=n;if(32&_.f)t.S(n);return n}else if(-1===n.i){n.i=0;if(void 0!==n.n){n.n.p=n.p;if(void 0!==n.p)n.p.n=n.n;n.p=_.s;n.n=void 0;_.s.n=n;_.s=n}return n}}function s(t){this.v=t;this.i=0;this.n=void 0;this.t=void 0}s.prototype.brand=t;s.prototype.h=function(){return!0};s.prototype.S=function(t){if(this.t!==t&&void 0===t.e){t.x=this.t;if(void 0!==this.t)this.t.e=t;this.t=t}};s.prototype.U=function(t){if(void 0!==this.t){const n=t.e,e=t.x;if(void 0!==n){n.x=e;t.e=void 0}if(void 0!==e){e.e=n;t.x=void 0}if(t===this.t)this.t=e}};s.prototype.subscribe=function(t){return k(()=>{const n=this.value,e=_;_=void 0;try{t(n)}finally{_=e}})};s.prototype.valueOf=function(){return this.value};s.prototype.toString=function(){return this.value+""};s.prototype.toJSON=function(){return this.value};s.prototype.peek=function(){const t=_;_=void 0;try{return this.value}finally{_=t}};Object.defineProperty(s.prototype,"value",{get(){const t=f(this);if(void 0!==t)t.i=this.i;return this.v},set(t){if(t!==this.v){if(u>100)throw new Error("Cycle detected");this.v=t;this.i++;l++;r++;try{for(let t=this.t;void 0!==t;t=t.x)t.t.N()}finally{n()}}}});function c(t){return new s(t)}function h(t){for(let n=t.s;void 0!==n;n=n.n)if(n.S.i!==n.i||!n.S.h()||n.S.i!==n.i)return!0;return!1}function a(t){for(let n=t.s;void 0!==n;n=n.n){const e=n.S.n;if(void 0!==e)n.r=e;n.S.n=n;n.i=-1;if(void 0===n.n){t.s=n;break}}}function p(t){let n,e=t.s;while(void 0!==e){const t=e.p;if(-1===e.i){e.S.U(e);if(void 0!==t)t.n=e.n;if(void 0!==e.n)e.n.p=t}else n=e;e.S.n=e.r;if(void 0!==e.r)e.r=void 0;e=t}t.s=n}function d(t){s.call(this,void 0);this.x=t;this.s=void 0;this.g=l-1;this.f=4}(d.prototype=new s).h=function(){this.f&=-3;if(1&this.f)return!1;if(32==(36&this.f))return!0;this.f&=-5;if(this.g===l)return!0;this.g=l;this.f|=1;if(this.i>0&&!h(this)){this.f&=-2;return!0}const t=_;try{a(this);_=this;const t=this.x();if(16&this.f||this.v!==t||0===this.i){this.v=t;this.f&=-17;this.i++}}catch(t){this.v=t;this.f|=16;this.i++}_=t;p(this);this.f&=-2;return!0};d.prototype.S=function(t){if(void 0===this.t){this.f|=36;for(let t=this.s;void 0!==t;t=t.n)t.S.S(t)}s.prototype.S.call(this,t)};d.prototype.U=function(t){if(void 0!==this.t){s.prototype.U.call(this,t);if(void 0===this.t){this.f&=-33;for(let t=this.s;void 0!==t;t=t.n)t.S.U(t)}}};d.prototype.N=function(){if(!(2&this.f)){this.f|=6;for(let t=this.t;void 0!==t;t=t.x)t.t.N()}};Object.defineProperty(d.prototype,"value",{get(){if(1&this.f)throw new Error("Cycle detected");const t=f(this);this.h();if(void 0!==t)t.i=this.i;if(16&this.f)throw this.v;return this.v}});function v(t){return new d(t)}function y(t){const e=t.u;t.u=void 0;if("function"==typeof e){r++;const i=_;_=void 0;try{e()}catch(n){t.f&=-2;t.f|=8;m(t);throw n}finally{_=i;n()}}}function m(t){for(let n=t.s;void 0!==n;n=n.n)n.S.U(n);t.x=void 0;t.s=void 0;y(t)}function g(t){if(_!==this)throw new Error("Out-of-order effect");p(this);_=t;this.f&=-2;if(8&this.f)m(this);n()}function b(t){this.x=t;this.u=void 0;this.s=void 0;this.o=void 0;this.f=32}b.prototype.c=function(){const t=this.S();try{if(8&this.f)return;if(void 0===this.x)return;const n=this.x();if("function"==typeof n)this.u=n}finally{t()}};b.prototype.S=function(){if(1&this.f)throw new Error("Cycle detected");this.f|=1;this.f&=-9;y(this);a(this);r++;const t=_;_=this;return g.bind(this,t)};b.prototype.N=function(){if(!(2&this.f)){this.f|=2;this.o=i;i=this}};b.prototype.d=function(){this.f|=8;if(!(1&this.f))m(this)};function k(t){const n=new b(t);try{n.c()}catch(t){n.d();throw t}return n.d.bind(n)}var w,S,x,C,U,E,H,P,N,$,T,D,M={},F=[],A=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i,W=Array.isArray;function L(t,n){for(var e in n)t[e]=n[e];return t}function O(t){var n=t.parentNode;n&&n.removeChild(t)}function R(t,n,e){var _,i,o,r={};for(o in n)"key"==o?_=n[o]:"ref"==o?i=n[o]:r[o]=n[o];if(arguments.length>2&&(r.children=arguments.length>3?w.call(arguments,2):e),"function"==typeof t&&null!=t.defaultProps)for(o in t.defaultProps)void 0===r[o]&&(r[o]=t.defaultProps[o]);return I(t,r,_,i,null)}function I(t,n,e,_,i){var o={type:t,props:n,key:e,ref:_,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,constructor:void 0,__v:null==i?++x:i,__i:-1,__u:0};return null==i&&null!=S.vnode&&S.vnode(o),o}function V(){return{current:null}}function j(t){return t.children}function q(t,n){this.props=t,this.context=n}function B(t,n){if(null==n)return t.__?B(t.__,t.__i+1):null;for(var e;nn&&U.sort(P));J.__r=0}function K(t,n,e,_,i,o,r,u,l,f,s){var c,h,a,p,d,v=_&&_.__k||F,y=n.length;for(e.__d=l,Q(e,n,v),l=e.__d,c=0;c0?I(i.type,i.props,i.key,i.ref?i.ref:null,i.__v):i)?(i.__=t,i.__b=t.__b+1,u=Z(i,e,r,s),i.__i=u,o=null,-1!==u&&(s--,(o=e[u])&&(o.__u|=131072)),null==o||null===o.__v?(-1==u&&c--,"function"!=typeof i.type&&(i.__u|=65536)):u!==r&&(u==r-1?c--:u==r+1?c++:u>r?s>l-r?c+=u-r:c--:u(null!=l&&0==(131072&l.__u)?1:0))for(;r>=0||u=0){if((l=n[r])&&0==(131072&l.__u)&&i==l.key&&o===l.type)return r;r--}if(u2&&(u.children=arguments.length>3?w.call(arguments,2):e),I(t.type,u,_||t.key,i||t.ref,null)}function ht(t,n){var e={__c:n="__cC"+D++,__:t,Consumer:function(t,n){return t.children(n)},Provider:function(t){var e,_;return this.getChildContext||(e=[],(_={})[n]=this,this.getChildContext=function(){return _},this.componentWillUnmount=function(){e=null},this.shouldComponentUpdate=function(t){this.props.value!==t.value&&e.some((function(t){t.__e=!0,G(t)}))},this.sub=function(t){e.push(t);var n=t.componentWillUnmount;t.componentWillUnmount=function(){e&&e.splice(e.indexOf(t),1),n&&n.call(t)}}),t.children}};return e.Provider.__=e.Consumer.contextType=e}w=F.slice,S={__e:function(t,n,e,_){for(var i,o,r;n=n.__;)if((i=n.__c)&&!i.__)try{if((o=i.constructor)&&null!=o.getDerivedStateFromError&&(i.setState(o.getDerivedStateFromError(t)),r=i.__d),null!=i.componentDidCatch&&(i.componentDidCatch(t,_||{}),r=i.__d),r)return i.__E=i}catch(n){t=n}throw t}},x=0,C=function(t){return null!=t&&null==t.constructor},q.prototype.setState=function(t,n){var e;e=null!=this.__s&&this.__s!==this.state?this.__s:this.__s=L({},this.state),"function"==typeof t&&(t=t(L({},e),this.props)),t&&L(e,t),null!=t&&this.__v&&(n&&this._sb.push(n),G(this))},q.prototype.forceUpdate=function(t){this.__v&&(this.__e=!0,t&&this.__h.push(t),G(this))},q.prototype.render=j,U=[],H="function"==typeof Promise?Promise.prototype.then.bind(Promise.resolve()):setTimeout,P=function(t,n){return t.__v.__b-n.__v.__b},J.__r=0,N=0,$=et(!1),T=et(!0),D=0;var at,pt,dt,vt,yt=0,mt=[],gt=S,bt=gt.__b,kt=gt.__r,wt=gt.diffed,St=gt.__c,xt=gt.unmount,Ct=gt.__;function Ut(t,n){gt.__h&>.__h(pt,t,yt||n),yt=0;var e=pt.__H||(pt.__H={__:[],__h:[]});return t>=e.__.length&&e.__.push({}),e.__[t]}function Et(t){return yt=1,Ht(Bt,t)}function Ht(t,n,e){var _=Ut(at++,2);if(_.t=t,!_.__c&&(_.__=[e?e(n):Bt(void 0,n),function(t){var n=_.__N?_.__N[0]:_.__[0],e=_.t(n,t);n!==e&&(_.__N=[e,_.__[1]],_.__c.setState({}))}],_.__c=pt,!pt.u)){var i=function(t,n,e){if(!_.__c.__H)return!0;var i=_.__c.__H.__.filter((function(t){return!!t.__c}));if(i.every((function(t){return!t.__N})))return!o||o.call(this,t,n,e);var r=!1;return i.forEach((function(t){if(t.__N){var n=t.__[0];t.__=t.__N,t.__N=void 0,n!==t.__[0]&&(r=!0)}})),!(!r&&_.__c.props===t)&&(!o||o.call(this,t,n,e))};pt.u=!0;var o=pt.shouldComponentUpdate,r=pt.componentWillUpdate;pt.componentWillUpdate=function(t,n,e){if(this.__e){var _=o;o=void 0,i(t,n,e),o=_}r&&r.call(this,t,n,e)},pt.shouldComponentUpdate=i}return _.__N||_.__}function Pt(t,n){var e=Ut(at++,3);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__H.__h.push(e))}function Nt(t,n){var e=Ut(at++,4);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__h.push(e))}function $t(t){return yt=5,Dt((function(){return{current:t}}),[])}function Tt(t,n,e){yt=6,Nt((function(){return"function"==typeof t?(t(n()),function(){return t(null)}):t?(t.current=n(),function(){return t.current=null}):void 0}),null==e?e:e.concat(t))}function Dt(t,n){var e=Ut(at++,7);return qt(e.__H,n)&&(e.__=t(),e.__H=n,e.__h=t),e.__}function Mt(t,n){return yt=8,Dt((function(){return t}),n)}function Ft(t){var n=pt.context[t.__c],e=Ut(at++,9);return e.c=t,n?(null==e.__&&(e.__=!0,n.sub(pt)),n.props.value):t.__}function At(t,n){gt.useDebugValue&>.useDebugValue(n?n(t):t)}function Wt(t){var n=Ut(at++,10),e=Et();return n.__=t,pt.componentDidCatch||(pt.componentDidCatch=function(t,_){n.__&&n.__(t,_),e[1](t)}),[e[0],function(){e[1](void 0)}]}function Lt(){var t=Ut(at++,11);if(!t.__){for(var n=pt.__v;null!==n&&!n.__m&&null!==n.__;)n=n.__;var e=n.__m||(n.__m=[0,0]);t.__="P"+e[0]+"-"+e[1]++}return t.__}function Ot(){for(var t;t=mt.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(Vt),t.__H.__h.forEach(jt),t.__H.__h=[]}catch(n){t.__H.__h=[],gt.__e(n,t.__v)}}gt.__b=function(t){pt=null,bt&&bt(t)},gt.__=function(t,n){t&&n.__k&&n.__k.__m&&(t.__m=n.__k.__m),Ct&&Ct(t,n)},gt.__r=function(t){kt&&kt(t),at=0;var n=(pt=t.__c).__H;n&&(dt===pt?(n.__h=[],pt.__h=[],n.__.forEach((function(t){t.__N&&(t.__=t.__N),t.i=t.__N=void 0}))):(n.__h.forEach(Vt),n.__h.forEach(jt),n.__h=[],at=0)),dt=pt},gt.diffed=function(t){wt&&wt(t);var n=t.__c;n&&n.__H&&(n.__H.__h.length&&(1!==mt.push(n)&&vt===gt.requestAnimationFrame||((vt=gt.requestAnimationFrame)||It)(Ot)),n.__H.__.forEach((function(t){t.i&&(t.__H=t.i),t.i=void 0}))),dt=pt=null},gt.__c=function(t,n){n.some((function(t){try{t.__h.forEach(Vt),t.__h=t.__h.filter((function(t){return!t.__||jt(t)}))}catch(r){n.some((function(t){t.__h&&(t.__h=[])})),n=[],gt.__e(r,t.__v)}})),St&&St(t,n)},gt.unmount=function(t){xt&&xt(t);var n,e=t.__c;e&&e.__H&&(e.__H.__.forEach((function(t){try{Vt(t)}catch(t){n=t}})),e.__H=void 0,n&>.__e(n,e.__v))};var Rt="function"==typeof requestAnimationFrame;function It(t){var n,e=function(){clearTimeout(_),Rt&&cancelAnimationFrame(n),setTimeout(t)},_=setTimeout(e,100);Rt&&(n=requestAnimationFrame(e))}function Vt(t){var n=pt,e=t.__c;"function"==typeof e&&(t.__c=void 0,e()),pt=n}function jt(t){var n=pt;t.__c=t.__(),pt=n}function qt(t,n){return!t||t.length!==n.length||n.some((function(n,e){return n!==t[e]}))}function Bt(t,n){return"function"==typeof n?n(t):n}function zt(t,n){S[t]=n.bind(null,S[t]||(()=>{}))}let Gt,Jt;function Kt(t){if(Jt)Jt();Jt=t&&t.S()}function Qt({data:t}){const n=Yt(t);n.value=t;const e=Dt(()=>{let t=this.__v;while(t=t.__)if(t.__c){t.__c.__$f|=4;break}this.__$u.c=()=>{var t;if(!C(e.peek())&&3===(null==(t=this.base)?void 0:t.nodeType))this.base.data=e.peek();else{this.__$f|=1;this.setState({})}};return v(()=>{let t=n.value.value;return 0===t?0:!0===t?"":t||""})},[]);return e.value}Qt.displayName="_st";Object.defineProperties(s.prototype,{constructor:{configurable:!0,value:void 0},type:{configurable:!0,value:Qt},props:{configurable:!0,get(){return{data:this}}},__b:{configurable:!0,value:1}});zt("__b",(t,n)=>{if("string"==typeof n.type){let t,e=n.props;for(let _ in e){if("children"===_)continue;let i=e[_];if(i instanceof s){if(!t)n.__np=t={};t[_]=i;e[_]=i.peek()}}}t(n)});zt("__r",(t,n)=>{Kt();let e,_=n.__c;if(_){_.__$f&=-2;e=_.__$u;if(void 0===e)_.__$u=e=function(t){let n;k((function(){n=this}));n.c=()=>{_.__$f|=1;_.setState({})};return n}()}Gt=_;Kt(e);t(n)});zt("__e",(t,n,e,_)=>{Kt();Gt=void 0;t(n,e,_)});zt("diffed",(t,n)=>{Kt();Gt=void 0;let e;if("string"==typeof n.type&&(e=n.__e)){let t=n.__np,_=n.props;if(t){let n=e.U;if(n)for(let e in n){let _=n[e];if(void 0!==_&&!(e in t)){_.d();n[e]=void 0}}else{n={};e.U=n}for(let i in t){let o=n[i],r=t[i];if(void 0===o){o=Xt(e,i,r,_);n[i]=o}else o.o(r,_)}}}t(n)});function Xt(t,n,e,_){const i=n in t&&void 0===t.ownerSVGElement,o=c(e);return{o:(t,n)=>{o.value=t;_=n},d:k(()=>{const e=o.value.value;if(_[n]!==e){_[n]=e;if(i)t[n]=e;else if(e)t.setAttribute(n,e);else t.removeAttribute(n)}})}}zt("unmount",(t,n)=>{if("string"==typeof n.type){let t=n.__e;if(t){const n=t.U;if(n){t.U=void 0;for(let t in n){let e=n[t];if(e)e.d()}}}}else{let t=n.__c;if(t){const n=t.__$u;if(n){t.__$u=void 0;n.d()}}}t(n)});zt("__h",(t,n,e,_)=>{if(_<3||9===_)n.__$f|=2;t(n,e,_)});q.prototype.shouldComponentUpdate=function(t,n){const e=this.__$u;if(!(e&&void 0!==e.s||4&this.__$f))return!0;if(3&this.__$f)return!0;for(let _ in n)return!0;for(let _ in t)if("__source"!==_&&t[_]!==this.props[_])return!0;for(let _ in this.props)if(!(_ in t))return!0;return!1};function Yt(t){return Dt(()=>c(t),[])}function Zt(t){const n=$t(t);n.current=t;Gt.__$f|=4;return Dt(()=>v(()=>n.current()),[])}function tn(t){const n=$t(t);n.current=t;Pt(()=>k(()=>n.current()),[])}var nn=function(t,n,e,_){var i;n[0]=0;for(var o=1;o=5&&((i||!t&&5===_)&&(r.push(_,0,i,e),_=6),t&&(r.push(_,t,0,e),_=6)),i=""},l=0;l"===n?(_=1,i=""):i=n+i[0]:o?n===o?o="":i+=n:'"'===n||"'"===n?o=n:">"===n?(u(),_=1):_&&("="===n?(_=5,e=i,i=""):"/"===n&&(_<5||">"===t[l][f+1])?(u(),3===_&&(r=r[0]),_=r,(r=r[0]).push(2,0,_),_=0):" "===n||"\t"===n||"\n"===n||"\r"===n?(u(),_=2):i+=n),3===_&&"!--"===i&&(_=4,r=r[0])}return u(),r}(t)),n),arguments,[])).length>1?n:n[0]}var on=_n.bind(R);export{q as Component,j as Fragment,s as Signal,e as batch,ct as cloneElement,v as computed,ht as createContext,R as createElement,V as createRef,k as effect,R as h,on as html,st as hydrate,C as isValidElement,S as options,ft as render,c as signal,Y as toChildArray,o as untracked,Mt as useCallback,Zt as useComputed,Ft as useContext,At as useDebugValue,Pt as useEffect,Wt as useErrorBoundary,Lt as useId,Tt as useImperativeHandle,Nt as useLayoutEffect,Dt as useMemo,Ht as useReducer,$t as useRef,Yt as useSignal,tn as useSignalEffect,Et as useState}; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 360f571e42867..e79e7aa2cb846 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -15,6 +15,8 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" +// mime type for sending response +#define MIMETYPE_JSON "application/json; charset=utf-8" // auto generated files (update with ./deps.sh) #include "colorthemes.css.hpp" @@ -67,7 +69,6 @@ enum slot_command { enum server_state { SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed }; enum server_task_type { @@ -631,6 +632,7 @@ struct server_context { bool clean_kv_cache = true; bool add_bos_token = true; + bool has_eos_token = false; int32_t n_ctx; // total context for all clients / slots @@ -692,8 +694,8 @@ struct server_context { n_ctx = llama_n_ctx(ctx); - add_bos_token = llama_should_add_bos_token(model); - GGML_ASSERT(llama_add_eos_token(model) != 1); + add_bos_token = llama_add_bos_token(model); + has_eos_token = !llama_add_eos_token(model); return true; } @@ -753,13 +755,13 @@ struct server_context { default_generation_settings_for_props = get_formated_generation(slots.front()); default_generation_settings_for_props["seed"] = -1; - // the update_slots() logic will always submit a maximum of n_batch tokens + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = llama_batch_init(n_batch, 0, 1); + batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1); } metrics.init(); @@ -1031,7 +1033,7 @@ struct server_context { { slot.sparams.logit_bias.clear(); - if (json_value(data, "ignore_eos", false)) { + if (json_value(data, "ignore_eos", false) && has_eos_token) { slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; } @@ -1136,28 +1138,19 @@ struct server_context { if (!system_prompt.empty()) { system_tokens = ::llama_tokenize(ctx, system_prompt, true); - llama_batch_clear(batch); + const int32_t n_batch = llama_n_batch(ctx); + const int32_t n_tokens_prompt = system_tokens.size(); - for (int i = 0; i < (int)system_tokens.size(); ++i) { - llama_batch_add(batch, system_tokens[i], i, { 0 }, false); - } + for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i); - const int32_t n_batch = llama_n_batch(ctx); + llama_batch_clear(batch); - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, 0, 0, // unused - }; + for (int32_t j = 0; j < n_tokens; ++j) { + llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false); + } - if (llama_decode(ctx, batch_view) != 0) { + if (llama_decode(ctx, batch) != 0) { LOG_ERROR("llama_decode() failed", {}); return; } @@ -1330,7 +1323,7 @@ struct server_context { return json { {"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, + {"n_predict", slot.n_predict}, // Server configured n_predict {"model", params.model_alias}, {"seed", slot.sparams.seed}, {"temperature", slot.sparams.temp}, @@ -1352,7 +1345,7 @@ struct server_context { {"mirostat_eta", slot.sparams.mirostat_eta}, {"penalize_nl", slot.sparams.penalize_nl}, {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict + {"max_tokens", slot.params.n_predict}, // User configured n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, {"ignore_eos", ignore_eos}, @@ -1860,6 +1853,8 @@ struct server_context { llama_lora_adapters_apply(ctx, lora_adapters); server_task_result result; result.id = task.id; + result.stop = true; + result.error = false; result.data = json{{ "success", true }}; queue_results.send(result); } break; @@ -2044,7 +2039,7 @@ struct server_context { slot.t_start_generation = 0; if (slot.infill) { - const bool add_bos = llama_should_add_bos_token(model); + const bool add_bos = llama_add_bos_token(model); bool suff_rm_leading_spc = true; if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { params.input_suffix.erase(0, 1); @@ -2512,6 +2507,9 @@ int main(int argc, char ** argv) { return 1; } + // parse arguments from environment variables + gpt_params_parse_from_env(params); + // TODO: not great to use extern vars server_log_json = params.log_json; server_verbose = params.verbosity > 0; @@ -2562,19 +2560,19 @@ int main(int argc, char ** argv) { svr->set_default_headers({{"Server", "llama.cpp"}}); // CORS preflight - svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) { + // Access-Control-Allow-Origin is already set by middleware res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Headers", "*"); - return res.set_content("", "application/json; charset=utf-8"); + return res.set_content("", "text/html"); // blank response, no data }); svr->set_logger(log_server_request); auto res_error = [](httplib::Response & res, json error_data) { json final_response {{"error", error_data}}; - res.set_content(final_response.dump(), "application/json; charset=utf-8"); + res.set_content(final_response.dump(), MIMETYPE_JSON); res.status = json_value(error_data, "code", 500); }; @@ -2604,11 +2602,6 @@ int main(int argc, char ** argv) { svr->set_read_timeout (params.timeout_read); svr->set_write_timeout(params.timeout_write); - if (!svr->bind_to_port(params.hostname, params.port)) { - fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port); - return 1; - } - std::unordered_map log_data; log_data["hostname"] = params.hostname; @@ -2624,35 +2617,6 @@ int main(int argc, char ** argv) { // Necessary similarity of prompt for slot selection ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - // load the model - if (!ctx_server.load_model(params)) { - state.store(SERVER_STATE_ERROR); - return 1; - } else { - ctx_server.init(); - state.store(SERVER_STATE_READY); - } - - LOG_INFO("model loaded", {}); - - const auto model_meta = ctx_server.model_meta(); - - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (params.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { - LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); - params.chat_template = "chatml"; - } - } - - // print sample chat example to make it clear which template is used - { - LOG_INFO("chat template", { - {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, - {"built_in", params.chat_template.empty()}, - }); - } - // // Middlewares // @@ -2696,8 +2660,6 @@ int main(int argc, char ** argv) { } // API key is invalid or not provided - // TODO: make another middleware for CORS related logic - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WARNING("Unauthorized: Invalid API Key", {}); @@ -2705,8 +2667,21 @@ int main(int argc, char ** argv) { return false; }; + auto middleware_server_state = [&res_error, &state](const httplib::Request &, httplib::Response & res) { + server_state current_state = state.load(); + if (current_state == SERVER_STATE_LOADING_MODEL) { + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + return false; + } + return true; + }; + // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { + svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } if (!middleware_validate_api_key(req, res)) { return httplib::Server::HandlerResponse::Handled; } @@ -2717,62 +2692,15 @@ int main(int argc, char ** argv) { // Route handlers (or controllers) // - const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - switch (current_state) { - case SERVER_STATE_READY: - { - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - task.id_target = -1; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - const int n_idle_slots = result.data.at("idle"); - const int n_processing_slots = result.data.at("processing"); - - json health = { - {"status", "ok"}, - {"slots_idle", n_idle_slots}, - {"slots_processing", n_processing_slots} - }; - - res.status = 200; // HTTP OK - if (params.endpoint_slots && req.has_param("include_slots")) { - health["slots"] = result.data.at("slots"); - } - - if (n_idle_slots == 0) { - health["status"] = "no slot available"; - if (req.has_param("fail_on_no_slot")) { - res.status = 503; // HTTP Service Unavailable - } - } - - res.set_content(health.dump(), "application/json"); - break; - } - case SERVER_STATE_LOADING_MODEL: - { - res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } break; - case SERVER_STATE_ERROR: - { - res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); - } break; - } + const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { + // error and loading states are handled by middleware + json health = {{"status", "ok"}}; + res.set_content(health.dump(), "application/json"); }; - const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { + const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { if (!params.endpoint_slots) { - res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2790,13 +2718,22 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); - res.set_content(result.data.at("slots").dump(), "application/json"); + // optionally return "fail_on_no_slot" error + const int n_idle_slots = result.data.at("idle"); + if (req.has_param("fail_on_no_slot")) { + if (n_idle_slots == 0) { + res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return; + } + } + + res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON); res.status = 200; // HTTP OK }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { if (!params.endpoint_metrics) { - res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2921,7 +2858,7 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); } }; @@ -2951,7 +2888,7 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); } }; @@ -2971,13 +2908,11 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); } }; const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - std::string id_slot_str = req.path_params.at("id_slot"); int id_slot; @@ -3001,7 +2936,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { std::string template_key = "tokenizer.chat_template", curr_tmpl; int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); if (tlen > 0) { @@ -3010,7 +2945,6 @@ int main(int argc, char ** argv) { curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); } } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "system_prompt", ctx_server.system_prompt.c_str() }, { "default_generation_settings", ctx_server.default_generation_settings_for_props }, @@ -3018,7 +2952,7 @@ int main(int argc, char ** argv) { { "chat_template", curr_tmpl.c_str() } }; - res.set_content(data.dump(), "application/json; charset=utf-8"); + res.set_content(data.dump(), MIMETYPE_JSON); }; const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { @@ -3027,8 +2961,6 @@ int main(int argc, char ** argv) { return; } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - json data = json::parse(req.body); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3039,7 +2971,7 @@ int main(int argc, char ** argv) { if (!json_value(data, "stream", false)) { server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); } else { res_error(res, result.data); } @@ -3102,9 +3034,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - + const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { json models = { {"object", "list"}, {"data", { @@ -3113,12 +3043,12 @@ int main(int argc, char ** argv) { {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, - {"meta", model_meta} + {"meta", ctx_server.model_meta()} }, }} }; - res.set_content(models.dump(), "application/json; charset=utf-8"); + res.set_content(models.dump(), MIMETYPE_JSON); }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { @@ -3126,8 +3056,6 @@ int main(int argc, char ** argv) { res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3142,7 +3070,7 @@ int main(int argc, char ** argv) { if (!result.error && result.stop) { json result_oai = format_final_response_oaicompat(data, result.data, completion_id); - res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); } else { res_error(res, result.data); } @@ -3204,8 +3132,6 @@ int main(int argc, char ** argv) { return; } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - json data = json::parse(req.body); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3216,7 +3142,7 @@ int main(int argc, char ** argv) { if (!json_value(data, "stream", false)) { server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); } else { res_error(res, result.data); } @@ -3264,7 +3190,6 @@ int main(int argc, char ** argv) { }; const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; @@ -3273,11 +3198,10 @@ int main(int argc, char ** argv) { tokens = ctx_server.tokenize(body.at("content"), add_special); } const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); + return res.set_content(data.dump(), MIMETYPE_JSON); }; const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; @@ -3287,12 +3211,10 @@ int main(int argc, char ** argv) { } const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); + return res.set_content(data.dump(), MIMETYPE_JSON); }; const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); bool is_openai = false; @@ -3338,11 +3260,10 @@ int main(int argc, char ** argv) { json root = is_openai ? format_embeddings_response_oaicompat(body, responses) : responses[0]; - return res.set_content(root.dump(), "application/json; charset=utf-8"); + return res.set_content(root.dump(), MIMETYPE_JSON); }; - const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { auto & la = ctx_server.lora_adapters[i]; @@ -3352,13 +3273,11 @@ int main(int argc, char ** argv) { {"scale", la.scale}, }); } - res.set_content(result.dump(), "application/json"); + res.set_content(result.dump(), MIMETYPE_JSON); res.status = 200; // HTTP OK }; const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const std::vector body = json::parse(req.body); int max_idx = ctx_server.lora_adapters.size(); @@ -3386,7 +3305,7 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); res.status = 200; // HTTP OK }; @@ -3462,35 +3381,75 @@ int main(int argc, char ** argv) { log_data["n_threads_http"] = std::to_string(params.n_threads_http); svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; - LOG_INFO("HTTP server listening", log_data); + // clean up function, to be called before exit + auto clean_up = [&svr]() { + svr->stop(); + llama_backend_free(); + }; + + // bind HTTP listen port, run the HTTP server in a thread + if (!svr->bind_to_port(params.hostname, params.port)) { + LOG_ERROR("couldn't bind HTTP server socket", { + {"hostname", params.hostname}, + {"port", params.port}, + }); + clean_up(); + LOG_ERROR("exiting due to HTTP server error", {}); + return 1; + } + std::thread t([&]() { svr->listen_after_bind(); }); + svr->wait_until_ready(); + + LOG_INFO("HTTP server is listening", log_data); - // run the HTTP server in a thread - see comment below - std::thread t([&]() { - if (!svr->listen_after_bind()) { - state.store(SERVER_STATE_ERROR); - return 1; + // load the model + LOG_INFO("loading model", log_data); + if (!ctx_server.load_model(params)) { + clean_up(); + t.join(); + LOG_ERROR("exiting due to model loading error", {}); + return 1; + } else { + ctx_server.init(); + state.store(SERVER_STATE_READY); + + LOG_INFO("model loaded", {}); + + // if a custom chat template is not supplied, we will use the one that comes with the model (if any) + if (params.chat_template.empty()) { + if (!ctx_server.validate_model_chat_template()) { + LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); + params.chat_template = "chatml"; + } } - return 0; - }); + // print sample chat example to make it clear which template is used + { + LOG_INFO("chat template", { + {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, + {"built_in", params.chat_template.empty()}, + }); + } - ctx_server.queue_tasks.on_new_task(std::bind( - &server_context::process_single_task, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_finish_multitask(std::bind( - &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_update_slots(std::bind( - &server_context::update_slots, &ctx_server)); - ctx_server.queue_results.on_multitask_update(std::bind( - &server_queue::update_multitask, - &ctx_server.queue_tasks, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3 - )); - - shutdown_handler = [&](int) { - ctx_server.queue_tasks.terminate(); - }; + ctx_server.queue_tasks.on_new_task(std::bind( + &server_context::process_single_task, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_finish_multitask(std::bind( + &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_update_slots(std::bind( + &server_context::update_slots, &ctx_server)); + ctx_server.queue_results.on_multitask_update(std::bind( + &server_queue::update_multitask, + &ctx_server.queue_tasks, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3 + )); + + shutdown_handler = [&](int) { + ctx_server.queue_tasks.terminate(); + }; + ctx_server.queue_tasks.start_loop(); + } #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -3506,12 +3465,8 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - ctx_server.queue_tasks.start_loop(); - - svr->stop(); + clean_up(); t.join(); - llama_backend_free(); - return 0; } diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 6705a34fc4696..1ba7b60b69c46 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -205,27 +205,20 @@ def step_start_server(context): async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str): match expecting_status: case 'healthy': - await wait_for_health_status(context, context.base_url, 200, 'ok', - timeout=30) + await wait_for_slots_status(context, context.base_url, 200, + timeout=30) case 'ready' | 'idle': - await wait_for_health_status(context, context.base_url, 200, 'ok', - timeout=30, - params={'fail_on_no_slot': 0, 'include_slots': 0}, - slots_idle=context.n_slots, - slots_processing=0, - expected_slots=[{'id': slot_id, 'state': 0} - for slot_id in - range(context.n_slots if context.n_slots else 1)]) + await wait_for_slots_status(context, context.base_url, 200, + timeout=30, + params={'fail_on_no_slot': 1}, + slots_idle=context.n_slots, + slots_processing=0) case 'busy': - await wait_for_health_status(context, context.base_url, 503, - 'no slot available', - params={'fail_on_no_slot': 0, 'include_slots': 0}, - slots_idle=0, - slots_processing=context.n_slots, - expected_slots=[{'id': slot_id, 'state': 1} - for slot_id in - range(context.n_slots if context.n_slots else 1)]) + await wait_for_slots_status(context, context.base_url, 503, + params={'fail_on_no_slot': 1}, + slots_idle=0, + slots_processing=context.n_slots) case _: assert False, "unknown status" @@ -1187,17 +1180,15 @@ async def gather_tasks_results(context): return n_completions -async def wait_for_health_status(context, - base_url, - expected_http_status_code, - expected_health_status, - timeout=3, - params=None, - slots_idle=None, - slots_processing=None, - expected_slots=None): +async def wait_for_slots_status(context, + base_url, + expected_http_status_code, + timeout=3, + params=None, + slots_idle=None, + slots_processing=None): if context.debug: - print(f"Starting checking for health for expected_health_status={expected_health_status}") + print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}") interval = 0.5 counter = 0 if 'GITHUB_ACTIONS' in os.environ: @@ -1205,26 +1196,19 @@ async def wait_for_health_status(context, async with aiohttp.ClientSession() as session: while True: - async with await session.get(f'{base_url}/health', params=params) as health_response: - status_code = health_response.status - health = await health_response.json() + async with await session.get(f'{base_url}/slots', params=params) as slots_response: + status_code = slots_response.status + slots = await slots_response.json() if context.debug: - print(f"HEALTH - response for expected health status='{expected_health_status}' on " - f"'{base_url}/health'?{params} is {health}\n") - if (status_code == expected_http_status_code - and health['status'] == expected_health_status - and (slots_idle is None or health['slots_idle'] == slots_idle) - and (slots_processing is None or health['slots_processing'] == slots_processing)): - if expected_slots is not None: - assert_slots_status(health['slots'], expected_slots) - return - if (status_code == expected_http_status_code - and health['status'] == expected_health_status - and (slots_idle is None or health['slots_idle'] == slots_idle) - and (slots_processing is None or health['slots_processing'] == slots_processing)): - if expected_slots is not None: - assert_slots_status(health['slots'], expected_slots) + print(f"slots responses {slots}\n") + if status_code == 503 and status_code == expected_http_status_code: return + if status_code == 200 and status_code == expected_http_status_code: + n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots) + n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots) + if ((slots_idle is None or slots_idle == n_slots_idle) + and (slots_processing is None or slots_processing == n_slots_processing)): + return await asyncio.sleep(interval) counter += interval @@ -1238,7 +1222,7 @@ async def wait_for_health_status(context, if n_completions > 0: return - assert False, f'{expected_health_status} timeout exceeded {counter}s>={timeout}' + assert False, f'slots check timeout exceeded {counter}s>={timeout}' def assert_embeddings(embeddings): diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index 17f5e496153a7..c817be566cf54 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -362,7 +362,7 @@ int main(int raw_argc, char ** raw_argv) { prompt = stdin_buffer.str(); } - const bool model_wants_add_bos = llama_should_add_bos_token(model); + const bool model_wants_add_bos = llama_add_bos_token(model); const bool add_bos = model_wants_add_bos && !no_bos; const bool parse_special = !no_parse_special; diff --git a/flake.lock b/flake.lock index f9e1548a2aca5..8e6c3e467947a 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1723175592, - "narHash": "sha256-M0xJ3FbDUc4fRZ84dPGx5VvgFsOzds77KiBMW/mMTnI=", + "lastModified": 1723637854, + "narHash": "sha256-med8+5DSWa2UnOqtdICndjDAEjxr5D7zaIiK4pn0Q7c=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "5e0ca22929f3342b19569b21b2f3462f053e497b", + "rev": "c3aa7b8938b17aebd2deecf7be0636000d62a2b9", "type": "github" }, "original": { diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 7fe1661bb96b4..cc16858849783 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -129,13 +129,13 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF) -option(GGML_CURL "ggml: use libcurl to download model from an URL" OFF) option(GGML_HIPBLAS "ggml: use hipBLAS" OFF) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF) option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF) +option(GGML_VULKAN_PERF "ggml: enable Vulkan perf output" OFF) option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF) option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_KOMPUTE "ggml: use Kompute" OFF) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 5f3f1e286990e..e73b9a7452fed 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -63,6 +63,7 @@ extern "C" { GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // "offset" refers to the offset of the tensor data for setting/getting data GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 15602a96df7ad..b11d047aeda7d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -220,7 +220,7 @@ #include #define GGML_FILE_MAGIC 0x67676d6c // "ggml" -#define GGML_FILE_VERSION 1 +#define GGML_FILE_VERSION 2 #define GGML_QNT_VERSION 2 // bump this on quantization format changes #define GGML_QNT_VERSION_FACTOR 1000 // do not change this @@ -244,6 +244,8 @@ #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 +#define GGML_ROPE_TYPE_NEOX 2 + #define GGUF_MAGIC "GGUF" #define GGUF_VERSION 3 @@ -451,6 +453,8 @@ extern "C" { GGML_OP_SQR, GGML_OP_SQRT, GGML_OP_LOG, + GGML_OP_SIN, + GGML_OP_COS, GGML_OP_SUM, GGML_OP_SUM_ROWS, GGML_OP_MEAN, @@ -488,9 +492,11 @@ extern "C" { GGML_OP_CLAMP, GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, + GGML_OP_IM2COL_BACK, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, + GGML_OP_POOL_2D_BACK, GGML_OP_UPSCALE, // nearest interpolate GGML_OP_PAD, GGML_OP_ARANGE, @@ -967,6 +973,22 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sin( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sin_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_cos( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_cos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // return scalar GGML_API struct ggml_tensor * ggml_sum( struct ggml_context * ctx, @@ -1453,8 +1475,8 @@ extern "C" { struct ggml_tensor * b); // rotary position embedding - // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED) - // if mode & 2 == 1, GPT-NeoX style + // if (mode & 1) - skip n_past elements (NOT SUPPORTED) + // if (mode & GGML_ROPE_TYPE_NEOX) - GPT-NeoX style // // b is an int32 vector with size a->ne[2], it contains the positions GGML_API struct ggml_tensor * ggml_rope( @@ -1564,34 +1586,49 @@ extern "C" { float min, float max); + // im2col + // converts data into a format that effectively results in a convolution when combined with matrix multiplication GGML_API struct ggml_tensor * ggml_im2col( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D, - enum ggml_type dst_type); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1, // dilation dimension 1 + bool is_2D, + enum ggml_type dst_type); + + GGML_API struct ggml_tensor * ggml_im2col_back( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // gradient of im2col output + int64_t * ne, // shape of im2col input + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1, // dilation dimension 1 + bool is_2D); GGML_API struct ggml_tensor * ggml_conv_depthwise_2d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data int s0, // stride int p0, // padding int d0); // dilation @@ -1600,29 +1637,29 @@ extern "C" { // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) GGML_API struct ggml_tensor* ggml_conv_1d_ph( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s, - int d); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s, // stride + int d); // dilation GGML_API struct ggml_tensor * ggml_conv_transpose_1d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int p0, // padding + int d0); // dilation GGML_API struct ggml_tensor * ggml_conv_2d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 // kernel size is a->ne[0] x a->ne[1] @@ -1684,6 +1721,18 @@ extern "C" { float p0, float p1); + GGML_API struct ggml_tensor * ggml_pool_2d_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * af, // "a"/input used in forward pass + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1); + // nearest interpolate // multiplies ne0 and ne1 by scale factor // used in stable-diffusion @@ -1758,7 +1807,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias); + float max_bias, + float logit_softcap); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, @@ -1775,10 +1825,8 @@ extern "C" { GGML_API struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq); + struct ggml_tensor * sx, + struct ggml_tensor * c); GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, @@ -1787,8 +1835,7 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq); + struct ggml_tensor * C); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index b7e149c604969..abf04a97fa0ac 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -549,6 +549,13 @@ if (GGML_SYCL) file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp") list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") + find_package(DNNL) + message("-- DNNL found:" ${DNNL_FOUND}) + if (GGML_SYCL_TARGET STREQUAL "INTEL") + add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND}) + else() + add_compile_definitions(GGML_SYCL_DNNL=0) + endif() if (WIN32) find_package(IntelSYCL REQUIRED) find_package(MKL REQUIRED) @@ -561,6 +568,9 @@ if (GGML_SYCL) set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl) endif() endif() + if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") + list(APPEND GGML_EXTRA_LIBS DNNL::dnnl) + endif() endif() if (GGML_RPC) @@ -602,6 +612,10 @@ if (GGML_VULKAN) add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) endif() + if (GGML_VULKAN_PERF) + add_compile_definitions(GGML_VULKAN_PERF) + endif() + if (GGML_VULKAN_VALIDATE) add_compile_definitions(GGML_VULKAN_VALIDATE) endif() diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 9e6bf26e3fa43..3c6652b8619a6 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -337,33 +337,18 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds } size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); - } - else { - assert(false); - return 0; - } + UNUSED(quant_weights); + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); } size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); - } - else { - assert(false); - return 0; - } + UNUSED(quant_weights); + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); } size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); - } - else { - assert(false); - return 0; - } + UNUSED(quant_weights); + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); } void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index e1651cc645c42..8856967c91104 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -1018,10 +1018,6 @@ static bool ggml_is_view_op(enum ggml_op op) { #define GGML_SCHED_MAX_BACKENDS 16 #endif -#ifndef GGML_SCHED_MAX_SPLITS -#define GGML_SCHED_MAX_SPLITS 2048 -#endif - #ifndef GGML_SCHED_MAX_SPLIT_INPUTS #define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC #endif @@ -1125,7 +1121,8 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co } #if 0 -static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only +#define GGML_SCHED_MAX_SPLITS_DEBUG 4096 +static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__) #define GET_CAUSE(node) causes[hash_id(node)] #else @@ -1549,7 +1546,6 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); GGML_ASSERT(sched->splits != NULL); } - GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS); split = &sched->splits[i_split]; split->backend_id = node_backend_id; split->i_start = i; @@ -1865,13 +1861,14 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); - const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2; + const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph + const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2; sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); - sched->context_buffer_size = GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); + sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); sched->context_buffer = malloc(sched->context_buffer_size); const int initial_splits_capacity = 16; diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 8c4132f5bb7ad..a4ec8418e2ab3 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2881,7 +2881,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const bool is_neox = mode & 2; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; // init cos/sin cache ggml_cann_pool_alloc sin_allocator( diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 682c30d45bcf4..8a844b02a27a5 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -9,8 +9,10 @@ #include "ggml-cuda/binbcast.cuh" #include "ggml-cuda/clamp.cuh" #include "ggml-cuda/concat.cuh" +#include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/convert.cuh" #include "ggml-cuda/cpy.cuh" +#include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/dmmv.cuh" #include "ggml-cuda/fattn.cuh" @@ -29,7 +31,6 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" -#include "ggml-cuda/conv-transpose-1d.cuh" #include #include @@ -2181,6 +2182,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ADD: ggml_cuda_op_add(ctx, dst); break; + case GGML_OP_SUB: + ggml_cuda_op_sub(ctx, dst); + break; case GGML_OP_ACC: ggml_cuda_op_acc(ctx, dst); break; @@ -2267,6 +2271,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SQRT: ggml_cuda_op_sqrt(ctx, dst); break; + case GGML_OP_SIN: + ggml_cuda_op_sin(ctx, dst); + break; + case GGML_OP_COS: + ggml_cuda_op_cos(ctx, dst); + break; case GGML_OP_CLAMP: ggml_cuda_op_clamp(ctx, dst); break; @@ -2303,6 +2313,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; + case GGML_OP_CROSS_ENTROPY_LOSS: + ggml_cuda_cross_entropy_loss(ctx, dst); + break; default: return false; } @@ -2610,6 +2623,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { if (node->src[j] != nullptr) { + assert(node->src[j]->buffer); assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); } } @@ -2853,12 +2867,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: @@ -2890,6 +2907,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + case GGML_OP_CROSS_ENTROPY_LOSS: + return true; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) default: return false; diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 34bc67acdd890..e1390a0414559 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) { return a + b; } +static __device__ __forceinline__ float op_sub(const float a, const float b) { + return a - b; +} + static __device__ __forceinline__ float op_mul(const float a, const float b) { return a * b; } @@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); +} + void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 4f63d6372eb50..198c9ef6fd8ea 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -2,5 +2,6 @@ void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu new file mode 100644 index 0000000000000..a14043e70451a --- /dev/null +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -0,0 +1,106 @@ +#include "common.cuh" +#include "cross-entropy-loss.cuh" +#include "sumrows.cuh" + +#include +#include + +static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) { + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE; + + const int ne_tmp = WARP_SIZE*nclasses; + + extern __shared__ float tmp_all[]; + float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp; + float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp; + + // Each warp first loads ne_tmp logits/labels into shared memory: + for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) { + const int ig = i0*nclasses + i; // ig == i global + + tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f; + tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f; + } + + // Each thread in the warp then calculates the cross entropy loss for a single row. + // TODO: pad in order to avoid shared memory bank conflicts. + + // Find maximum for softmax: + float max = -INFINITY; + for (int i = 0; i < nclasses; ++i) { + max = fmaxf(max, tmp_logits[lane_id*nclasses + i]); + } + + // Calculate log(softmax(logits)) which is just logits - max: + float sum = 0.0f; + for (int i = 0; i < nclasses; ++i) { + float val = tmp_logits[lane_id*nclasses + i] - max; + sum += expf(val); + tmp_logits[lane_id*nclasses + i] = val; + } + sum = logf(sum); + + // log(exp(logits - max) / sum) = (logits - max) - log(sum) + float loss = 0.0f; + for (int i = 0; i < nclasses; ++i) { + loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i]; + } + loss = -warp_reduce_sum(loss) / (float)k; + + __syncthreads(); + + if (lane_id == 0) { + tmp_all[warp_id] = loss; + } + + __syncthreads(); + + if (warp_id != 0) { + return; + } + + loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f; + loss = warp_reduce_sum(loss); + + if (lane_id != 0) { + return; + } + + dst[blockIdx.x] = loss; +} + +void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t stream = ctx.stream(); + + const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); + const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); + const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float); + + ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); + + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + + // Combine results from individual blocks: + sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream); +} diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cuh b/ggml/src/ggml-cuda/cross-entropy-loss.cuh new file mode 100644 index 0000000000000..9d7b8b0f0082b --- /dev/null +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256 + +void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 950fd93dfe1ee..1fb5c09c3b179 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -657,11 +658,17 @@ void launch_fattn( const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale = 1.0f; - float max_bias = 0.0f; + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } const uint32_t n_head = Q->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -675,7 +682,7 @@ void launch_fattn( V_data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 1b2fd500b746c..342f2eb665312 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; - half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + half sum; + if (use_logit_softcap) { + const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + sum = logit_softcap * tanhf(tmp.x + tmp.y); + } else { + sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); @@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f16_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f16_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index f3e68dbfa6a6a..827437ca0ad1f 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; + if (use_logit_softcap) { + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); @@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32( } } -template +template void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f32_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f32_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 02a4ad072bbd6..448a9a9054cca 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); @@ -190,6 +197,11 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { @@ -286,10 +298,10 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -297,48 +309,81 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * KQV = dst; - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 11a5e355fd40b..bf5125902534e 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); @@ -180,6 +187,11 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -267,10 +279,10 @@ static __global__ void flash_attn_vec_ext_f32( } } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -278,44 +290,78 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index ae23222426097..b10d19d9327c9 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -6,7 +6,7 @@ #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -22,6 +22,7 @@ static __global__ void flash_attn_ext_f16( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. @@ -85,6 +92,8 @@ static __global__ void flash_attn_ext_f16( const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); + const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); + frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: @@ -194,6 +203,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + + if (use_logit_softcap) { + KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + } } float KQ_max_new = KQ_max_f[j0/nwarps]; @@ -237,6 +250,15 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + + if (use_logit_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; + } } half2 KQ_max_new = KQ_max_h2[j0/nwarps]; @@ -427,7 +449,8 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); template void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; constexpr int nwarps = 4; @@ -435,20 +458,50 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 29f608b0ff98d..f87f33b3e574b 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; if (precision != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { @@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 99ec1dd98ca9c..88f586d689cfd 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -226,7 +226,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const bool is_neox = mode & 2; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const int32_t * pos = (const int32_t *) src1_d; diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 82e8e875f9be3..38dbf1b5e1fa9 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc } } -static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); k_sum_rows_f32<<>>(x, dst, ncols); @@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); - const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh index e7545f83c496b..191db1c13167e 100644 --- a/ggml/src/ggml-cuda/sumrows.cuh +++ b/ggml/src/ggml-cuda/sumrows.cuh @@ -1,3 +1,5 @@ #include "common.cuh" +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); + void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index f9e208011e2a8..89abfc21d8a56 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) { dst[i] = sqrtf(x[i]); } +static __global__ void sin_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = sinf(x[i]); +} + +static __global__ void cos_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = cosf(x[i]); +} + static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); @@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_ sqrt_f32<<>>(x, dst, k); } +static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE; + sin_f32<<>>(x, dst, k); +} + +static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE; + cos_f32<<>>(x, dst, k); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } + +void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} + +void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 4cfb0479e7169..c610e996abeb6 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -9,6 +9,8 @@ #define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_SQRT_BLOCK_SIZE 256 +#define CUDA_SIN_BLOCK_SIZE 256 +#define CUDA_COS_BLOCK_SIZE 256 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index aad189430ab0b..91b5e61b23ead 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -31,6 +31,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, GGML_METAL_KERNEL_TYPE_ADD_ROW, + GGML_METAL_KERNEL_TYPE_SUB, + GGML_METAL_KERNEL_TYPE_SUB_ROW, GGML_METAL_KERNEL_TYPE_MUL, GGML_METAL_KERNEL_TYPE_MUL_ROW, GGML_METAL_KERNEL_TYPE_DIV, @@ -82,6 +84,8 @@ GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, + GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, @@ -205,6 +209,9 @@ GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, + GGML_METAL_KERNEL_TYPE_SQRT, + GGML_METAL_KERNEL_TYPE_SIN, + GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_COUNT @@ -491,6 +498,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); @@ -542,6 +551,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); @@ -665,6 +676,9 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } @@ -765,15 +779,20 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_PERMUTE: case GGML_OP_CONCAT: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_REPEAT: case GGML_OP_SCALE: case GGML_OP_CLAMP: + return true; case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + return ggml_is_contiguous(op->src[0]); case GGML_OP_SUM_ROWS: - return true; case GGML_OP_SOFT_MAX: case GGML_OP_RMS_NORM: case GGML_OP_GROUP_NORM: @@ -803,6 +822,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx return false; } return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return ctx->support_simdgroup_reduction && @@ -1050,6 +1072,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: { @@ -1073,6 +1096,7 @@ static enum ggml_status ggml_metal_graph_compute( nb = ne00 / 4; switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; default: GGML_ABORT("fatal error"); @@ -1082,6 +1106,7 @@ static enum ggml_status ggml_metal_graph_compute( } else { switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; default: GGML_ABORT("fatal error"); @@ -1409,6 +1434,48 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SQRT: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SIN: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_COS: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_SUM_ROWS: @@ -1538,6 +1605,121 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } } break; + case GGML_OP_SSM_CONV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SSM_SCAN: + { + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + struct ggml_tensor * src4 = gf->nodes[i]->src[4]; + struct ggml_tensor * src5 = gf->nodes[i]->src[5]; + + GGML_ASSERT(src3); + GGML_ASSERT(src4); + GGML_ASSERT(src5); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; + id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + + const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); + + const uint64_t nb30 = src3->nb[0]; + const uint64_t nb31 = src3->nb[1]; + + const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); + const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + + const uint64_t nb40 = src4->nb[0]; + const uint64_t nb41 = src4->nb[1]; + const uint64_t nb42 = src4->nb[2]; + + const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); + const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); + const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + + const uint64_t nb50 = src5->nb[0]; + const uint64_t nb51 = src5->nb[1]; + const uint64_t nb52 = src5->nb[2]; + + const int64_t d_state = ne00; + const int64_t d_inner = ne01; + const int64_t n_seq_tokens = ne11; + const int64_t n_seqs = ne02; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; + + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_MUL_MAT: { GGML_ASSERT(ne00 == ne10); @@ -2313,7 +2495,7 @@ static enum ggml_status ggml_metal_graph_compute( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const bool is_neox = mode & 2; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; id pipeline = nil; @@ -2624,9 +2806,14 @@ static enum ggml_status ggml_metal_graph_compute( float scale; float max_bias; + float logit_softcap; + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -2677,30 +2864,31 @@ static enum ggml_status ggml_metal_graph_compute( } else { [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&scale length:sizeof( float) atIndex:23]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; + [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3bb37d32aced0..f323ab5f447d5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -17,7 +17,7 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, multiplication and division of two tensors +// general-purpose kernel for addition, subtraction, multiplication and division of two tensors // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( @@ -70,6 +70,56 @@ kernel void kernel_add( } } +kernel void kernel_sub( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int64_t & offs, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + } +} + kernel void kernel_mul( device const char * src0, device const char * src1, @@ -226,6 +276,15 @@ kernel void kernel_add_row( dst[tpig] = src0[tpig] + src1[tpig % nb]; } +kernel void kernel_sub_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] - src1[tpig % nb]; +} + kernel void kernel_mul_row( device const float4 * src0, device const float4 * src1, @@ -358,6 +417,27 @@ kernel void kernel_sqr( dst[tpig] = src0[tpig] * src0[tpig]; } +kernel void kernel_sqrt( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, @@ -667,6 +747,127 @@ kernel void kernel_diag_mask_inf_8( } } +// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 +// TODO: optimize +kernel void kernel_ssm_conv_f32( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = ne10; + const int64_t ncs = ne00; + const int64_t nr = ne01; + const int64_t n_t = ne1; + const int64_t n_s = ne2; + + device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); + device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); + device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// TODO: optimize +kernel void kernel_ssm_scan_f32( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i3 = tgpig.y; + + const int64_t nc = d_state; + const int64_t nr = d_inner; + const int64_t n_t = n_seq_tokens; + const int64_t n_s = n_seqs; + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); + device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); + device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); + device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); + device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); + + if (i2 > 0) { + s0 = s; + } + + // i1 == 0 + float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + int64_t i = i0; + float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + } +} + kernel void kernel_norm( device const void * src0, device float * dst, @@ -1976,6 +2177,7 @@ typedef void (flash_attn_ext_f16_t)( constant float & m0, constant float & m1, constant uint32_t & n_head_log2, + constant float & logit_softcap, threadgroup half * shared, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2014,6 +2216,7 @@ kernel void kernel_flash_attn_ext_f16( constant float & m0, constant float & m1, constant uint32_t & n_head_log2, + constant float & logit_softcap, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2138,19 +2341,6 @@ kernel void kernel_flash_attn_ext_f16( } simdgroup_store(mqk, ss + 8*cc, TF, 0, false); - - const short tx = tiisg%4; - const short ty = tiisg/4; - - if (mask != q) { - // mqk = mqk*scale + mask*slope - ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; - ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; - } else { - // mqk = mqk*scale - ss[8*cc + ty*TF + 2*tx + 0] *= scale; - ss[8*cc + ty*TF + 2*tx + 1] *= scale; - } } } @@ -2162,10 +2352,19 @@ kernel void kernel_flash_attn_ext_f16( float ms[Q]; for (short j = 0; j < Q; ++j) { - const short p = tiisg; - const float m = M[j]; - const float s = ss[j*TF + p]; + + // scale and apply the logitcap / mask + float s = ss[j*TF + tiisg]*scale; + + if (logit_softcap != 0.0f) { + s = logit_softcap*precise::tanh(s); + } + + if (mask != q) { + // mqk = mqk + mask*slope + s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; + } smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); @@ -2176,7 +2375,7 @@ kernel void kernel_flash_attn_ext_f16( S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; + ss[j*TF + tiisg] = vs; } // create a QxQ diagonal matrix for rescaling the output @@ -2345,6 +2544,7 @@ kernel void kernel_flash_attn_ext_vec_f16( constant float & m0, constant float & m1, constant uint32_t & n_head_log2, + constant float & logit_softcap, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2479,7 +2679,13 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask*slope if (tiisg == 0) { - mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); + mqk *= scale; + + if (logit_softcap != 0.0f) { + mqk = logit_softcap*precise::tanh(mqk); + } + + mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; ss4[cc] = mqk; } diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 1158444191dd5..8bc55fbad6e67 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3651,7 +3651,7 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q8_K_ref(x, y, k); } -//===================================== Dot ptoducts ================================= +//===================================== Dot products ================================= // // Helper functions diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp index 7757615f5a24b..8f9d0a4601969 100644 --- a/ggml/src/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc.cpp @@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of // RPC commands enum rpc_cmd { - ALLOC_BUFFER = 0, - GET_ALIGNMENT, - GET_MAX_SIZE, - BUFFER_GET_BASE, - FREE_BUFFER, - BUFFER_CLEAR, - SET_TENSOR, - GET_TENSOR, - COPY_TENSOR, - GRAPH_COMPUTE, - GET_DEVICE_MEMORY, + RPC_CMD_ALLOC_BUFFER = 0, + RPC_CMD_GET_ALIGNMENT, + RPC_CMD_GET_MAX_SIZE, + RPC_CMD_BUFFER_GET_BASE, + RPC_CMD_FREE_BUFFER, + RPC_CMD_BUFFER_CLEAR, + RPC_CMD_SET_TENSOR, + RPC_CMD_GET_TENSOR, + RPC_CMD_COPY_TENSOR, + RPC_CMD_GRAPH_COMPUTE, + RPC_CMD_GET_DEVICE_MEMORY, + RPC_CMD_COUNT, }; // RPC data structures @@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t uint64_t remote_ptr = ctx->remote_ptr; memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.empty()); delete ctx; @@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b uint64_t remote_ptr = ctx->remote_ptr; memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == sizeof(uint64_t)); // output serialization format: | base_ptr (8 bytes) | @@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); std::vector output; - bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output); GGML_ASSERT(status); } @@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == size); // output serialization format: | data (size bytes) | @@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b memcpy(input.data(), &rpc_src, sizeof(rpc_src)); memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output); GGML_ASSERT(status); // output serialization format: | result (1 byte) | GGML_ASSERT(output.size() == 1); @@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr)); memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output); GGML_ASSERT(status); } @@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer memcpy(input.data(), &size, sizeof(size)); std::vector output; auto sock = get_socket(buft_ctx->endpoint); - bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | @@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr & sock) { // input serialization format: | 0 bytes | std::vector input; std::vector output; - bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == sizeof(uint64_t)); // output serialization format: | alignment (8 bytes) | @@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr & sock) { // input serialization format: | 0 bytes | std::vector input; std::vector output; - bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == sizeof(uint64_t)); // output serialization format: | max_size (8 bytes) | @@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t serialize_graph(cgraph, input); std::vector output; auto sock = get_socket(rpc_ctx->endpoint); - bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 1); return (enum ggml_status)output[0]; @@ -636,7 +637,7 @@ GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const } GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { + if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { return false; } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; @@ -678,6 +679,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const } auto sock = get_socket(endpoint); if (sock == nullptr) { + fprintf(stderr, "Failed to connect to %s\n", endpoint); return nullptr; } size_t alignment = get_alignment(sock); @@ -719,7 +721,7 @@ static void get_device_memory(const std::shared_ptr & sock, size_t * f // input serialization format: | 0 bytes | std::vector input; std::vector output; - bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); // output serialization format: | free (8 bytes) | total (8 bytes) | @@ -1098,59 +1100,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre if (!recv_data(sockfd, &cmd, 1)) { break; } + if (cmd >= RPC_CMD_COUNT) { + // fail fast if the command is invalid + fprintf(stderr, "Unknown command: %d\n", cmd); + break; + } std::vector input; std::vector output; uint64_t input_size; if (!recv_data(sockfd, &input_size, sizeof(input_size))) { break; } - input.resize(input_size); + try { + input.resize(input_size); + } catch (const std::bad_alloc & e) { + fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size); + break; + } if (!recv_data(sockfd, input.data(), input_size)) { break; } bool ok = true; switch (cmd) { - case ALLOC_BUFFER: { + case RPC_CMD_ALLOC_BUFFER: { ok = server.alloc_buffer(input, output); break; } - case GET_ALIGNMENT: { + case RPC_CMD_GET_ALIGNMENT: { server.get_alignment(output); break; } - case GET_MAX_SIZE: { + case RPC_CMD_GET_MAX_SIZE: { server.get_max_size(output); break; } - case BUFFER_GET_BASE: { + case RPC_CMD_BUFFER_GET_BASE: { ok = server.buffer_get_base(input, output); break; } - case FREE_BUFFER: { + case RPC_CMD_FREE_BUFFER: { ok = server.free_buffer(input); break; } - case BUFFER_CLEAR: { + case RPC_CMD_BUFFER_CLEAR: { ok = server.buffer_clear(input); break; } - case SET_TENSOR: { + case RPC_CMD_SET_TENSOR: { ok = server.set_tensor(input); break; } - case GET_TENSOR: { + case RPC_CMD_GET_TENSOR: { ok = server.get_tensor(input, output); break; } - case COPY_TENSOR: { + case RPC_CMD_COPY_TENSOR: { ok = server.copy_tensor(input, output); break; } - case GRAPH_COMPUTE: { + case RPC_CMD_GRAPH_COMPUTE: { ok = server.graph_compute(input, output); break; } - case GET_DEVICE_MEMORY: { + case RPC_CMD_GET_DEVICE_MEMORY: { // output serialization format: | free (8 bytes) | total (8 bytes) | output.resize(2*sizeof(uint64_t), 0); memcpy(output.data(), &free_mem, sizeof(free_mem)); @@ -1203,8 +1215,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free return; } printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + fflush(stdout); rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); printf("Client connection closed\n"); + fflush(stdout); } #ifdef _WIN32 WSACleanup(); diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index d8eb86c2c1862..0d884f89a4e7b 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -38,6 +38,7 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/presets.hpp" +#include "ggml-sycl/gemm.hpp" bool ggml_sycl_loaded(void); void ggml_sycl_free_data(struct ggml_tensor * tensor); @@ -893,43 +894,6 @@ static void clamp_f32(const float * x, float * dst, const float min, const float dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } -template -static void im2col_kernel(const float *x, T *dst, int offset_delta, - int IW, int IH, int OW, int KW, int KH, - int pelements, int CHW, int s0, int s1, int p0, - int p1, int d0, int d1, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_id(2) + - item_ct1.get_group(2) * item_ct1.get_local_range(2); - if (i >= pelements) { - return; - } - - const int ksize = OW * (KH > 1 ? KW : 1); - const int kx = i / ksize; - const int kd = kx * ksize; - const int ky = (i - kd) / OW; - const int ix = i % OW; - - const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1; - - const int64_t offset_dst = - (item_ct1.get_group(1) * OW + ix) * CHW + - (item_ct1.get_group(0) * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = - sycl::vec(0.0f) - .convert()[0]; - } else { - const int64_t offset_src = item_ct1.get_group(0) * offset_delta; - dst[offset_dst] = - sycl::vec(x[offset_src + iih * IW + iiw]) - .convert()[0]; - } -} - template static void pool2d_nchw_kernel( const int ih, const int iw, const int oh, const int ow, @@ -1742,32 +1706,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst, }); } -template -static void im2col_sycl(const float *x, T *dst, int IW, int IH, - int OW, int OH, int KW, int KH, int IC, - int offset_delta, int s0, int s1, int p0, - int p1, int d0, int d1, - queue_ptr stream) { - const int parallel_elements = OW * KW * KH; - const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; - sycl::range<3> block_nums(IC, OH, num_blocks); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * - sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH, - parallel_elements, (IC * KH * KW), s0, s1, p0, - p1, d0, d1, item_ct1); - }); - } -} - - static bool g_sycl_loaded = false; bool ggml_sycl_loaded(void) { @@ -2545,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; +#if !GGML_SYCL_DNNL SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, @@ -2554,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl( dpct::library_data_t::real_half))); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); +#else + auto dnnl_stream = ctx.stream_dnnl(stream); + DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), + src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt()); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); + to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); +#endif } else { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); @@ -2576,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl( const float alpha = 1.0f; const float beta = 0.0f; - +#if !GGML_SYCL_DNNL SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); +#else + auto dnnl_stream = ctx.stream_dnnl(stream); + DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(), + src0_ddf_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt()); +#endif } (void) dst; (void) src1_ddq_i; @@ -2636,47 +2587,6 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens (void) src1_dd; } -inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; - - const int64_t IC = src1->ne[is_2D ? 2 : 1]; - const int64_t IH = is_2D ? src1->ne[1] : 1; - const int64_t IW = src1->ne[0]; - - const int64_t KH = is_2D ? src0->ne[1] : 1; - const int64_t KW = src0->ne[0]; - - const int64_t OH = is_2D ? dst->ne[2] : 1; - const int64_t OW = dst->ne[1]; - - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - - if (dst->type == GGML_TYPE_F16) { - im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); - } else { - im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); - } - - (void) src0; - (void) src0_dd; -} - inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, @@ -3581,7 +3491,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE + && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE); bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 58dd9c9a60e7d..d21b5f8dd2627 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -25,5 +25,6 @@ #include "norm.hpp" #include "softmax.hpp" #include "tsembd.hpp" +#include "im2col.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index e878f4f50f09e..cf5291b31fe91 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try { << ", line:" << __LINE__ << std::endl; std::exit(1); } + +int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { + const int64_t max_range = std::numeric_limits::max(); + int64_t sycl_down_blk_size = block_size; + int64_t global_range = accumulate_block_num * sycl_down_blk_size; + while(global_range > max_range) { + sycl_down_blk_size /= 2; + global_range = accumulate_block_num * sycl_down_blk_size; + } + return sycl_down_blk_size; +} diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 86d8b40e8b013..05947ccb746f2 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -19,6 +19,10 @@ #include "dpct/helper.hpp" #include "ggml-sycl.h" #include "presets.hpp" +#if GGML_SYCL_DNNL +#include "dnnl.hpp" +#include "dnnl_sycl.hpp" +#endif #define GGML_COMMON_DECL_SYCL #define GGML_COMMON_IMPL_SYCL @@ -130,6 +134,7 @@ typedef sycl::float2 dfloat2; #endif // GGML_SYCL_F16 #define MMVQ_MAX_BATCH_SIZE 8 +#define MMVQ_MIN_BATCH_SIZE 4 static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -276,6 +281,52 @@ struct ggml_backend_sycl_context { return stream(device, 0); } +#if GGML_SYCL_DNNL + dnnl::engine make_engine(sycl::queue* q) { + // Get the device associated with the queue + sycl::device dev = q->get_device(); + // Get the context associated with the queue + sycl::context ctx = q->get_context(); + const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); + return eng; + } + + std::unordered_map stream_map; + std::unordered_map engine_map; + dnnl::stream stream_dnnl(int device, int _stream) { + auto q = stream(device, _stream); + return stream_dnnl(q); + } + dnnl::engine engine_dnnl(sycl::queue* qptr) { + auto it = engine_map.find(qptr); + if (it == engine_map.end()) { + auto eng = make_engine(qptr); + engine_map[qptr] = eng; + return eng; + } + else + { + return it->second; + } + } + dnnl::stream stream_dnnl(sycl::queue* qptr) { + auto it = stream_map.find(qptr); + if (it == stream_map.end()) { + auto eng = engine_dnnl(qptr); + auto stream = dnnl::sycl_interop::make_stream(eng, *qptr); + stream_map[qptr] = stream; + return stream; + } + else + { + return it->second; + } + } + dnnl::stream stream_dnnl() { + return stream_dnnl(device, 0); + } +#endif + // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; @@ -352,4 +403,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { return acc.template get_multi_ptr().get(); } +int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 39c28753cf85c..5fd15e6cdccab 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -3,19 +3,19 @@ #include "presets.hpp" template -static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, +static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { - const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) + + const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)); if (i >= k) { return; } - const int ib = i/qk; // block index - const int iqs = (i%qk)/qr; // quant index - const int iybs = i - i%qk; // y block start index - const int y_offset = qr == 1 ? 1 : qk/2; + const int64_t ib = i/qk; // block index + const int64_t iqs = (i%qk)/qr; // quant index + const int64_t iybs = i - i%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; // dequantize dfloat2 v; @@ -27,9 +27,9 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ template static void dequantize_block_sycl(const void *__restrict__ vx, - dst_t *__restrict__ y, const int k, + dst_t *__restrict__ y, const int64_t k, dpct::queue_ptr stream) { - const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE); + const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE); { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx, } template -static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 { dpct::has_capability_or_fail(stream->get_device(), @@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 { dpct::has_capability_or_fail(stream->get_device(), @@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb32 = k / 32; - const int nb = (k + 255) / 256; + const int64_t nb32 = k / 32; + const int64_t nb = (k + 255) / 256; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb32 = k / 32; - const int nb = (k + 255) / 256; + const int64_t nb32 = k / 32; + const int64_t nb = (k + 255) / 256; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k, template -static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 { dpct::has_capability_or_fail(stream->get_device(), @@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 { dpct::has_capability_or_fail(stream->get_device(), @@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k, template -static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = (k + QK_K - 1) / QK_K; + const int64_t nb = (k + QK_K - 1) / QK_K; #if QK_K == 64 dequantize_row_iq4_nl_sycl(vx, y, k, stream); #else @@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = (k + QK_K - 1) / QK_K; + const int64_t nb = (k + QK_K - 1) / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -418,34 +418,34 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k, } template -static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, +static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } + const int64_t work_group_size = item_ct1.get_local_range(2); + const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + // make each work-item deal with more elements since sycl global range can not exceed max int const src_t * x = (src_t *) vx; - - y[i] = x[i]; + for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) { + y[i] = x[i]; + } } template static void convert_unary_sycl(const void *__restrict__ vx, - dst_t *__restrict__ y, const int k, + dst_t *__restrict__ y, const int64_t k, dpct::queue_ptr stream) { - const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; + const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; + + // decrease global range when it exceeds the max int + int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE); + sycl::range<3> block_nums(1, 1, num_blocks); + sycl::range<3> local_range(1, 1, local_size); { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); stream->parallel_for( - sycl::nd_range<3>( - sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), + sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { convert_unary(vx, y, k, item_ct1); }); diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index b1f10d6355535..0ce2874aaaef9 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -17,7 +17,7 @@ template using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y, - int k, dpct::queue_ptr stream); + int64_t k, dpct::queue_ptr stream); typedef to_t_sycl_t to_fp32_sycl_t; typedef to_t_sycl_t to_fp16_sycl_t; diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index ed8ad098bcb2f..8f4041fffce33 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -15,9 +15,9 @@ #include "common.hpp" -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); -static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_1 * x = (const block_q4_1 *) vx; @@ -64,7 +64,7 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_0 * x = (const block_q5_0 *) vx; @@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_1 * x = (const block_q5_1 *) vx; @@ -118,7 +118,7 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q8_0 * x = (const block_q8_0 *) vx; @@ -138,16 +138,16 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib, } template -static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, +static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); // assume 32 threads - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; - const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } @@ -168,16 +168,16 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri } template -static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, +static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); // assume 32 threads - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; - const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } @@ -203,14 +203,14 @@ template static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_q2_K * x = (const block_q2_K *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int n = tid/32; - const int l = tid - 32*n; - const int is = 8*n + l/16; + const int64_t n = tid/32; + const int64_t l = tid - 32*n; + const int64_t is = 8*n + l/16; const uint8_t q = x[i].qs[32*n + l]; dst_t * y = yy + i*QK_K + 128*n; @@ -222,8 +222,8 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); #else - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 const uint8_t q = x[i].qs[il] >> (2*is); dst_t * y = yy + i*QK_K + 16*is + il; @@ -239,19 +239,19 @@ template static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_q3_K * x = (const block_q3_K *) vx; #if QK_K == 256 - const int r = item_ct1.get_local_id(2) / 4; - const int tid = r/2; - const int is0 = r%2; - const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); - const int n = tid / 4; - const int j = tid - 4*n; + const int64_t r = item_ct1.get_local_id(2) / 4; + const int64_t tid = r/2; + const int64_t is0 = r%2; + const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); + const int64_t n = tid / 4; + const int64_t j = tid - 4*n; uint8_t m = 1 << (4*n + j); - int is = 8*n + 2*j + is0; + int64_t is = 8*n + 2*j + is0; int shift = 2*j; int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : @@ -267,11 +267,11 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); #else - const int tid = item_ct1.get_local_id(2); - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 - const int im = il/8; // 0...1 - const int in = il%8; // 0...7 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 + const int64_t im = il/8; // 0...1 + const int64_t in = il%8; // 0...7 dst_t * y = yy + i*QK_K + 16*is + il; @@ -307,15 +307,15 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { const block_q4_K * x = (const block_q4_K *) vx; - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); #if QK_K == 256 // assume 32 threads - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; - const int ir = tid%8; - const int is = 2*il; - const int n = 4; + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t is = 2*il; + const int64_t n = 4; dst_t * y = yy + i*QK_K + 64*il + n*ir; @@ -341,7 +341,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri y[l +32] = d2 * (q_vec[l] >> 4) - m2; } #else - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); const uint8_t * q = x[i].qs; dst_t * y = yy + i*QK_K; const float d = (float)x[i].dm[0]; @@ -356,14 +356,14 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri const sycl::nd_item<3> &item_ct1) { const block_q5_K * x = (const block_q5_K *) vx; - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); #if QK_K == 256 // assume 64 threads - this is very slightly better than the one below - const int tid = item_ct1.get_local_id(2); - const int il = tid/16; // il is in 0...3 - const int ir = tid%16; // ir is in 0...15 - const int is = 2*il; // is is in 0...6 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/16; // il is in 0...3 + const int64_t ir = tid%16; // ir is in 0...15 + const int64_t is = 2*il; // is is in 0...6 dst_t * y = yy + i*QK_K + 64*il + 2*ir; @@ -386,11 +386,11 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; #else - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); const uint8_t q = x[i].qs[tid]; - const int im = tid/8; // 0...3 - const int in = tid%8; // 0...7 - const int is = tid/16; // 0 or 1 + const int64_t im = tid/8; // 0...3 + const int64_t in = tid%8; // 0...7 + const int64_t is = tid/16; // 0 or 1 const uint8_t h = x[i].qh[in] >> im; const float d = x[i].d; dst_t * y = yy + i*QK_K + tid; @@ -404,14 +404,14 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri const sycl::nd_item<3> &item_ct1) { const block_q6_K * x = (const block_q6_K *) vx; - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); #if QK_K == 256 // assume 64 threads - this is very slightly better than the one below - const int tid = item_ct1.get_local_id(2); - const int ip = tid/32; // ip is 0 or 1 - const int il = tid - 32*ip; // 0...32 - const int is = 8*ip + il/16; + const int64_t tid = item_ct1.get_local_id(2); + const int64_t ip = tid/32; // ip is 0 or 1 + const int64_t il = tid - 32*ip; // 0...32 + const int64_t is = 8*ip + il/16; dst_t * y = yy + i*QK_K + 128*ip + il; @@ -428,9 +428,9 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri #else // assume 32 threads - const int tid = item_ct1.get_local_id(2); - const int ip = tid/16; // 0 or 1 - const int il = tid - 16*ip; // 0...15 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t ip = tid/16; // 0 or 1 + const int64_t il = tid - 16*ip; // 0...15 dst_t * y = yy + i*QK_K + 16*ip + il; @@ -452,13 +452,13 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res const uint8_t *ksigns_iq2xs_ptr, const uint8_t *kmask_iq2xs_ptr) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq2_xxs * x = (const block_iq2_xxs *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * q2 = x[i].qs + 4*ib; const uint8_t * aux8 = (const uint8_t *)q2; @@ -480,13 +480,13 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest const uint8_t *ksigns_iq2xs, const uint8_t *kmask_iq2xs) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq2_xs * x = (const block_iq2_xs *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * q2 = x[i].qs + 4*ib; const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); @@ -504,13 +504,13 @@ __dpct_inline__ static void dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq2_s * x = (const block_iq2_s *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; @@ -532,13 +532,13 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res const uint8_t *ksigns_iq2xs, const uint8_t *kmask_iq2xs) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq3_xxs * x = (const block_iq3_xxs *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * q3 = x[i].qs + 8*ib; const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; @@ -563,13 +563,13 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1, const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq3_s * x = (const block_iq3_s *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * qs = x[i].qs + 8*ib; const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); @@ -593,13 +593,13 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1, const uint32_t *iq1s_grid_gpu) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq1_s * x = (const block_iq1_s *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1); @@ -623,13 +623,13 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1, const uint32_t *iq1s_grid_gpu) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq1_m * x = (const block_iq1_m *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * sc = (const uint16_t *)x[i].scales; iq1m_scale_t scale; @@ -656,12 +656,12 @@ __dpct_inline__ static void dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL); - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 4*il; const uint8_t * q4 = x[ib].qs + 4*il; const float d = (float)x[ib].d; @@ -678,12 +678,12 @@ template __dpct_inline__ static void dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq4_xs * x = (const block_iq4_xs *)vx; - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 4*il; const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index ae45630e1173d..5c343822f390f 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -4,7 +4,7 @@ #include "presets.hpp" -static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; // automatic half -> float type cast if dfloat == float @@ -12,7 +12,7 @@ static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v.y() = x[ib + iqs + 1]; } -static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const float * x = (const float *) vx; // automatic half -> float type cast if dfloat == float diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp new file mode 100644 index 0000000000000..2ad9b36f419ce --- /dev/null +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -0,0 +1,101 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_GEMM_HPP +#define GGML_SYCL_GEMM_HPP + +#include +#include + +#include "ggml-sycl.h" + +#if GGML_SYCL_DNNL + +#include "dnnl.hpp" +#include "dnnl_sycl.hpp" + +class DnnlGemmWrapper { +public: + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + + template + static constexpr dt to_dt() { + if constexpr (std::is_same_v) return dt::f32; + else if constexpr (std::is_same_v) return dt::f16; + else static_assert(0); + } + + static inline void row_gemm(sycl::queue& q, bool a_trans, + bool b_trans, int m, int n, int k, + const void* a, dt at, const void* b, dt bt, void* c, dt ct) + { + // Get the device associated with the queue + sycl::device dev = q.get_device(); + // Get the context associated with the queue + sycl::context ctx = q.get_context(); + const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); + const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); + dnnl::memory::dims a_dims = { m, k }; + dnnl::memory::dims b_dims = { k, n }; + dnnl::memory::dims c_dims = { m, n }; + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); + auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); + auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); + + // Create the primitive. + auto matmul_prim = dnnl::matmul(matmul_pd); + // Primitive arguments. + std::unordered_map matmul_args; + matmul_args.insert({ DNNL_ARG_SRC, a_mem }); + matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); + + matmul_prim.execute(stream, matmul_args); + } + + + static inline void row_gemm(const dnnl::stream& stream, bool a_trans, + bool b_trans, int m, int n, int k, + const void* a, dt at, const void* b, dt bt, void* c, dt ct) + { + auto const eng = stream.get_engine(); + dnnl::memory::dims a_dims = { m, k }; + dnnl::memory::dims b_dims = { k, n }; + dnnl::memory::dims c_dims = { m, n }; + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); + auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); + auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); + + // Create the primitive. + auto matmul_prim = dnnl::matmul(matmul_pd); + // Primitive arguments. + std::unordered_map matmul_args; + matmul_args.insert({ DNNL_ARG_SRC, a_mem }); + matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); + + matmul_prim.execute(stream, matmul_args); + } +}; + +#endif + +#endif // GGML_SYCL_GEMM_HPP diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp new file mode 100644 index 0000000000000..6a0a0fcd08c68 --- /dev/null +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -0,0 +1,125 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "im2col.hpp" + +template +static void im2col_kernel( + const float *x, T *dst, int64_t batch_offset, int64_t offset_delta, + int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, + int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1, + const sycl::nd_item<3> &item_ct1) { + const int64_t work_group_size = item_ct1.get_local_range(2); + const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + + // make each work-item deal with more elements since sycl global range can not exceed max int + for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) { + + const int64_t ksize = OW * (KH > 1 ? KW : 1); + const int64_t kx = i / ksize; + const int64_t kd = kx * ksize; + const int64_t ky = (i - kd) / OW; + const int64_t ix = i % OW; + + const int64_t oh = item_ct1.get_group(1); + const int64_t batch = item_ct1.get_group(0) / IC; + const int64_t ic = item_ct1.get_group(0) % IC; + + const int64_t iiw = ix * s0 + kx * d0 - p0; + const int64_t iih = oh * s1 + ky * d1 - p1; + + const int64_t offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = + sycl::vec(0.0f) + .convert()[0]; + } else { + const int64_t offset_src = ic * offset_delta + batch * batch_offset; + dst[offset_dst] = + sycl::vec(x[offset_src + iih * IW + iiw]) + .convert()[0]; + } + } +} + +template +static void im2col_sycl( + const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, + int s0, int s1, int p0, int p1, int d0, int d1, + queue_ptr stream) { + const int64_t parallel_elements = OW * KW * KH; + const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + + // decrease global range when it exceeds the max int + int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); + sycl::range<3> block_nums(batch * IC, OH, num_blocks); + sycl::range<3> local_range(1, 1, local_size); + + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * local_range, local_range), + [=](sycl::nd_item<3> item_ct1) { + im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, + parallel_elements, (IC * KH * KW), s0, s1, p0, + p1, d0, d1, item_ct1); + }); + } +} + +void ggml_sycl_op_im2col( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + + const int64_t IC = src1->ne[is_2D ? 2 : 1]; + const int64_t IH = is_2D ? src1->ne[1] : 1; + const int64_t IW = src1->ne[0]; + + const int64_t KH = is_2D ? src0->ne[1] : 1; + const int64_t KW = src0->ne[0]; + + const int64_t OH = is_2D ? dst->ne[2] : 1; + const int64_t OW = dst->ne[1]; + + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t batch = src1->ne[3]; + const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 + + if (dst->type == GGML_TYPE_F16) { + im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + } else { + im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + } + + (void) src0; + (void) src0_dd; +} diff --git a/ggml/src/ggml-sycl/im2col.hpp b/ggml/src/ggml-sycl/im2col.hpp new file mode 100644 index 0000000000000..7db144fbbe524 --- /dev/null +++ b/ggml/src/ggml-sycl/im2col.hpp @@ -0,0 +1,23 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_IM2COL_HPP +#define GGML_SYCL_IM2COL_HPP + +#include "common.hpp" + +void ggml_sycl_op_im2col( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream); + +#endif // GGML_SYCL_IM2COL_HPP diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index c7545bcc1a8a9..1f06f78fa3d91 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -226,7 +226,7 @@ void ggml_sycl_op_rope( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const bool is_neox = mode & 2; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const int32_t * pos = (const int32_t *) src1_dd; diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 86732837254f0..ca4f44cf75615 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -1,6 +1,6 @@ #include "ggml-vulkan.h" #include -#ifdef GGML_VULKAN_RUN_TESTS +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) #include #endif @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -34,9 +35,7 @@ #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de -#define VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN 0 -#define VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI 1 -#define VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE 2 +#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 #define GGML_VK_MAX_NODES 8192 @@ -74,6 +73,8 @@ struct vk_queue { std::vector cmd_buffers; vk::PipelineStageFlags stage_flags; + + bool transfer_only; }; struct vk_pipeline_struct { @@ -133,6 +134,9 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { #ifdef GGML_VULKAN_MEMORY_DEBUG class vk_memory_logger; #endif +#ifdef GGML_VULKAN_PERF +class vk_perf_logger; +#endif static void ggml_vk_destroy_buffer(vk_buffer& buf); struct vk_device_struct { @@ -148,7 +152,6 @@ struct vk_device_struct { vk_queue compute_queue; vk_queue transfer_queue; bool single_queue; - uint32_t descriptor_set_mode; uint32_t subgroup_size; bool uma; @@ -177,6 +180,7 @@ struct vk_device_struct { vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_acc_f32; vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16; vk_pipeline pipeline_mul_f32; vk_pipeline pipeline_div_f32; @@ -184,8 +188,11 @@ struct vk_device_struct { vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; + vk_pipeline pipeline_sin_f32; + vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; + vk_pipeline pipeline_repeat_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; @@ -205,7 +212,8 @@ struct vk_device_struct { vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; - std::vector pipelines; + std::unordered_map pipelines; + std::unordered_map pipeline_descriptor_set_requirements; std::vector> pinned_memory; @@ -217,6 +225,9 @@ struct vk_device_struct { #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; #endif +#ifdef GGML_VULKAN_PERF + std::unique_ptr perf_logger; +#endif ~vk_device_struct() { VK_LOG_DEBUG("destroy device " << name); @@ -231,11 +242,11 @@ struct vk_device_struct { } for (auto& pipeline : pipelines) { - if (pipeline.expired()) { + if (pipeline.second.expired()) { continue; } - vk_pipeline pl = pipeline.lock(); + vk_pipeline pl = pipeline.second.lock(); ggml_vk_destroy_pipeline(device, pl); } pipelines.clear(); @@ -479,6 +490,48 @@ class vk_memory_logger { #define VK_LOG_MEMORY(msg) ((void) 0) #endif // GGML_VULKAN_MEMORY_DEBUG +#if defined(GGML_VULKAN_PERF) + +class vk_perf_logger { +public: + void print_timings() { + std::cerr << "----------------\nVulkan Timings:" << std::endl; + for (const auto& t : timings) { + uint64_t total = 0; + for (const auto& time : t.second) { + total += time; + } + std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl; + } + + timings.clear(); + } + + void log_timing(const ggml_tensor * node, uint64_t time) { + if (node->op == GGML_OP_UNARY) { + timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); + return; + } + if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->src[1]->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + std::string name = ggml_op_name(node->op); + if (n == 1) { + name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); + } else { + name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); + } + timings[name].push_back(time); + return; + } + timings[ggml_op_name(node->op)].push_back(time); + } +private: + std::map> timings; +}; +#endif // GGML_VULKAN_PERF + struct ggml_backend_vk_context { std::string name; @@ -489,9 +542,6 @@ struct ggml_backend_vk_context { size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; vk_buffer prealloc_x, prealloc_y, prealloc_split_k; vk::Fence fence; - vk_buffer staging; - size_t staging_size; - size_t staging_offset; vk_buffer buffer_pool[MAX_VK_BUFFERS]; @@ -595,35 +645,9 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co descriptor_set_layout_create_info.setPNext(&dslbfci); pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); - // Check if device supports multiple descriptors per pool - if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) { - const uint32_t alloc_count = 2; - - // Try allocating multiple sets from one pool - // This fails on AMD for some reason, so add a fall back to allocating one pool per set - vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count); - vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, alloc_count, descriptor_pool_size); - vk::DescriptorPool pool = device->device.createDescriptorPool(descriptor_pool_create_info); - - std::vector layouts(alloc_count); - for (uint32_t i = 0; i < alloc_count; i++) { - layouts[i] = pipeline->dsl; - } - try { - vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pool, alloc_count, layouts.data()); - std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); - } catch(vk::OutOfPoolMemoryError const&) { - device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE; - } - - device->device.destroyDescriptorPool(pool); - } - - if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) { - vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count); - vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 128, descriptor_pool_size); - pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); - } + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); pipeline->descriptor_set_idx = 0; @@ -657,7 +681,7 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co pipeline->layout); pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; - device->pipelines.push_back(pipeline); + device->pipelines.insert({ pipeline->name, pipeline }); } static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { @@ -678,34 +702,49 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) device.destroyPipeline(pipeline->pipeline); } -static void ggml_pipeline_allocate_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { - VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")"); - if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) { - // Enough descriptors are available - return; - } +static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { + VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); + device->pipeline_descriptor_set_requirements[pipeline->name] += n; +} +static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { std::lock_guard guard(device->mutex); - if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) { - const uint32_t alloc_count = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size(); + for (auto& pair : device->pipeline_descriptor_set_requirements) { + vk_pipeline pipeline = device->pipelines.at(pair.first).lock(); + const uint64_t n = pair.second; + + VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")"); - std::vector layouts(alloc_count); - for (uint32_t i = 0; i < alloc_count; i++) { - layouts[i] = pipeline->dsl; + if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) { + // Enough descriptors are available + continue; } - vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[0], alloc_count, layouts.data()); - std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); - pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end()); - } else { - for (uint32_t i = pipeline->descriptor_sets.size(); i < pipeline->descriptor_set_idx + n; i++) { - vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count); - vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 1, descriptor_pool_size); - pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); - vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[i], 1, &pipeline->dsl); + uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size(); + uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; + uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + while (to_alloc > 0) { + const uint32_t alloc_count = std::min(pool_remaining, to_alloc); + to_alloc -= alloc_count; + pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + if (pool_idx >= pipeline->descriptor_pools.size()) { + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); + } + + std::vector layouts(alloc_count); + for (uint32_t i = 0; i < alloc_count; i++) { + layouts[i] = pipeline->dsl; + } + vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data()); std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); - pipeline->descriptor_sets.push_back(sets[0]); + pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end()); + + pool_idx++; } } } @@ -866,11 +905,12 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector guard(device->mutex); q.queue_family_index = queue_family_index; + q.transfer_only = transfer_only; vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index); q.pool = device->device.createCommandPool(command_pool_create_info_compute); @@ -1067,13 +1107,16 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { static void ggml_vk_sync_buffers(vk_context& ctx) { VK_LOG_DEBUG("ggml_vk_sync_buffers()"); + + const bool transfer_queue = ctx->q->transfer_only; + ctx->s->buffer.pipelineBarrier( ctx->q->stage_flags, ctx->q->stage_flags, {}, { { - {vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite}, - {vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite} + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) } } }, {}, {} @@ -1647,6 +1690,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -1659,11 +1704,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -1703,6 +1752,9 @@ static vk_device ggml_vk_get_device(size_t idx) { #ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger = std::unique_ptr(new vk_memory_logger()); #endif +#ifdef GGML_VULKAN_PERF + device->perf_logger = std::unique_ptr(new vk_perf_logger()); +#endif size_t dev_num = vk_instance.device_indices[idx]; @@ -1833,17 +1885,15 @@ static vk_device ggml_vk_get_device(size_t idx) { device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); - device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN; - // Queues - ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }); + ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false); // Shaders ggml_vk_load_shaders(device); if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; - ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }); + ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); } else { // TODO: Use pointer or reference to avoid copy device->transfer_queue = device->compute_queue; @@ -2130,9 +2180,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->fence = ctx->device->device.createFence({}); - ctx->staging_size = 0; - ctx->staging_offset = 0; - #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); @@ -2565,23 +2612,15 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont return; } - // Staging buffer required - vk_buffer staging = ctx->staging; - size_t staging_offset = ctx->staging_offset; - const size_t copy_size = ts*ne/bs; - if (ctx->staging->size < ctx->staging_offset + copy_size) { - if (sync_staging) { - // Create temporary larger buffer - ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); - - staging = ctx->device->sync_staging; - staging_offset = 0; - } else { - GGML_ABORT("fatal error"); - } + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); } - VkBufferCopy buf_copy{ staging_offset, offset, copy_size }; + // Staging buffer required + vk_buffer& staging = ctx->device->sync_staging; + const uint64_t copy_size = ts*ne/bs; + ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); + VkBufferCopy buf_copy{ 0, offset, copy_size }; ggml_vk_sync_buffers(subctx); vkCmdCopyBuffer(subctx->s->buffer, staging->buffer, dst->buffer, 1, &buf_copy); @@ -2590,14 +2629,14 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont for (uint64_t i2 = 0; i2 < ne2; i2++) { // Find longest contiguous slice if (ne1*nb1 == dstnb2) { - deferred_memcpy((uint8_t *)staging->ptr + staging_offset + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys); + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys); } else { for (uint64_t i1 = 0; i1 < ne1; i1++) { if (ne0*nb0/bs == dstnb1) { - deferred_memcpy((uint8_t *)staging->ptr + staging_offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys); + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys); } else { const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; - const uint64_t d_off = staging_offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1; for (uint64_t i0 = 0; i0 < ne0; i0++) { deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys); } @@ -2608,7 +2647,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } -static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { +static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); // Buffer is already mapped if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { @@ -2643,21 +2682,18 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } VK_LOG_DEBUG("STAGING"); + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + // Staging buffer required const size_t copy_size = width*height; - if (staging_buffer == nullptr || staging_buffer->size < staging_offset + copy_size) { - if (sync_staging) { - ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); - staging_buffer = dst->device->sync_staging; - staging_offset = 0; - } else { - GGML_ABORT("fatal error"); - } - } + vk_buffer& staging_buffer = dst->device->sync_staging; VkBufferCopy buf_copy = { - staging_offset, + 0, offset, copy_size}; @@ -2665,17 +2701,17 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz vkCmdCopyBuffer(subctx->s->buffer, staging_buffer->buffer, dst->buffer, 1, &buf_copy); if (width == spitch) { - deferred_memcpy((uint8_t *)staging_buffer->ptr + staging_offset, src, width * height, &subctx->in_memcpys); + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); } else { for (size_t i = 0; i < height; i++) { - deferred_memcpy((uint8_t *)staging_buffer->ptr + staging_offset + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); + deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); } } } -static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { +static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); - return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, staging_buffer, staging_offset, sync_staging); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); } static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { @@ -2690,7 +2726,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * } else { vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); ggml_vk_ctx_begin(dst->device, subctx); - ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, nullptr, 0, true); + ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); ggml_vk_ctx_end(subctx); for (auto& cpy : subctx->in_memcpys) { @@ -2708,7 +2744,7 @@ static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); } -static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { +static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")"); GGML_ASSERT(width > 0); GGML_ASSERT(height > 0); @@ -2745,18 +2781,15 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size } VK_LOG_DEBUG("STAGING"); + if (!sync_staging) { + GGML_ABORT("Asynchronous read from non-pinned memory not supported"); + } + // Fall back to staging buffer const size_t copy_size = dpitch * height; - if (staging_buffer == nullptr || staging_buffer->size < staging_offset + copy_size) { - if (sync_staging) { - // Create temporary larger buffer - ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); - staging_buffer = src->device->sync_staging; - } else { - GGML_ABORT("fatal error"); - } - } + vk_buffer& staging_buffer = src->device->sync_staging; ggml_vk_sync_buffers(subctx); subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); @@ -2764,8 +2797,8 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); } -static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { - return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, staging_buffer, staging_offset, sync_staging); +static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) { + return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); } static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { @@ -2777,7 +2810,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ } else { vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); ggml_vk_ctx_begin(src->device, subctx); - ggml_vk_buffer_read_async(subctx, src, offset, dst, size, nullptr, 0, true); + ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, src->device->fence); @@ -2978,10 +3011,11 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 }); } -static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; - std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT @@ -3055,6 +3089,56 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; const uint64_t d_sz = sizeof(float) * d_ne; + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || + (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); + } + return; + } + vk_buffer d_D = extra->buffer_gpu.lock(); const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); @@ -3090,34 +3174,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub GGML_ASSERT(qy_sz == y_sz); } - vk_pipeline to_fp16_vk_0 = nullptr; - vk_pipeline to_fp16_vk_1 = nullptr; - - if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16); - } else { - to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); - } - if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16); - } else { - to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); - } - GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT - GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT - - // Allocate descriptor sets - ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1); - if (qx_needs_dequant) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1); - } - if (qy_needs_dequant) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, 1); - } - if (split_k > 1) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); - } - if (x_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } else if (qx_needs_dequant) { @@ -3151,10 +3207,11 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub ); // NOLINT } -static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; - std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT @@ -3218,6 +3275,47 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; const uint64_t d_sz = sizeof(float) * d_ne; + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + return; + } + vk_buffer d_D = extra->buffer_gpu.lock(); const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); @@ -3250,30 +3348,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& GGML_ASSERT(qy_sz == y_sz); } - vk_pipeline to_fp16_vk_0 = nullptr; - vk_pipeline to_fp16_vk_1 = nullptr; - if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type); - } - if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type); - } else { - to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); - } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type); - GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT - GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT - GGML_ASSERT(dmmv != nullptr); - - // Allocate descriptor sets - if (qx_needs_dequant) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1); - } - if (qy_needs_dequant) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13); - } - ggml_pipeline_allocate_descriptor_sets(ctx->device, dmmv, ne12 * ne13); - if (x_non_contig) { GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); @@ -3316,10 +3390,11 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); } -static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; - std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT @@ -3360,6 +3435,12 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t d_sz = sizeof(float) * d_ne; + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); + return; + } + vk_buffer d_D = extra->buffer_gpu.lock(); const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); @@ -3372,9 +3453,6 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c GGML_ASSERT(d_Qx != nullptr); } - // Allocate descriptor sets - ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); - const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; @@ -3387,10 +3465,11 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); } -static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; - std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); @@ -3435,6 +3514,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t qy_sz = ggml_nbytes(src1); const uint64_t d_sz = sizeof(float) * d_ne; + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + return; + } + vk_buffer d_D = extra->buffer_gpu.lock(); const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); @@ -3447,9 +3532,6 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con GGML_ASSERT(d_Qx != nullptr); } - // Allocate descriptor sets - ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); - const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; @@ -3463,20 +3545,20 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); } -static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) { - ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst); + ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun); } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1) { - ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst); + ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { - ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst); + ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); } else { - ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst); + ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); } } -static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; @@ -3566,6 +3648,48 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t ids_sz = nbi2; const uint64_t d_sz = sizeof(float) * d_ne; + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + return; + } + vk_buffer d_D = extra->buffer_gpu.lock(); const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); @@ -3605,31 +3729,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& GGML_ASSERT(qy_sz == y_sz); } - vk_pipeline to_fp16_vk_0 = nullptr; - vk_pipeline to_fp16_vk_1 = nullptr; - - if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16); - } else { - to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); - } - if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16); - } else { - to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); - } - GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT - GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT - - // Allocate descriptor sets - ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1); - if (qx_needs_dequant) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1); - } - if (qy_needs_dequant) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, 1); - } - if (x_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } else if (qx_needs_dequant) { @@ -3664,11 +3763,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ); // NOLINT } -static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; - std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -3742,6 +3842,47 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t ids_sz = nbi2; const uint64_t d_sz = sizeof(float) * d_ne; + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + return; + } + vk_buffer d_D = extra->buffer_gpu.lock(); const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); @@ -3779,30 +3920,6 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte GGML_ASSERT(qy_sz == y_sz); } - vk_pipeline to_fp16_vk_0 = nullptr; - vk_pipeline to_fp16_vk_1 = nullptr; - if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type); - } - if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type); - } else { - to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); - } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); - GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT - GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT - GGML_ASSERT(dmmv != nullptr); - - // Allocate descriptor sets - if (qx_needs_dequant) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1); - } - if (qy_needs_dequant) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13); - } - ggml_pipeline_allocate_descriptor_sets(ctx->device, dmmv, ne12 * ne13); - if (x_non_contig) { GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); @@ -3841,85 +3958,15 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z }); } -static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")"); if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { - ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst); + ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); } else { - ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst); + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); } } -static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - VK_LOG_DEBUG("ggml_vk_op_repeat(" << src0 << ", " << src1 << ", " << dst << ")"); - const uint64_t ne0 = dst->ne[0]; - const uint64_t ne1 = dst->ne[1]; - const uint64_t ne2 = dst->ne[2]; - const uint64_t ne3 = dst->ne[3]; - - const uint64_t ne00 = src0->ne[0]; - const uint64_t ne01 = src0->ne[1]; - const uint64_t ne02 = src0->ne[2]; - const uint64_t ne03 = src0->ne[3]; - - const uint64_t nb0 = dst->nb[0]; - const uint64_t nb1 = dst->nb[1]; - const uint64_t nb2 = dst->nb[2]; - const uint64_t nb3 = dst->nb[3]; - - const uint64_t nb00 = src0->nb[0]; - const uint64_t nb01 = src0->nb[1]; - const uint64_t nb02 = src0->nb[2]; - const uint64_t nb03 = src0->nb[3]; - - // guaranteed to be an integer due to the check in ggml_can_repeat - const uint64_t nr0 = ne0/ne00; - const uint64_t nr1 = ne1/ne01; - const uint64_t nr2 = ne2/ne02; - const uint64_t nr3 = ne3/ne03; - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; - ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; - - const vk_buffer src_buf = extra_src0->buffer_gpu.lock(); - const uint64_t src_offset = extra_src0->offset + src0->view_offs; - vk_buffer dst_buf = extra->buffer_gpu.lock(); - const uint64_t dst_offset = extra->offset + dst->view_offs; - - std::vector copies; - - for (uint64_t i3 = 0; i3 < nr3; i3++) { - for (uint64_t k3 = 0; k3 < ne03; k3++) { - for (uint64_t i2 = 0; i2 < nr2; i2++) { - for (uint64_t k2 = 0; k2 < ne02; k2++) { - for (uint64_t i1 = 0; i1 < nr1; i1++) { - for (uint64_t k1 = 0; k1 < ne01; k1++) { - for (uint64_t i0 = 0; i0 < nr0; i0++) { - copies.push_back({ - src_offset + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01, - dst_offset + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0, - ne00*nb0, - }); - } - } - } - } - } - } - } - - ggml_vk_sync_buffers(subctx); - subctx->s->buffer.copyBuffer(src_buf->buffer, dst_buf->buffer, copies); - - GGML_UNUSED(ctx); - GGML_UNUSED(src1); -} - - static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { switch (op) { case GGML_OP_GET_ROWS: @@ -3931,6 +3978,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_get_rows_f32[src0->type]; } return nullptr; + case GGML_OP_ACC: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_acc_f32; + } + return nullptr; case GGML_OP_ADD: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_add_f32; @@ -3975,6 +4027,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_sqr_f32; } return nullptr; + case GGML_OP_SIN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sin_f32; + } + return nullptr; + case GGML_OP_COS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cos_f32; + } + return nullptr; case GGML_OP_CLAMP: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_clamp_f32; @@ -3985,6 +4047,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_pad_f32; } return nullptr; + case GGML_OP_REPEAT: + if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { + return ctx->device->pipeline_repeat_f32; + } + return nullptr; case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: @@ -4053,7 +4120,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_ROPE: { const int mode = ((const int32_t *) dst->op_params)[2]; - const bool is_neox = mode & 2; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { @@ -4107,15 +4174,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const GGML_UNUSED(src2); } -static ggml_vk_func_t ggml_vk_op_get_func(ggml_op op) { - switch(op) { - case GGML_OP_REPEAT: - return ggml_vk_op_repeat; - default: - return nullptr; - } -} - static bool ggml_vk_op_supports_incontiguous(ggml_op op) { switch (op) { case GGML_OP_CPY: @@ -4127,8 +4185,11 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: + case GGML_OP_REPEAT: return true; default: return false; @@ -4136,7 +4197,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { } template -static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) { +static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; if (src1 != nullptr) { std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -4144,7 +4205,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co if (src2 != nullptr) { std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; } - std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")"); + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(dst->extra != nullptr); @@ -4176,20 +4238,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint64_t ned = ned0 * ned1; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); - ggml_vk_func_t op_func; if (pipeline == nullptr) { - op_func = ggml_vk_op_get_func(op); - if (op_func == nullptr) { - std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type); - if (src1 != nullptr) { - std::cerr << " and " << ggml_type_name(src1->type); - } - std::cerr << " to " << ggml_type_name(dst->type) << std::endl; - GGML_ABORT("fatal error"); + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type); + if (src1 != nullptr) { + std::cerr << " and " << ggml_type_name(src1->type); } + std::cerr << " to " << ggml_type_name(dst->type) << std::endl; + GGML_ABORT("fatal error"); + } - op_func(ctx, subctx, src0, src1, dst); + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); return; } @@ -4278,188 +4338,143 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co std::array elements; // Single call if dimension 2 is contiguous - if (op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1); - - switch (op) { - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_SOFT_MAX: - case GGML_OP_SUM_ROWS: - { - const uint32_t nr = ggml_nrows(src0); - if (nr > 262144) { - elements = { 512, 512, CEIL_DIV(nr, 262144) }; - } else if (nr > 512) { - elements = { 512, CEIL_DIV(nr, 512), 1 }; - } else { - elements = { nr, 1, 1 }; - } - } break; - case GGML_OP_GROUP_NORM: - { - const uint32_t num_groups = dst->op_params[0]; - elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; - } break; - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_ROPE: - elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; - break; - case GGML_OP_GET_ROWS: - elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; - break; - case GGML_OP_ARGSORT: - elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; - break; - case GGML_OP_IM2COL: - { - const bool is_2D = dst->op_params[6] == 1; + GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); - const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + switch (op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_SOFT_MAX: + case GGML_OP_SUM_ROWS: + { + const uint32_t nr = ggml_nrows(src0); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } break; + case GGML_OP_GROUP_NORM: + { + const uint32_t num_groups = dst->op_params[0]; + elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; + } break; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_ROPE: + elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; + break; + case GGML_OP_GET_ROWS: + elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; + break; + case GGML_OP_ARGSORT: + elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; + break; + case GGML_OP_IM2COL: + { + const bool is_2D = dst->op_params[6] == 1; - const uint32_t KH = is_2D ? src0->ne[1] : 1; - const uint32_t KW = src0->ne[0]; + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; - const uint32_t OH = is_2D ? dst->ne[2] : 1; - const uint32_t OW = dst->ne[1]; + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; - const uint32_t batch = src1->ne[3]; + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; - elements = { OW * KW * KH, OH, batch * IC }; - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - const uint32_t dim = dst->op_params[0]; - uint32_t half_ceil = (dim + 1) / 2; - elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; - } break; - case GGML_OP_ADD: - case GGML_OP_DIV: - case GGML_OP_MUL: - case GGML_OP_SCALE: - case GGML_OP_SQR: - case GGML_OP_CLAMP: - case GGML_OP_PAD: - case GGML_OP_CPY: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_UNARY: - { - const uint32_t ne = ggml_nelements(dst); - if (ne > 262144) { - elements = { 512, 512, CEIL_DIV(ne, 262144) }; - } else if (ne > 512) { - elements = { 512, CEIL_DIV(ne, 512), 1 }; - } else { - elements = { ne, 1, 1 }; - } - } break; - default: - elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; - break; - } - - if (!op_supports_incontiguous) { - if (x_sz != VK_WHOLE_SIZE) { - x_sz *= ne02 * ne03; - } - if (use_src1 && y_sz != VK_WHOLE_SIZE) { - y_sz *= ne12 * ne13; - } - if (use_src2 && z_sz != VK_WHOLE_SIZE) { - z_sz *= ne22 * ne23; - } - if (d_sz != VK_WHOLE_SIZE) { - d_sz *= ned2 * ned3; - } - } + const uint32_t batch = src1->ne[3]; - if (op == GGML_OP_SOFT_MAX) { - // Empty src1 is possible in soft_max, but the shader needs a buffer - vk_subbuffer subbuf_y; - if (use_src1) { - subbuf_y = { d_Y, y_buf_offset, y_sz }; + elements = { OW * KW * KH, OH, batch * IC }; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + const uint32_t dim = dst->op_params[0]; + uint32_t half_ceil = (dim + 1) / 2; + elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; + } break; + case GGML_OP_ADD: + case GGML_OP_DIV: + case GGML_OP_MUL: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_REPEAT: + case GGML_OP_CPY: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_UNARY: + { + const uint32_t ne = ggml_nelements(dst); + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; } else { - subbuf_y = { d_X, 0, x_sz }; + elements = { ne, 1, 1 }; } + } break; + default: + elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; + break; + } - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); - } else if (op == GGML_OP_ROPE) { - // Empty src2 is possible in rope, but the shader needs a buffer - vk_subbuffer subbuf_z; - if (use_src2) { - subbuf_z = { d_Z, z_buf_offset, z_sz }; - } else { - subbuf_z = { d_X, 0, x_sz }; - } + if (!op_supports_incontiguous) { + if (x_sz != VK_WHOLE_SIZE) { + x_sz *= ne02 * ne03; + } + if (use_src1 && y_sz != VK_WHOLE_SIZE) { + y_sz *= ne12 * ne13; + } + if (use_src2 && z_sz != VK_WHOLE_SIZE) { + z_sz *= ne22 * ne23; + } + if (d_sz != VK_WHOLE_SIZE) { + d_sz *= ned2 * ned3; + } + } - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); - } else if (op == GGML_OP_IM2COL) { - // im2col uses only src1 and dst buffers - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); - } else if (use_src2) { - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); - } else if (use_src1) { - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + if (op == GGML_OP_SOFT_MAX) { + // Empty src1 is possible in soft_max, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; } else { - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + subbuf_y = { d_X, 0, x_sz }; } - } else { - GGML_ASSERT(op != GGML_OP_SOFT_MAX); - GGML_ASSERT(op != GGML_OP_ARGSORT); - GGML_ASSERT(!use_src2); - - ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, ne02 * ne03); - switch (op) { - case GGML_OP_NORM: - case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: - elements = { (uint32_t)ne01, 1, 1 }; - break; - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_ROPE: - elements = { (uint32_t)ne01, (uint32_t)ne00, 1 }; - break; - case GGML_OP_GET_ROWS: - elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; - break; - default: - elements = { (uint32_t)ne0, 1, 1 }; - break; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_ROPE) { + // Empty src2 is possible in rope, but the shader needs a buffer + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; } - for (uint64_t i03 = 0; i03 < ne03; i03++) { - for (uint64_t i02 = 0; i02 < ne02; i02++) { - const uint32_t it_idx0 = (i03 * ne02 + i02); - const uint32_t it_idx1 = use_src1 ? ((i03 % ne13) * ne12 + (i02 % ne12)) : 0; - const uint32_t x_offset = x_sz * it_idx0; - const uint32_t y_offset = y_sz * it_idx1; - const uint32_t d_offset = d_sz * it_idx0; - - if (use_src1) { - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset + x_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset + y_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements); - } else { - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset + x_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements); - } - } - } + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_IM2COL) { + // im2col uses only src1 and dst buffers + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (use_src2) { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (use_src1) { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); } } -static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {}); -} - -static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4471,10 +4486,32 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, 0, - }); + }, dryrun); } -static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + d_offset, + 0.0f, 0.0f, offset, + }, dryrun); +} + +static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4486,10 +4523,10 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, 0, - }); + }, dryrun); } -static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4501,10 +4538,10 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, 0, - }); + }, dryrun); } -static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4516,10 +4553,10 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, 0, - }); + }, dryrun); } -static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; const uint32_t src0_type_size = ggml_type_size(src0->type); @@ -4533,10 +4570,10 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, op_params[0], - }); + }, dryrun); } -static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const float sf0 = (float)dst->ne[0] / src0->ne[0]; @@ -4549,10 +4586,10 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], sf0, sf1, sf2, sf3, - }); + }, dryrun); } -static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4563,10 +4600,10 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, op_params[0], 0.0f - }); + }, dryrun); } -static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4576,10 +4613,36 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, + }, dryrun); +} + +static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + }); +} + +static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, }); } -static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4590,10 +4653,10 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, op_params[0], op_params[1], - }); + }, dryrun); } -static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4603,10 +4666,23 @@ static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, - }); + }, dryrun); } -static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + }, dryrun); +} + +static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4618,40 +4694,41 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, d_offset, 0.0f, 0.0f, - }); + }, dryrun); } -static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } -static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - int * op_params = (int *)dst->op_params; +static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const int * int_op_params = (const int *)dst->op_params; + const float * float_op_params = (const float *)dst->op_params; - uint32_t num_groups = op_params[0]; - uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); - static const float eps = 1e-6f; + const uint32_t num_groups = int_op_params[0]; + const float eps = float_op_params[1]; + const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } -static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); +static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } -static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { int32_t * op_params = (int32_t *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); } -static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; float scale = op_params[0]; @@ -4673,10 +4750,10 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, scale, max_bias, m0, m1, n_head_log2, - }); + }, dryrun); } -static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { const int n_dims = ((int32_t *) dst->op_params)[1]; // const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -4697,10 +4774,10 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, src2 != nullptr, - }); + }, dryrun); } -static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { int32_t * op_params = (int32_t *)dst->op_params; uint32_t ncols = src0->ne[0]; @@ -4716,14 +4793,14 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c ncols, ncols_pad, op_params[0], - }); + }, dryrun); } -static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }); +static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); } -static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; const int32_t p0 = dst->op_params[2]; @@ -4754,22 +4831,22 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co pelements, IC * KH * KW, s0, s1, p0, p1, d0, d1, - }); + }, dryrun); } -static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t dim = dst->op_params[0]; const uint32_t max_period = dst->op_params[1]; const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { nb1, dim, max_period, - }); + }, dryrun); } -static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const float * op_params = (const float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); } #ifdef GGML_VULKAN_RUN_TESTS @@ -4915,9 +4992,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } - ggml_pipeline_allocate_descriptor_sets(ctx->device, p, num_it); + ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); if (split_k > 1) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { // Resize buffer @@ -5164,7 +5241,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_vk_quantize_data(x, qx, ne, quant); ggml_vk_dequantize_data(qx, x_ref, ne, quant); - ggml_pipeline_allocate_descriptor_sets(ctx->device, p, 1); + ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); @@ -5285,9 +5362,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, y[i] = (i % k == i / k) ? 1.0f : 0.0f; } - ggml_pipeline_allocate_descriptor_sets(ctx->device, p, num_it); + ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); if (split_k > 1) { - ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { // Resize buffer @@ -5415,135 +5492,6 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) return extra; } -static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){ - VK_LOG_DEBUG("ggml_vk_preallocate_buffers_graph(" << node << ")"); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra; - - if (extra == nullptr) { - return; - } - - ggml_tensor * src0 = node->src[0]; - ggml_tensor * src1 = node->src[1]; - - const bool use_src0 = src0 != nullptr; - const int64_t ne00 = use_src0 ? src0->ne[0] : 0; - const int64_t ne01 = use_src0 ? src0->ne[1] : 0; - const int64_t ne02 = use_src0 ? src0->ne[2] : 0; - const int64_t ne03 = use_src0 ? src0->ne[3] : 0; - const bool use_src1 = src1 != nullptr && node->op != GGML_OP_CPY && node->op != GGML_OP_CONT && node->op != GGML_OP_DUP; - const int64_t ne10 = use_src1 ? src1->ne[0] : 0; - const int64_t ne11 = use_src1 ? src1->ne[1] : 0; - const int64_t ne12 = use_src1 ? src1->ne[2] : 0; - const int64_t ne13 = use_src1 ? src1->ne[3] : 0; - const int64_t ne20 = node->ne[0]; - const int64_t ne21 = node->ne[1]; - const int64_t ne22 = node->ne[2]; - const int64_t ne23 = node->ne[3]; - - const ggml_type src0_type = (use_src0 && src0->type == GGML_TYPE_F32) ? src0->type : GGML_TYPE_F16; - const ggml_type src1_type = (use_src1 && src1->type == GGML_TYPE_F32) ? src1->type : GGML_TYPE_F16; - - const bool x_non_contig = use_src0 && !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = use_src1 && !ggml_vk_dim01_contiguous(src1); - - const bool y_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32 && !y_non_contig; - - bool mmp = (use_src0 && use_src1 && (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID)) ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type) != nullptr : false; - - const bool qx_needs_dequant = use_src0 && (!mmp || x_non_contig); - const bool qy_needs_dequant = use_src1 && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig); - - int split_k; - if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { - split_k = ggml_vk_guess_split_k(ne01, ne11, ne10); - } else { - split_k = 1; - } - const uint32_t x_ne = ne00 * ne01; - const uint32_t y_ne = ne10 * ne11; - const uint32_t d_ne = ne20 * ne21; - - const uint64_t x_sz = (use_src0 && qx_needs_dequant) ? ggml_vk_align_size(sizeof(src0_type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0; - const uint64_t y_sz = (use_src1 && qy_needs_dequant) ? ggml_vk_align_size(sizeof(src1_type) * y_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0; - uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23; - const uint64_t split_k_size = split_k > 1 ? d_sz * 4 : 0; - - if (extra->buffer_gpu.expired()) { - // Workaround for CPU backend BLAS matmul calls - extra->buffer_gpu = ggml_vk_create_buffer_temp(ctx, d_sz); - } - - switch (node->op) { - case GGML_OP_REPEAT: - case GGML_OP_GET_ROWS: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_ADD: - case GGML_OP_SCALE: - case GGML_OP_SQR: - case GGML_OP_CLAMP: - case GGML_OP_PAD: - case GGML_OP_CPY: - case GGML_OP_CONT: - case GGML_OP_DUP: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_NORM: - case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: - case GGML_OP_ROPE: - case GGML_OP_ARGSORT: - case GGML_OP_SUM_ROWS: - case GGML_OP_IM2COL: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_LEAKY_RELU: - break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(node)) { - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_TANH: - break; - default: - return; - } - break; - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - if ( - x_sz > ctx->device->max_memory_allocation_size || - y_sz > ctx->device->max_memory_allocation_size || - d_sz > ctx->device->max_memory_allocation_size || - split_k_size > ctx->device->max_memory_allocation_size) { - GGML_ABORT("Requested preallocation size is too large"); - } - if (ctx->prealloc_size_x < x_sz) { - ctx->prealloc_size_x = x_sz; - } - if (ctx->prealloc_size_y < y_sz) { - ctx->prealloc_size_y = y_sz; - } - if (ctx->prealloc_size_split_k < split_k_size) { - ctx->prealloc_size_split_k = split_k_size; - } - if (ctx->staging_size < x_sz + y_sz) { - ctx->staging_size = x_sz + y_sz; - } - break; - default: - return; - } -} - static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { #if defined(GGML_VULKAN_RUN_TESTS) ctx->staging = ggml_vk_create_buffer_check(ctx->device, 100ul * 1024ul * 1024ul, @@ -5708,19 +5656,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); } - if (ctx->staging == nullptr || (ctx->staging_size > 0 && ctx->staging->size < ctx->staging_size)) { - VK_LOG_MEMORY("ggml_vk_preallocate_buffers(staging_size: " << ctx->staging_size << ")"); - // Resize buffer - if (ctx->staging != nullptr) { - ggml_vk_destroy_buffer(ctx->staging); - } - ctx->staging = ggml_vk_create_buffer_check(ctx->device, ctx->staging_size, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); - } } -static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, bool last_node){ +static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, bool last_node, bool dryrun){ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra; if (ggml_is_empty(node) || extra == nullptr) { @@ -5729,7 +5667,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); ctx->semaphore_idx = 0; - ctx->staging_offset = 0; const ggml_tensor * src0 = node->src[0]; const ggml_tensor * src1 = node->src[1]; @@ -5758,12 +5695,15 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_REPEAT: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_CPY: @@ -5791,75 +5731,89 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod vk_context compute_ctx; - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); + if (!dryrun) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } } switch (node->op) { case GGML_OP_REPEAT: - ggml_vk_repeat(ctx, compute_ctx, src0, node); + ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ACC: + ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); break; case GGML_OP_GET_ROWS: - ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node); + ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); break; case GGML_OP_ADD: - ggml_vk_add(ctx, compute_ctx, src0, src1, node); + ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); break; case GGML_OP_MUL: - ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); break; case GGML_OP_DIV: - ggml_vk_div(ctx, compute_ctx, src0, src1, node); + ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); break; case GGML_OP_CONCAT: - ggml_vk_concat(ctx, compute_ctx, src0, src1, node); + ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); break; case GGML_OP_UPSCALE: - ggml_vk_upscale(ctx, compute_ctx, src0, node); + ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_SCALE: - ggml_vk_scale(ctx, compute_ctx, src0, node); + ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_SQR: - ggml_vk_sqr(ctx, compute_ctx, src0, node); + ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SIN: + ggml_vk_sin(ctx, compute_ctx, src0, node); + + break; + case GGML_OP_COS: + ggml_vk_cos(ctx, compute_ctx, src0, node); break; case GGML_OP_CLAMP: - ggml_vk_clamp(ctx, compute_ctx, src0, node); + ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_PAD: - ggml_vk_pad(ctx, compute_ctx, src0, node); + ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: - ggml_vk_cpy(ctx, compute_ctx, src0, node); + ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_NORM: - ggml_vk_norm(ctx, compute_ctx, src0, node); + ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_GROUP_NORM: - ggml_vk_group_norm(ctx, compute_ctx, src0, node); + ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_RMS_NORM: - ggml_vk_rms_norm(ctx, compute_ctx, src0, node); + ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_UNARY: @@ -5869,59 +5823,63 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: - ggml_vk_unary(ctx, compute_ctx, src0, node); + ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); break; default: return; } break; case GGML_OP_DIAG_MASK_INF: - ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node); + ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_SOFT_MAX: - ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node); + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); break; case GGML_OP_ROPE: - ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node); + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun); break; case GGML_OP_ARGSORT: - ggml_vk_argsort(ctx, compute_ctx, src0, node); + ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_SUM_ROWS: - ggml_vk_sum_rows(ctx, compute_ctx, src0, node); + ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_IM2COL: - ggml_vk_im2col(ctx, compute_ctx, src0, src1, node); + ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); break; case GGML_OP_TIMESTEP_EMBEDDING: - ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node); + ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_LEAKY_RELU: - ggml_vk_leaky_relu(ctx, compute_ctx, src0, node); + ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_MUL_MAT: - ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node); + ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); break; case GGML_OP_MUL_MAT_ID: - ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node); + ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); break; default: return; } + if (dryrun) { + return; + } + ctx->tensor_ctxs[node_idx] = compute_ctx; -#ifdef GGML_VULKAN_CHECK_RESULTS +#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF) // Force context reset on each node so that each tensor ends up in its own context // and can be run and compared to its CPU equivalent separately last_node = true; @@ -5939,6 +5897,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * switch (tensor->op) { case GGML_OP_ADD: + case GGML_OP_ACC: case GGML_OP_GET_ROWS: case GGML_OP_MUL: case GGML_OP_DIV: @@ -5946,6 +5905,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_CPY: @@ -6005,6 +5966,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); +#ifdef GGML_VULKAN_PERF + std::chrono::steady_clock::time_point start; +#endif // GGML_VULKAN_PERF + // Only run if ctx hasn't been submitted yet if (!subctx->seqs.empty()) { // Do staging buffer copies @@ -6012,11 +5977,21 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * memcpy(cpy.dst, cpy.src, cpy.n); } +#ifdef GGML_VULKAN_PERF + start = std::chrono::steady_clock::now(); +#endif // GGML_VULKAN_PERF + ggml_vk_submit(subctx, ctx->fence); } if (tensor_idx == subctx->exit_tensor_idx) { VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); + +#ifdef GGML_VULKAN_PERF + auto duration = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + ctx->device->perf_logger->log_timing(tensor, duration.count()); +#endif // GGML_VULKAN_PERF + ctx->device->device.resetFences({ ctx->fence }); // Do staging buffer copies @@ -6038,12 +6013,14 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { } ctx->gc.temp_buffers.clear(); - for (auto& pipeline : ctx->device->pipelines) { - if (pipeline.expired()) { + for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) { + vk_pipeline_ref plr = ctx->device->pipelines[dsr.first]; + + if (plr.expired()) { continue; } - vk_pipeline pl = pipeline.lock(); + vk_pipeline pl = plr.lock(); ggml_pipeline_cleanup(pl); } @@ -6067,10 +6044,9 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->device->device.resetEvent(event); } - ctx->staging_offset = 0; - ctx->tensor_ctxs.clear(); ctx->gc.contexts.clear(); + ctx->device->pipeline_descriptor_set_requirements.clear(); } // Clean up on backend free @@ -6081,7 +6057,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->prealloc_x); ggml_vk_destroy_buffer(ctx->prealloc_y); ggml_vk_destroy_buffer(ctx->prealloc_split_k); - ggml_vk_destroy_buffer(ctx->staging); for (auto& buffer : ctx->buffer_pool) { ggml_vk_destroy_buffer(buffer); @@ -6090,7 +6065,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; ctx->prealloc_size_split_k = 0; - ctx->staging_size = 0; for (auto& event : ctx->gc.events) { ctx->device->device.destroyEvent(event); @@ -6419,7 +6393,7 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset); + ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -6442,7 +6416,7 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset); + ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { @@ -6508,9 +6482,10 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_preallocate_buffers_graph(ctx, cgraph->nodes[i]); + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, 0, true); } ggml_vk_preallocate_buffers(ctx); + ggml_pipeline_allocate_descriptor_sets(ctx->device); int last_node = cgraph->n_nodes - 1; @@ -6523,7 +6498,7 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen ctx->tensor_ctxs.resize(cgraph->n_nodes); for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, i == last_node); + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, i == last_node, false); } for (int i = 0; i < cgraph->n_nodes; i++) { @@ -6549,6 +6524,10 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen GGML_ASSERT(ok); } +#ifdef GGML_VULKAN_PERF + ctx->device->perf_logger->print_timings(); +#endif + ggml_vk_graph_cleanup(ctx); return GGML_STATUS_SUCCESS; @@ -6640,10 +6619,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const return false; } break; case GGML_OP_REPEAT: - { - ggml_type src0_type = op->src[0]->type; - return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; - } break; + return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); case GGML_OP_ROPE: return ggml_is_contiguous(op->src[0]); case GGML_OP_NONE: @@ -6655,12 +6631,15 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_ADD: + case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_CONT: @@ -7103,16 +7082,24 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_SIN) { + tensor_clone = ggml_sin(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_COS) { + tensor_clone = ggml_cos(ggml_ctx, src0_clone); } else if (tensor->op == GGML_OP_CLAMP) { tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_PAD) { tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); + } else if (tensor->op == GGML_OP_REPEAT) { + tensor_clone = ggml_repeat(ggml_ctx, src0_clone, src1_clone); } else if (tensor->op == GGML_OP_ADD) { tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_ACC) { + tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { - tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params); + tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_SOFT_MAX) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 1146726759358..3fedfef5c099f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2329,7 +2329,9 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } -inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); } +inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); } inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } @@ -2688,6 +2690,19 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, return sum; } +static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) { + // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i) + + int i = 0; + ggml_float sum = 0; + for (; i < n; ++i) { + float val = x[i] - max; + y[i] = val; + sum += (ggml_float)expf(val); + } + return sum = (ggml_float)logf(sum); +} + inline static float ggml_silu_backward_f32(float x, float dy) { const float s = 1.0f/(1.0f + expf(-x)); return dy*s*(1.0f + x*(1.0f - s)); @@ -2779,6 +2794,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SQR", "SQRT", "LOG", + "SIN", + "COS", "SUM", "SUM_ROWS", "MEAN", @@ -2816,9 +2833,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CLAMP", "CONV_TRANSPOSE_1D", "IM2COL", + "IM2COL_BACK", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", + "POOL_2D_BACK", "UPSCALE", "PAD", "ARANGE", @@ -2852,7 +2871,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2867,6 +2886,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "x^2", "√x", "log(x)", + "sin(x)", + "cos(x)", "Σx", "Σx_k", "Σx/n", @@ -2904,9 +2925,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "clamp(x)", "conv_transpose_1d(x)", "im2col(x)", + "im2col_back(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", + "pool_2d_back(x)", "upscale(x)", "pad(x)", "arange(start, stop, step)", @@ -2940,7 +2963,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -3786,6 +3809,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( } struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); + GGML_ASSERT(obj_new); // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here @@ -4505,8 +4529,6 @@ static struct ggml_tensor * ggml_add_impl( bool is_node = false; if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); is_node = true; } @@ -4680,11 +4702,13 @@ static struct ggml_tensor * ggml_sub_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_can_repeat(b, a)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); is_node = true; } @@ -4899,6 +4923,72 @@ struct ggml_tensor * ggml_log_inplace( return ggml_log_impl(ctx, a, true); } +// ggml_sin + +static struct ggml_tensor * ggml_sin_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SIN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_sin( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sin_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sin_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sin_impl(ctx, a, true); +} + +// ggml_cos + +static struct ggml_tensor * ggml_cos_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_COS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_cos( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cos_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_cos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cos_impl(ctx, a, true); +} + // ggml_sum struct ggml_tensor * ggml_sum( @@ -6746,17 +6836,20 @@ struct ggml_tensor * ggml_im2col( GGML_ASSERT(a->ne[2] == b->ne[2]); } else { GGML_ASSERT(a->ne[1] == b->ne[1]); + GGML_ASSERT(b->ne[3] == 1); } bool is_node = false; - if (a->grad || b->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward + if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data is_node = true; } const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); + GGML_ASSERT((OW > 0) && "b too small compared to a"); + const int64_t ne[4] = { is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], OW, @@ -6776,6 +6869,37 @@ struct ggml_tensor * ggml_im2col( return result; } +struct ggml_tensor * ggml_im2col_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t * ne, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D) { + + bool is_node = false; + + if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OC, OH, OW] @@ -6789,7 +6913,7 @@ struct ggml_tensor * ggml_conv_2d( int p1, int d0, int d1) { - struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW] + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW] struct ggml_tensor * result = ggml_mul_mat(ctx, @@ -6915,17 +7039,17 @@ struct ggml_tensor * ggml_pool_2d( bool is_node = false; if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } struct ggml_tensor * result; - const int64_t ne[3] = { + const int64_t ne[4] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), a->ne[2], + a->ne[3], }; - result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; ggml_set_op_params(result, params, sizeof(params)); @@ -6936,6 +7060,37 @@ struct ggml_tensor * ggml_pool_2d( return result; } +struct ggml_tensor * ggml_pool_2d_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * af, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result; + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne); + + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_POOL_2D_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = af; + return result; +} + // ggml_upscale static struct ggml_tensor * ggml_upscale_impl( @@ -7114,7 +7269,8 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias) { + float max_bias, + float logit_softcap) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -7141,7 +7297,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale, max_bias }; + float params[] = { scale, max_bias, logit_softcap }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -7161,7 +7317,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second + ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_attn_back @@ -7248,43 +7404,34 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq) { - GGML_ASSERT(ggml_is_3d(s)); - GGML_ASSERT(ggml_is_matrix(x)); + struct ggml_tensor * sx, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_is_3d(sx)); GGML_ASSERT(ggml_is_matrix(c)); - GGML_ASSERT(ggml_is_matrix(sq)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); - const int64_t d_conv = c->ne[0]; - const int64_t d_inner = c->ne[1]; - const int64_t n_tokens = x->ne[1]; - const int64_t n_kv = s->ne[2]; + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence + const int64_t n_s = sx->ne[2]; - GGML_ASSERT( s->ne[0] == d_conv - 1); - GGML_ASSERT( s->ne[1] == d_inner); - GGML_ASSERT( x->ne[0] == d_inner); - GGML_ASSERT(sq->ne[0] == n_kv); - GGML_ASSERT(sq->ne[1] == n_tokens); + // TODO: maybe support other strides than 1? + GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); + GGML_ASSERT(sx->ne[1] == d_inner); + GGML_ASSERT(n_t >= 0); bool is_node = false; - if (s->grad || x->grad || c->grad || sq->grad) { + if (sx->grad || c->grad) { GGML_ABORT("fatal error"); // TODO: implement is_node = true; } - // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s); result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = s; - result->src[1] = x; - result->src[2] = c; - result->src[3] = sq; + result->src[0] = sx; + result->src[1] = c; return result; } @@ -7298,39 +7445,42 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq) { + struct ggml_tensor * C) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(A)); + GGML_ASSERT(ggml_is_3d(B)); + GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(ggml_are_same_shape(B, C)); { - const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_tokens = x->ne[1]; + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_seq_tokens = x->ne[1]; + const int64_t n_seqs = x->ne[2]; + GGML_ASSERT(s->ne[2] == n_seqs); GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_tokens); - GGML_ASSERT(C->ne[0] == d_state); - GGML_ASSERT(C->ne[1] == n_tokens); + GGML_ASSERT(B->ne[1] == n_seq_tokens); + GGML_ASSERT(B->ne[2] == n_seqs); } bool is_node = false; - if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { + if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) { GGML_ABORT("fatal error"); // TODO: implement is_node = true; } - // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} + // concatenated y + ssm_states struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; @@ -7341,7 +7491,6 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = sq; return result; } @@ -10123,11 +10272,10 @@ static void ggml_compute_forward_sub_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->ith != 0) { - return; - } + assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + const int ith = params->ith; + const int nth = params->nth; const int nr = ggml_nrows(src0); @@ -10136,40 +10284,55 @@ static void ggml_compute_forward_sub_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + if (nb10 == sizeof(float)) { - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + for (int64_t r = 0; r < nr0; ++r) { #ifdef GGML_USE_ACCELERATE - vDSP_vsub( - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); + vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); #else - ggml_vec_sub_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); #endif - // } - // } + } } } else { // src1 is not contiguous - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; } @@ -10515,6 +10678,96 @@ static void ggml_compute_forward_log( } } +// ggml_compute_forward_sin + +static void ggml_compute_forward_sin_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sin_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sin( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sin_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_cos + +static void ggml_compute_forward_cos_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_cos_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_cos( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cos_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_sum static void ggml_compute_forward_sum_f32( @@ -11014,11 +11267,6 @@ static void ggml_compute_forward_concat_f32( GGML_TENSOR_BINARY_OP_LOCALS - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - const int32_t dim = ggml_get_op_params_i32(dst, 0); GGML_ASSERT(dim >= 0 && dim < 4); @@ -14113,7 +14361,7 @@ static void ggml_compute_forward_rope_f32( float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const bool is_neox = mode & 2; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const float * freq_factors = NULL; if (src2 != NULL) { @@ -14238,7 +14486,7 @@ static void ggml_compute_forward_rope_f16( float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const bool is_neox = mode & 2; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const float * freq_factors = NULL; if (src2 != NULL) { @@ -14555,6 +14803,7 @@ static void ggml_compute_forward_conv_transpose_1d( } } +// ggml_compute_forward_im2col_f32 // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] // dst: result [N, OH, OW, IC*KH*KW] @@ -14565,7 +14814,6 @@ static void ggml_compute_forward_im2col_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -14596,7 +14844,6 @@ static void ggml_compute_forward_im2col_f32( int ofs0 = is_2D ? nb13 : nb12; int ofs1 = is_2D ? nb12 : nb11; - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] @@ -14632,6 +14879,7 @@ static void ggml_compute_forward_im2col_f32( } +// ggml_compute_forward_im2col_f16 // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] // dst: result [N, OH, OW, IC*KH*KW] @@ -14727,6 +14975,99 @@ static void ggml_compute_forward_im2col( } } +// ggml_compute_forward_im2col_back_f32 + +static void ggml_compute_forward_im2col_back_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne3 : ne2; + const int64_t IC = is_2D ? ne2 : ne1; + const int64_t IH = is_2D ? ne1 : 1; + const int64_t IW = ne0; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne12 : 1; + const int64_t OW = ne11; + + int ofs0 = is_2D ? nb3 : nb2; + int ofs1 = is_2D ? nb2 : nb1; + + GGML_ASSERT(nb0 == sizeof(float)); + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + for (int64_t iih = 0; iih < IH; iih++) { + for (int64_t iiw = 0; iiw < IW; iiw++) { + + // micro kernel + float grad = 0.0f; + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + // For s0 > 1 some values were skipped over in the forward pass. + // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well. + const int64_t tmpw = (iiw + p0 - ikw*d0); + if (tmpw % s0 != 0) { + continue; + } + const int64_t iow = tmpw / s0; + + // Equivalent logic as above except for s1. + int64_t ioh; + if (is_2D) { + const int64_t tmph = iih + p1 - ikh*d1; + + if (tmph % s1 != 0) { + continue; + } + + ioh = tmph / s1; + } else { + ioh = 0; + } + + if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) { + continue; + } + + const float * const src_data = (const float *) src1->data + + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; + } + } + float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] + dst_data[iih*IW + iiw] = grad; + } + } + } + } + } +} // ggml_compute_forward_conv_transpose_2d @@ -14969,6 +15310,128 @@ static void ggml_compute_forward_pool_2d( } } +// ggml_compute_forward_pool_2d_back + +static void ggml_compute_forward_pool_2d_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src = dst->src[0]; + const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst + + assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + + if (params->ith != 0) { + return; + } + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + + char * cdata = (char *) dst->data; + const char * cdataf = (const char *) dstf->data; + const char * const data_end = cdata + ggml_nbytes(dst); + + GGML_ASSERT(params->ith == 0); + memset(cdata, 0, ggml_nbytes(dst)); + + const int64_t px = src->ne[0]; + const int64_t py = src->ne[1]; + const int64_t pa = px * py; + + const float * splane = (const float *) src->data; + + const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; + + while (cdata < data_end) { + for (int oy = 0; oy < py; ++oy) { + const float * const srow = splane + oy * px; + for (int ox = 0; ox < px; ++ox) { + const float grad0 = srow[ox]; + + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; + + if (op == GGML_OP_POOL_MAX) { + float maxval = -FLT_MAX; + int kxmax = -1; + int kymax = -1; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + const float val = dst->type == GGML_TYPE_F32 ? + ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]); + if (val <= maxval) { + continue; + } + + maxval = val; + kxmax = kx; + kymax = ky; + } + } + + if (kxmax == -1 || kymax == -1) { + continue; + } + + void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax)); + const int j = ix + kxmax; + if (dst->type == GGML_TYPE_F32) { + ((float *) drow)[j] += grad0; + } else { + ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j])); + } + } else if (op == GGML_OP_POOL_AVG) { + const float grad = grad0 / ka; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + void * drow = (void *)(cdata + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + if (dst->type == GGML_TYPE_F32) { + ((float *) drow)[j] += grad; + } else { + ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad); + } + } + } + } else { + GGML_ASSERT(false); + } + } + } + + cdata += dst->nb[2]; + cdataf += dst->nb[2]; + splane += pa; + } +} + // ggml_compute_forward_upscale static void ggml_compute_forward_upscale_f32( @@ -15302,11 +15765,17 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float scale = 1.0f; - float max_bias = 0.0f; + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); @@ -15370,7 +15839,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - s = s*scale + mv; // scale KQ value and apply mask + s = s*scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap*tanhf(s); + } + + s += mv; // apply mask const float Mold = M; @@ -15379,7 +15854,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - if (v->type== GGML_TYPE_F16) { + if (v->type == GGML_TYPE_F16) { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; @@ -15446,7 +15921,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[2]) { + switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { @@ -15801,27 +16276,22 @@ static void ggml_compute_forward_flash_attn_back( static void ggml_compute_forward_ssm_conv_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // conv_state - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight - const struct ggml_tensor * src3 = dst->src[3]; // state_seq + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight const int ith = params->ith; const int nth = params->nth; - const int nc = src2->ne[0]; // d_conv - const int nr = src0->ne[1]; // d_inner - const int n_t = src1->ne[1]; // n_tokens - const int n_kv = src0->ne[2]; // max number of sequences in the batch + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t + const int nr = src0->ne[1]; // d_inner + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch - GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); + GGML_ASSERT( dst->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src2->nb[0] == sizeof(float)); - GGML_ASSERT(src3->nb[0] == sizeof(int32_t)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // for use with the destination state offset between sequences - GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -15831,74 +16301,27 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { - // multiple sequences means it's hard to know when it's the first time a state is read, - // so copy them all over to the destination, just to be sure. - for (int i3 = 0; i3 < n_kv; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); - // can't use memcpy because of d_conv vs d_conv - 1 - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - // copy s0 to last (d_conv - 1) columns of s - s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; - } - } - } - } - - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} - float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} - float * s0; // {d_conv - 1, d_inner, n_kv} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - int ne0s0; - - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); - - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} - ne0s0 = src0->ne[0]; - } else { - // the source is the last (d_conv - 1) columns of the destination - s0 = s + 1; - ne0s0 = nc; - } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // shift state left - for (int i0 = 0; i0 < nc - 1; ++i0) { - s[i0 + i1*nc] = s0[i0 + i1*ne0s0]; - } - // insert x on the last column - s[(nc - 1) + i1*nc] = x0[i1]; - } + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + // {d_conv - 1 + n_t, d_inner, n_seqs} + // sliding window + const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} + float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } + // TODO: transpose the output for smaller strides for big batches? + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision + float sumf = 0.0f; - // it seems a little faster when this is separate from the state shift - for (int i1 = 0; i1 < ir; ++i1) { - // rowwise dot product - float sumf = 0.0f; - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - sumf += s[i] * c[i]; + // d_conv + for (int i0 = 0; i0 < nc; ++i0) { + sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; + } + x[i1] = sumf; } - x[i1] = sumf; } } } @@ -15929,15 +16352,14 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src5 = dst->src[5]; // C - const struct ggml_tensor * src6 = dst->src[6]; // sq const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens in the batch - const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens per sequence + const int64_t n_s = src0->ne[2]; // number of sequences in the batch GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15946,12 +16368,12 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C, and when copying the states + // required for the dot product between s and C GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // required for per-sequence offsets for states GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[2]) - GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -15961,64 +16383,36 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { - // it's hard to know if the source states have already been copied - // when there are multiple, so copy them already. - for (int i3 = 0; i3 < n_kv; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); - memcpy(s, s0, nc*ir*sizeof(float)); - } - } - - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * s0; - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); - - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} - } else { - // otherwise the source is the same as the destination - s0 = s; - } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; - } - y[i1] = sumf; - } - - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; } } } @@ -16580,8 +16974,6 @@ static void ggml_compute_forward_cross_entropy_loss_f32( } ggml_barrier(params->shared); - const double eps = 1e-9; - // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16602,20 +16994,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32( } #endif - // soft_max float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); - ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max); - assert(sum > 0.0); - sum = (1.0 - eps) / sum; + ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max); + assert(sum >= 0.0); - // avoid log(0) by rescaling from [0..1] to [eps..1] - ggml_vec_scale_f32(nc, st, sum); - ggml_vec_add1_f32(nc, st, st, eps); - ggml_vec_log_f32(nc, st, st); + ggml_vec_add1_f32(nc, st, st, -sum); ggml_vec_mul_f32(nc, st, st, s1); - float st_sum = 0; + float st_sum = 0.0f; ggml_vec_sum_f32(nc, &st_sum, st); sums[ith] += st_sum; @@ -16672,8 +17059,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ith = params->ith; const int64_t nth = params->nth; - const double eps = 1e-9; - // TODO: handle transposed/permuted matrices const int64_t nc = src0->ne[0]; const int64_t nr = ggml_nrows(src0); @@ -16705,11 +17090,9 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( ggml_vec_max_f32(nc, &max, s0); ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); assert(sum > 0.0); - sum = (1.0 - eps) / sum; + ggml_vec_scale_f32(nc, ds0, 1.0/sum); // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr - ggml_vec_scale_f32(nc, ds0, sum); - ggml_vec_add1_f32(nc, ds0, ds0, eps); ggml_vec_sub_f32(nc, ds0, ds0, s1); ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); @@ -16790,6 +17173,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_log(params, tensor); } break; + case GGML_OP_SIN: + { + ggml_compute_forward_sin(params, tensor); + } break; + case GGML_OP_COS: + { + ggml_compute_forward_cos(params, tensor); + } break; case GGML_OP_SUM: { ggml_compute_forward_sum(params, tensor); @@ -16930,6 +17321,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col(params, tensor); } break; + case GGML_OP_IM2COL_BACK: + { + ggml_compute_forward_im2col_back_f32(params, tensor); + } break; case GGML_OP_CONV_TRANSPOSE_2D: { ggml_compute_forward_conv_transpose_2d(params, tensor); @@ -16942,6 +17337,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_pool_2d(params, tensor); } break; + case GGML_OP_POOL_2D_BACK: + { + ggml_compute_forward_pool_2d_back(params, tensor); + } break; case GGML_OP_UPSCALE: { ggml_compute_forward_upscale(params, tensor); @@ -17310,7 +17709,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); + if (ggml_are_same_shape(src0, src1)) { + src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); + } else { + src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table); + } } } break; case GGML_OP_ADD1: @@ -17436,6 +17839,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_SIN: + { + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_mul(ctx, + tensor->grad, + ggml_cos(ctx, src0)), + zero_table); + } + } break; + case GGML_OP_COS: + { + if (src0->grad) { + src0->grad = + ggml_sub_or_set(ctx, + src0->grad, + ggml_mul(ctx, + tensor->grad, + ggml_sin(ctx, src0)), + zero_table); + } + } break; case GGML_OP_SUM: { if (src0->grad) { @@ -17883,6 +18310,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ABORT("fatal error"); // TODO: not implemented } case GGML_OP_IM2COL: + { + if (src1->grad) { + const int32_t s0 = ggml_get_op_params_i32(tensor, 0); + const int32_t s1 = ggml_get_op_params_i32(tensor, 1); + const int32_t p0 = ggml_get_op_params_i32(tensor, 2); + const int32_t p1 = ggml_get_op_params_i32(tensor, 3); + const int32_t d0 = ggml_get_op_params_i32(tensor, 4); + const int32_t d1 = ggml_get_op_params_i32(tensor, 5); + const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; + + src1->grad = ggml_add_or_set(ctx, + src1->grad, + ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), + zero_table); + } + } break; + case GGML_OP_IM2COL_BACK: { GGML_ABORT("fatal error"); // TODO: not implemented } @@ -17895,6 +18339,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ABORT("fatal error"); // TODO: not implemented } case GGML_OP_POOL_2D: + { + if (src0->grad) { + const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0); + const int32_t k0 = ggml_get_op_params_i32(tensor, 1); + const int32_t k1 = ggml_get_op_params_i32(tensor, 2); + const int32_t s0 = ggml_get_op_params_i32(tensor, 3); + const int32_t s1 = ggml_get_op_params_i32(tensor, 4); + const int32_t p0 = ggml_get_op_params_i32(tensor, 5); + const int32_t p1 = ggml_get_op_params_i32(tensor, 6); + + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), + zero_table); + } + } break; + case GGML_OP_POOL_2D_BACK: { GGML_ABORT("fatal error"); // TODO: not implemented } @@ -18184,6 +18645,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { GGML_ASSERT(gf->n_nodes > 0); + GGML_ASSERT(gf->grads); // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph if (keep) { @@ -18523,6 +18985,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_LOG: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -18609,6 +19073,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); } break; case GGML_OP_IM2COL: + case GGML_OP_IM2COL_BACK: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { @@ -18616,6 +19081,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_POOL_1D: case GGML_OP_POOL_2D: + case GGML_OP_POOL_2D_BACK: { n_tasks = 1; } break; @@ -19129,9 +19595,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { const uint32_t type = tensor->type; const uint32_t op = tensor->op; + const int32_t flags = tensor->flags; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&flags, sizeof(int32_t), 1, fout); for (int j = 0; j < GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -19161,9 +19629,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { const uint32_t type = tensor->type; const uint32_t op = tensor->op; + const int32_t flags = tensor->flags; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&flags, sizeof(int32_t), 1, fout); for (int j = 0; j < GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -19222,6 +19692,14 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { } } } + + // dump the data + // TODO: pad this to 32 byte boundary + if ((flags & GGML_TENSOR_FLAG_PARAM)) { + const size_t size = ggml_nbytes(tensor); + + fwrite(tensor->data, sizeof(char), size, fout); + } } } @@ -19335,10 +19813,12 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * { uint32_t type; uint32_t op; + int32_t flags; for (uint32_t i = 0; i < n_leafs; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); + flags = *(const int32_t *) ptr; ptr += sizeof(flags); int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; @@ -19356,20 +19836,19 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, GGML_MAX_DIMS, ne); - tensor->op = (enum ggml_op) op; + tensor->op = (enum ggml_op) op; + tensor->flags = flags; memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS; - tensor->data = (void *) ptr; - for (int j = 0; j < GGML_MAX_DIMS; ++j) { tensor->nb[j] = nb[j]; } - result->leafs[i] = tensor; + tensor->data = (void *) ptr; ptr += ggml_nbytes(tensor); - ptr += ggml_nbytes(tensor); + result->leafs[i] = tensor; fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); } @@ -19381,10 +19860,12 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * { uint32_t type; uint32_t op; + int32_t flags; for (uint32_t i = 0; i < n_nodes; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); + flags = *(const int32_t *) ptr; ptr += sizeof(flags); enum ggml_op eop = (enum ggml_op) op; @@ -19474,6 +19955,11 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * result->nodes[i] = tensor; + // TODO tensor data is be duplicated due to ggml_new_tensor call above + if (flags & GGML_TENSOR_FLAG_PARAM) { + tensor->data = (void *) ptr; ptr += ggml_nbytes(tensor); + } + fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); } } @@ -19742,6 +20228,7 @@ static enum ggml_opt_result ggml_opt_adam( ggml_opt_callback callback, void * callback_data) { GGML_ASSERT(ggml_is_scalar(f)); + GGML_ASSERT(f->type == GGML_TYPE_F32); // these will store the parameters we want to optimize struct ggml_tensor * ps[GGML_MAX_PARAMS]; @@ -20508,6 +20995,8 @@ enum ggml_opt_result ggml_opt( struct ggml_context * ctx, struct ggml_opt_params params, struct ggml_tensor * f) { + GGML_ASSERT(f->grad && "ggml_set_param called for at least one parent tensor."); + bool free_ctx = false; if (ctx == NULL) { struct ggml_init_params params_ctx = { @@ -20562,6 +21051,8 @@ enum ggml_opt_result ggml_opt_resume_g( ggml_opt_callback callback, void * callback_data) { + GGML_ASSERT(f->grad && "ggml_set_param must be called for at least one ancestor"); + // build forward + backward compute graphs enum ggml_opt_result result = GGML_OPT_RESULT_OK; @@ -21148,7 +21639,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p (int64_t) info->ne[2] * (int64_t) info->ne[3]; - if (ne % ggml_blck_size(info->type) != 0) { + if (ggml_blck_size(info->type) == 0 || ne % ggml_blck_size(info->type) != 0) { fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); fclose(file); @@ -21649,6 +22140,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { void gguf_add_tensor( struct gguf_context * ctx, const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor); if (gguf_find_tensor(ctx, tensor->name) != -1) { GGML_ABORT("duplicated tensor name"); } diff --git a/ggml/src/kompute-shaders/op_rope_f16.comp b/ggml/src/kompute-shaders/op_rope_f16.comp index 1a4058b3f1f10..0ecfb2eab527c 100644 --- a/ggml/src/kompute-shaders/op_rope_f16.comp +++ b/ggml/src/kompute-shaders/op_rope_f16.comp @@ -11,7 +11,7 @@ void main() { const uint i2 = gl_WorkGroupID.y; const uint i1 = gl_WorkGroupID.x; - const bool is_neox = (pcs.mode & 2) != 0; + const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0; float corr_dims[2]; rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); diff --git a/ggml/src/kompute-shaders/op_rope_f32.comp b/ggml/src/kompute-shaders/op_rope_f32.comp index 65e03827a2660..cec0fd9a5d10c 100644 --- a/ggml/src/kompute-shaders/op_rope_f32.comp +++ b/ggml/src/kompute-shaders/op_rope_f32.comp @@ -11,7 +11,7 @@ void main() { const uint i2 = gl_WorkGroupID.y; const uint i1 = gl_WorkGroupID.x; - const bool is_neox = (pcs.mode & 2) != 0; + const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0; float corr_dims[2]; rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); diff --git a/ggml/src/kompute-shaders/rope_common.comp b/ggml/src/kompute-shaders/rope_common.comp index 7b9394cb2fffc..df4702896d46f 100644 --- a/ggml/src/kompute-shaders/rope_common.comp +++ b/ggml/src/kompute-shaders/rope_common.comp @@ -1,5 +1,7 @@ #include "common.comp" +#define GGML_ROPE_TYPE_NEOX 2 + // TODO: use a local size of 32 or more (Metal uses 1024) layout(local_size_x = 1) in; diff --git a/ggml/src/vulkan-shaders/acc.comp b/ggml/src/vulkan-shaders/acc.comp new file mode 100644 index 0000000000000..4c8739efee227 --- /dev/null +++ b/ggml/src/vulkan-shaders/acc.comp @@ -0,0 +1,24 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + const uint offset = p.param3; + const uint src1_i = idx - offset; + const uint oz = src1_i / p.nb02; + const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; + const uint ox = src1_i % p.nb01; + + if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + } else { + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)])); + } +} + diff --git a/ggml/src/vulkan-shaders/concat.comp b/ggml/src/vulkan-shaders/concat.comp index 08ab5514bfb49..c23b6eb1b0cd5 100644 --- a/ggml/src/vulkan-shaders/concat.comp +++ b/ggml/src/vulkan-shaders/concat.comp @@ -30,6 +30,10 @@ void main() { #ifndef OPTIMIZATION_ERROR_WORKAROUND data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]); #else - data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx]; + if (is_src0) { + data_d[p.d_offset + dst_idx] = data_a[src0_idx]; + } else { + data_d[p.d_offset + dst_idx] = data_b[src1_idx]; + } #endif } diff --git a/ggml/src/vulkan-shaders/cos.comp b/ggml/src/vulkan-shaders/cos.comp new file mode 100644 index 0000000000000..f9a858cbf16ce --- /dev/null +++ b/ggml/src/vulkan-shaders/cos.comp @@ -0,0 +1,15 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val)); +} diff --git a/ggml/src/vulkan-shaders/mul_mat_vec.comp b/ggml/src/vulkan-shaders/mul_mat_vec.comp index 46a6369bcfd20..d3ccba7fcb3fd 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec.comp @@ -39,8 +39,7 @@ void main() { vec2 v = dequantize(ib, iqs, a_offset / QUANT_K); // matrix multiplication - tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) + - FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]); + tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid])); } // sum up partial sums and write back result diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp index cb3f3c0df0801..1cc4996d393a2 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp @@ -53,7 +53,7 @@ void main() { const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); - tmp[tid] += xi * FLOAT_TYPE(data_b[iy]); + tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); } // sum up partial sums and write back result diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp b/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp index 4b1871caaf4c2..9b443807d8781 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp @@ -52,7 +52,7 @@ void main() { // y is not transposed but permuted const uint iy = channel*nrows_y + row_y; - tmp[tid] += xi * FLOAT_TYPE(data_b[iy]); + tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); } // dst is not transposed and not permuted diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp index 4cd97799df196..ec8eadcd5828a 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -39,24 +39,25 @@ void main() { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3); - sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF); + sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3), sum1)))))))); + sum2 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF), sum2)))))))); } - tmp[16 * ix + tid] += dall * sum1 - dmin * sum2; + const uint tmp_idx = 16 * ix + tid; + tmp[tmp_idx] = fma(dall, sum1, fma(-dmin, sum2, tmp[tmp_idx])); } // sum up partial sums and write back result diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp index a6e430ea08c5d..3ca4ad85a5ca0 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -40,16 +40,17 @@ void main() { FLOAT_TYPE sum = FLOAT_TYPE(0.0); for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)); + sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); } - tmp[16 * ix + tid] += d * sum; + const uint tmp_idx = 16 * ix + tid; + tmp[tmp_idx] = fma(d, sum, tmp[tmp_idx]); } // sum up partial sums and write back result diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp index 75569363c64a9..d91e00e10061a 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -67,17 +67,17 @@ void main() { const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4); const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4); - const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3); - const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7); - const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11); - const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15); - const FLOAT_TYPE smin = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y1_idx ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7 - + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7 - + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7 - + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7 - ); - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); + const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]), q4_0, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), q4_1, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6, FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]), q4_8, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), q4_9, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7, + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7, + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7, + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7))))))))))))))); + const uint tmp_idx = 16 * ix + tid; + tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx])); #else const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); @@ -88,16 +88,19 @@ void main() { const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx ]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1); - const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3); - const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx ]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5); - const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7); - const FLOAT_TYPE smin = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7 - + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7 - ); - - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); + const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7, + + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7))))))); + + tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); + const uint tmp_idx = 16 * ix + tid; + tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f), + fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx])); #endif } diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp index 9be3645bdea0e..2306785af4226 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -66,35 +66,33 @@ void main() { const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4); const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4); - const FLOAT_TYPE sx = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y1_idx ]) * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)) - ); - const FLOAT_TYPE sy = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)) - ); - const FLOAT_TYPE sz = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y2_idx ]) * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)) - ); - const FLOAT_TYPE sw = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)) - ); - const FLOAT_TYPE smin = FLOAT_TYPE( - (FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3 - + (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7 - ); - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)), + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)), + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)), + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)), + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)), + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)), + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)), + fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)), + fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)), + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)), + fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)), + fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)), + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2, + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3, + fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6, + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7))); + const uint tmp_idx = 16 * ix + tid; + tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx])); } // sum up partial sums and write back result diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp index d610cf0306b0a..95c286eeb17e1 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -44,22 +44,22 @@ void main() { const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); #if K_QUANTS_PER_ITERATION == 1 - FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32); - tmp[16 * ix + tid] += sum; + const uint tmp_idx = 16 * ix + tid; + tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx])))))))); #else FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 4; ++l) { - sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32); + sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32), + fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum)))); } tmp[16 * ix + tid] += sum; #endif diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/vulkan-shaders/mul_mm.comp index 5fe9d5241381e..fffdd18189d55 100644 --- a/ggml/src/vulkan-shaders/mul_mm.comp +++ b/ggml/src/vulkan-shaders/mul_mm.comp @@ -326,10 +326,10 @@ void main() { mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); } const float d = loadd.x * sc; - const float m = loadd.y * mbyte; + const float m = -loadd.y * mbyte; - buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m); - buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m); + buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); + buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; @@ -357,10 +357,10 @@ void main() { mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); } const float d = loadd.x * sc; - const float m = loadd.y * mbyte; + const float m = -loadd.y * mbyte; - buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m); - buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m); + buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); + buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; @@ -463,7 +463,8 @@ void main() { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]); } } } diff --git a/ggml/src/vulkan-shaders/repeat.comp b/ggml/src/vulkan-shaders/repeat.comp new file mode 100644 index 0000000000000..a86af87e7b7f9 --- /dev/null +++ b/ggml/src/vulkan-shaders/repeat.comp @@ -0,0 +1,24 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +uint src0_idx_mod(uint idx) { + const uint i13 = idx / (p.ne12*p.ne11*p.ne10); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00; +} + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]); +} diff --git a/ggml/src/vulkan-shaders/sin.comp b/ggml/src/vulkan-shaders/sin.comp new file mode 100644 index 0000000000000..7faf9be9362bf --- /dev/null +++ b/ggml/src/vulkan-shaders/sin.comp @@ -0,0 +1,15 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val)); +} diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp index a792e203b273a..0c5b7b2794ad0 100644 --- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp @@ -368,6 +368,10 @@ void process_shaders(std::vector>& tasks) { string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + })); + tasks.push_back(std::async(std::launch::async, [] { string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); })); @@ -380,6 +384,10 @@ void process_shaders(std::vector>& tasks) { string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + })); + tasks.push_back(std::async(std::launch::async, [] { string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); })); @@ -388,6 +396,14 @@ void process_shaders(std::vector>& tasks) { string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + })); + + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + })); + tasks.push_back(std::async(std::launch::async, [] { string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); })); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f63ec450a4e09..b55effa9907b1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -130,6 +130,7 @@ class SSM: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class Tokenizer: MODEL = "tokenizer.ggml.model" @@ -219,6 +220,8 @@ class MODEL_ARCH(IntEnum): T5 = auto() T5ENCODER = auto() JAIS = auto() + NEMOTRON = auto() + EXAONE = auto() class MODEL_TENSOR(IntEnum): @@ -347,6 +350,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.EXAONE: "exaone", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1065,6 +1070,37 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.NEMOTRON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.EXAONE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } @@ -1105,6 +1141,10 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.CHATGLM: [ MODEL_TENSOR.ROPE_FREQS, ], + MODEL_ARCH.NEMOTRON: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], } # @@ -1333,6 +1373,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 76385a82872c9..af3b98c679b0b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,9 @@ def add_ssm_state_size(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: + self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 9aa2209e2f69f..a4f185c0658a3 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -10,10 +10,10 @@ class TensorNameMap: # Token embeddings MODEL_TENSOR.TOKEN_EMBD: ( "gpt_neox.embed_in", # gptneox - "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais + "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf + "model.embed_tokens", # llama-hf nemotron "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -52,7 +52,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -62,7 +62,7 @@ class TensorNameMap: # Output norm MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox - "transformer.ln_f", # gpt2 gpt-j falcon jais + "transformer.ln_f", # gpt2 gpt-j falcon jais exaone "model.norm", # llama-hf baichuan internlm2 "norm", # llama-pth "transformer.norm_f", # mpt dbrx @@ -75,6 +75,7 @@ class TensorNameMap: "transformer.rms_norm", # Grok "encoder.final_layernorm", # chatglm "transformer.norm", # openelm + "model.norm", # nemotron ), # Rope frequencies @@ -88,12 +89,12 @@ class TensorNameMap: # Attention norm MODEL_TENSOR.ATTN_NORM: ( "gpt_neox.layers.{bid}.input_layernorm", # gptneox - "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais + "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone "transformer.blocks.{bid}.norm_1", # mpt "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf + "model.layers.{bid}.input_layernorm", # llama-hf nemotron "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -135,18 +136,19 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf + "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert "transformer.h.{bid}.attn.q_proj", # gpt-j "model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.{bid}.attention.wq", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok + "transformer.h.{bid}.attn.attention.q_proj", # exaone ), # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf + "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert "transformer.h.{bid}.attn.k_proj", # gpt-j @@ -154,18 +156,20 @@ class TensorNameMap: "model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok + "transformer.h.{bid}.attn.attention.k_proj", # exaone ), # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf + "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert "transformer.h.{bid}.attn.v_proj", # gpt-j "transformer.h.{bid}.attn.v", # refact "model.layers.layers.{bid}.self_attn.v_proj", # plamo "model.layers.{bid}.attention.wv", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok + "transformer.h.{bid}.attn.attention.v_proj", # exaone ), # Attention output @@ -175,7 +179,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf + "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert "transformer.h.{bid}.attn.out_proj", # gpt-j @@ -190,6 +194,7 @@ class TensorNameMap: "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "encoder.layers.{bid}.self_attention.dense", # chatglm "transformer.layers.{bid}.attn.out_proj", # openelm + "transformer.h.{bid}.attn.attention.out_proj", # exaone ), # Attention output norm @@ -215,10 +220,10 @@ class TensorNameMap: # Feed-forward norm MODEL_TENSOR.FFN_NORM: ( "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox - "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais + "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone "h.{bid}.post_attention_layernorm", # bloom "transformer.blocks.{bid}.norm_2", # mpt - "model.layers.{bid}.post_attention_layernorm", # llama-hf + "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron "layers.{bid}.ffn_norm", # llama-pth "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "model.layers.{bid}.ln2", # yi @@ -258,7 +263,7 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon "h.{bid}.mlp.dense_h_to_4h", # bloom - "model.layers.{bid}.mlp.up_proj", # llama-hf refact + "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert "transformer.h.{bid}.mlp.fc_in", # gpt-j @@ -277,6 +282,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm + "transformer.h.{bid}.mlp.c_fc_1", # exaone ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -308,6 +314,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic + "transformer.h.{bid}.mlp.c_fc_0", # exaone ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -329,7 +336,7 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.down_proj", # mpt "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon "h.{bid}.mlp.dense_4h_to_h", # bloom - "model.layers.{bid}.mlp.down_proj", # llama-hf + "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert "transformer.h.{bid}.mlp.fc_out", # gpt-j @@ -347,6 +354,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w2", # arctic "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm + "model.layers.h.{bid}.mlp.c_proj", # exaone ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 19f6761e2f912..eea381e5a6b92 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.9.1" +version = "0.10.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ diff --git a/include/llama.h b/include/llama.h index be0b7d4e83e48..20d49383eb4bf 100644 --- a/include/llama.h +++ b/include/llama.h @@ -93,15 +93,15 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, }; - // note: these values should be synchronized with ggml_rope - // TODO: maybe move this enum to ggml.h (ggml_rope_type) enum llama_rope_type { LLAMA_ROPE_TYPE_NONE = -1, - LLAMA_ROPE_TYPE_NORM = 0, - LLAMA_ROPE_TYPE_NEOX = 2, - LLAMA_ROPE_TYPE_GLM = 4, + LLAMA_ROPE_TYPE_NORM = 0, + LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX, }; enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file @@ -511,6 +511,9 @@ extern "C" { // to the decoder to start generating output sequence. For other models, it returns -1. LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); + // Returns true if the model is recurrent (like Mamba, RWKV, etc.) + LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -915,11 +918,8 @@ extern "C" { LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding - // Returns -1 if unknown, 1 for true or 0 for false. - LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model); - - // Returns -1 if unknown, 1 for true or 0 for false. - LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model); + LLAMA_API bool llama_add_bos_token(const struct llama_model * model); + LLAMA_API bool llama_add_eos_token(const struct llama_model * model); // Codellama infill tokens LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index eef6768b149db..1e6db754fe68a 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -797faa25af14126eb30134d4033139ae3c5428ed +28b7633d733bbeef0026570fbc61c79c5e9aa5ae diff --git a/src/llama-impl.h b/src/llama-impl.h index 399b134a7f9bc..9527740961da6 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -31,11 +31,17 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * static void replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d41218c702360..843a470c7b99f 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -85,14 +85,14 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra constexpr float bucket_low = -10.0f; constexpr float bucket_high = 10.0f; constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); - constexpr float bucker_inter = -bucket_low * bucket_scale; + constexpr float bucket_inter = -bucket_low * bucket_scale; std::vector bucket_idx(candidates->size); std::vector histo(nbuckets, 0); for (int i = 0; i < (int)candidates->size; ++i) { const float val = candidates->data[i].logit; - int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); ib = std::max(0, std::min(nbuckets-1, ib)); bucket_idx[i] = ib; ++histo[ib]; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 749f8571829df..323660ef54cb0 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -321,6 +321,21 @@ struct llm_tokenizer_spm { // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused +template, typename Compare = std::less> +class llama_priority_queue : public std::priority_queue { +public: + using std::priority_queue::priority_queue; + + T pop_move() { + T item = std::move(this->c.front()); + std::pop_heap(this->c.begin(), this->c.end(), this->comp); + this->c.pop_back(); + return item; + } + + void pop() = delete; +}; + struct llm_bigram_bpe { struct comparator { bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { @@ -329,7 +344,7 @@ struct llm_bigram_bpe { }; using queue_storage = std::vector; - using queue = std::priority_queue; + using queue = llama_priority_queue; llm_symbol::index left; llm_symbol::index right; std::string text; @@ -388,6 +403,7 @@ struct llm_tokenizer_bpe { case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: case LLAMA_VOCAB_PRE_TYPE_SMOLLM: case LLAMA_VOCAB_PRE_TYPE_CODESHELL: + case LLAMA_VOCAB_PRE_TYPE_EXAONE: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -410,6 +426,8 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_PORO: + case LLAMA_VOCAB_PRE_TYPE_BLOOM: + case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: regex_exprs = { " ?[^(\\s|.,!?…。,、।۔،)]+", }; @@ -517,8 +535,7 @@ struct llm_tokenizer_bpe { // build token(s) while (!work_queue.empty()) { - auto bigram = work_queue.top(); - work_queue.pop(); + auto bigram = work_queue.pop_move(); auto & left_symbol = symbols[bigram.left]; auto & right_symbol = symbols[bigram.right]; @@ -1466,11 +1483,11 @@ llama_token llama_token_pad_impl(const struct llama_vocab & vocab) { return vocab.special_pad_id; } -int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) { +bool llama_add_bos_token_impl(const struct llama_vocab & vocab) { return vocab.tokenizer_add_bos; } -int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) { +bool llama_add_eos_token_impl(const struct llama_vocab & vocab) { return vocab.tokenizer_add_eos; } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 7adfc16da3af3..6e8f30be43ba1 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -95,8 +95,8 @@ llama_token llama_token_sep_impl(const struct llama_vocab & vocab); llama_token llama_token_nl_impl (const struct llama_vocab & vocab); llama_token llama_token_pad_impl(const struct llama_vocab & vocab); -int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab); -int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab); +bool llama_add_bos_token_impl(const struct llama_vocab & vocab); +bool llama_add_eos_token_impl(const struct llama_vocab & vocab); llama_token llama_token_prefix_impl(const struct llama_vocab & vocab); llama_token llama_token_middle_impl(const struct llama_vocab & vocab); diff --git a/src/llama.cpp b/src/llama.cpp index 5878c0271f245..85751dc1ba123 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -210,6 +210,8 @@ enum llm_arch { LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, + LLM_ARCH_NEMOTRON, + LLM_ARCH_EXAONE, LLM_ARCH_UNKNOWN, }; @@ -255,6 +257,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -324,6 +328,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_DT_B_C_RMS, LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_PRE, @@ -422,6 +427,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -1296,6 +1302,43 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_NEMOTRON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_EXAONE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2196,6 +2239,7 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -2245,6 +2289,7 @@ struct llama_hparams { if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->dec_start_token_id != other.dec_start_token_id) return true; @@ -2471,10 +2516,29 @@ struct llama_layer { struct ggml_tensor * ffn_down_scale; }; +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + // TODO: whole_seqs for embeddings? + + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_seq_tokens; // tokens per sequence + uint32_t n_seqs; + + llama_token * token; // [n_tokens] + float * embd; // [n_embd, n_tokens] + llama_pos * pos; // [n_tokens] + int32_t * n_seq_id; // [n_seqs] + llama_seq_id ** seq_id; // [n_seqs] + int8_t * output; // [n_tokens] +}; + struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; - int32_t src = 0; // used by recurrent state models to copy states + int32_t src = -1; // used by recurrent state models to copy states + int32_t tail = -1; std::set seq_id; @@ -2495,7 +2559,6 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; - bool do_copy = false; bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token bool v_trans = true; // the value tensor is transposed @@ -2658,6 +2721,340 @@ struct llama_model { } }; +struct llama_sbatch_seq { + int32_t n_seq_id; + llama_seq_id * seq_id; + size_t offset; + size_t length; + + // helper for smoother batch API transition -- can be deprecated in the future + llama_seq_id all_seq_id; // used if seq_id == NULL +}; + +// sequence-length-aware batch splitting +struct llama_sbatch { + // tokens left in this batch + size_t n_tokens; + + size_t n_embd; + + bool logits_all; // TODO: remove once lctx.logits_all is removed too + + // sorted indices into the batch + std::vector ids; + // batch indices of the output + std::vector out_ids; + std::vector seq; + const llama_batch * batch = nullptr; + + // buffers for the ubatch + std::vector ubatch_token; + std::vector ubatch_embd; + std::vector ubatch_pos; + std::vector ubatch_n_seq_id; + std::vector ubatch_seq_id; + std::vector ubatch_output; + + llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) { + // clear empty sequences + // the previous ubatch is assumed to be gone, + // so nothing should refer to values in these sequences anymore. + for (size_t i = seq.size(); i-- > 0;) { + if (seq[i].length == 0) { + seq.pop_back(); + } else { + break; + } + } + ubatch_token.resize(!has_embd ? n_ubatch : 0); + ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); + ubatch_pos.resize(n_ubatch); + ubatch_n_seq_id.resize(n_ubatch); + ubatch_seq_id.resize(n_ubatch); + ubatch_output.resize(n_ubatch); + llama_ubatch ubatch = { + /*equal_seqs =*/ true, + /*n_tokens =*/ 0, + /*n_seq_tokens =*/ 0, + /*n_seqs =*/ 0, + /*token =*/ !has_embd ? ubatch_token.data() : nullptr, + /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, + /*pos =*/ ubatch_pos.data(), + /*n_seq_id =*/ ubatch_n_seq_id.data(), + /*seq_id =*/ ubatch_seq_id.data(), + /*output =*/ ubatch_output.data(), + }; + return ubatch; + } + + void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { + GGML_ASSERT(batch != nullptr); + GGML_ASSERT(length <= seq.length); + // Can only add sequences of equal lengths to a batch, + // otherwise it isn't clear to which sequence a token belongs + GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); + GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); + // NOTE: loops are separated for cache-friendliness + if (batch->token) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.token = batch->token + seq.offset; + } + } else { + ubatch.token = nullptr; + } + if (batch->embd) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + n_embd * (ubatch.n_tokens + i), + batch->embd + n_embd * ids[seq.offset + i], + n_embd * sizeof(float) + ); + } + } else { + // simple split + ubatch.embd = batch->embd + (n_embd * seq.offset); + } + } else { + ubatch.embd = nullptr; + } + // from here on, the else branches are deprecated; + // they are helpers for smoother batch API transition + if (batch->pos) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; + } + } else { + for (size_t i = 0; i < length; ++i) { + llama_pos bi = ids[seq.offset + i]; + ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); + } + } + if (ubatch.equal_seqs) { + ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; + if (seq.seq_id) { + ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; + } else { + GGML_ASSERT(seq.n_seq_id == 1); + ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; + } + } else { + // simple split + if (batch->n_seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id = batch->n_seq_id + seq.offset; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + } + } + if (batch->seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id = batch->seq_id + seq.offset; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id; + } + } + } + if (logits_all) { + for (size_t i = 0; i < length; ++i) { + ubatch.output[ubatch.n_tokens + i] = 1; + out_ids.push_back(ids[seq.offset + i]); + } + } else if (batch->logits) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else { + // simple split + ubatch.output = batch->logits + seq.offset; + for (size_t i = 0; i < length; ++i) { + if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } + } + } + } else { + // only get last output + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_last = id == ids.size() - 1; + ubatch.output[ubatch.n_tokens + i] = is_last; + if (is_last) { out_ids.push_back(id); } + } + } + if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { + ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; + } + ubatch.n_tokens += length; + ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits + seq.offset += length; + seq.length -= length; + n_tokens -= length; + GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); + } + + // simple split, unknown number of sequences of unequal lengths + llama_ubatch split_simple(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + ubatch.equal_seqs = false; + if (!seq.empty()) { + llama_sbatch_seq & s = seq[0]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; + } + + // make batches of equal-length sequences + llama_ubatch split_equal(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + size_t length = 0; + size_t n_tokens_in_ubatch = 0; + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits + // smallest first, because it's easier to split this way; + // starting from the end to pop in constant time. + for (size_t i = seq.size(); i-- > 0;) { + llama_sbatch_seq & s = seq[i]; + GGML_ASSERT(s.length > 0); + if (length == 0) { + length = s.length < n_ubatch ? s.length : n_ubatch; + } + add_seq_to_ubatch(ubatch, s, length); + n_tokens_in_ubatch += length; + // shared prompts can't be mixed with any of their sequences, + // so it's safer to compute them in their own ubatch + if (s.n_seq_id > 1) { break; } + // stop when there isn't enough space for another sequence + if (length + n_tokens_in_ubatch > n_ubatch) { break; } + } + } + return ubatch; + } + + // sequence-wise split + llama_ubatch split_seq(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + llama_sbatch_seq & s = seq[seq.size() - 1]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; + } + + void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) { + GGML_ASSERT(batch.n_tokens >= 0); + this->batch = &batch; + this->n_embd = n_embd; + this->logits_all = logits_all; + + n_tokens = batch.n_tokens; + ids.resize(n_tokens); + out_ids.clear(); + // TODO: reserve out_ids and seq + + for (size_t i = 0; i < n_tokens; ++i) { + ids[i] = i; + } + if (simple_split) { + seq.resize(1); + llama_sbatch_seq & s = seq[0]; + s.n_seq_id = 0; + s.seq_id = nullptr; + s.offset = 0; + s.length = n_tokens; + s.all_seq_id = batch.all_seq_id; + return; + } + std::sort(ids.begin(), ids.end(), + [&batch](size_t a, size_t b) { + int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; + int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; + // sort by seq_id, then by pos + if (n_seq_a == n_seq_b) { + if (batch.seq_id) { + for (int32_t i = 0; i < n_seq_a; ++i) { + llama_seq_id seq_id_a = batch.seq_id[a][i]; + llama_seq_id seq_id_b = batch.seq_id[b][i]; + // smaller seq_ids go first + if (seq_id_a != seq_id_b) { + return seq_id_a < seq_id_b; + } + } + } + // when all else is equal, sort by pos + if (batch.pos) { + return batch.pos[a] < batch.pos[b]; + } + // no pos, sort by id (assuming batch.all_pos_1 is positive) + return a < b; + } + // shared prompts go first + return n_seq_a > n_seq_b; + } + ); + // init seq + llama_sbatch_seq * last_seq = nullptr; + + if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; + } + } + if (same) { + last_seq->length += 1; + continue; + } + } + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; + seq.push_back(new_seq); + last_seq = &seq.back(); + } + } else { + llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id}; + seq.push_back(new_seq); + } + // keep shared prompts first at the end, then sort by length descending. + std::sort(seq.begin(), seq.end(), + [](llama_sbatch_seq & a, llama_sbatch_seq & b) { + if (a.n_seq_id == b.n_seq_id) { + return a.length > b.length; + } + return a.n_seq_id < b.n_seq_id; + } + ); + } +}; + struct llama_context { llama_context(const llama_model & model) : model(model) @@ -2679,6 +3076,7 @@ struct llama_context { struct llama_cparams cparams; struct llama_sampling sampling; + struct llama_sbatch sbatch; struct llama_kv_cache kv_self; struct llama_control_vector cvec; @@ -2939,8 +3337,7 @@ static bool llama_kv_cache_init( cache.has_shift = false; - // TODO: find a nicer way to add other recurrent model architectures - cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.recurrent = llama_model_is_recurrent(&model); cache.v_trans = !cache.recurrent && !cparams.flash_attn; cache.head = 0; @@ -2953,13 +3350,6 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(kv_size); - if (cache.recurrent) { - // init state copy sources - for (uint32_t i = 0; i < cache.size; ++i) { - cache.cells[i].src = i; - } - } - // count used buffer types std::map buft_layer_count; if (offload) { @@ -3027,46 +3417,162 @@ static bool llama_kv_cache_init( // to the first cell of the slot. static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_batch & batch) { + const struct llama_ubatch & batch) { const uint32_t n_tokens = batch.n_tokens; + const uint32_t n_seqs = batch.n_seqs; + const uint32_t n_seq_tokens = batch.n_seq_tokens; if (cache.recurrent) { // For recurrent state architectures (like Mamba), - // each KV cache cell can store the state for a whole sequence. - - llama_seq_id min = cache.size - 1; - llama_seq_id max = 0; - - for (uint32_t i = 0; i < n_tokens; ++i) { - for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { - llama_seq_id seq_id = batch.seq_id[i][j]; - // make sure it's a valid seq_id - if ((uint32_t) seq_id < cache.size) { - if (seq_id > max) { - max = seq_id; - } - if (seq_id < min) { - min = seq_id; + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(batch.equal_seqs); + + int32_t min = cache.size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = batch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = batch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + return false; + } + if (j > 0) { + llama_kv_cell & seq = cache.cells[seq_id]; + if (seq.tail >= 0) { + llama_kv_cell & cell = cache.cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + cache.used -= 1; + } } - // Assuming the tokens are in-order - if (batch.pos[i] != cache.cells[seq_id].pos + 1) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(cache.size, -1); + for (uint32_t i = 0; i < cache.size; ++i) { + llama_kv_cell & cell = cache.cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); } - if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { - cache.used += 1; + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < cache.size; ++i) { + if (tails_verif[i] != cache.cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = cache.head; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } + llama_kv_cell & cell = cache.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + llama_kv_cell & seq_meta = cache.cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + llama_kv_cell & cell = cache.cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + llama_kv_cell & empty_cell = cache.cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + llama_kv_cell & orig_cell = cache.cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < cache.size; ++i) { + if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } + llama_kv_cell & cell = cache.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; } - cache.cells[seq_id].pos = batch.pos[i]; - // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set - } else { - // too big seq_id - // TODO: would it be possible to resize the KV cache size instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); - return false; } } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cache.cells[batch.seq_id[s][0]].tail; + if (dst_id != src_id) { + llama_kv_cell & dst_cell = cache.cells[dst_id]; + llama_kv_cell & src_cell = cache.cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cache.cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cache.cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + llama_kv_cell & cell = cache.cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = batch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cache.cells[seq_id].tail = cell_id; + } } // allow getting the range of used cells, from head to head + n @@ -3074,7 +3580,7 @@ static bool llama_kv_cache_find_slot( cache.n = max - min + 1; // sanity check - return max >= min; + return cache.n >= n_seqs; } // otherwise, one cell per token. @@ -3112,11 +3618,14 @@ static bool llama_kv_cache_find_slot( } } - for (uint32_t i = 0; i < n_tokens; i++) { - cache.cells[cache.head + i].pos = batch.pos[i]; + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cache.cells[cache.head + k].pos = batch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { + cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]); + } } } @@ -3142,6 +3651,8 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { for (int32_t i = 0; i < (int32_t) cache.size; ++i) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); + cache.cells[i].src = -1; + cache.cells[i].tail = -1; } cache.head = 0; cache.used = 0; @@ -3168,9 +3679,16 @@ static bool llama_kv_cache_seq_rm( return false; } if (0 <= seq_id) { - // partial intersection is invalid - if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) { - return false; + int32_t & tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + const llama_kv_cell & cell = cache.cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + if (p0 <= cell.pos && p1 < cell.pos) { + tail_id = -1; + } } } else { // seq_id is negative, then the range should include everything or nothing @@ -3194,6 +3712,7 @@ static bool llama_kv_cache_seq_rm( if (cache.cells[i].pos >= 0) cache.used--; cache.cells[i].pos = -1; + cache.cells[i].src = -1; if (new_head == cache.size) new_head = i; } } @@ -3216,23 +3735,29 @@ static void llama_kv_cache_seq_cp( if (cache.recurrent) { if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { - seq_id_src = cache.cells[seq_id_src].src; - GGML_ASSERT((uint32_t) seq_id_src < cache.size); - // intent to "copy from" - // supports copy chains thanks to taking the source of the source - cache.cells[seq_id_dst].src = seq_id_src; - - // preserve the "keep or clear" status of the copied sequence - if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) { - cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); - } else { - cache.cells[seq_id_dst].seq_id.erase(seq_id_dst); + llama_kv_cell & tail_src = cache.cells[seq_id_src]; + llama_kv_cell & tail_dst = cache.cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + llama_kv_cell & cell_dst = cache.cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.delta = -1; + cell_dst.src = -1; + cache.used -= 1; + } } + if (tail_src.tail >= 0) { + llama_kv_cell & cell_src = cache.cells[tail_src.tail]; - cache.do_copy = true; - - cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } } + return; } // otherwise, this is the KV cache of a Transformer-like model @@ -3250,9 +3775,13 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id uint32_t new_head = cache.size; for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.recurrent && (llama_seq_id) i != seq_id) { + cache.cells[i].tail = -1; + } if (!cache.cells[i].has_seq_id(seq_id)) { if (cache.cells[i].pos >= 0) cache.used--; cache.cells[i].pos = -1; + cache.cells[i].src = -1; cache.cells[i].seq_id.clear(); if (new_head == cache.size) new_head = i; } else { @@ -3281,9 +3810,12 @@ static void llama_kv_cache_seq_add( if (cache.recurrent) { // for Mamba-like models, only the pos needs to be shifted if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; + const int32_t tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cache.cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += delta; + } } } return; @@ -3327,9 +3859,12 @@ static void llama_kv_cache_seq_div( if (cache.recurrent) { // for Mamba-like models, only the pos needs to be changed if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; + const int32_t tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cache.cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } } } return; @@ -3361,7 +3896,9 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama } static void llama_kv_cache_defrag(struct llama_kv_cache & cache) { - cache.do_defrag = true; + if (!cache.recurrent) { + cache.do_defrag = true; + } } static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) { @@ -3575,13 +4112,8 @@ namespace GGUFMeta { using llama_buf_map = std::unordered_map; -// TODO: update when needed or think of some clever automatic way to do this -static size_t llama_model_max_nodes(const llama_model & /*model*/) { - //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B - // return 32768; - //} - - return 8192; +static size_t llama_model_max_nodes(const llama_model & model) { + return std::max(8192, model.tensors_by_name.size()*5); } struct llama_model_loader { @@ -5016,6 +5548,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5240,6 +5773,23 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_NEMOTRON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_4B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_EXAONE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -5472,6 +6022,15 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "codeshell") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "bloom") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BLOOM; + } else if ( + tokenizer_pre == "gpt3-finnish") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; + } else if ( + tokenizer_pre == "exaone") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -5845,6 +6404,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); @@ -6045,6 +6605,7 @@ static bool llm_load_tensors( const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab_type = hparams.n_vocab_type; + const int64_t n_rot = hparams.n_rot; const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; @@ -6102,7 +6663,7 @@ static bool llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); if (n_expert == 0) { layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); @@ -6110,9 +6671,9 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); // optional MLP bias - layer.ffn_gate_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); @@ -6436,7 +6997,7 @@ static bool llm_load_tensors( layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens - layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); @@ -7555,8 +8116,8 @@ static bool llm_load_tensors( layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)}); + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); @@ -7567,6 +8128,78 @@ static bool llm_load_tensors( layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); } } break; + case LLM_ARCH_NEMOTRON: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + // optional MLP bias + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_EXAONE: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7808,7 +8441,7 @@ static struct ggml_tensor * llm_build_inp_embd( struct ggml_context * ctx, struct llama_context & lctx, const llama_hparams & hparams, - const llama_batch & batch, + const llama_ubatch & batch, struct ggml_tensor * tok_embd, const llm_build_cb & cb) { const int64_t n_embd = hparams.n_embd; @@ -8242,9 +8875,10 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } @@ -8253,7 +8887,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -8357,19 +8991,187 @@ static struct ggml_tensor * llm_build_kv( return cur; } -struct llm_build_context { - const llama_model & model; - llama_context & lctx; - const llama_hparams & hparams; - const llama_cparams & cparams; - const llama_batch & batch; - const llama_kv_cache & kv_self; +static struct ggml_tensor * llm_build_copy_mask_state( + struct ggml_context * ctx, + struct ggml_cgraph * graph, + struct ggml_tensor * s, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + int32_t n_state, + int32_t kv_size, + int32_t kv_head, + int32_t n_kv, + int32_t n_seqs) { + struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size); + + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // this shrinks the tensors's ne[1] to n_kv + states = ggml_get_rows(ctx, states, state_copy); + + // clear states of sequences which are starting at the beginning of this batch + // FIXME: zero-out NANs? + states = ggml_mul(ctx, states, state_mask); + + // copy states which won't be changed further (between n_seqs and n_rs) + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)), + ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + + // the part of the states that will be used and modified + return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); +} - const int64_t n_embd; - const int64_t n_layer; - const int64_t n_rot; - const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) - const int64_t n_head; +// TODO: split +static struct ggml_tensor * llm_build_mamba( + struct ggml_context * ctx, + struct llama_context & lctx, + const llama_ubatch & batch, + struct ggml_cgraph * graph, + struct ggml_tensor * cur, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + int32_t kv_head, + int32_t n_kv, + const llm_build_cb & cb, + int il) { + const llama_model & model = lctx.model; + const llama_hparams & hparams = model.hparams; + const llama_kv_cache & kv = lctx.kv_self; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_seqs = batch.n_seqs; + // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) + const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; + // Use the same RMS norm as the final layer norm + const float norm_rms_eps = hparams.f_norm_rms_eps; + + const int64_t n_seq_tokens = batch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * conv_states_all = kv.k_l[il]; + struct ggml_tensor * ssm_states_all = kv.v_l[il]; + + // (ab)using the KV cache to store the states + struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, + graph, conv_states_all, state_copy, state_mask, + hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); + conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs); + struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, + graph, ssm_states_all, state_copy, state_mask, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); + ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); + // split the above in two + // => {d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); + struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0); + + // copy last (d_conv - 1) columns back into the state cache + struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, last_conv, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); + + // bias + x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); + + x = ggml_silu(ctx, x); + } + + // ssm + { + // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x); + // split + struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); + struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + + // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers + if (ssm_dt_b_c_rms) { + dt = ggml_rms_norm(ctx, dt, norm_rms_eps); + B = ggml_rms_norm(ctx, B, norm_rms_eps); + C = ggml_rms_norm(ctx, C, norm_rms_eps); + } + + // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); + + // store last states + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + + // TODO: skip computing output earlier for unused tokens + + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; +} + +struct llm_build_context { + const llama_model & model; + llama_context & lctx; + const llama_hparams & hparams; + const llama_cparams & cparams; + const llama_ubatch & batch; + const llama_kv_cache & kv_self; + + const int64_t n_embd; + const int64_t n_layer; + const int64_t n_rot; + const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) + const int64_t n_head; const int64_t n_head_kv; const int64_t n_embd_head_k; const int64_t n_embd_k_gqa; @@ -8408,7 +9210,7 @@ struct llm_build_context { // TODO: consider making the entire interface noexcept llm_build_context( llama_context & lctx, - const llama_batch & batch, + const llama_ubatch & batch, const llm_build_cb & cb, bool worst_case) : model (lctx.model), @@ -8515,29 +9317,6 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_s_copy() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); - - GGML_ASSERT(kv_self.recurrent); - - struct ggml_tensor * state_copy = build_inp_s_copy(); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - - conv_states = ggml_get_rows(ctx0, conv_states, state_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); - - // TODO: name the intermediate tensors with cb() - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); - } - - return gf; - } - struct ggml_cgraph * build_defrag(const std::vector & ids) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -8672,7 +9451,7 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_copy() { - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size); + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); cb(lctx.inp_s_copy, "inp_s_copy", -1); ggml_set_input(lctx.inp_s_copy); return lctx.inp_s_copy; @@ -8685,13 +9464,6 @@ struct llm_build_context { return lctx.inp_s_mask; } - struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); - cb(lctx.inp_s_seq, "inp_s_seq", -1); - ggml_set_input(lctx.inp_s_seq); - return lctx.inp_s_seq; - } - struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { // find result_norm tensor for input struct ggml_tensor * inp = nullptr; @@ -12021,125 +12793,31 @@ struct llm_build_context { struct ggml_cgraph * build_mamba() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); - const int64_t d_model = n_embd; - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - GGML_ASSERT(2 * d_model == d_inner); - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - struct ggml_tensor * cur; struct ggml_tensor * inpL; // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); - struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - // (ab)using the KV cache to store the states - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - - // clear states of sequences which are starting at the beginning of this batch - { - conv_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]), - state_mask); - ssm_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]), - state_mask); - } - - conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); - ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); - // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} - struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur); - // split the above in two - // => {d_inner, n_tokens} - struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); - struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); - - // conv - { - // Custom operator which is needed only to ease simultaneous sequence processing. - // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, - // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weigth, - // then sum the elements of each row, - // (the last two steps are a dot product over rows (also doable with mul_mat)) - // then permute away the ne[0] dimension, - // and then you're left with the resulting x tensor. - // The new conv_states is the last (d_conv - 1) columns - // of the last 3rd dimensional "layer" of the self-overlapping view. - // For simultaneous sequences, it's more complicated. - struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq); - - // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); - - // extract x from x_conv - x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); - - // bias - x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); - - x = ggml_silu(ctx0, x); - } - - // ssm - { - // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} - struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x); - // split - struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); - - // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} - dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt); - dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); - - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined, - // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); - - // store last states (the second part of y_ssm_states) - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states)))); - - struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - x = ggml_get_rows(ctx0, x, inp_out_ids); - y = ggml_get_rows(ctx0, y, inp_out_ids); - z = ggml_get_rows(ctx0, z, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } - - // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); - // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} - cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } // residual @@ -13754,28 +14432,259 @@ struct llm_build_context { return gf; } -}; -static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { - llama_batch dummy; - dummy.n_tokens = 0; + struct ggml_cgraph * build_nemotron() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + //GGML_ASSERT(n_embd_head == hparams.n_rot); - struct llm_build_context llm(lctx, dummy, cb, false); + struct ggml_tensor * cur; + struct ggml_tensor * inpL; - llm.init(); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); - struct ggml_cgraph * result = llm.build_defrag(ids); + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); - llm.free(); + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - return result; -} + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; -static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_exaone() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } +}; + +static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { + llama_ubatch dummy = {}; + dummy.equal_seqs = true; llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; @@ -13783,16 +14692,16 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { llm.init(); - struct ggml_cgraph * result = llm.build_k_shift(); + struct ggml_cgraph * result = llm.build_defrag(ids); llm.free(); return result; } -static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; +static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { + llama_ubatch dummy = {}; + dummy.equal_seqs = true; llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; @@ -13800,7 +14709,7 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { llm.init(); - struct ggml_cgraph * result = llm.build_s_copy(); + struct ggml_cgraph * result = llm.build_k_shift(); llm.free(); @@ -13809,7 +14718,7 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { static struct ggml_cgraph * llama_build_graph( llama_context & lctx, - const llama_batch & batch, + const llama_ubatch & batch, bool worst_case) { const auto & model = lctx.model; @@ -14009,6 +14918,14 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_jais(); } break; + case LLM_ARCH_NEMOTRON: + { + result = llm.build_nemotron(); + } break; + case LLM_ARCH_EXAONE: + { + result = llm.build_exaone(); + } break; default: GGML_ABORT("fatal error"); } @@ -14071,7 +14988,7 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t return relative_bucket; } -static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { +static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { // // set input data // @@ -14110,10 +15027,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_tokens; ++i) { data[i] = i; } - } else if (batch.logits) { + } else if (batch.output) { int32_t n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - if (batch.logits[i]) { + if (batch.output[i]) { data[n_outputs++] = i; } } @@ -14137,8 +15054,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn && !lctx.is_encoding) { - const int64_t n_kv = kv_self.n; - const int64_t n_tokens = batch.n_tokens; + const int64_t n_kv = kv_self.n; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; float * data = nullptr; @@ -14158,32 +15077,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; - for (int i = 0; i < n_kv; ++i) { - float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -std::abs(lctx.kv_self.cells[i].pos - pos); + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = batch.pos[s*n_seq_tokens + j]; + + for (int i = 0; i < n_kv; ++i) { + float f; + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + f = -INFINITY; } else { - f = 0.0f; + if (hparams.use_alibi) { + f = -std::abs(kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } } - } - if (data) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; - } + if (data) { + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } - // may need to cut off old tokens for sliding window - if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } - data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } } @@ -14205,8 +15127,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } else { + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; // when using kv cache, the mask needs to match the kv cache size - const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -14214,27 +15138,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_seq_id seq_id = batch.seq_id[j][0]; - - for (int i = 0; i < n_tokens; ++i) { - float f = -INFINITY; - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id[i][s] == seq_id) { - if (hparams.use_alibi) { - f = -std::abs(batch.pos[i] - batch.pos[j]); - } else { - f = 0.0f; + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = batch.seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < batch.n_seq_id[s0]; ++s) { + if (batch.seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(batch.pos[ti] - batch.pos[tj]); + } else { + f = 0.0f; + } + break; + } } - break; + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; } } - data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } } } } @@ -14242,7 +15174,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_mean); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); @@ -14251,12 +15185,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); std::vector sum(n_tokens, 0); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - sum[seq_id] += 1; + sum[seq_id] += batch.n_seq_tokens; } std::vector div(n_tokens, 0.0f); @@ -14267,14 +15203,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - data[seq_id*n_tokens + i] = div[seq_id]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } } } if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_cls); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -14282,20 +15223,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { uint32_t * data = (uint32_t *) lctx.inp_cls->data; memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_pos pos = batch.pos[i]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); - if (pos == 0) { - data[seq_id] = i; + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = batch.pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } } } } if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_cls); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -14306,15 +15253,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { std::vector last_pos(n_tokens, -1); std::vector last_row(n_tokens, -1); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_pos pos = batch.pos[i]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); - if (pos >= last_pos[seq_id]) { - last_pos[seq_id] = pos; - last_row[seq_id] = i; + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = batch.pos[s*n_seq_tokens + i]; + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = s*n_seq_tokens + i; + } } } @@ -14332,41 +15283,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); float * data = (float *) lctx.inp_s_mask->data; - // states which are not affected by the current batch are left untouched + // clear unused states for (int i = 0; i < n_kv; ++i) { - llama_seq_id seq_id = i + lctx.kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; - bool has_self_seq = kv_cell.has_seq_id(seq_id); + uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - data[i] = (float) has_self_seq; + data[i] = (float) (kv_cell.src >= 0); - // ensure current sequences will be kept - if (!has_self_seq && kv_cell.pos >= 0) { - kv_cell.seq_id.insert(seq_id); + // only clear once + if (kv_cell.src < 0) { + kv_cell.src = cell_id; } } } - // For Mamba (and other recurrent architectures), - // update the correct state(s)/sequence(s) for each token of the batch. - // Like with the KQ_mask, if a token in the batch has multiple sequences, - // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). - if (lctx.inp_s_seq) { - const int64_t n_tokens = batch.n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); - int32_t * data = (int32_t *) lctx.inp_s_seq->data; + if (lctx.inp_s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_copy->data; - for (int j = 0; j < n_tokens; ++j) { - const int32_t n_seq = batch.n_seq_id[j]; - GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - for (int i = 0; i < n_kv; ++i) { - if (i < n_seq) { - // for this type of model, the head is the minimum seq_id of the batch - data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head; - } else { - data[j*n_kv + i] = -1; - } + // prevent out-of-bound sources + if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + kv_cell.src = cell_id; + } + + data[i] = kv_cell.src; + + // ensure copy only happens once + if (kv_cell.src != (int32_t) cell_id) { + kv_cell.src = cell_id; } } } @@ -14376,6 +15325,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); + GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; @@ -14411,6 +15361,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer)); + GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing float * data = (float *) lctx.inp_KQ_mask_cross->data; @@ -14504,6 +15455,43 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { return n_outputs_max; } +// make the outputs have the same order they had in the user-provided batch +static void llama_output_reorder(struct llama_context * ctx) { + std::vector & out_ids = ctx->sbatch.out_ids; + if (!out_ids.empty()) { + uint32_t n_vocab = ctx->model.hparams.n_vocab; + uint32_t n_embd = ctx->model.hparams.n_embd; + int32_t n_outputs = ctx->n_outputs; + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (ctx->logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]); + } + } + if (ctx->embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]); + } + } + } + std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + ctx->output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} static void llama_graph_compute( llama_context & lctx, @@ -14576,15 +15564,11 @@ static int llama_decode_internal( const auto n_ubatch = cparams.n_ubatch; - // TODO: simplify or deprecate - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + lctx.embd_seq.clear(); + // count outputs if (batch_all.logits && !embd_pooled) { for (uint32_t i = 0; i < n_tokens_all; ++i) { @@ -14597,55 +15581,42 @@ static int llama_decode_internal( n_outputs = 1; } + lctx.sbatch.from_batch(batch_all, n_embd, + /* simple_split */ !kv_self.recurrent, + /* logits_all */ n_outputs == n_tokens_all); + // reserve output buffer if (llama_output_reserve(lctx, n_outputs) < n_outputs) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); return -2; }; - // set output mappings - if (batch_all.logits) { - int32_t i_logits = 0; - for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch_all.logits[i]) { - lctx.output_ids[i] = i_logits++; + while (lctx.sbatch.n_tokens > 0) { + llama_ubatch ubatch; + if (kv_self.recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = lctx.sbatch.split_seq(n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = lctx.sbatch.split_equal(n_ubatch); } + } else { + ubatch = lctx.sbatch.split_simple(n_ubatch); } - } else { - for (uint32_t i = 0; i < n_outputs; ++i) { - lctx.output_ids[i] = i; - } - } - - for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { - const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); - llama_batch u_batch = { - /* .n_tokens = */ (int32_t) n_tokens, - /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, - /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr, - /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr, - /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr, - /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr, - /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr, - /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1, - /* .all_pos_1 = */ batch_all.all_pos_1, - /* .all_seq_id = */ batch_all.all_seq_id, - }; + const uint32_t n_tokens = ubatch.n_tokens; // count the outputs in this u_batch { int32_t n_outputs_new = 0; - if (u_batch.logits && !embd_pooled) { - for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += u_batch.logits[i] != 0; - } - } else if (n_outputs == n_tokens_all) { + if (n_outputs == n_tokens_all) { n_outputs_new = n_tokens; } else { - // keep last output only - if (cur_token + n_tokens >= n_tokens_all) { - n_outputs_new = 1; + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < n_tokens; i++) { + n_outputs_new += (int32_t) (ubatch.output[i] != 0); } } @@ -14656,32 +15627,6 @@ static int llama_decode_internal( int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; GGML_ASSERT(n_threads > 0); - // helpers for smoother batch API transition - // after deprecating the llama_eval calls, these will be removed - if (u_batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1; - } - - u_batch.pos = pos.data(); - } - - if (u_batch.seq_id == nullptr) { - n_seq_id.resize(n_tokens); - seq_id.resize(n_tokens); - seq_id_arr.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - n_seq_id[i] = 1; - seq_id[i].resize(1); - seq_id[i][0] = u_batch.all_seq_id; - seq_id_arr[i] = seq_id[i].data(); - } - - u_batch.n_seq_id = n_seq_id.data(); - u_batch.seq_id = seq_id_arr.data(); - } - // non-causal masks do not use the KV cache if (hparams.causal_attn) { llama_kv_cache_update(&lctx); @@ -14692,7 +15637,7 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + if (!llama_kv_cache_find_slot(kv_self, ubatch)) { return 1; } @@ -14711,7 +15656,7 @@ static int llama_decode_internal( ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false); // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; @@ -14739,7 +15684,7 @@ static int llama_decode_internal( ggml_backend_sched_alloc_graph(lctx.sched, gf); - llama_set_inputs(lctx, u_batch); + llama_set_inputs(lctx, ubatch); llama_graph_compute(lctx, gf, n_threads); @@ -14797,12 +15742,11 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: { - // extract sequence embeddings + // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = lctx.embd_seq; - embd_seq_out.clear(); - for (uint32_t i = 0; i < n_tokens; i++) { - const llama_seq_id seq_id = u_batch.seq_id[i][0]; + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { continue; } @@ -14819,6 +15763,25 @@ static int llama_decode_internal( n_outputs_prev += lctx.n_outputs; } + // set output mappings + { + bool sorted_output = true; + + GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs); + + for (size_t i = 0; i < n_outputs; ++i) { + size_t out_id = lctx.sbatch.out_ids[i]; + lctx.output_ids[out_id] = i; + if (out_id != i) { + sorted_output = false; + } + } + + if (sorted_output) { + lctx.sbatch.out_ids.clear(); + } + } + // set to total number of outputs in the batch, for use in llama_get_logits_ith lctx.n_outputs = n_outputs; @@ -14883,11 +15846,9 @@ static int llama_encode_internal( const int64_t n_embd = hparams.n_embd; - // TODO: simplify or deprecate - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; + lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + + const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); // reserve output buffer if (llama_output_reserve(lctx, n_tokens) < n_tokens) { @@ -14905,36 +15866,10 @@ static int llama_encode_internal( const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; GGML_ASSERT(n_threads > 0); - // helpers for smoother batch API transition - // after deprecating the llama_eval calls, these will be removed - if (batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = batch.all_pos_0 + i*batch.all_pos_1; - } - - batch.pos = pos.data(); - } - - if (batch.seq_id == nullptr) { - n_seq_id.resize(n_tokens); - seq_id.resize(n_tokens); - seq_id_arr.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - n_seq_id[i] = 1; - seq_id[i].resize(1); - seq_id[i][0] = batch.all_seq_id; - seq_id_arr[i] = seq_id[i].data(); - } - - batch.n_seq_id = n_seq_id.data(); - batch.seq_id = seq_id_arr.data(); - } - ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - ggml_cgraph * gf = llama_build_graph(lctx, batch, false); + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false); // the output embeddings after the final encoder normalization struct ggml_tensor * embd = nullptr; @@ -14958,7 +15893,7 @@ static int llama_encode_internal( ggml_backend_sched_alloc_graph(lctx.sched, gf); - llama_set_inputs(lctx, batch); + llama_set_inputs(lctx, ubatch); llama_graph_compute(lctx, gf, n_threads); @@ -14972,12 +15907,13 @@ static int llama_encode_internal( float * embd_out = lctx.embd_enc.data(); ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits // remember the sequence ids used during the encoding - needed for cross attention later lctx.seq_ids_enc.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { - for (int s = 0; s < batch.n_seq_id[i]; s++) { - llama_seq_id seq_id = batch.seq_id[i][s]; + for (int s = 0; s < ubatch.n_seq_id[i]; s++) { + llama_seq_id seq_id = ubatch.seq_id[i][s]; lctx.seq_ids_enc[i].insert(seq_id); } } @@ -15002,8 +15938,10 @@ static int llama_encode_internal( auto & embd_seq_out = lctx.embd_seq; embd_seq_out.clear(); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits + for (uint32_t i = 0; i < n_tokens; i++) { - const llama_seq_id seq_id = batch.seq_id[i][0]; + const llama_seq_id seq_id = ubatch.seq_id[i][0]; if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { continue; } @@ -15281,32 +16219,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } - if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { - { - ggml_backend_sched_reset(lctx.sched); - - ggml_cgraph * gf = llama_build_graph_s_copy(lctx); - - ggml_backend_sched_alloc_graph(lctx.sched, gf); - - llama_set_s_copy(lctx); - - llama_graph_compute(lctx, gf, lctx.cparams.n_threads); - - need_reserve = true; - } - - { - auto & kv_self = lctx.kv_self; - - kv_self.do_copy = false; - - for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].src = i; - } - } - } - // defragment the KV cache if needed if (lctx.kv_self.do_defrag) { llama_kv_cache_defrag_internal(lctx); @@ -15320,10 +16232,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { if (need_reserve) { // TODO: extract to a function // build worst-case graph - int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); - int n_past = lctx.cparams.n_ctx - n_tokens; + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); // initialize scheduler with the worst-case graph ggml_backend_sched_reset(lctx.sched); @@ -15715,6 +16628,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); } + if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { + new_type = GGML_TYPE_F16; + } LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); ++qs.n_fallback; } @@ -15906,7 +16822,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // TODO: avoid hardcoded tensor names - use the TN_* constants if (name.find("attn_v.weight") != std::string::npos || - name.find("attn_qkv.weight") != std::string::npos) { + name.find("attn_qkv.weight") != std::string::npos || + name.find("attn_kv_b.weight")!= std::string::npos) { ++qs.n_attention_wv; } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { qs.has_output = true; @@ -15916,12 +16833,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks - // - // - qs.n_attention_wv == 0 for Mamba models - // - qs.n_attention_wv == model.hparams.n_layer for Transformer models - // - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models - // - GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected"); + { + const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); + // attention layers have a non-zero number of kv heads + int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); + if (llama_model_has_encoder(&model)) { + n_attn_layer *= 3; + } + GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + } size_t total_size_org = 0; size_t total_size_new = 0; @@ -16043,8 +16963,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // do not quantize Mamba's small yet 2D weights // NOTE: can't use LLM_TN here because the layer number is not known quantize &= name.find("ssm_conv1d.weight") == std::string::npos; - quantize &= name.find("ssm_x.weight") == std::string::npos; - quantize &= name.find("ssm_dt.weight") == std::string::npos; // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; @@ -16618,12 +17536,6 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.attn_soft_cap) { - LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__); - params.flash_attn = false; - } - - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); params.flash_attn = false; @@ -16732,7 +17644,7 @@ struct llama_context * llama_new_context_with_model( ggml_type type_v = params.type_v; // Mamba only needs a constant number of KV cache cells per sequence - if (model->arch == LLM_ARCH_MAMBA) { + if (llama_model_is_recurrent(model)) { // Mamba needs at least as many KV cells as there are sequences kept at any time kv_size = std::max((uint32_t) 1, params.n_seq_max); // it's probably best to keep as much precision as possible for the states @@ -16964,10 +17876,11 @@ struct llama_context * llama_new_context_with_model( } // build worst-case graph - int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); - int n_past = cparams.n_ctx - n_tokens; + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); // initialize scheduler with the worst-case graph if (!ggml_backend_sched_reserve(ctx->sched, gf)) { @@ -17079,6 +17992,8 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: case LLM_ARCH_CODESHELL: + case LLM_ARCH_NEMOTRON: + case LLM_ARCH_EXAONE: return LLAMA_ROPE_TYPE_NEOX; // all model arches should be listed explicitly here @@ -17205,6 +18120,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { return model->hparams.dec_start_token_id; } +bool llama_model_is_recurrent(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_MAMBA: return true; + default: return false; + } +} + uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, @@ -17526,7 +18448,9 @@ struct llama_data_write { write_string(rng_str); } - void write_output_ids(const struct llama_context * ctx) { + void write_output_ids(struct llama_context * ctx) { + llama_output_reorder(ctx); + const uint32_t n_outputs = ctx->n_outputs; std::vector output_pos; @@ -17814,8 +18738,11 @@ struct llama_data_read { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - llama_batch batch = llama_batch_init(cell_count, 0, 1); + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; uint32_t n_seq_id; @@ -17829,11 +18756,10 @@ struct llama_data_read { } batch.pos[i] = pos; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = dest_seq_id; } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; if (!llama_kv_cache_find_slot(kv_self, batch)) { - llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } @@ -17845,9 +18771,6 @@ struct llama_data_read { GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); - - // Cleanup - llama_batch_free(batch); } else { // whole KV cache restore @@ -17879,6 +18802,15 @@ struct llama_data_read { } cell.seq_id.insert(seq_id); + + if (kv_self.recurrent) { + int32_t & tail = kv_self.cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } } } @@ -17886,6 +18818,14 @@ struct llama_data_read { kv_self.used = cell_count; } + if (kv_self.recurrent) { + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = kv_self.head + i; + // make sure the recurrent states will keep their restored state + kv_self.cells[cell_id].src = cell_id; + } + } + return true; } @@ -18473,7 +19413,18 @@ struct llama_batch llama_batch_get_one( } struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { - llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; + llama_batch batch = { + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*all_pos_0 =*/ 0, + /*all_pos_1 =*/ 0, + /*all_seq_id =*/ 0, + }; if (embd) { batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); @@ -18559,6 +19510,10 @@ void llama_synchronize(struct llama_context * ctx) { float * llama_get_logits(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder logits for backward compatibility + // TODO: maybe deprecate this + llama_output_reorder(ctx); + return ctx->logits; } @@ -18603,6 +19558,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { float * llama_get_embeddings(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder embeddings for backward compatibility + // TODO: maybe deprecate this + llama_output_reorder(ctx); + return ctx->embd; } @@ -18704,11 +19663,11 @@ llama_token llama_token_pad(const struct llama_model * model) { return llama_token_pad_impl(model->vocab); } -int32_t llama_add_bos_token(const struct llama_model * model) { +bool llama_add_bos_token(const struct llama_model * model) { return llama_add_bos_token_impl(model->vocab); } -int32_t llama_add_eos_token(const struct llama_model * model) { +bool llama_add_eos_token(const struct llama_model * model) { return llama_add_eos_token_impl(model->vocab); } @@ -19009,6 +19968,22 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "Assistant:"; } + } else if (tmpl == "exaone3" || (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]"))) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "user") { + ss << "[|user|]" << trim(message->content) << "\n"; + } else if (role == "assistant") { + ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n"; + } + } + if (add_ass) { + ss << "[|assistant|]"; + } } else { // template not supported return -1; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2f4117a627de0..c832bc9569bbf 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -949,6 +949,58 @@ struct test_rms_norm : public test_case { } }; +// GGML_OP_SSM_CONV +struct test_ssm_conv : public test_case { + const ggml_type type; + const std::array ne_a; + const std::array ne_b; + + std::string vars() override { + return VARS_TO_STR3(type, ne_a, ne_b); + } + + test_ssm_conv(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {10, 10, 10, 1}, + std::array ne_b = {3, 3, 1, 1}) + : type(type), ne_a(ne_a), ne_b(ne_b) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_tensor * out = ggml_ssm_conv(ctx, a, b); + return out; + } +}; + +// GGML_OP_SSM_SCAN +struct test_ssm_scan : public test_case { + const ggml_type type; + + const int64_t d_state; + const int64_t d_inner; + const int64_t n_seq_tokens; + const int64_t n_seqs; + + std::string vars() override { + return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + } + + test_ssm_scan(ggml_type type = GGML_TYPE_F32, + int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) + : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); + ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); + ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); + ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); + ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); + ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + return out; + } +}; + // GGML_OP_MUL_MAT struct test_mul_mat : public test_case { const ggml_type type_a; @@ -1108,6 +1160,58 @@ struct test_sqrt : public test_case { } }; +// GGML_OP_SIN +struct test_sin : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_sin(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_sin(ctx, a); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -100.0f, 100.0f); + } + } +}; + +// GGML_OP_COS +struct test_cos : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_cos(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_cos(ctx, a); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -100.0f, 100.0f); + } + } +}; + // GGML_OP_CLAMP struct test_clamp : public test_case { const ggml_type type; @@ -1652,19 +1756,20 @@ struct test_flash_attn_ext : public test_case { const bool mask; // use mask const float max_bias; // ALiBi + const float logit_softcap; // Gemma 2 const ggml_type type_KV; std::string vars() override { - return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV); + return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16) - : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16) + : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); @@ -1673,7 +1778,28 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr; - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap); + return out; + } +}; + +// GGML_OP_CROSS_ENTROPY_LOSS +struct test_cross_entropy_loss : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels); return out; } }; @@ -2145,6 +2271,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); + // sycl backend will limit task global_range < MAX_INT + // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) + // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) + // these cases are verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend) + // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true)); + // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_conv_transpose_1d()); test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1)); @@ -2232,6 +2365,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); } + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); + + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + #if 1 for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { @@ -2287,6 +2426,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1})); + // sycl backend will limit task global_range < MAX_INT + // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion) + // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.) + // this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend) + // test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1})); + for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (int n_mats : {4, 8}) { @@ -2321,6 +2466,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_sqr()); test_cases.emplace_back(new test_sqrt()); + test_cases.emplace_back(new test_sin()); + test_cases.emplace_back(new test_cos()); test_cases.emplace_back(new test_clamp()); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); @@ -2424,11 +2571,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (bool mask : { true, false } ) { for (float max_bias : { 0.0f, 8.0f }) { if (!mask && max_bias > 0.0f) continue; - for (int nh : { 32, }) { - for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, }) { - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV)); + for (float logit_softcap : {0.0f, 10.0f}) { + if (hs != 128 && logit_softcap != 0.0f) continue; + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV)); + } } } } @@ -2437,6 +2587,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_cross_entropy_loss()); + // these tests are disabled to save execution time, but they can be handy for debugging #if 0 test_cases.emplace_back(new test_llama(1)); @@ -2470,7 +2622,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } GGML_ABORT("fatal error"); - return false; } static void usage(char ** argv) { diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index a353276459b2d..1834c11d894b4 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -1,10 +1,14 @@ #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows #include "ggml.h" +#include #include +#include #include #include #include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -217,7 +221,8 @@ static bool check_gradient( int nargs, float eps, float max_error_abs, - float max_error_rel) { + float max_error_rel, + std::vector expected_vals) { static int n_threads = -1; if (n_threads < 0) { @@ -248,9 +253,10 @@ static bool check_gradient( // ggml_graph_dump_dot(gb, gf, "test-grad0-backward.dot"); for (int i = 0; i < nargs; ++i) { + bool all_g0_bad = true; const int nelements = ggml_nelements(x[i]); for (int k = 0; k < nelements; ++k) { - // compute gradient using finite differences + // Calculate gradient numerically: const float x0 = ggml_get_f32_1d(x[i], k); const float xm = x0 - eps; const float xp = x0 + eps; @@ -267,6 +273,28 @@ static bool check_gradient( const double f1 = ggml_get_f32_1d(f, 0); const double g0 = (f0 - f1)/(2.0*(double) eps); + // The numerical calculation of the gradient fails around noncontinuities (e.g. 0 for ReLU). + // In such cases, provide a vector of expected values and skip the comparison for failed calculations. + if (!expected_vals.empty()) { + bool matches_any = false; + for (const double & ev : expected_vals) { + const double error_abs = std::fabs(g0 - ev); + if (error_abs > max_error_abs) { + continue; + } + const double error_rel = g0 != 0.0 ? fabs(g0 - ev)/fabs(g0) : 0.0; + if (error_rel > max_error_rel) { + continue; + } + matches_any = true; + break; + } + if (!matches_any) { + continue; + } + } + all_g0_bad = false; + ggml_set_f32_1d(x[i], k, x0); // compute gradient using backward graph @@ -278,7 +306,7 @@ static bool check_gradient( const double g1 = ggml_get_f32_1d(x[i]->grad, k); const double error_abs = fabs(g0 - g1); - const double error_rel = g0 != 0 ? fabs(g0 - g1)/fabs(g0) : 0; + const double error_rel = g0 != 0.0 ? fabs(g0 - g1)/fabs(g0) : 0.0; if (error_abs > max_error_abs || error_rel > max_error_rel) { printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n", @@ -287,6 +315,10 @@ static bool check_gradient( return false; } } + if (all_g0_bad) { + printf("%s: numerical calculation of the gradient failed for all values\n", op_name); + return false; + } } return true; @@ -404,7 +436,7 @@ int main(int argc, const char ** argv) { seed_iter = rand(); unsigned seed = rand(); - printf("test-grad0: iter:%d/%d\n", iter, niter); + printf("test-grad0: iter:%d/%d\n", (iter+1), niter); struct ggml_context * ctx0 = ggml_init(params); get_random_dims(ne, 4); @@ -424,7 +456,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1])); - check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f); + check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f, {}); } } @@ -441,7 +473,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1])); - check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f); + check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f, {}); } } @@ -458,7 +490,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1])); - check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -475,7 +507,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1])); - check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -492,7 +524,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1])); - check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f); + check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f, {}); } } @@ -509,7 +541,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0])); - check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -526,7 +558,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0])); - check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f); + check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f, {}); } } @@ -543,7 +575,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0])); - check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f); + check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f, {}); } } @@ -560,7 +592,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, x[0]); - check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -578,7 +610,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0]))); - check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY); + check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); } } @@ -596,7 +628,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0])); - check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -614,7 +646,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0])); - check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -637,7 +669,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1])))); - check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY); + check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); } } @@ -660,25 +692,25 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0])))); - check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY); + check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); } } - // abs (finite differences do not work) - //{ - // const int nargs = 1; + // abs + { + const int nargs = 1; - // for (int ndims = 1; ndims <= 2; ++ndims) { - // for (int i = 0; i < nargs; ++i) { - // x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - // ggml_set_param(ctx0, x[i]); - // } + for (int ndims = 1; ndims <= 4; ++ndims) { + for (int i = 0; i < nargs; ++i) { + x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); + ggml_set_param(ctx0, x[i]); + } - // struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0])); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0])); - // check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f); - // } - //} + check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f, {-1.0, 1.0}); + } + } // sgn { @@ -693,7 +725,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0])); - check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0}); } } @@ -710,7 +742,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0])); - check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -727,7 +759,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0])); - check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0}); } } @@ -745,7 +777,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0])); - check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -776,7 +808,7 @@ int main(int argc, const char ** argv) { GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims); - check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); if (ndims == 2) { // check_mat_mul does not support ndims > 2 check_mat_mul(m, x[1], x[0]); @@ -800,7 +832,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0])); - check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -817,7 +849,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0])); - check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {0.0, 1.0}); } } @@ -835,7 +867,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0])); - check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -854,9 +886,9 @@ int main(int argc, const char ** argv) { #ifdef GGML_SILU_FP16 // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds. - check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY); + check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY, {}); #else - check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); #endif } } @@ -874,7 +906,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f)); - check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY); + check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY, {}); } } @@ -892,7 +924,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s)); - check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -910,7 +942,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1])); - check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -928,7 +960,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1])); - check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY); + check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {}); } } @@ -952,7 +984,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1])); - check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -976,7 +1008,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1])); - check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1004,7 +1036,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1037,7 +1069,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1072,7 +1104,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1109,7 +1141,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1137,7 +1169,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset)); - check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1170,7 +1202,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset)); - check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1194,7 +1226,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset)); - check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1225,7 +1257,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset)); - check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1257,7 +1289,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset)); - check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1291,7 +1323,7 @@ int main(int argc, const char ** argv) { // sum requires contiguous tensor rows struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3))); - check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1319,7 +1351,7 @@ int main(int argc, const char ** argv) { // sum requires contiguous tensor rows struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, x[0]))); - check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1337,7 +1369,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1])); - check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } // diag_mask_inf @@ -1353,7 +1385,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past)); - check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } // diag_mask_zero @@ -1369,7 +1401,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past)); - check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } // softmax @@ -1395,7 +1427,7 @@ int main(int argc, const char ** argv) { 1.0f - eps), ggml_new_f32(ctx0, eps)))); - check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY); + check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY, {}); // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf. // this may result in different gradients too finite differences. // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause. @@ -1412,7 +1444,7 @@ int main(int argc, const char ** argv) { get_random_dims(ne2, 4); for (int ndims = 1; ndims <= 4; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f); + x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f); // the second argument to cross_entropy_loss must sum up to 1 for each row int nr = ggml_nrows(x[1]); @@ -1430,7 +1462,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]); - check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY); + check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1468,7 +1500,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); - check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY); + check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {}); } } } @@ -1508,12 +1540,93 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); - check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY); + check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {}); } } } } + // im2col f32 + { + srand(seed); + const int nargs = 1; + const int ndims = 4; + + for (const bool is_2D : {false, true}) { + int64_t ne0[ndims]; + int64_t ne1[ndims]; + get_random_dims(ne0, ndims); + get_random_dims(ne1, ndims); + + // // Ensure that the output is not zero-sized: + ne1[0] += 8; + ne1[1] += 8; + + if (is_2D) { + ne1[2] = ne0[2]; + } else { + ne1[1] = ne0[1]; + ne0[3] = 1; + ne1[3] = 1; + } + + // The order of arguments is swapped because the first tensor is only used for its shape. + x[1] = get_random_tensor_f16(ctx0, ndims, ne0, -1.0f, 1.0f); + x[0] = get_random_tensor_f32(ctx0, ndims, ne1, -1.0f, 1.0f); + + ggml_set_param(ctx0, x[0]); + + const int s0 = 1 + irand(2); + const int s1 = is_2D ? 1 + irand(2) : 0; + const int p0 = 0 + irand(2); + const int p1 = is_2D ? 0 + irand(2) : 0; + const int d0 = 1 + irand(2); + const int d1 = is_2D ? 1 + irand(2) : 0; + + struct ggml_tensor * f = ggml_sum(ctx0, ggml_im2col(ctx0, x[1], x[0], s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F32)); + + GGML_PRINT_DEBUG("im2col f32: is_2D=%s, s0=%d, s1=%d, p0=%d, p1=%d, d0=%d, d1=%d\n", is_2D ? "yes" : "no", s0, s1, p0, p1, d0, d1); + check_gradient("im2col f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {}); + } + } + + // pool_2d f32 + { + srand(seed); + const int nargs = 1; + const int ndims = 4; + + for (const enum ggml_op_pool op : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { + int64_t ne0[ndims]; + get_random_dims(ne0, ndims); + + ne0[0] += 8; + ne0[1] += 8; + + x[0] = get_random_tensor_f32(ctx0, ndims, ne0, -1.0f, 1.0f); + + ggml_set_param(ctx0, x[0]); + + const int k0 = 2 + irand(2); + const int k1 = 2 + irand(2); + const int s0 = 2 + irand(2); + const int s1 = 2 + irand(2); + const int p0 = 0 + irand(2); + const int p1 = 0 + irand(2); + + struct ggml_tensor * f = ggml_sum(ctx0, ggml_pool_2d(ctx0, x[0], op, k0, k1, s0, s1, p0, p1)); + + GGML_PRINT_DEBUG("ggml_pool_2d f32: op=%s k0=%d, k1=%d, s0=%d, s1=%d, p0=%d, p1=%d\n", + op == GGML_OP_POOL_MAX ? "max" : "avg", k0, k1, s0, s1, p0, p1); + std::vector expected_vals; + if (op == GGML_OP_POOL_MAX) { + expected_vals.push_back(0.0); + expected_vals.push_back(1.0); + } + check_gradient("ggml_pool_2d f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, expected_vals); + } + } + // flash_attn f32 // TODO: adapt to ggml_flash_attn_ext() changes //{ @@ -1553,7 +1666,7 @@ int main(int argc, const char ** argv) { // struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); + // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY, {}); // } // } // } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 68f971bfe2cf3..9c4e7d18e37b2 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -503,7 +503,7 @@ static void test_special_chars() { "aaaaabcccc", "aaaabccc", "aaaabccccc", - "🔵🟠✅❌abc❌✅🟠🔵" + "🔵🟠✅❌abc❌✅🟠🔵", "🔵🟠abc🟠🔵" } ); diff --git a/tests/test-lora-conversion-inference.sh b/tests/test-lora-conversion-inference.sh new file mode 100755 index 0000000000000..fe90ce0d1b801 --- /dev/null +++ b/tests/test-lora-conversion-inference.sh @@ -0,0 +1,139 @@ +#!/bin/bash +set -e + +# Array of models to iterate over +declare -a params=( + "Gemma2ForCausalLM 64" + "LlamaForCausalLM 64" + "Phi3ForCausalLM 64" +) + +MODELS_REPO=lora-tests +MODELS_REPO_URL=https://huggingface.co/ggml-org/$MODELS_REPO + +# Clone the Hugging Face repository if the directory does not exist +if [ ! -d "$MODELS_REPO" ]; then + echo "Cloning the Hugging Face repository..." + git clone $MODELS_REPO_URL --depth 1 +else + echo "Repository already exists. Skipping clone." +fi + +# Array to store results to print +results=() + +trim_leading_whitespace() { + local input_string="$1" + echo "${input_string#"${input_string%%[![:space:]]*}"}" +} + +extract_starting_substring() { + local reference_string="$1" + local target_string="$2" + + local target_length=${#target_string} + echo "${reference_string:0:$target_length}" +} + +get_first_word() { + local input_string="$1" + read -r first_word _ <<< "$input_string" + echo "$first_word" +} + +# Load the expected strings +EXPECTED_BASE_FULL=$(cat $MODELS_REPO/data/pale_blue_dot.txt) +EXPECTED_LORA_FULL=$(cat $MODELS_REPO/data/bohemian_rhapsody.txt) +EXPECTED_BASE_FIRST_WORD=$(get_first_word "$EXPECTED_BASE_FULL") +EXPECTED_LORA_FIRST_WORD=$(get_first_word "$EXPECTED_LORA_FULL") + +run_conversion_and_inference_lora() { + local model_name=$1 + local hidden_size=$2 + + echo -e "\n\n-------- RUNNING TEST FOR MODEL $model_name --------\n\n" + + # Convert safetensors to gguf + echo "Running convert_hf_to_gguf.py for $model_name with hidden_size $hidden_size..." + python convert_hf_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \ + --outfile $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ + --outtype f32 + + echo -e "\n\n---------------------------\n\n" + echo "Running convert_lora_to_gguf.py for $model_name with hidden_size $hidden_size..." + python3 convert_lora_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora \ + --base $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \ + --outtype f32 + + echo -e "\n\n---------------------------\n\n" + echo "Running llama-export-lora with lora for $model_name with hidden_size $hidden_size..." + ./llama-export-lora \ + -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ + -o $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \ + --lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf + + # Run inference + echo -e "\n\n---------------------------\n\n" + echo "Running llama-cli without lora for $model_name with hidden_size $hidden_size..." + OUTPUT_BASE=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ + -p "$EXPECTED_BASE_FIRST_WORD" -n 50 --seed 42 --temp 0) + + echo -e "\n\n---------------------------\n\n" + echo "Running llama-cli with hot lora for $model_name with hidden_size $hidden_size..." + OUTPUT_LORA_HOT=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ + --lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf \ + -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0) + + echo -e "\n\n---------------------------\n\n" + echo "Running llama-cli with merged lora for $model_name with hidden_size $hidden_size..." + OUTPUT_LORA_MERGED=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \ + -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0) + + # Remove any initial white space + OUTPUT_BASE=$(trim_leading_whitespace "$OUTPUT_BASE") + OUTPUT_LORA_HOT=$(trim_leading_whitespace "$OUTPUT_LORA_HOT") + OUTPUT_LORA_MERGED=$(trim_leading_whitespace "$OUTPUT_LORA_MERGED") + # Extract the corresponding substring from full string + EXPECTED_BASE=$(extract_starting_substring "$EXPECTED_BASE_FULL" "$OUTPUT_BASE") + EXPECTED_LORA=$(extract_starting_substring "$EXPECTED_LORA_FULL" "$OUTPUT_LORA_HOT") + + # Assert output equals the expected output + if [[ "$OUTPUT_BASE" != "$EXPECTED_BASE" ]]; then + echo "Error: $model_name OUTPUT_BASE does not start with the expected string." + echo -e "Out=$OUTPUT_BASE\n\nExp=$EXPECTED_BASE" + exit 1 + fi + if [[ "$OUTPUT_LORA_HOT" != "$EXPECTED_LORA" ]]; then + echo "Error: $model_name OUTPUT_LORA_HOT does not start with the expected string." + echo -e "Out=$OUTPUT_LORA_HOT\n\nExp=$EXPECTED_LORA" + exit 1 + fi + if [[ "$OUTPUT_LORA_MERGED" != "$EXPECTED_LORA" ]]; then + echo "Error: $model_name OUTPUT_LORA_MERGED does not start with the expected string." + echo -e "Out=$OUTPUT_LORA_MERGED\n\nExp=$EXPECTED_LORA" + exit 1 + fi + + # Store the results + results+=(" + \n\033[1mResults for $model_name with hidden_size $hidden_size:\033[0m + \n\033[32m • Base:\n$OUTPUT_BASE + \n\033[34m • Lora hot:\n$OUTPUT_LORA_HOT + \n\033[36m • Lora merged:\n$OUTPUT_LORA_MERGED + \n \033[0m + ") + + echo "All tests passed for $model_name with hidden_size $hidden_size!" +} + +# Run test for each model +for param in "${params[@]}"; do + run_conversion_and_inference_lora $param +done + +# Print results +echo -e "\n\n---------------------------\n\n" +echo -e "\n\033[1mSummary of All Results:\033[0m" +for result in "${results[@]}"; do + echo -e "$result" +done diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index de858bd3b87e3..6c2a5db9accf2 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -166,12 +166,12 @@ static void test_sampler_queue( for (auto s : samplers_sequence) { switch (s){ case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; - case 'f': GGML_ABORT("tail_free test not implemented"); break; - case 'y': GGML_ABORT("typical test not implemented"); break; + case 'f': GGML_ABORT("tail_free test not implemented"); + case 'y': GGML_ABORT("typical test not implemented"); case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; - case 't': GGML_ABORT("temperature test not implemented"); break; - default : GGML_ABORT("Unknown sampler"); break; + case 't': GGML_ABORT("temperature test not implemented"); + default : GGML_ABORT("Unknown sampler"); } llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests