Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Refactor Object Detection inference to use new Model Trainer type (#3034
Browse files Browse the repository at this point in the history
)
  • Loading branch information
nickjong authored Mar 7, 2020
1 parent 5393ebb commit 038db92
Show file tree
Hide file tree
Showing 12 changed files with 584 additions and 191 deletions.
1 change: 1 addition & 0 deletions src/ml/neural_net/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ if(APPLE AND HAS_MPS AND NOT TC_BUILD_IOS)
mps_weight.mm
mps_device_manager.m
mps_descriptor_utils.m
mps_od_backend.mm
style_transfer/mps_style_transfer.m
style_transfer/mps_style_transfer_backend.mm
style_transfer/mps_style_transfer_utils.m
Expand Down
3 changes: 1 addition & 2 deletions src/ml/neural_net/mps_compute_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class mps_compute_context: public compute_context {
std::function<float(float lower, float upper)> rng);

private:

std::unique_ptr<mps_command_queue> command_queue_;
std::shared_ptr<mps_command_queue> command_queue_;
};

} // namespace neural_net
Expand Down
27 changes: 17 additions & 10 deletions src/ml/neural_net/mps_compute_context.mm
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

#include <core/logging/logger.hpp>
#include <core/storage/fileio/fileio_constants.hpp>
#include <ml/neural_net/mps_image_augmentation.hpp>
#include <ml/neural_net/mps_od_backend.hpp>
#include <ml/neural_net/style_transfer/mps_style_transfer_backend.hpp>

#include <ml/neural_net/mps_cnnmodule.h>
#include <ml/neural_net/mps_graph_cnnmodule.h>
#include <ml/neural_net/mps_image_augmentation.hpp>

#import <ml/neural_net/style_transfer/mps_style_transfer_backend.hpp>

namespace turi {
namespace neural_net {
Expand Down Expand Up @@ -125,18 +127,23 @@ float_array_map multiply_mps_od_loss_multiplier(float_array_map config,
std::unique_ptr<model_backend> mps_compute_context::create_object_detector(
int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out,
const float_array_map& config, const float_array_map& weights) {
float_array_map updated_config;
mps_od_backend::parameters params;
params.command_queue = command_queue_;
params.n = n;
params.c_in = c_in;
params.h_in = h_in;
params.w_in = w_in;
params.c_out = c_out;
params.h_out = h_out;
params.w_out = w_out;
params.weights = weights;

std::vector<std::string> update_keys = {
"learning_rate", "od_scale_class", "od_scale_no_object", "od_scale_object",
"od_scale_wh", "od_scale_xy", "gradient_clipping"};
updated_config = multiply_mps_od_loss_multiplier(config, update_keys);
std::unique_ptr<mps_graph_cnn_module> result(
new mps_graph_cnn_module(*command_queue_));

result->init(/* network_id */ kODGraphNet, n, c_in, h_in, w_in, c_out, h_out,
w_out, updated_config, weights);
params.config = multiply_mps_od_loss_multiplier(config, update_keys);

return result;
return std::unique_ptr<mps_od_backend>(new mps_od_backend(std::move(params)));
}

std::unique_ptr<model_backend> mps_compute_context::create_activity_classifier(
Expand Down
63 changes: 63 additions & 0 deletions src/ml/neural_net/mps_od_backend.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright © 2020 Apple Inc. All rights reserved.
*
* Use of this source code is governed by a BSD-3-clause license that can
* be found in the LICENSE.txt file or at
* https://opensource.org/licenses/BSD-3-Clause
*/

#ifndef MPS_OD_BACKEND_HPP_
#define MPS_OD_BACKEND_HPP_

#include <ml/neural_net/mps_graph_cnnmodule.h>
#include <ml/neural_net/model_backend.hpp>

namespace turi {
namespace neural_net {

/**
* Model backend for object detection that uses a separate mps_graph_cnnmodule
* for training and for inference, since mps_graph_cnnmodule doesn't currently
* support doing both.
*/
class mps_od_backend : public model_backend {
public:
struct parameters {
std::shared_ptr<mps_command_queue> command_queue;
int n;
int c_in;
int h_in;
int w_in;
int c_out;
int h_out;
int w_out;
float_array_map config;
float_array_map weights;
};

mps_od_backend(parameters params);

// Training
void set_learning_rate(float lr) override;
float_array_map train(const float_array_map& inputs) override;

// Inference
float_array_map predict(const float_array_map& inputs) const override;

float_array_map export_weights() const override;

private:
void ensure_training_module();
void ensure_prediction_module() const;

parameters params_;

std::unique_ptr<mps_graph_cnn_module> training_module_;

// Cleared whenever the training module is updated.
mutable std::unique_ptr<mps_graph_cnn_module> prediction_module_;
};

} // namespace neural_net
} // namespace turi

#endif // MPS_OD_BACKEND_HPP_
86 changes: 86 additions & 0 deletions src/ml/neural_net/mps_od_backend.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/* Copyright © 2020 Apple Inc. All rights reserved.
*
* Use of this source code is governed by a BSD-3-clause license that can
* be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
*/

#include <ml/neural_net/mps_od_backend.hpp>

namespace turi {
namespace neural_net {

void mps_od_backend::ensure_training_module() {
if (training_module_) return;

training_module_.reset(new mps_graph_cnn_module(*params_.command_queue));
training_module_->init(/* network_id */ kODGraphNet, params_.n, params_.c_in, params_.h_in,
params_.w_in, params_.c_out, params_.h_out, params_.w_out, params_.config,
params_.weights);

// Clear params_.weights to free up memory, since they are now superceded by
// whatever the training module contains.
params_.weights.clear();
}

void mps_od_backend::ensure_prediction_module() const {
if (prediction_module_) return;

// Adjust configuration for prediction.
float_array_map config = params_.config;
config["mode"] = shared_float_array::wrap(2.0f);
config["od_include_loss"] = shared_float_array::wrap(0.0f);

// Take weights from training module if present, else from original weights.
float_array_map weights;
if (training_module_) {
weights = training_module_->export_weights();
} else {
weights = params_.weights;
}

prediction_module_.reset(new mps_graph_cnn_module(*params_.command_queue));
prediction_module_->init(/* network_id */ kODGraphNet, params_.n, params_.c_in, params_.h_in,
params_.w_in, params_.c_out, params_.h_out, params_.w_out, config,
weights);
}

mps_od_backend::mps_od_backend(parameters params) : params_(std::move(params)) {
// Immediate instantiate at least one module, since at present we can't
// guarantee that the weights will remain valid after we return.
// TODO: Remove this eager construction once we stop putting weak pointers in
// float_array_map.
if (params_.config.at("mode").data()[0] == 0.f) {
ensure_training_module();
} else {
ensure_prediction_module();
}
}

void mps_od_backend::set_learning_rate(float lr) {
ensure_training_module();
training_module_->set_learning_rate(lr);
}

float_array_map mps_od_backend::train(const float_array_map& inputs) {
// Invalidate prediction_module, since its weights will be stale.
prediction_module_.reset();

ensure_training_module();
return training_module_->train(inputs);
}

float_array_map mps_od_backend::predict(const float_array_map& inputs) const {
ensure_prediction_module();
return prediction_module_->predict(inputs);
}

float_array_map mps_od_backend::export_weights() const {
if (training_module_) {
return training_module_->export_weights();
} else {
return params_.weights;
}
}

} // namespace neural_net
} // namespace turi
Loading

0 comments on commit 038db92

Please sign in to comment.