Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change to_array routines to fix pointer issues #175

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/6_Autograd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ if(CMAKE_BUILD_TESTS)
COMMAND autograd
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
set_tests_properties(fautograd PROPERTIES PASS_REGULAR_EXPRESSION
"2.00000000 3.00000000")
"test completed successfully")
endif()
43 changes: 34 additions & 9 deletions examples/6_Autograd/autograd.f90
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,48 @@ program example
integer, parameter :: wp = sp

! Set up Fortran data structures
real(wp), dimension(2), target :: in_data
real(wp), dimension(:), pointer :: out_data
integer :: tensor_layout(1) = [1]
integer, parameter :: n=2, m=5
real(wp), dimension(n,m), target :: in_data
real(wp), dimension(:,:), pointer :: out_data
integer :: tensor_layout(2) = [1, 2]
integer :: i, j

! Set up Torch data structures
type(torch_tensor) :: a
type(torch_tensor) :: tensor

! initialize in_data with some fake data
do j = 1, m
do i = 1, n
in_data(i,j) = ((i-1)*m + j) * 1.0_wp
end do
end do

! Construct a Torch Tensor from a Fortran array
in_data(:) = [2.0, 3.0]
call torch_tensor_from_array(a, in_data, tensor_layout, torch_kCPU)
call torch_tensor_from_array(tensor, in_data, tensor_layout, torch_kCPU)

! check tensor rank and shape match those of in_data
if (tensor%get_rank() /= 2) then
print *, "Error :: rank should be 2"
stop 1
end if
if (any(tensor%get_shape() /= [2, 5])) then
print *, "Error :: shape should be (2, 5)"
stop 1
end if

! Extract a Fortran array from a Torch tensor
call torch_tensor_to_array(a, out_data, shape(in_data))
write (*,*) "a = ", out_data(:)
call torch_tensor_to_array(tensor, out_data, shape(in_data))

! check that the data match
if (any(in_data /= out_data)) then
print *, "Error :: in_data does not match out_data"
stop 1
end if
jwallwork23 marked this conversation as resolved.
Show resolved Hide resolved

! Cleanup
nullify(out_data)
call torch_tensor_delete(a)
call torch_tensor_delete(tensor)

write (*,*) "test completed successfully"

end program example
12 changes: 12 additions & 0 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ int torch_tensor_get_device_index(const torch_tensor_t tensor)
return t->device().index();
}

int torch_tensor_get_rank(const torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
return t->sizes().size();
}

const long int* torch_tensor_get_sizes(const torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
return t->sizes().data();
}

void torch_tensor_delete(torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
Expand Down
14 changes: 14 additions & 0 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ EXPORT_C void torch_tensor_print(const torch_tensor_t tensor);
*/
EXPORT_C int torch_tensor_get_device_index(const torch_tensor_t tensor);

/**
* Function to determine the rank of a Torch Tensor
* @param Torch Tensor to determine the rank of
* @return rank of the Torch Tensor
*/
EXPORT_C int torch_tensor_get_rank(const torch_tensor_t tensor);

/**
* Function to determine the sizes (shape) of a Torch Tensor
* @param Torch Tensor to determine the rank of
* @return pointer to the sizes array of the Torch Tensor
*/
EXPORT_C const long int* torch_tensor_get_sizes(const torch_tensor_t tensor);

/**
* Function to delete a Torch Tensor to clean up
* @param Torch Tensor to delete
Expand Down
Loading
Loading