Skip to content

Commit

Permalink
Clean up TOD estimator.
Browse files Browse the repository at this point in the history
  • Loading branch information
astamm committed Nov 15, 2023
1 parent f7f8a70 commit d1a361a
Show file tree
Hide file tree
Showing 4 changed files with 547 additions and 859 deletions.
45 changes: 18 additions & 27 deletions Anima/diffusion/odf/tod_estimator/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,39 +1,30 @@
if(BUILD_TOOLS AND USE_NLOPT)

project(animaTODEstimator)
project(animaTODEstimator)

## #############################################################################
## List Sources
## #############################################################################
# ############################################################################
# List Sources
# ############################################################################

list_source_files(${PROJECT_NAME}
${CMAKE_CURRENT_SOURCE_DIR}
)
list_source_files(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR})

# ############################################################################
# add executable
# ############################################################################

## #############################################################################
## add executable
## #############################################################################
add_executable(${PROJECT_NAME} ${${PROJECT_NAME}_CFILES})

add_executable(${PROJECT_NAME}
${${PROJECT_NAME}_CFILES}
)
# ############################################################################
# Link
# ############################################################################

## #############################################################################
## Link
## #############################################################################
target_link_libraries(${PROJECT_NAME} ${ITKIO_LIBRARIES} AnimaDataIO
AnimaSHTools)

target_link_libraries(${PROJECT_NAME}
${VTK_LIBRARIES}
${ITKIO_LIBRARIES}
AnimaSHTools
AnimaDataIO
)
# ############################################################################
# install
# ############################################################################

## #############################################################################
## install
## #############################################################################

set_exe_install_rules(${PROJECT_NAME})
set_exe_install_rules(${PROJECT_NAME})

endif()
71 changes: 38 additions & 33 deletions Anima/diffusion/odf/tod_estimator/animaTODEstimator.cxx
Original file line number Diff line number Diff line change
@@ -1,54 +1,61 @@
#include <animaReadWriteFunctions.h>
#include <animaTODEstimatorImageFilter.h>

#include <itkTimeProbe.h>
#include <tclap/CmdLine.h>

#include <animaReadWriteFunctions.h>
#include <tclap/CmdLine.h>

void eventCallback (itk::Object* caller, const itk::EventObject& event, void* clientData)
void eventCallback(itk::Object *caller, const itk::EventObject &event, void *clientData)
{
itk::ProcessObject * processObject = (itk::ProcessObject*) caller;
std::cout<<"\033[K\rProgression: "<<(int)(processObject->GetProgress() * 100)<<"%"<<std::flush;
itk::ProcessObject *processObject = (itk::ProcessObject *)caller;
std::cout << "\033[K\rProgression: " << (int)(processObject->GetProgress() * 100) << "%" << std::flush;
}

int main(int argc, char **argv)
{
TCLAP::CmdLine cmd("INRIA / IRISA - VisAGeS/Empenn Team", ' ',ANIMA_VERSION);

TCLAP::ValueArg<std::string> inArg("i","input","Input tractography image (.vtk or .vtp)",true,"","input tractography image",cmd);
TCLAP::ValueArg<std::string> resArg("o","outputfile","Result TOD image",true,"","result TOD image",cmd);
TCLAP::ValueArg<std::string> refArg("g","geometry","Output image geometry",true,"","output geometry",cmd);

TCLAP::SwitchArg normArg("N", "Normalize", "Normalize TOD", cmd, false);

TCLAP::ValueArg<unsigned int> orderArg("k","order","Order of spherical harmonics basis (default 4)",false,4,"Order of SH basis",cmd);

TCLAP::ValueArg<unsigned int> nbpArg("p","numberofthreads","Number of threads to run on (default: all cores)",false,itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads(),"number of threads",cmd);
TCLAP::CmdLine cmd("INRIA / IRISA - VisAGeS/Empenn Team", ' ', ANIMA_VERSION);

TCLAP::ValueArg<std::string> inArg(
"i", "input-file",
"A string specifying the name of a file storing the input tractography image. Supported formats are `.vtk`, `.vtp` or `.fds`.",
true, "", "input tractography image", cmd);
TCLAP::ValueArg<std::string> outArg(
"o", "output-file",
"A string specifying the name of a file storing the output TOD image.",
true, "", "output TOD image", cmd);
TCLAP::ValueArg<std::string> refArg(
"g", "geometry-file",
"A string specifying the name of a file storing the reference geometry image.",
true, "", "reference geometry image", cmd);

TCLAP::SwitchArg normArg(
"N", "normalize-tod",
"A switch to turn on TOD normalization.",
cmd, false);

TCLAP::ValueArg<unsigned int> nbpArg(
"T", "nb-threads",
"An integer value specifying the number of threads to run on (default: all cores).",
false, itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads(), "number of threads", cmd);

try
{
cmd.parse(argc,argv);
cmd.parse(argc, argv);
}
catch (TCLAP::ArgException& e)
catch (TCLAP::ArgException &e)
{
std::cerr << "Error: " << e.error() << "for argument " << e.argId() << std::endl;
return EXIT_FAILURE;
}

typedef anima::TODEstimatorImageFilter FilterType;
using FilterType = anima::TODEstimatorImageFilter<float>;
using InputImageType = FilterType::InputImageType;

FilterType::Pointer mainFilter = FilterType::New();

// if (orderArg.getValue() % 2 == 0)
// mainFilter->SetLOrder(orderArg.getValue());
// else
// mainFilter->SetLOrder(orderArg.getValue() - 1);
typedef FilterType::InputImageType InputImageType;
mainFilter->SetInput(anima::readImage<InputImageType>(refArg.getValue()));

mainFilter->SetLOrder(orderArg.getValue());
mainFilter->SetInputFileName(inArg.getValue());
mainFilter->SetRefFileName(refArg.getValue());
mainFilter->SetNormalize(normArg.getValue());
mainFilter->SetReferenceFileName(refArg.getValue());
mainFilter->SetUseNormalization(normArg.getValue());
mainFilter->SetNumberOfWorkUnits(nbpArg.getValue());

itk::CStyleCommand::Pointer callback = itk::CStyleCommand::New();
Expand All @@ -57,14 +64,12 @@ int main(int argc, char **argv)

itk::TimeProbe tmpTime;
tmpTime.Start();

mainFilter->Update();

tmpTime.Stop();

std::cout << std::endl << "Execution Time: " << tmpTime.GetTotal() << std::endl;
std::cout << "\nExecution Time: " << tmpTime.GetTotal() << "s" << std::endl;

anima::writeImage <FilterType::TOutputImage> (resArg.getValue(),mainFilter->GetOutput());
anima::writeImage<FilterType::OutputImageType>(outArg.getValue(), mainFilter->GetOutput());

return EXIT_SUCCESS;
}
176 changes: 75 additions & 101 deletions Anima/diffusion/odf/tod_estimator/animaTODEstimatorImageFilter.h
Original file line number Diff line number Diff line change
@@ -1,138 +1,112 @@
#pragma once

#include <iostream>
#include <itkImageSource.h>
#include <itkVectorImage.h>
#include <itkImage.h>
#include <vector>

#include <animaNumberedThreadImageToImageFilter.h>
#include <animaODFSphericalHarmonicBasis.h>
#include <animaVectorOperations.h>
#include <animaReadWriteFunctions.h>

#include <vtkPoints.h>
#include <vtkVector.h>

namespace anima
{

//template <typename TOutputPixelType>
class TODEstimatorImageFilter :
public anima::NumberedThreadImageToImageFilter < itk::Image<double,3>, itk::VectorImage<double,3> >
{
public:

typedef TODEstimatorImageFilter Self;
// typedef itk::Vector<float, 3> pointType;
typedef itk::Point<float, 3> PointType;
// typedef PointType DirType;
typedef itk::Vector<double, 3> DirType;
typedef std::vector<DirType> DirVectorType;
typedef std::vector<PointType> FiberType;

typedef double MathScalarType;

typedef itk::Matrix <MathScalarType,3,3> Matrix3DType;
typedef itk::Vector <MathScalarType,3> Vector3DType;
typedef itk::VariableLengthVector <MathScalarType> VectorType;

typedef anima::ODFSphericalHarmonicBasis baseSH;
typedef std::complex <double> complexType;

typedef itk::VectorImage<double, 3> TOutputImage;
typedef itk::Image<int, 3> TRefImage;

typedef anima::NumberedThreadImageToImageFilter <InputImageType, OutputImageType> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;


itkNewMacro(Self)

itkTypeMacro(TODEstimatorImageFilter, anima::NumberedThreadImageToImageFilter);

typedef typename TOutputImage::Pointer OutputImagePointer;

typedef typename Superclass::InputImageRegionType InputImageRegionType;
typedef typename Superclass::OutputImageRegionType OutputImageRegionType;

itkSetMacro(InputFileName,std::string);
itkSetMacro(RefFileName,std::string);
itkSetMacro(LOrder,unsigned int);
itkSetMacro(Normalize, bool);


protected:
TODEstimatorImageFilter()
: Superclass()
template <typename ScalarType>
class TODEstimatorImageFilter : public anima::NumberedThreadImageToImageFilter<itk::Image<ScalarType, 3>, itk::VectorImage<ScalarType, 3>>
{
}
public:
/** Standard class typedefs. */
using Self = TODEstimatorImageFilter;
using InputImageType = itk::Image<ScalarType, 3>;
using OutputImageType = itk::VectorImage<ScalarType, 3>;
using ReferenceImageType = itk::Image<unsigned int, 3>;
using Superclass = anima::NumberedThreadImageToImageFilter<InputImageType, OutputImageType>;
using Pointer = itk::SmartPointer<Self>;
using ConstPointer = itk::SmartPointer<const Self>;

virtual ~TODEstimatorImageFilter()
{
/** Method for creation through the object factory. */
itkNewMacro(Self);

/** Run-time type information (and related methods) */
itkTypeMacro(TODEstimatorImageFilter, anima::NumberedThreadImageToImageFilter);

}
/** Superclass typedefs. */
using InputImageRegionType = typename Superclass::InputImageRegionType;
using OutputImageRegionType = typename Superclass::OutputImageRegionType;

// void GenerateData() ITK_OVERRIDE;
using InputImagePointerType = typename InputImageType::Pointer;
using OutputImagePointerType = typename OutputImageType::Pointer;
using ReferenceImagePointerType = typename ReferenceImageType::Pointer;
using InputImagePixelType = typename InputImageType::PixelType;
using OutputImagePixelType = typename OutputImageType::PixelType;
using ReferenceImagePixelType = ReferenceImageType::PixelType;

void BeforeThreadedGenerateData() ITK_OVERRIDE;
void DynamicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread) ITK_OVERRIDE;
void AfterThreadedGenerateData() ITK_OVERRIDE;
using PointType = itk::Point<ScalarType, 3>;
using DirType = itk::Vector<double, 3>;
using DirVectorType = std::vector<DirType>;
using FiberType = std::vector<PointType>;
using Matrix3DType = itk::Matrix<double, 3, 3>;
using Vector3DType = itk::Vector<double, 3>;
using MatrixType = vnl_matrix<double>;
using BasisType = anima::ODFSphericalHarmonicBasis;
using ComplexType = std::complex<double>;

itkSetMacro(InputFileName, std::string);
itkSetMacro(ReferenceFileName, std::string);
itkSetMacro(LOrder, unsigned int);
itkSetMacro(UseNormalization, bool);

FiberType readFiber(vtkIdType numberOfPoints, const vtkIdType *indices, vtkPoints *points);
PointType getCenterVoxel(int index, FiberType &fiber);
DirType getFiberDirection(int index, FiberType &fiber);
void getSHCoefs(DirType dir, VectorType &resSH, baseSH &basis);
void ComputeCoefs();
void processFiber(FiberType &fiber, baseSH &basis);
protected:
TODEstimatorImageFilter() {}

void getMainDirections(DirVectorType inDirs, DirVectorType &mainDirs);
double getEuclideanDistance(DirType dir1, DirType dir2);
DirType getNewClusterAverage(int numCluster, DirVectorType &dirs, std::vector<int> &cluster);
virtual ~TODEstimatorImageFilter() {}

void getSHCoef(DirType dir, VectorType &coefs);
void BeforeThreadedGenerateData() ITK_OVERRIDE;
void DynamicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread) ITK_OVERRIDE;

void precomputeSH();
void discretizeODF(VectorType ODFCoefs, std::vector<double> &ODFDiscret);
VectorType getSquareRootODF(std::vector<double> ODFDiscret);
VectorType getSquareODF(std::vector<double> ODFDiscret);
void getAverageCoefs(std::vector<VectorType> &vecCoefs, VectorType &avgCoef);

void averageODFs(std::vector<VectorType> &vecCoefs, VectorType &resOdf);
FiberType ReadFiber(vtkIdType numberOfPoints, const vtkIdType *indices, vtkPoints *points);
PointType GetCenterVoxel(int index, FiberType &fiber);
DirType GetFiberDirection(int index, FiberType &fiber);
void GetSHCoefs(DirType dir, OutputImagePixelType &resSH, BasisType &basis);
void ComputeCoefs();
void ProcessFiber(FiberType &fiber, BasisType &basis);

vnl_matrix <double> GetRotationMatrix(DirType dir1, DirType dir2);
void GetMainDirections(DirVectorType inDirs, DirVectorType &mainDirs);
double GetEuclideanDistance(DirType dir1, DirType dir2);
DirType GetNewClusterAverage(int numCluster, DirVectorType &dirs, std::vector<int> &cluster);

// void GenerateOutputInformation() ITK_OVERRIDE;
void GetSHCoef(DirType dir, OutputImagePixelType &coefs);

private:
ITK_DISALLOW_COPY_AND_ASSIGN(TODEstimatorImageFilter);
void PrecomputeSH();
void DiscretizeODF(OutputImagePixelType ODFCoefs, std::vector<double> &ODFDiscret);
OutputImagePixelType GetSquareRootODF(std::vector<double> ODFDiscret);
OutputImagePixelType GetSquareODF(std::vector<double> ODFDiscret);
void GetAverageCoefs(std::vector<OutputImagePixelType> &vecCoefs, OutputImagePixelType &avgCoef);

std::string m_InputFileName;
std::string m_RefFileName;
void AverageODFs(std::vector<OutputImagePixelType> &vecCoefs, OutputImagePixelType &resOdf);

DirType m_CstDir;
MatrixType GetRotationMatrix(DirType dir1, DirType dir2);

unsigned int m_LOrder;
int m_VectorLength;
private:
ITK_DISALLOW_COPY_AND_ASSIGN(TODEstimatorImageFilter);

double m_NbSample;
std::string m_InputFileName;
std::string m_ReferenceFileName;

bool m_Normalize;
DirType m_CstDir;

unsigned int m_LOrder;
int m_VectorLength;

VectorType m_GaussCoefs;
double m_NbSample;

std::vector<std::vector<double>> m_SphereSampl;
vnl_matrix<double> m_SpherHarm;
bool m_UseNormalization;

anima::ODFSphericalHarmonicBasis *m_ODFSHBasis;
std::vector<std::vector<DirType>> m_ImgDir;
OutputImagePixelType m_GaussCoefs;

itk::Image<int, 3>::Pointer test;
std::vector<std::vector<double>> m_SphereSampl;
MatrixType m_SpherHarm;

};
anima::ODFSphericalHarmonicBasis *m_ODFSHBasis;
std::vector<DirVectorType> m_ImgDir;
};
} // end namespace anima

#include "animaTODEstimatorImageFilter.hxx"
Loading

0 comments on commit d1a361a

Please sign in to comment.