Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Added checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekpratapa authored and Zach Nation committed Oct 3, 2019
1 parent f10ccb2 commit fb8c383
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/ml/neural_net/mps_layer_instance_norm_data_loader.mm
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ - (instancetype) initWithParams:(NSString *)name
size_t offset = index * _numberOfFeatureChannels * sizeof(float);

style_property.gammaBuffer = [dev newBufferWithBytes:(char *) _gamma_weights.bytes + offset
length:sizeof(float) * _numberOfFeatureChannels
options:MTLResourceStorageModeManaged];

style_property.betaBuffer = [dev newBufferWithBytes:(char *) _beta_weights.bytes + offset
length:sizeof(float) * _numberOfFeatureChannels
options:MTLResourceStorageModeManaged];

style_property.betaBuffer = [dev newBufferWithBytes:(char *) _beta_weights.bytes + offset
length:sizeof(float) * _numberOfFeatureChannels
options:MTLResourceStorageModeManaged];

style_property.gammaMomentumBuffer = [dev newBufferWithBytes:zeros_ptr
length:sizeof(float) * _numberOfFeatureChannels
options:MTLResourceStorageModeManaged];
Expand Down Expand Up @@ -196,8 +196,8 @@ - (void) loadGamma:(float *)gamma {

// TODO: refactor for multiple indicies
- (float *) gamma {
NSUInteger previousStyle = _styleIndex;
_gammaPlaceHolder = [NSMutableData data];
NSUInteger previousStyle = _styleIndex;
for (NSUInteger index = 0; index < _styles; index++) {
_styleIndex = index;
[self checkpointWithCommandQueue:_cq];
Expand Down Expand Up @@ -243,7 +243,7 @@ - (MPSCNNNormalizationGammaAndBetaState *)updateGammaAndBetaWithCommandBuffer:(i

}

return [[_style_props objectAtIndex: _styleIndex] state];
return [[_style_props objectAtIndex: _styleIndex] state];
}

- (void)checkpointWithCommandQueue:(nonnull id<MTLCommandQueue>)commandQueue {
Expand Down
1 change: 1 addition & 0 deletions src/ml/neural_net/style_transfer/mps_style_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ API_AVAILABLE(macos(10.15))
- (NSDictionary<NSString *, NSData *> *) predict:(NSDictionary<NSString *, NSData *> *)inputs;
- (void) setLearningRate:(float)lr;
- (NSDictionary<NSString *, NSData *> *) train:(NSDictionary<NSString *, NSData *> *)inputs;
- (void) checkpoint;

@end

Expand Down
9 changes: 9 additions & 0 deletions src/ml/neural_net/style_transfer/mps_style_transfer.m
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ - (instancetype) initWithDev:(id<MTLDevice>) dev
}

- (NSDictionary<NSString *, NSData *> *) exportWeights {
[self checkpoint];
return [_model exportWeights:@"transformer_"];
}

Expand Down Expand Up @@ -437,6 +438,14 @@ - (void) setLearningRate:(float)lr {
return [lossDict copy];
}

- (void) checkpoint {
_inferenceGraph = [MPSNNGraph graphWithDevice:_dev
resultImage:_model.forwardPass
resultImageIsNeeded:YES];

_inferenceGraph.format = MPSImageFeatureChannelFormatFloat32;
}

@end

#endif // #ifdef HAS_MACOS_10_15

0 comments on commit fb8c383

Please sign in to comment.