Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
corrected review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekpratapa authored and Zach Nation committed Oct 3, 2019
1 parent fb8c383 commit e4a2893
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/ml/neural_net/mps_layer_instance_norm_data_loader.mm
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,11 @@ - (void) loadBeta:(float *)beta {

- (float *) beta {
NSUInteger previousStyle = _styleIndex;
_betaPlaceHolder = [NSMutableData data];
_betaPlaceHolder.length = 0;
for (NSUInteger index = 0; index < _styles; index++) {
_styleIndex = index;
[self checkpointWithCommandQueue:_cq];
float* betaWeights = (float *) [[[_style_props objectAtIndex: _styleIndex] betaBuffer] contents];
float* betaWeights = (float *) [_style_props[_styleIndex].betaBuffer contents];
[_betaPlaceHolder appendBytes:betaWeights length:sizeof(float)*_numberOfFeatureChannels];
}
_styleIndex = previousStyle;
Expand All @@ -196,12 +196,12 @@ - (void) loadGamma:(float *)gamma {

// TODO: refactor for multiple indicies
- (float *) gamma {
_gammaPlaceHolder = [NSMutableData data];
NSUInteger previousStyle = _styleIndex;
_gammaPlaceHolder.length = 0;
for (NSUInteger index = 0; index < _styles; index++) {
_styleIndex = index;
[self checkpointWithCommandQueue:_cq];
float* gammaWeights = (float *) [[[_style_props objectAtIndex: _styleIndex] gammaBuffer] contents];
float* gammaWeights = (float *) [_style_props[_styleIndex].gammaBuffer contents];
[_gammaPlaceHolder appendBytes:gammaWeights length:sizeof(float)*_numberOfFeatureChannels];
}
_styleIndex = previousStyle;
Expand Down
4 changes: 4 additions & 0 deletions src/ml/neural_net/style_transfer/mps_style_transfer.m
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ - (void) setLearningRate:(float)lr {
return [lossDict copy];
}

/**
* HACK: this somehow checkpoints the model for weight exports and updating the
* data loaders. Following up internally for a proper fix to this issue.
**/
- (void) checkpoint {
_inferenceGraph = [MPSNNGraph graphWithDevice:_dev
resultImage:_model.forwardPass
Expand Down
12 changes: 12 additions & 0 deletions src/python/turicreate/toolkits/style_transfer/style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,18 @@ def create(style_dataset, content_dataset, style_feature=None,

mps_mxnet_key_map = _MpsStyleGraphAPI.mps_mxnet_weight_dict()

# LOGIC
#
# The two layers we are concerned about are very different. The instance
# norm layer weights should have a dimensionality two. Wheras the
# Convolutional Weights have a dimensionality of four. The Convolutional
# weights are in a different format in MxNet then they are in MPS. Since
# the arrays come back flattened in the MPS a reshape has to occur. But
# this reshape happens before the transpose therefore the shape itself
# has to be transposed. For the InstanceNorm Layer, however, the weights
# are passed back to MxNet in the correct format so just a simple
# reshape suffices.
#
for key in mps_weights:
if "transformer" in key:
weight = transformer.collect_params()[mps_mxnet_key_map[key]].data()
Expand Down

0 comments on commit e4a2893

Please sign in to comment.