Skip to content

Commit

Permalink
Adjust WebGPU device registration
Browse files Browse the repository at this point in the history
wgpu::Device is already a refcounted type. So, there is no need to add an extra layer of shared_ptr.
PiperOrigin-RevId: 673890893
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 12, 2024
1 parent c16a0dc commit 6d6e93c
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 35 deletions.
5 changes: 3 additions & 2 deletions mediapipe/gpu/gpu_buffer_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ class GpuBufferStorageImpl : public GpuBufferStorage, public U... {
// Exposing this as a function allows dependent initializers to call this to
// ensure proper ordering.
static GpuBufferStorageRegistry::RegistryToken RegisterOnce() {
static auto registration = GpuBufferStorageRegistry::Get().Register<T>();
return registration;
static auto ordered_registration =
GpuBufferStorageRegistry::Get().Register<T>();
return ordered_registration;
}

private:
Expand Down
24 changes: 0 additions & 24 deletions mediapipe/gpu/webgpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,3 @@ cc_library(
"//mediapipe/framework/formats:tensor",
],
)

cc_test(
name = "webgpu_device_registration_test",
srcs = select({
"//mediapipe:emscripten": [],
"//conditions:default": [
"webgpu_check.cc",
"webgpu_device_registration_test.cc",
],
}),
defines = ["MEDIAPIPE_USE_WEBGPU"],
deps = select({
"//mediapipe:emscripten": [],
"//conditions:default": [
":webgpu_check",
":webgpu_device_registration",
"//mediapipe/framework/deps:no_destructor",
"//third_party/dawn:dawncpp_headers",
"//third_party/dawn:libdawn_proc",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest_main",
],
}),
)
3 changes: 1 addition & 2 deletions mediapipe/gpu/webgpu/webgpu_device_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ WebGpuDeviceRegistration& WebGpuDeviceRegistration::GetInstance() {
return *instance;
}

void WebGpuDeviceRegistration::RegisterWebGpuDevice(
std::shared_ptr<wgpu::Device> device) {
void WebGpuDeviceRegistration::RegisterWebGpuDevice(wgpu::Device device) {
device_ = std::move(device);
}

Expand Down
8 changes: 3 additions & 5 deletions mediapipe/gpu/webgpu/webgpu_device_registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#ifndef MEDIAPIPE_GPU_WEBGPU_WEBGPU_DEVICE_REGISTRATION_H_
#define MEDIAPIPE_GPU_WEBGPU_WEBGPU_DEVICE_REGISTRATION_H_

#include <memory>

#include "mediapipe/framework/deps/no_destructor.h"
#include "third_party/dawn/include/webgpu/webgpu_cpp.h"

Expand All @@ -30,19 +28,19 @@ class WebGpuDeviceRegistration {
WebGpuDeviceRegistration(const WebGpuDeviceRegistration&) = delete;
WebGpuDeviceRegistration& operator=(const WebGpuDeviceRegistration&) = delete;

void RegisterWebGpuDevice(std::shared_ptr<wgpu::Device> device);
void RegisterWebGpuDevice(wgpu::Device device);

void UnRegisterWebGpuDevice();

std::shared_ptr<wgpu::Device> GetWebGpuDevice() const { return device_; }
wgpu::Device GetWebGpuDevice() const { return device_; }

private:
friend class NoDestructor<WebGpuDeviceRegistration>;

WebGpuDeviceRegistration();
~WebGpuDeviceRegistration();

std::shared_ptr<wgpu::Device> device_;
wgpu::Device device_;
};

} // namespace mediapipe
Expand Down
2 changes: 1 addition & 1 deletion mediapipe/gpu/webgpu/webgpu_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ WebGpuService::WebGpuService()
#else
WebGpuService::WebGpuService()
: canvas_selector_(""),
device_(*WebGpuDeviceRegistration::GetInstance().GetWebGpuDevice()) {}
device_(WebGpuDeviceRegistration::GetInstance().GetWebGpuDevice()) {}
#endif // __EMSCRIPTEN__

ABSL_CONST_INIT const GraphService<WebGpuService> kWebGpuService(
Expand Down
2 changes: 1 addition & 1 deletion mediapipe/gpu/webgpu/webgpu_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ static WebGpuDeviceAttachmentManager& GetEmscriptenDeviceAttachmentManager() {
#else
static WebGpuDeviceAttachmentManager& GetNativeDeviceAttachmentManager() {
static mediapipe::NoDestructor<WebGpuDeviceAttachmentManager> manager(
wgpu::Device(*WebGpuDeviceRegistration::GetInstance().GetWebGpuDevice()));
wgpu::Device(WebGpuDeviceRegistration::GetInstance().GetWebGpuDevice()));
return *manager;
}
#endif // __EMSCRIPTEN__
Expand Down

0 comments on commit 6d6e93c

Please sign in to comment.