Skip to content

Commit

Permalink
bugfix: fix for windows and arm long long int
Browse files Browse the repository at this point in the history
fixes [#183](#183)

There is an issue when building on mac (arm_64) or windows. The version
of `libtorch` exposes a torch tensors shape (`t->sizes().data()`) as a
`const long long int*` instead of just a `const long int*` like on linux
and mac (x86).

This commit adds preprocessor macro to switch between implementations
automatically detecting the correct version at CMake build stage.
  • Loading branch information
TomMelt committed Nov 19, 2024
1 parent d64358c commit 23c8c34
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 7 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/fypp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ jobs:

- name: Check ftorch.fypp matches ftorch.f90
run: |
fypp src/ftorch.fypp src/temp.f90_temp
if ! diff -q src/ftorch.f90 src/temp.f90_temp; then
echo "Error: The code in ftorch.f90 does not match that expected from ftorch.fypp."
fypp src/ftorch.fypp src/temp.F90_temp
if ! diff -q src/ftorch.F90 src/temp.F90_temp; then
echo "Error: The code in ftorch.F90 does not match that expected from ftorch.fypp."
echo "Please re-run fypp on ftorch.fypp to ensure consistency and re-commit."
exit 1
else
Expand Down
2 changes: 1 addition & 1 deletion examples/n_c_and_cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ set(CMAKE_INSTALL_RPATH $ORIGIN/${relDir})
find_package(Torch REQUIRED)

# Library with C and Fortran bindings
add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.f90)
add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.F90)
add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME})
set_target_properties(${LIB_NAME} PROPERTIES
PUBLIC_HEADER "ctorch.h"
Expand Down
10 changes: 9 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ set(CMAKE_INSTALL_RPATH $ORIGIN/${relDir})
find_package(Torch REQUIRED)

# Library with C and Fortran bindings
add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.f90 ftorch_test_utils.f90)
add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.F90 ftorch_test_utils.f90)

if(UNIX)
message(STATUS "CMAKE_SYSTEM_PROCESSOR = ${CMAKE_SYSTEM_PROCESSOR}")
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
target_compile_definitions(${LIB_NAME} PRIVATE UNIX)
endif()
endif()

# Add an alias FTorch::ftorch for the library
add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME})
set_target_properties(${LIB_NAME} PROPERTIES
Expand Down
8 changes: 8 additions & 0 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,19 @@ int torch_tensor_get_rank(const torch_tensor_t tensor)
return t->sizes().size();
}

#ifdef UNIX
const long int* torch_tensor_get_sizes(const torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
return t->sizes().data();
}
#else
const long long int* torch_tensor_get_sizes(const torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
return t->sizes().data();
}
#endif

void torch_tensor_delete(torch_tensor_t tensor)
{
Expand Down
4 changes: 4 additions & 0 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ EXPORT_C int torch_tensor_get_rank(const torch_tensor_t tensor);
* @param Torch Tensor to determine the rank of
* @return pointer to the sizes array of the Torch Tensor
*/
#ifdef UNIX
EXPORT_C const long int* torch_tensor_get_sizes(const torch_tensor_t tensor);
#else
EXPORT_C const long long int* torch_tensor_get_sizes(const torch_tensor_t tensor);
#endif

/**
* Function to delete a Torch Tensor to clean up
Expand Down
6 changes: 5 additions & 1 deletion src/ftorch.f90 → src/ftorch.F90
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,13 @@ end function get_rank

!> Determines the shape of a tensor.
function get_shape(self) result(sizes)
use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr
use, intrinsic :: iso_c_binding, only : c_int, c_long, c_long_long, c_ptr
class(torch_tensor), intent(in) :: self
#ifdef UNIX
integer(kind=c_long), pointer :: sizes(:) !! Pointer to tensor data
#else
integer(kind=c_long_long), pointer :: sizes(:) !! Pointer to tensor data
#endif
integer(kind=int32) :: ndims(1)
type(c_ptr) :: cptr

Expand Down
6 changes: 5 additions & 1 deletion src/ftorch.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,13 @@ contains

!> Determines the shape of a tensor.
function get_shape(self) result(sizes)
use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr
use, intrinsic :: iso_c_binding, only : c_int, c_long, c_long_long, c_ptr
class(torch_tensor), intent(in) :: self
#ifdef UNIX
integer(kind=c_long), pointer :: sizes(:) !! Pointer to tensor data
#else
integer(kind=c_long_long), pointer :: sizes(:) !! Pointer to tensor data
#endif
integer(kind=int32) :: ndims(1)
type(c_ptr) :: cptr

Expand Down

0 comments on commit 23c8c34

Please sign in to comment.