diff --git a/src/ml/neural_net/mps_layer_instance_norm_data_loader.h b/src/ml/neural_net/mps_layer_instance_norm_data_loader.h index d4aac15623..d2dea1a476 100644 --- a/src/ml/neural_net/mps_layer_instance_norm_data_loader.h +++ b/src/ml/neural_net/mps_layer_instance_norm_data_loader.h @@ -70,6 +70,7 @@ API_AVAILABLE(macos(10.14)) - (MPSCNNNormalizationGammaAndBetaState *)updateGammaAndBetaWithCommandBuffer:(id)commandBuffer instanceNormalizationStateBatch:(MPSCNNInstanceNormalizationGradientStateBatch *)instanceNormalizationStateBatch; +- (void)checkpoint; - (void)checkpointWithCommandQueue:(nonnull id)commandQueue; - (NSString*__nullable) label; diff --git a/src/ml/neural_net/mps_layer_instance_norm_data_loader.mm b/src/ml/neural_net/mps_layer_instance_norm_data_loader.mm index 613596baf8..d7ce5cd78f 100644 --- a/src/ml/neural_net/mps_layer_instance_norm_data_loader.mm +++ b/src/ml/neural_net/mps_layer_instance_norm_data_loader.mm @@ -45,6 +45,8 @@ @interface TCMPSInstanceNormDataLoader () { MPSVectorDescriptor *_vDesc; + MPSCNNInstanceNormalization* _instanceNormFilter; + id _cq; MPSNNOptimizerAdam *_adamGamma; MPSNNOptimizerAdam *_adamBeta; @@ -214,38 +216,44 @@ - (MPSCNNNormalizationGammaAndBetaState *)updateGammaAndBetaWithCommandBuffer:(i NSUInteger t1 = [_adamGamma timeStep]; NSUInteger t2 = [_adamBeta timeStep]; - for (MPSCNNInstanceNormalizationGradientState *instanceNormalizationState in instanceNormalizationStateBatch) { - MPSVector *gradientWeightsVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gradientForGamma) - descriptor:_vDesc]; + _instanceNormFilter = instanceNormalizationStateBatch[0].instanceNormalization; + + for (MPSCNNInstanceNormalizationGradientState *instanceNormalizationState in instanceNormalizationStateBatch) { + MPSVector *gradientWeightsVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gradientForGamma) + descriptor:_vDesc]; + _adamGamma.timeStep = t1; + [_adamGamma encodeToCommandBuffer:commandBuffer + inputGradientVector:gradientWeightsVector + inputValuesVector:[[_style_props objectAtIndex: _styleIndex] gammaVector] + inputMomentumVector:[[_style_props objectAtIndex: _styleIndex] gammaMomentumVector] + inputVelocityVector:[[_style_props objectAtIndex: _styleIndex] gammaVelocityVector] + resultValuesVector:[[_style_props objectAtIndex: _styleIndex] gammaVector]]; - MPSVector *inputWeightsVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gamma) + MPSVector *gradientBiasesVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gradientForBeta) descriptor:_vDesc]; - _adamGamma.timeStep = t1; - [_adamGamma encodeToCommandBuffer:commandBuffer - inputGradientVector:gradientWeightsVector - inputValuesVector:inputWeightsVector - inputMomentumVector:[[_style_props objectAtIndex: _styleIndex] gammaMomentumVector] - inputVelocityVector:[[_style_props objectAtIndex: _styleIndex] gammaVelocityVector] - resultValuesVector:[[_style_props objectAtIndex: _styleIndex] gammaVector]]; - - MPSVector *gradientBiasesVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.gradientForBeta) - descriptor:_vDesc]; - - MPSVector *inputBiasesVector = [[MPSVector alloc] initWithBuffer:nonnull_cast(instanceNormalizationState.beta) - descriptor:_vDesc]; - _adamBeta.timeStep = t2; - [_adamBeta encodeToCommandBuffer:commandBuffer - inputGradientVector:gradientBiasesVector - inputValuesVector:inputBiasesVector - inputMomentumVector:[[_style_props objectAtIndex: _styleIndex] betaMomentumVector] - inputVelocityVector:[[_style_props objectAtIndex: _styleIndex] betaVelocityVector] - resultValuesVector:[[_style_props objectAtIndex: _styleIndex] betaVector]]; + + _adamBeta.timeStep = t2; + [_adamBeta encodeToCommandBuffer:commandBuffer + inputGradientVector:gradientBiasesVector + inputValuesVector:[[_style_props objectAtIndex: _styleIndex] betaVector] + inputMomentumVector:[[_style_props objectAtIndex: _styleIndex] betaMomentumVector] + inputVelocityVector:[[_style_props objectAtIndex: _styleIndex] betaVelocityVector] + resultValuesVector:[[_style_props objectAtIndex: _styleIndex] betaVector]]; } return [[_style_props objectAtIndex: _styleIndex] state]; } +- (void)checkpoint { + if (_instanceNormFilter) { + id cmdBuf = [_cq commandBuffer]; + [_instanceNormFilter reloadGammaAndBetaWithCommandBuffer: cmdBuf + gammaAndBetaState: [[_style_props objectAtIndex: _styleIndex] state]]; + [cmdBuf commit]; + } +} + - (void)checkpointWithCommandQueue:(nonnull id)commandQueue { id commandBuffer = [commandQueue commandBuffer]; id blit = commandBuffer.blitCommandEncoder; diff --git a/src/ml/neural_net/style_transfer/mps_style_transfer_decoding_node.mm b/src/ml/neural_net/style_transfer/mps_style_transfer_decoding_node.mm index 79117eb12b..abfa27919b 100644 --- a/src/ml/neural_net/style_transfer/mps_style_transfer_decoding_node.mm +++ b/src/ml/neural_net/style_transfer/mps_style_transfer_decoding_node.mm @@ -69,6 +69,7 @@ - (instancetype) initWithParameters:(NSString *)name - (void) setStyleIndex:(NSUInteger)styleIndex { _instNorm.tc_weightsData.styleIndex = styleIndex; + [_instNorm.tc_weightsData checkpoint]; } - (MPSNNImageNode *) backwardPass:(MPSNNImageNode *) inputNode { diff --git a/src/ml/neural_net/style_transfer/mps_style_transfer_encoding_node.mm b/src/ml/neural_net/style_transfer/mps_style_transfer_encoding_node.mm index 5e286c4ae1..36a2736f36 100644 --- a/src/ml/neural_net/style_transfer/mps_style_transfer_encoding_node.mm +++ b/src/ml/neural_net/style_transfer/mps_style_transfer_encoding_node.mm @@ -72,6 +72,7 @@ - (MPSNNImageNode *) backwardPass:(MPSNNImageNode *) inputNode { - (void) setStyleIndex:(NSUInteger)styleIndex { _instNorm.tc_weightsData.styleIndex = styleIndex; + [_instNorm.tc_weightsData checkpoint]; } - (void) setLearningRate:(float)lr { diff --git a/src/ml/neural_net/style_transfer/mps_style_transfer_residual_node.mm b/src/ml/neural_net/style_transfer/mps_style_transfer_residual_node.mm index 863a4d512e..3f3639cf19 100644 --- a/src/ml/neural_net/style_transfer/mps_style_transfer_residual_node.mm +++ b/src/ml/neural_net/style_transfer/mps_style_transfer_residual_node.mm @@ -97,7 +97,9 @@ - (instancetype) initWithParameters:(NSString *)name - (void) setStyleIndex:(NSUInteger)styleIndex { _instNorm1.tc_weightsData.styleIndex = styleIndex; + [_instNorm1.tc_weightsData checkpoint]; _instNorm2.tc_weightsData.styleIndex = styleIndex; + [_instNorm2.tc_weightsData checkpoint]; } - (MPSNNImageNode *) backwardPass:(MPSNNImageNode *) inputNode { diff --git a/src/ml/neural_net/style_transfer/mps_style_transfer_transformer_network.mm b/src/ml/neural_net/style_transfer/mps_style_transfer_transformer_network.mm index 14d51d20d0..fd3a48c84d 100644 --- a/src/ml/neural_net/style_transfer/mps_style_transfer_transformer_network.mm +++ b/src/ml/neural_net/style_transfer/mps_style_transfer_transformer_network.mm @@ -182,6 +182,7 @@ - (void) setStyleIndex:(NSUInteger)styleIndex { _decoding1.styleIndex = styleIndex; _decoding2.styleIndex = styleIndex; _instNorm.tc_weightsData.styleIndex = styleIndex; + [_instNorm.tc_weightsData checkpoint]; } - (void) setLearningRate:(float)lr {