- numpy
- tensorflow
- imblearn
- tensorflow-gpu
- sklearn
- matplotlib
According to the curve waveform of voltage, current and electrode position, judge the quality of flash welding. The data is multi-dimensional time series. We have 2000 of good quality and 50 of bad quality in ./data/
. In this network, I use data augmentation to increase the number of bad. Experiments show that CNN is more effective than BP-network and Dropout is effective. I think convolution can identify the relative positional relationship between multi-dimensional time series, which reduces the over-fitting of the model to some extent. As shown below, the origin data is multi-dimensional time series.
git clone [email protected]:wzx140/welding_prediction.git
conda install --yes --file requirements.txt
. Import dependency to anaconda- change some variables in
resource/config.py
cd welding_prediction
python main.py train
to train the model and save the model inresource/model
. There is a trained model in this foldertensorboard --logdir resource/tsb
, to see the visualization of the data after trainingpython main.py predict path-to-mode path-to-sample
to predict the quality of the welding
Implemented Dropout
- set keep_prob in
resource/config.py
range from 0~1. 1 means Dropout is disabled
Since we only have 50 bad samples, we use ADASYN to expand the bad samples.
For more information, you can read my blog about ADASYN
After training, the data for tensorboard will store in resource/tsb
. Just run tensorboard --logdir resource/tsb
If you want to use tfdbg, you should,
- install pyreadline by pip
- set enable_debug True in
resource/config.py
- run
python main.py train --debug
in project dir
For more information, you can read official document