Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
fixes for Multiple Styles using the checkpointing api in MPS. (Who k
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekpratapa authored and Zach Nation committed Oct 3, 2019
1 parent e3e8647 commit 7849861
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 24 deletions.
1 change: 1 addition & 0 deletions src/ml/neural_net/mps_layer_instance_norm_data_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ API_AVAILABLE(macos(10.14))
- (MPSCNNNormalizationGammaAndBetaState *)updateGammaAndBetaWithCommandBuffer:(id<MTLCommandBuffer>)commandBuffer
instanceNormalizationStateBatch:(MPSCNNInstanceNormalizationGradientStateBatch *)instanceNormalizationStateBatch;

- (void)checkpoint;
- (void)checkpointWithCommandQueue:(nonnull id<MTLCommandQueue>)commandQueue;

- (NSString*__nullable) label;
Expand Down
56 changes: 32 additions & 24 deletions src/ml/neural_net/mps_layer_instance_norm_data_loader.mm
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ @interface TCMPSInstanceNormDataLoader () {

MPSVectorDescriptor *_vDesc;

MPSCNNInstanceNormalization* _instanceNormFilter;

id<MTLCommandQueue> _cq;
MPSNNOptimizerAdam *_adamGamma;
MPSNNOptimizerAdam *_adamBeta;
Expand Down Expand Up @@ -214,38 +216,44 @@ - (MPSCNNNormalizationGammaAndBetaState *)updateGammaAndBetaWithCommandBuffer:(i
NSUInteger t1 = [_adamGamma timeStep];
NSUInteger t2 = [_adamBeta timeStep];

for (MPSCNNInstanceNormalizationGradientState *instanceNormalizationState in instanceNormalizationStateBatch) {
MPSVector *gradientWeightsVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gradientForGamma)
descriptor:_vDesc];
_instanceNormFilter = instanceNormalizationStateBatch[0].instanceNormalization;

for (MPSCNNInstanceNormalizationGradientState *instanceNormalizationState in instanceNormalizationStateBatch) {
MPSVector *gradientWeightsVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gradientForGamma)
descriptor:_vDesc];
_adamGamma.timeStep = t1;
[_adamGamma encodeToCommandBuffer:commandBuffer
inputGradientVector:gradientWeightsVector
inputValuesVector:[[_style_props objectAtIndex: _styleIndex] gammaVector]
inputMomentumVector:[[_style_props objectAtIndex: _styleIndex] gammaMomentumVector]
inputVelocityVector:[[_style_props objectAtIndex: _styleIndex] gammaVelocityVector]
resultValuesVector:[[_style_props objectAtIndex: _styleIndex] gammaVector]];

MPSVector *inputWeightsVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gamma)
MPSVector *gradientBiasesVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gradientForBeta)
descriptor:_vDesc];
_adamGamma.timeStep = t1;
[_adamGamma encodeToCommandBuffer:commandBuffer
inputGradientVector:gradientWeightsVector
inputValuesVector:inputWeightsVector
inputMomentumVector:[[_style_props objectAtIndex: _styleIndex] gammaMomentumVector]
inputVelocityVector:[[_style_props objectAtIndex: _styleIndex] gammaVelocityVector]
resultValuesVector:[[_style_props objectAtIndex: _styleIndex] gammaVector]];

MPSVector *gradientBiasesVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gradientForBeta)
descriptor:_vDesc];

MPSVector *inputBiasesVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.beta)
descriptor:_vDesc];
_adamBeta.timeStep = t2;
[_adamBeta encodeToCommandBuffer:commandBuffer
inputGradientVector:gradientBiasesVector
inputValuesVector:inputBiasesVector
inputMomentumVector:[[_style_props objectAtIndex: _styleIndex] betaMomentumVector]
inputVelocityVector:[[_style_props objectAtIndex: _styleIndex] betaVelocityVector]
resultValuesVector:[[_style_props objectAtIndex: _styleIndex] betaVector]];

_adamBeta.timeStep = t2;
[_adamBeta encodeToCommandBuffer:commandBuffer
inputGradientVector:gradientBiasesVector
inputValuesVector:[[_style_props objectAtIndex: _styleIndex] betaVector]
inputMomentumVector:[[_style_props objectAtIndex: _styleIndex] betaMomentumVector]
inputVelocityVector:[[_style_props objectAtIndex: _styleIndex] betaVelocityVector]
resultValuesVector:[[_style_props objectAtIndex: _styleIndex] betaVector]];

}

return [[_style_props objectAtIndex: _styleIndex] state];
}

- (void)checkpoint {
if (_instanceNormFilter) {
id <MTLCommandBuffer> cmdBuf = [_cq commandBuffer];
[_instanceNormFilter reloadGammaAndBetaWithCommandBuffer: cmdBuf
gammaAndBetaState: [[_style_props objectAtIndex: _styleIndex] state]];
[cmdBuf commit];
}
}

- (void)checkpointWithCommandQueue:(nonnull id<MTLCommandQueue>)commandQueue {
id<MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer];
id<MTLBlitCommandEncoder> blit = commandBuffer.blitCommandEncoder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ - (instancetype) initWithParameters:(NSString *)name

- (void) setStyleIndex:(NSUInteger)styleIndex {
_instNorm.tc_weightsData.styleIndex = styleIndex;
[_instNorm.tc_weightsData checkpoint];
}

- (MPSNNImageNode *) backwardPass:(MPSNNImageNode *) inputNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ - (MPSNNImageNode *) backwardPass:(MPSNNImageNode *) inputNode {

- (void) setStyleIndex:(NSUInteger)styleIndex {
_instNorm.tc_weightsData.styleIndex = styleIndex;
[_instNorm.tc_weightsData checkpoint];
}

- (void) setLearningRate:(float)lr {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ - (instancetype) initWithParameters:(NSString *)name

- (void) setStyleIndex:(NSUInteger)styleIndex {
_instNorm1.tc_weightsData.styleIndex = styleIndex;
[_instNorm1.tc_weightsData checkpoint];
_instNorm2.tc_weightsData.styleIndex = styleIndex;
[_instNorm2.tc_weightsData checkpoint];
}

- (MPSNNImageNode *) backwardPass:(MPSNNImageNode *) inputNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ - (void) setStyleIndex:(NSUInteger)styleIndex {
_decoding1.styleIndex = styleIndex;
_decoding2.styleIndex = styleIndex;
_instNorm.tc_weightsData.styleIndex = styleIndex;
[_instNorm.tc_weightsData checkpoint];
}

- (void) setLearningRate:(float)lr {
Expand Down

0 comments on commit 7849861

Please sign in to comment.