-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from nicolas-chaulet/cpuversion
Cpuversion
- Loading branch information
Showing
34 changed files
with
498 additions
and
470 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#------------------------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See https://go.microsoft.com/fwlink/?linkid=2090316 for license information. | ||
#------------------------------------------------------------------------------------------------------------- | ||
|
||
FROM ubuntu:bionic | ||
|
||
# Avoid warnings by switching to noninteractive | ||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
# This Dockerfile adds a non-root user with sudo access. Use the "remoteUser" | ||
# property in devcontainer.json to use it. On Linux, the container user's GID/UIDs | ||
# will be updated to match your local UID/GID (when using the dockerFile property). | ||
# See https://aka.ms/vscode-remote/containers/non-root-user for details. | ||
ARG USERNAME=vscode | ||
ARG USER_UID=1000 | ||
ARG USER_GID=$USER_UID | ||
|
||
# Uncomment the following COPY line and the corresponding lines in the `RUN` command if you wish to | ||
# include your requirements in the image itself. It is suggested that you only do this if your | ||
# requirements rarely (if ever) change. | ||
|
||
RUN apt-get update \ | ||
&& apt-get install -y --fix-missing --no-install-recommends\ | ||
libffi-dev libssl-dev build-essential \ | ||
python3-pip python3-dev python3-venv python3-setuptools\ | ||
git iproute2 procps lsb-release clang-format \ | ||
&& apt-get clean \ | ||
&& rm -rf /var/lib/apt/lists/* | ||
|
||
RUN pip3 install -U pip | ||
RUN pip3 install torch numpy scikit-learn flake8 setuptools | ||
RUN pip3 install torch_cluster torch_sparse torch_scatter torch_geometric |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// For format details, see https://aka.ms/vscode-remote/devcontainer.json or this file's README at: | ||
// https://github.com/microsoft/vscode-dev-containers/tree/v0.101.1/containers/python-3 | ||
{ | ||
"name": "Python 3", | ||
"context": "..", | ||
"dockerFile": "Dockerfile", | ||
// Set *default* container specific settings.json values on container create. | ||
"settings": { | ||
"terminal.integrated.shell.linux": "/bin/bash", | ||
"python.pythonPath": "/usr/local/bin/python", | ||
"python.linting.enabled": true, | ||
"python.linting.pylintEnabled": true, | ||
"python.linting.pylintPath": "/usr/local/bin/pylint" | ||
}, | ||
// Add the IDs of extensions you want installed when the container is created. | ||
"extensions": [ | ||
"ms-python.python", | ||
"ms-vscode.cpptools" | ||
] | ||
// Use 'forwardPorts' to make a list of ports inside the container available locally. | ||
// "forwardPorts": [], | ||
// Use 'postCreateCommand' to run commands after the container is created. | ||
// "postCreateCommand": "pip install -r requirements.txt", | ||
// Uncomment to connect as a non-root user. See https://aka.ms/vscode-remote/containers/non-root. | ||
// "remoteUser": "vscode" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#pragma once | ||
#include <torch/extension.h> | ||
at::Tensor fps(at::Tensor points, const int nsamples, bool random = true); |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#pragma once | ||
#include <torch/extension.h> | ||
|
||
at::Tensor knn_interpolate(at::Tensor features, at::Tensor idx, at::Tensor weight); | ||
|
||
at::Tensor knn_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tensor weight, | ||
const int m); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#pragma once | ||
#include <torch/extension.h> | ||
std::pair<at::Tensor, at::Tensor> dense_knn(at::Tensor support, at::Tensor query, int k); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#include <torch/extension.h> | ||
|
||
#include "compat.h" | ||
#include "utils.h" | ||
|
||
at::Tensor get_dist(at::Tensor x, ptrdiff_t index) | ||
{ | ||
return (x - x[index]).norm(2, 1); | ||
} | ||
|
||
at::Tensor fps(at::Tensor points, const int nsamples, bool random) | ||
{ | ||
CHECK_CONTIGUOUS(points); | ||
|
||
auto out_options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU); | ||
auto batch_size = points.size(0); | ||
auto out = torch::empty({batch_size, nsamples}, out_options); | ||
auto out_a = out.accessor<long, 2>(); | ||
|
||
for (ptrdiff_t b = 0; b < batch_size; b++) | ||
{ | ||
auto y = points[b]; | ||
ptrdiff_t start = 0; | ||
if (random) | ||
start = at::randperm(y.size(0), out_options).DATA_PTR<int64_t>()[0]; | ||
|
||
out_a[b][0] = start; | ||
auto dist = get_dist(y, start); | ||
for (ptrdiff_t i = 1; i < nsamples; i++) | ||
{ | ||
ptrdiff_t argmax = dist.argmax().DATA_PTR<int64_t>()[0]; | ||
out_a[b][i] = argmax; | ||
dist = at::min(dist, get_dist(y, argmax)); | ||
} | ||
} | ||
return out; | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#include "compat.h" | ||
#include "utils.h" | ||
#include <iostream> | ||
#include <torch/extension.h> | ||
|
||
at::Tensor knn_interpolate(at::Tensor features, at::Tensor idx, at::Tensor weight) | ||
{ | ||
CHECK_CONTIGUOUS(features); | ||
CHECK_CONTIGUOUS(idx); | ||
CHECK_CONTIGUOUS(weight); | ||
CHECK_CPU(idx); | ||
CHECK_CPU(features); | ||
CHECK_CPU(weight); | ||
|
||
at::Tensor output = torch::zeros({features.size(0), features.size(1), idx.size(1)}, | ||
at::device(features.device()).dtype(features.scalar_type())); | ||
|
||
AT_DISPATCH_ALL_TYPES(features.scalar_type(), "knn_interpolate", [&] { | ||
auto output_a = output.accessor<scalar_t, 3>(); | ||
auto features_a = features.accessor<scalar_t, 3>(); | ||
auto weight_a = weight.accessor<scalar_t, 3>(); | ||
auto idx_a = idx.accessor<long, 3>(); | ||
|
||
auto batch_size = idx.size(0); | ||
for (auto b = 0; b < batch_size; b++) | ||
{ | ||
for (auto p = 0; p < idx.size(1); p++) | ||
{ | ||
for (auto c = 0; c < features.size(1); c++) | ||
{ | ||
output_a[b][c][p] = 0; | ||
for (int i = 0; i < idx.size(2); i++) | ||
{ | ||
auto new_idx = idx_a[b][p][i]; | ||
output_a[b][c][p] += features_a[b][c][new_idx] * weight_a[b][p][i]; | ||
} | ||
} | ||
} | ||
} | ||
}); | ||
return output; | ||
} | ||
|
||
at::Tensor knn_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tensor weight, const int m) | ||
{ | ||
CHECK_CPU(grad_out); | ||
at::Tensor output = torch::zeros({grad_out.size(0), grad_out.size(1), m}, | ||
at::device(grad_out.device()).dtype(grad_out.scalar_type())); | ||
|
||
AT_DISPATCH_ALL_TYPES(grad_out.scalar_type(), "knn_interpolate_grad", [&] { | ||
auto output_a = output.accessor<scalar_t, 3>(); | ||
auto grad_out_a = grad_out.accessor<scalar_t, 3>(); | ||
auto weight_a = weight.accessor<scalar_t, 3>(); | ||
auto idx_a = idx.accessor<long, 3>(); | ||
|
||
auto batch_size = idx.size(0); | ||
for (auto b = 0; b < batch_size; b++) | ||
{ | ||
for (auto p = 0; p < idx.size(1); p++) | ||
{ | ||
for (auto c = 0; c < grad_out.size(1); c++) | ||
{ | ||
for (int i = 0; i < idx.size(2); i++) | ||
{ | ||
auto new_idx = idx_a[b][p][i]; | ||
output_a[b][c][new_idx] += grad_out_a[b][c][p] * weight_a[b][p][i]; | ||
} | ||
} | ||
} | ||
} | ||
}); | ||
return output; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#include "compat.h" | ||
#include "neighbors.cpp" | ||
#include "neighbors.h" | ||
#include "utils.h" | ||
#include <iostream> | ||
#include <torch/extension.h> | ||
|
||
std::pair<at::Tensor, at::Tensor> _single_batch_knn(at::Tensor support, at::Tensor query, int k) | ||
{ | ||
CHECK_CONTIGUOUS(support); | ||
CHECK_CONTIGUOUS(query); | ||
if (support.size(0) < k) | ||
TORCH_CHECK(false, | ||
"Not enough points in support to find " + std::to_string(k) + " neighboors") | ||
std::vector<long> neighbors_indices(query.size(0) * k, -1); | ||
std::vector<float> neighbors_dists(query.size(0) * k, -1); | ||
|
||
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU); | ||
auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU); | ||
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "knn", [&] { | ||
auto data_q = query.DATA_PTR<scalar_t>(); | ||
auto data_s = support.DATA_PTR<scalar_t>(); | ||
std::vector<scalar_t> queries_stl = | ||
std::vector<scalar_t>(data_q, data_q + query.size(0) * query.size(1)); | ||
std::vector<scalar_t> supports_stl = | ||
std::vector<scalar_t>(data_s, data_s + support.size(0) * support.size(1)); | ||
|
||
nanoflann_knn_neighbors<scalar_t>(queries_stl, supports_stl, neighbors_indices, | ||
neighbors_dists, k); | ||
}); | ||
auto neighbors_dists_ptr = neighbors_dists.data(); | ||
long* neighbors_indices_ptr = neighbors_indices.data(); | ||
auto out = torch::from_blob(neighbors_indices_ptr, {query.size(0), k}, options = options); | ||
auto out_dists = | ||
torch::from_blob(neighbors_dists_ptr, {query.size(0), k}, options = options_dist); | ||
|
||
return std::make_pair(out.clone(), out_dists.clone()); | ||
} | ||
|
||
std::pair<at::Tensor, at::Tensor> dense_knn(at::Tensor support, at::Tensor query, int k) | ||
{ | ||
CHECK_CONTIGUOUS(support); | ||
CHECK_CONTIGUOUS(query); | ||
CHECK_CPU(query); | ||
CHECK_CPU(support); | ||
|
||
int b = query.size(0); | ||
vector<at::Tensor> batch_idx; | ||
vector<at::Tensor> batch_dist; | ||
for (int i = 0; i < b; i++) | ||
{ | ||
auto out_pair = _single_batch_knn(support[i], query[i], k); | ||
batch_idx.push_back(out_pair.first); | ||
batch_dist.push_back(out_pair.second); | ||
} | ||
auto out_idx = torch::stack(batch_idx); | ||
auto out_dist = torch::stack(batch_dist); | ||
return std::make_pair(out_idx, out_dist); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.