This executor implements the pure Attention based Audio Spectral Transformer(AST) model. This model has achieved SOTA results on various publicly available audio datasets. This executor will return the embedding and the prediction of an audio file based on the trained model.
This executor takes a DocumentArray
on /index
request. Each Document
should contain the complete name of the audio file in the filename
tag.
from jina import Document, DocumentArray
import os
dir = os.getcwd()
filename= os.path.join(dir,'data/1-4211-A-12.wav')
doc = DocumentArray([Document(tags={'filename':filename} )])
The inputs and outputs of the document after using this executor.
filename
of the audio file.
embedding
generated by the transformer model.prediction
by the model.
The executor can be used with a pre-trained model or a fine-tuned model.
from jina import Flow
f = Flow().add(uses='jinahub+docker://executor-audio-ASTransformer')
from jina import Flow
f = Flow().add(uses='jinahub://executor-audio-ASTransformer')
- To override
__init__
args & kwargs, use.add(..., uses_with: {'key': 'value'})
- To override class metas, use
.add(..., uses_metas: {'key': 'value})
This is the default setting of the executor and does not require any parameter to be passed to the executor. The default parameters are already calculated in the AST implementation in these lines. The default parameters used in the executors can be found here.
from jina import Document, DocumentArray
import os
from jina import Flow
filename = os.path.join(os.getcwd(), 'data/1-4211-A-12.wav')
doc = DocumentArray([Document(tags={'filename': filename})])
f = Flow().add(uses='jinahub://ASTransformer_encoder', install_requirements=True, force=True)
with f:
responses = f.post(on='index', inputs=doc, return_results=True)
print(responses)
The AST can be fine-tuned on a specific dataset and the final model can be used with the executor as well. The executor will require the following parameters.
total_labels
: Total classes in the dataset on which the model is fine-tuned.input_target_dim
: Length of the input signal. It is calculated as10*t
wheret
is the time of an audio clip in seconds.dataset_mean_std
: Mean and standard deviation of the dataset. The process to calculate both can be found here.model_path: str
: Path to the fine-tuned model.
Below is an example to use it with the ESC50 dataset.
from jina import Document, DocumentArray
import os
from jina import Flow
filename = os.path.join(os.getcwd(), 'data/1-4211-A-12.wav')
doc = DocumentArray([Document(tags={'filename': filename})])
params={'total_labels': 50,
'input_target_dim': 512,
'dataset_mean_std': [-6.6268077, 5.358466],
'model_path': '/home/dato6579/jobs/ast_exec/executor-audio-ASTransformer/pretrained_models/ast_esc50_best_model.pth'}
f = Flow().add(uses='jinahub://ASTransformer_encoder',uses_with=params, install_requirements=True, force=True)
with f:
responses = f.post(on='index', inputs=doc, return_results=True)
print(responses)
The executor predicts and store the results in the tags
of the Document
with the key prediction
. The embedding of the audio is stored in the embedding
variable.
{'docs': [{'id': '9fbafd41-34a0-11ec-8d14-b06ebf2c5fd9',
'tags': {'filename': '<PATH>/data/1-4211-A-12.wav', 'prediction': 12.0},
'embedding': {'dense': {'buffer': 'mqpsPwh9EUBIb0E+Go4+PqzGV77ZN3s/CSAfPwnS673iX/C9THirPuZVjz/c8VY/WqtCv9mvJj6uKSDAdBUkv7oO8j48Oju/bdYKP4WxIb+ubGW/RPQ2PwwRFD8CNIo+MIOgPjdgZT+zVZ8/3P2Bv+0C774Sh7i/0Mn4vAezEb7om2i9ekrAPzU2sT9441k/7LVVP981jT9l9ne+063FvvFstD7ilrW/6e2evxSIXT/g04Y/He2Av9EEm73kVYQ/XN3PvuaCJMDE/M6+PFiIv6Akr73OX5u/bP6mP2Q6VL+6bVE/yEZLP+bGZ7+mCJ6+z/qRvgBxpb632cm+YEEnv9iE4b8Rn5m/Tez2PXjVlD5mte2/yhCKP4kIvT60+fE/L7DnvxJfQr8f3V4/Qr9Zv5qbVD6WHtm++CKqPfRiUj8jtnU/PLO1Pgvx875kfpy/Z2sEQPI+mD4+u3g+58DKPjK5EMCTUo4/ugAPvytlwb/mljY+p2wEvgKGAsBjfDK/O27QPgEXkj9Q6lK9n0p2QETCRj88STQ/NVYivlYvF8DvFqI+cGnAvxiNFcBkoSBAbrm9Pu3ojr+tSOQ+A0mSvwYzGL46KMC/GDCGPz5flr/WFuA/eRDbvVM+j78qpSS/zWF4v5nckT3g6Ig+7rBKP0hnA7ugPu6+gqOkPi0eYj634K4/1u2xP2iPOz1qYAc+qvc1QLhERj92Zem+9QbWPuvEQz6MCwFAQIaGvCEYLj6OLC8/MY2hvsb8bb+m7py/pZ9qv7yEAEAvA5A+9CbKv1MpGL/nIc4/ZGGYPjPeT76w7cg/pg1Jv8wYcD1KaXu/vHJ3vy/5iz8KnAK/ykQrv2ykvj8INh+/NOuTP1QyKMAa2Ko+O6MsQOySA0AYP7a9vI5yvlReRUDPiAFAEECGP/XqiD4slqW/vz0MQBakbL9keZo+/WRKPvMOBb/mrUW+sGkNwHYvv7+uvSk/jjiFv3DDFz4VgLk+pB8qwB+LXT/2xwQ+pj0PwA34Ez8yfjG/AEq7Pd85nb4Up4S/TrtnvzZjZb9rL3k/aD9NPxdm5L5K7pu/IMDRP1oYn74wJ1LA6bNvvu75ob74vv2+EnPKv6ZG/L/u5Rw+wGwzP/gS2z80SqA//oSjvzTSxD7I5r4/vFmhv+gxlT8agZQ/BMLzPhRy5j/8jbY/vIy8Pw0G3z45Niw/4BnoP3fWiD+k8YE/QnWlPzcNE78Yajm/Y4elv5+WHT8rih2+ReSxPuVo/j6aWk1APJqIP2HZxj+iLwm+aNjdv0B/1bzy/ZC/4wrJvjLUFj6MBrq9ZDqZvHCZND+b+wzAjCmRv0pp9T+vmO4/C19Rv8M7ZT4+l/O/zPKIv+hg5b0IJ4W+/Y+Lv+KJRb9YtPK/lZCqPhnC2b5wpRw/qhunv6eGBb+o/qU/xCsHv09TvL7QXbu+ztZXPqVfQ0BeXNy/aHqEP100gr/oftM+nG9iP4XcJj6k8nM/9CBSP0DFtr8Z184/6EK9v6BmrD90m8W+FyPBPwaSir/3DDk/7hEdvnzZz7/O3VK/5BcqPx2pSD/yOhW/es8pPzB1cb+aHSc+QLXJP1bXQT5G97S/RF3IP6K2gT3FeJm+EhfGPzjQ4T6UT/W+NIakv4Ys/b9+ay3AFUmzPYgsIr5t4IM+4F2zO5EB3j1+2Um/eNA1P2L3u76s7Iu/aKL5vX5Jbz+GLiO/BFCbvdLdlL8392U/GcCHvwI/F8DHnxS//2ThP0RkFkDeqle/voLXP1o4jj84/r8/yXpcv7ivaz0krgLAWuO7vTqCnr8EJM+/XUShP9psxj9C7cE+8gh4v8T4VL8SYBtAuO1PwBLj3r3n/3u/TM+Jv3xsqz51oiA/eokbwAU9DD8KFbI/bIyjvWgtVUCg+Y+/MH+Rv+AgBcD+Vhs/EO8QwIzpGUC6qQDAdljsPm9ZAcDSMZg/MegXQAKRqD9M9fO/YHMmvyci9b+WfBXAuI0NP3qLIUDK3OI9IHXyPwLMcj9iGQw+mYATQOHdBr5YfZ2+qecUPsF11748To2/ZC6rP0rYqj5CFJg++MX5PvqHWL8N6ZI/TFXPv5Z3iz+27z49v8qqv265qz/IYBG/CnHFv3JfDsC568Y+LGO0PmWh0767xEVA/m4BQG+KJz4OR4S/0BD2vyjNtr/AaVu/URaUv+bQfD/xpp0//KXev+So5j/eqcY/NF/AvKwM5j6u0EO/3wIRvpQoab7gFaM/lFQPP94MnT+4EtY+WMVOPwjZGkCCiwm/4PraPHiaNj7scLK/pEDIv9jg0z0iXsM/mAfKPgH9gr9esYK+lNkEvm5cMD9me4Q+ll6LPo2/vD9UKwy+xQrYvhEyvT56hZk9kMjoPO7El77Exg296cAAQGv9zT/I2a8/ho7Cv8wxfT7FxWc+7GoRvhi10L+G+ARAQ32tvzVmkT0acaU/S8nevc6o2z/QoOQ9FvmTP7VohL1WsgM/Uoy0P4YKwL9P05I/MiEPv4hl8L/KLW8/g2/+v/xBij+BfgS/3Ze2PczRpD9gMxjARskUwEvJ6T/fIni/hz8swJQN6j6p9B6/hcGCP6TWpz21Sv+/+aM8P1heSL9cY/y/TErsP83JUj+tTpW/0gL3vRethr+6joI/92mMv5ZMqr6SiJq+3NWrvJq6OsBt/vY96lduvy9QkD8d6Go/OW4Fvt6F9T67YSpADet7v3milb/w1QHABFjbvd4zDb9WXdU/aiYaPxeYjD5AWWi9TKuoPxq6AcBEH7S/jC60P6YnxL+eFI69EtvOvzgViD6G0Ke/OE5Sv5hosz7+iTg9kyGePtbjZT+d/w3AuE0dv0j5IEBLBHg+dsfevxCEuL52/hQ+bn9mP5xY2T+4SVm/EnKZvvgT/L3K3HQ+6PwtPuw/qz9KDki/Gr2qPyi27TwSEgHAeVAEwHxazT9GdWK9sufqv2qcgz7NXLU/4k4VwLNc6j4A+Wi/72o1wPN8a77Io7c+NpiHv863179tlgbArn96PyBvPDyda5C+n6M0vn6Gmj6oFsO+VNTGv+N4Pb9G50K/BK5XwM51sD67T5m/gww0P1lQ5r9/+eI+YI4iv5xS3T/GCcG+pCtSvnCcb714GiG/ddrXv5DfuDy2bC5AOCaTv63HrL8YHOu/sgKpv/hk7T+DVtI/iapMvySVNL+Qbjk/fzZVQGUNnT4O1G4/f/0vv5DLq76EOi++DiS6v9i5mT/vSxi/PkdiPr4rCUBvUOU9f3XFvoEeDUDqQlu/2i0+vztKq76sIyy/bGx/P0af7L6I106+/EsdP38jE78HmqS+TEp8PQxMYL9BbgLAUnc9vredvb1WjhS/yjdtv1yDJ783pee+DS3jPi3iEEA+IV3AgtGuv/abyT9q+W8/4DD1PJyjRb8E9ns+1JORvwr81T+A2qe+pjdSQL/7N78b1Fi/y1j7P2M15T2fKBC/mproPsjxJz9KPGA9EPDrvkFP4759F6G/r9+kPwaBhj+EFec/zXEjPzCQSL8ybto/8JD5v/GqID7c86w/5QeXPsx1MkB8P1jAJxaeP8kaPr+kU8g9isu5vyRAUz86DXq/YNoDP7QuEj+u5jjAKfajPldknr2G0Ne/8OeOPyBiBr1sL5W/qqWNP3yEJEA4vq0/TAZ6v6H/Ij7sqbe/nlpjP/LMB8DGq8Q/zIN8P+K3yj9a988+JNzIPX1ypb+7yQi+tsX4P1lvZ74WpUk+RDS+v0BqEr+GTVK/GGuwPpOeOz8yBR2+QlYDQOpOrz4+DTg/zKwSPtY3JMBbOF8/hNg0PiAadL7ezW+/O4ltP6hfpL0PthG/7u92PqkDnL/bI3w/JZhPP1bZvb83cAfAkF61PnoOaL78w/6+dFCtv/uXDEBJUSxAGPs6Pj7DYD/COGi/MqHwvqCa6z+OKcs/7LZoP4yyfr/FWlc+tp+MvuxB/j5hSLs/nNYEv7zCJT9pTac/8pmXv9LtIb8eb8s/yDdkvwByAL3S590+IBKEvaTAvj9A4ew+/9w2PskC7z4QCKk/P1UWvkY1fb2i9na+kwNUv5kSyD/rU1I/alKAv9DTvz+MJZK/l++rP+IUHL+K/6s/', 'shape': [1, 768], 'dtype': '<f4'}}}]
}