diff --git a/.gitignore b/.gitignore
index eb0eec5c..57fe8ce0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@ python_env/
marimbabot_vision/scripts/data*
marimbabot_vision/scripts/test_data*
+marimbabot_vision/scripts/samples/*
marimbabot_vision/scripts/lightning_logs/
marimbabot_vision/scripts/model*
diff --git a/README.md b/README.md
index 08d1fba3..369f9e95 100644
--- a/README.md
+++ b/README.md
@@ -61,9 +61,9 @@ pip3 install wheel # firstly install the wheel for further package building if
pip3 install -r src/marimbabot/requirements.txt
#install the precise engine
-wget https://github.com/MycroftAI/mycroft-precise/releases/download/v0.3.0/precise-all_0.3.0_x86_64.tar.gz
-tar -zxvf precise-all_0.3.0_x86_64.tar.gz -C src/marimbabot/marimbabot_speech/utils/kws/
-rm precise-all_0.3.0_x86_64.tar.gz
+wget https://github.com/MycroftAI/mycroft-precise/releases/download/v0.3.0/precise-engine_0.3.0_x86_64.tar.gz
+tar -zxvf precise-engine_0.3.0_x86_64.tar.gz -C src/marimbabot/marimbabot_speech/utils/kws/
+rm precise-engine_0.3.0_x86_64.tar.gz
```
Now you are ready to go.
diff --git a/marimbabot_audio/README.assets/image-20231004154604276.png b/marimbabot_audio/README.assets/image-20231004154604276.png
new file mode 100644
index 00000000..4170c855
Binary files /dev/null and b/marimbabot_audio/README.assets/image-20231004154604276.png differ
diff --git a/marimbabot_audio/README.md b/marimbabot_audio/README.md
new file mode 100644
index 00000000..51b113e8
--- /dev/null
+++ b/marimbabot_audio/README.md
@@ -0,0 +1,154 @@
+# TAMS Master Project 2022/2023 - Music note detection
+
+## 1. Motivation
+
+The robot's performance of playing the marimba need to be evaluated. Therefore, this submodule is designed to can evaluate the music produced by robot through the audio feedback. It can detect the western music note by the raw audio input, and visualize it in form of midi figure and Constant-Q transform spectrum, also there is a final score for evaluate the final motion. Moreover, it can synthesis the music from lilypond.
+
+## 2. Dependencies
+
+### 2.1 python packages
+
+```
+abjad==3.4
+crepe==0.0.13
+librosa==0.9.2
+midi2audio==0.1.1
+numpy==1.23.3
+opencv-python==4.7.0.68
+pretty-midi==0.2.10
+tensorflow==2.13
+```
+
+### 2.2 Ros official packages
+
+- fluidsynth
+- audio_capture
+- sound_play
+
+## 3. Overview
+
+### 3.1 Folder overview
+
+The following folder tree structure show some important files, unimportant files are ignored.
+
+```bash
+├── launch
+│ ├── audio_feedback.launch # To lauch the music note detector, you can tune the parametes here.
+│ ├── marimbabot.launch # main launch file of this submodule.
+│ └── open_rviz.launch # the launch file the open the rviz.
+└── src
+ ├── audio_from_lilypond.py # music systhesis, test music sequence to audio music
+ ├── eval_visualization.py # midi visualiazation for evaluation, the mismatch the note wil be shown here.
+ ├── onset_detection.py # onset detection and music note classification, also spectrum visualization.
+ ├── onsets_visualization.py # the live midi visualization.
+ └── sequence_evaluation.py # music evaluation, compare the groundtruth with the robot performance.
+```
+
+### 3.2 Nodes overview
+
+
+
+- **/audio_node/node_audio_capture**: capture the music to raw data.
+- **/audio_node/onset_detector**: detect the music notes by raw data input.
+- **/audio_node/seq_eval**: compare the music notes sequence between the ground-truth and robot performance.
+- **/audio_node/onsetviz**: visualization of onset notes in format of live midi.
+- **/audio_node/evalviz**: visualization of evaluation result in format of a midi figure.
+- **/audio_from_lilypond**: synthesize the audio from lilypond, and play it.
+
+## 4. Pipeline of music note detection
+
+1. **Audio capture**:
+ The audio raw data will be captured to `/audio_node/audio_stamped`.
+2. **Music note detection**(`src/onset_detection.py`):
+ 1. Continuing **receive chunks** **of raw data** until 1 second data is gathered, and convert it to float array.
+ 2. Apply **Constant-Q Transform(CQT)** to it, that is, transforms a data series to the frequency domain. Its design is suited for musical representation.
+ 3. Then, use the **peak-pick based onset detection** to detect the frame of peaks, i.e. the candidates of music onset events .
+ 4. Select a chunk of CQT data after each frame of peaks, and send it to **music classification** model, i.e. [Crepe](https://github.com/marl/crepe), a monophonic pitch tracker based on DNN.
+ 5. **Visualize** the detected music note to CQT spectrum to `/audio_node/spectrogram_img` , also publish the detected notes to `/audio_node/onset_notes`.
+ 6. **Live midi visualization** of music notes: `src/onsets_visualization.py`.
+3. **Sequence evaluation**(`src/sequence_evaluation.py`):
+ 1. It will keep listen the detected note, and that is produced by the robot, and storage it to a cache list.
+ 2. Once a ground-truth is received from topic `/audio/hit_sequence`, it will retrieve the related music note in cache list according to the time slot of the ground-truth.
+ 3. Then compare two sequence and get the matched result. And publish the matched results to `/audio_node/match_result`.
+ 4. Visualization of evaluation results: `src/eval_visualization.py`
+
+## 5. Node list
+
+*Only the primary topic and node will be covered here.*
+
+- #### /audio_node/note_audio_capture:
+
+ - **Description**: capture the audio in format of raw data.
+ - **Output**:
+ - Message type: [`AudioDataStamped.msg`](https://github.com/ros-drivers/audio_common/blob/master/audio_common_msgs/msg/AudioDataStamped.msg)
+ - topic: `/audio_node/audio_stamped`
+ - Audio format: `wave` format in `bit-rate 44100` with `depth 16` and `mono channel`, and it correspond to S16LE of PCM sample formats.
+
+- #### /audio_node/onset_detector
+
+ - **Description**: detect the music note from the raw data, and publish the detected music note also the spectrum
+ - **Input**: `uint8[]` from topic `/audio_node/audio_stamped`
+ - **Output_1**: Music note
+ - Topic: `/audio_node/onset_notes`
+ - Message type: [`NoteOnset.msg`](../marimbabot_msgs/msg/NoteOnset.msg)
+ - **Output_2**: normalized Constant-Q transform spectrum include the onset event(vertical white line, aka. candidates) and detected music note(horizontal green line).
+ - Topic: `/audio_node/spectrogram_img`
+ - Message type: [`sensor_msgs/Image.msg`](http://docs.ros.org/en/noetic/api/sensor_msgs/html/msg/Image.html)
+
+- #### /audio_node/seq_eval
+
+ - **Description**: Compare the music sequence from the ground-truth and the robot-played music, calculate the score and visualize the match result.
+ - **Input_1**: ground-truth music sequence
+ - Topic: `/audio_node/hit_sequence`
+ - Message type: [`HitSequence.msg`](../marimbabot_msgs/msg/HitSequence.msg)
+ - **Input_2:** robot-played music note
+ - Topic: `/audio_node/onset_notes`
+ - Message type:[`NoteOnset.msg`](../marimbabot_msgs/msg/NoteOnset.msg)
+ - **output**: match results
+ - Topic: `/audio_node/match_result`
+ - Message type: [SequenceMatchResult.msg](../marimbabot_msgs/msg/SequenceMatchResult.msg)
+
+- #### /audio_node/onsetviz
+
+ - **Description**: visualize the music note in live midi
+ - **Input**: music notes
+ - Topic: `/audio_node/onset_notes`
+ - Message type:[`NoteOnset.msg`](../marimbabot_msgs/msg/NoteOnset.msg)
+ - **Output:** midi image
+ - Topic: `/audio_node/live_midi_img`
+ - Message type: [`sensor_msgs/Image.msg`](http://docs.ros.org/en/noetic/api/sensor_msgs/html/msg/Image.html)
+
+- #### /audio_node/evalviz
+
+ - **Description**: visualize the match result in live midi
+ - **Input**: match result
+ - Topic: `/audio_node/match_result`
+ - Message type: [SequenceMatchResult.msg](../marimbabot_msgs/msg/SequenceMatchResult.msg)
+ - **Output:** midi image
+ - Topic: `/audio_node/feedback_img`
+ - Message type: [`sensor_msgs/Image.msg`](http://docs.ros.org/en/noetic/api/sensor_msgs/html/msg/Image.html)
+
+- #### /audio_from_lilypond
+
+ - **Description**: Synthesize the music from lilypond
+ - **Input**: lilypond
+ - Action name: `audio_from_lilypond`
+ - Action type: [`LilypondAudio.action`](../marimbabot_msgs/action/LilypondAudio.action)
+
+## 6. Topic list
+
+| Topic | Description | Message Type |
+| --------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
+| /audio_node/audio | publish audio data | [AudioData.msg](https://github.com/ros-drivers/audio_common/blob/master/audio_common_msgs/msg/AudioData.msg) |
+| /audio_node/audio_info | description data of audio | [AudioInfo.msg](http://docs.ros.org/en/noetic/api/audio_common_msgs/html/msg/AudioInfo.html) |
+| /audio_node/audio_stamped | publish audio data with time stamp | [AudioDataStamped.msg](https://github.com/ros-drivers/audio_common/blob/master/audio_common_msgs/msg/AudioDataStamped.msg) |
+| /audio_node/compute_time | the computing time for music note detection of 1 sec data chunk | Float32 |
+| /audio_node/cqt | spectrum of constant-Q transform from raw data input | [CQTStamped](../marimbabot_msgs/msg/CQTStamped.msg) |
+| /audio_node/feedback_img | the MIDI figure of final evaluation | [Image.msg](http://docs.ros.org/en/noetic/api/sensor_msgs/html/msg/Image.html) |
+| /audio_node/live_midi_img | live MIDI figure to show the detection of music note | [Image.msg](http://docs.ros.org/en/noetic/api/sensor_msgs/html/msg/Image.html) |
+| /audio_node/match_result | the evaluation result between ground-truth and perceived music notes. | [SequenceMatchResult.msg](../marimbabot_msgs/msg/SequenceMatchResult.msg) |
+| /audio_node/onset_notes | the detected music note | [NoteOnset.msg](../marimbabot_msgs/msg/NoteOnset.msg) |
+| /audio_node/spectrogram_img | normalized constant-Q transform spectrum figure | [Image.msg](http://docs.ros.org/en/noetic/api/sensor_msgs/html/msg/Image.html) |
+| /sound_play | play the music | [SoundRequest.msg](http://docs.ros.org/en/jade/api/sound_play/html/msg/SoundRequest.html) |
+| /audio_from_lilypond | music synthesis from lilypond input | [LilypondAudio.action](../marimbabot_msgs/action/LilypondAudio.action) |
+
diff --git a/marimbabot_audio/config/config.rviz b/marimbabot_audio/config/config.rviz
deleted file mode 100644
index 9cc25646..00000000
--- a/marimbabot_audio/config/config.rviz
+++ /dev/null
@@ -1,147 +0,0 @@
-Panels:
- - Class: rviz/Displays
- Help Height: 0
- Name: Displays
- Property Tree Widget:
- Expanded:
- - /Global Options1
- - /Status1
- - /Image1
- - /Image2
- Splitter Ratio: 0.5
- Tree Height: 140
- - Class: rviz/Selection
- Name: Selection
- - Class: rviz/Tool Properties
- Expanded:
- - /2D Pose Estimate1
- - /2D Nav Goal1
- - /Publish Point1
- Name: Tool Properties
- Splitter Ratio: 0.5886790156364441
- - Class: rviz/Views
- Expanded:
- - /Current View1
- Name: Views
- Splitter Ratio: 0.5
- - Class: rviz/Time
- Name: Time
- SyncMode: 0
- SyncSource: Image
-Preferences:
- PromptSaveOnExit: true
-Toolbars:
- toolButtonStyle: 2
-Visualization Manager:
- Class: ""
- Displays:
- - Alpha: 0.5
- Cell Size: 1
- Class: rviz/Grid
- Color: 160; 160; 164
- Enabled: true
- Line Style:
- Line Width: 0.029999999329447746
- Value: Lines
- Name: Grid
- Normal Cell Count: 0
- Offset:
- X: 0
- Y: 0
- Z: 0
- Plane: XY
- Plane Cell Count: 10
- Reference Frame:
- Value: true
- - Class: rviz/Image
- Enabled: true
- Image Topic: /audio/spectrogram_img
- Max Value: 1
- Median window: 5
- Min Value: 0
- Name: Image
- Normalize Range: true
- Queue Size: 2
- Transport Hint: raw
- Unreliable: false
- Value: true
- - Class: rviz/Image
- Enabled: true
- Image Topic: /audio/midi_img
- Max Value: 1
- Median window: 5
- Min Value: 0
- Name: Image
- Normalize Range: true
- Queue Size: 2
- Transport Hint: raw
- Unreliable: false
- Value: true
- Enabled: true
- Global Options:
- Background Color: 48; 48; 48
- Default Light: true
- Fixed Frame: map
- Frame Rate: 30
- Name: root
- Tools:
- - Class: rviz/Interact
- Hide Inactive Objects: true
- - Class: rviz/MoveCamera
- - Class: rviz/Select
- - Class: rviz/FocusCamera
- - Class: rviz/Measure
- - Class: rviz/SetInitialPose
- Theta std deviation: 0.2617993950843811
- Topic: /initialpose
- X std deviation: 0.5
- Y std deviation: 0.5
- - Class: rviz/SetGoal
- Topic: /move_base_simple/goal
- - Class: rviz/PublishPoint
- Single click: true
- Topic: /clicked_point
- Value: true
- Views:
- Current:
- Class: rviz/Orbit
- Distance: 12.543999671936035
- Enable Stereo Rendering:
- Stereo Eye Separation: 0.05999999865889549
- Stereo Focal Distance: 1
- Swap Stereo Eyes: false
- Value: false
- Field of View: 0.7853981852531433
- Focal Point:
- X: 0
- Y: 0
- Z: 0
- Focal Shape Fixed Size: true
- Focal Shape Size: 0.05000000074505806
- Invert Z Axis: false
- Name: Current View
- Near Clip Distance: 0.009999999776482582
- Pitch: 0.785398006439209
- Target Frame:
- Yaw: 0.785398006439209
- Saved: ~
-Window Geometry:
- Displays:
- collapsed: false
- Height: 1016
- Hide Left Dock: false
- Hide Right Dock: false
- Image:
- collapsed: false
- QMainWindow State: 000000ff00000000fd0000000400000000000006840000035afc020000000efb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d000000c9000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261fb0000000a0049006d00610067006500000001f4000000d60000000000000000fb0000000a0049006d00610067006500000002d0000000c70000000000000000fb0000000a0049006d006100670065000000028b0000010c0000000000000000fb0000000a0049006d006100670065010000010c000001030000001600fffffffb0000000a0049006d0061006700650000000331000000660000000000000000fb0000000a0049006d0061006700650100000215000001820000001600ffffff000000010000010f000002b0fc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d000002b0000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000007380000003efc0100000002fb0000000800540069006d0065010000000000000738000003bc00fffffffb0000000800540069006d00650100000000000004500000000000000000000000ae0000035a00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
- Selection:
- collapsed: false
- Time:
- collapsed: false
- Tool Properties:
- collapsed: false
- Views:
- collapsed: false
- Width: 1848
- X: 72
- Y: 27
diff --git a/marimbabot_audio/examples/send_hit_sequence.py b/marimbabot_audio/examples/send_hit_sequence.py
deleted file mode 100644
index 42ca202d..00000000
--- a/marimbabot_audio/examples/send_hit_sequence.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import numpy as np
-import rospy
-
-from marimbabot_msgs.msg import HitSequence as HitSequenceMsg
-from marimbabot_msgs.msg import HitSequenceElement as HitSequenceElementMsg
-from marimbabot_msgs.msg import NoteOnset as NoteOnsetMsg
-
-
-def gererate_sequence_msg(seq:list,id=0,noisy=True):
- msgs_audio = []
- msgs_vision = HitSequenceMsg()
- msgs_vision.header.stamp = rospy.Time.now()
- msgs_vision.sequence_id = id
- msgs_vision.hit_sequence_elements = []
- time_start = rospy.Time.now()
- for idx, each in enumerate(seq):
- hit = HitSequenceElementMsg()
- '''
- string tone_name
- int32 octave
- time start_time
- duration tone_duration
- float32 loudness
- '''
- # build test msg for vision
- hit.tone_name = each[:-1]
- hit.octave = int(each[-1])
- interval = rospy.Duration(0.5)
- hit.start_time = rospy.Time(0.5*idx)
- hit.tone_duration = rospy.Duration(0.5)
- hit.loudness = 0.5
-
- if noisy:
- if idx > 0:
- hit.start_time += rospy.Duration(np.random.uniform(-0.5, 0.5))
- hit.loudness += np.random.uniform(-0.2, 0.2)
-
- msgs_vision.hit_sequence_elements.append(hit)
-
- # build test msg for audio
- msg_audio = NoteOnsetMsg()
- msg_audio.header.stamp = time_start + interval*idx
- msg_audio.loudness = 0.8
- msg_audio.confidence = 0.8
- msg_audio.duration = 0.5
- if noisy:
- msg_audio.header.stamp += rospy.Duration(np.random.uniform(-0.5,0.5))
- msg_audio.duration += np.random.uniform(-0.2,0.2)
-
- msg_audio.note = each
- msgs_audio.append(msg_audio)
-
- return msgs_audio, msgs_vision
-
-def send_hit_sequnce4test():
- seq = ['C4', 'C#4', 'D4', 'D#4', 'E4', 'F4', 'G4', 'G#4', 'A4', 'B4', 'C5']
- vision_pub = rospy.Publisher('/audio/hit_sequence', HitSequenceMsg, queue_size=10)
- audio_pub = rospy.Publisher('/audio/onset_notes', NoteOnsetMsg, queue_size=10)
-
- rate = rospy.Rate(10)
- n = 0
- while not rospy.is_shutdown():
- msgs_audio, msgs_vision = gererate_sequence_msg(seq,n)
- rospy.sleep(2)
- vision_pub.publish(msgs_vision)
- rospy.logdebug('publish vision msg')
- for each in msgs_audio:
- audio_pub.publish(each)
- rate.sleep()
- rospy.sleep(3)
- n += 1
-
-if __name__ == '__main__':
- rospy.init_node('send_hit_sequnce', log_level=rospy.DEBUG)
- send_hit_sequnce4test()
- # rospy.spin()
diff --git a/marimbabot_audio/launch/audio_feedback.launch b/marimbabot_audio/launch/audio_feedback.launch
index 82a33244..3f7b1a3e 100644
--- a/marimbabot_audio/launch/audio_feedback.launch
+++ b/marimbabot_audio/launch/audio_feedback.launch
@@ -12,7 +12,10 @@
-
+
+
+
+
diff --git a/marimbabot_audio/launch/open_rviz.launch b/marimbabot_audio/launch/open_rviz.launch
deleted file mode 100644
index 6dae1bca..00000000
--- a/marimbabot_audio/launch/open_rviz.launch
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
diff --git a/marimbabot_audio/scripts/dummy_client_audio_from_lilypond.py b/marimbabot_audio/scripts/dummy_client_audio_from_lilypond.py
deleted file mode 100644
index d5d812cb..00000000
--- a/marimbabot_audio/scripts/dummy_client_audio_from_lilypond.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import actionlib
-import rospy
-from std_msgs.msg import String
-
-from marimbabot_msgs.msg import LilypondAudioAction, LilypondAudioGoal
-
-
-def run_dummy():
- sentence = "c4 d4 e4"
- # send lilypond string to action server 'lilypond_audio_generation'
- client = actionlib.SimpleActionClient('audio_from_lilypond', LilypondAudioAction)
- # Waits until the action server has started up and started
- # listening for goals.
- client.wait_for_server()
-
- # Creates a goal to send to the action server.
- goal = LilypondAudioGoal(lilypond_string=String(data=sentence))
- # Sends the goal to the action server.
- client.send_goal(goal)
-
- # Waits for the server to finish performing the action.
- # Includes that the audio file is generated and was played
- client.wait_for_result()
-
- # Prints out the result of executing the action
- print(f"Result from audio_from_lilypond action server: {client.get_result()}")
-
-
-if __name__ == '__main__':
- rospy.init_node('dummy_audio_from_lilypond_client', anonymous=True)
- run_dummy()
diff --git a/marimbabot_audio/src/onset_detection.py b/marimbabot_audio/src/onset_detection.py
index a006c6ec..7fc45bf2 100644
--- a/marimbabot_audio/src/onset_detection.py
+++ b/marimbabot_audio/src/onset_detection.py
@@ -127,7 +127,8 @@ def init_detection_config(self):
confidence threshold for note classification(crepe)
"""
# For onset detection
- self.confidence_threshold = 0.7 # the threshold for note classification
+ self.confidence_threshold = rospy.get_param("~confidence_threshold")
+ self.amplitude_ref = rospy.get_param("~amplitude_ref", 10.0)
self.windows_for_classification = 0.1 # using 0.1 sec data after onset time for note classification
# preload model to not block the callback on first message
# capacities: 'tiny', 'small', 'medium', 'large', 'full'
@@ -262,7 +263,7 @@ def audio_process(self, msg):
cqt = self.cqt() # cqt ndarrary (60,173)
self.publish_cqt(cqt)
- onset_env_cqt = librosa.onset.onset_strength(sr=self.sr, S=librosa.amplitude_to_db(cqt, ref=np.max) )
+ onset_env_cqt = librosa.onset.onset_strength(sr=self.sr, S=librosa.amplitude_to_db(cqt, ref=self.amplitude_ref))
# detect when the onset(peak) happened within 2 sec cqt with shape (60,173)
'''
A sample n is selected as an peak if the corresponding x[n] fulfills the following three conditions:
diff --git a/marimbabot_behavior/src/marimbabot_behavior/behavior_node.py b/marimbabot_behavior/src/marimbabot_behavior/behavior_node.py
index de646512..8c3c51d8 100755
--- a/marimbabot_behavior/src/marimbabot_behavior/behavior_node.py
+++ b/marimbabot_behavior/src/marimbabot_behavior/behavior_node.py
@@ -149,44 +149,42 @@ def assign_volume(self, value='\\mp', override=False):
# changes the volume of the current sequence and updates the hit sequence
def change_volume(self, louder=True, value=1):
- # assign the default volume if there is no volume symbol in the sequence
- self.assign_volume('\\mf', override=False)
+ # assign a default volume if there is no volume symbol in the sequence
+ self.assign_volume('\\mf')
dynamics = ['\\ppp', '\\pp', '\\p', '\\mp', '\\mf', '\\f', '\\ff', '\\fff']
sequence_list = self.note_sequence.split(' ')
sequence_dynamics = [(i,x) for i, x in enumerate(sequence_list) if x in dynamics]
- # if there are already dynamic symbols in the sequence, swap them with the next louder/softer dynamic symbol
- if len(sequence_dynamics) > 0:
- # check if the volume can be increased/decreased by the specified value for all dynamic symbols
- if (louder and any(dynamics.index(x[1])+value > 7 for x in sequence_dynamics)) or (not louder and any(dynamics.index(x[1])-value < 0 for x in sequence_dynamics)):
- if louder:
- max_steps = min(7-dynamics.index(x[1]) for x in sequence_dynamics)
- if max_steps == 0:
- rospy.logwarn('Volume can not be increased any further.')
- self.response_pub.publish('Volume can not be increased any further.')
- else:
- rospy.logwarn('Volume can only be increased by {}.'.format(min(7-dynamics.index(x[1]) for x in sequence_dynamics)))
- self.response_pub.publish('Volume can only be increased by {}.'.format(min(7-dynamics.index(x[1]) for x in sequence_dynamics)))
+ # check if the volume can be increased/decreased by the specified value for all dynamic symbols
+ if (louder and any(dynamics.index(x[1])+value > 7 for x in sequence_dynamics)) or (not louder and any(dynamics.index(x[1])-value < 0 for x in sequence_dynamics)):
+ if louder:
+ max_steps = min(7-dynamics.index(x[1]) for x in sequence_dynamics)
+ if max_steps == 0:
+ rospy.logwarn('Volume can not be increased any further.')
+ self.response_pub.publish('Volume can not be increased any further.')
else:
- max_steps = min(dynamics.index(x[1]) for x in sequence_dynamics)
- if max_steps == 0:
- rospy.logwarn('Volume can not be decreased any further.')
- self.response_pub.publish('Volume can not be decreased any further.')
- else:
- rospy.logwarn('Volume can only be decreased by {}.'.format(max(dynamics.index(x[1]) for x in sequence_dynamics)))
- self.response_pub.publish('Volume can only be decreased by {}.'.format(max(dynamics.index(x[1]) for x in sequence_dynamics)))
- return 'fail'
-
- # change the volume of all dynamic symbols in the sequence
- for i, x in sequence_dynamics:
- new_dynamic = dynamics[min(dynamics.index(x)+value, 7)] if louder else dynamics[max(dynamics.index(x)-value, 0)]
- sequence_list[i] = new_dynamic
- self.note_sequence = ' '.join(sequence_list)
-
- rospy.logdebug(f"updated notes: {self.note_sequence}")
- self.update_hit_sequence()
- return 'success'
+ rospy.logwarn('Volume can only be increased by {}.'.format(min(7-dynamics.index(x[1]) for x in sequence_dynamics)))
+ self.response_pub.publish('Volume can only be increased by {}.'.format(min(7-dynamics.index(x[1]) for x in sequence_dynamics)))
+ else:
+ max_steps = min(dynamics.index(x[1]) for x in sequence_dynamics)
+ if max_steps == 0:
+ rospy.logwarn('Volume can not be decreased any further.')
+ self.response_pub.publish('Volume can not be decreased any further.')
+ else:
+ rospy.logwarn('Volume can only be decreased by {}.'.format(max(dynamics.index(x[1]) for x in sequence_dynamics)))
+ self.response_pub.publish('Volume can only be decreased by {}.'.format(max(dynamics.index(x[1]) for x in sequence_dynamics)))
+ return 'fail'
+
+ # change the volume of all dynamic symbols in the sequence
+ for i, x in sequence_dynamics:
+ new_dynamic = dynamics[min(dynamics.index(x)+value, 7)] if louder else dynamics[max(dynamics.index(x)-value, 0)]
+ sequence_list[i] = new_dynamic
+ self.note_sequence = ' '.join(sequence_list)
+
+ rospy.logdebug(f"updated notes: {self.note_sequence}")
+ self.update_hit_sequence()
+ return 'success'
def send_ground_truth_hit_sequence_to_audio(self, absolute_start_time: rospy.Time):
"""
diff --git a/marimbabot_bringup/launch/marimbabot.launch b/marimbabot_bringup/launch/marimbabot.launch
index d77a34d0..6729313b 100644
--- a/marimbabot_bringup/launch/marimbabot.launch
+++ b/marimbabot_bringup/launch/marimbabot.launch
@@ -1,7 +1,7 @@
-
-
+
+
diff --git a/marimbabot_description/urdf/marimba.urdf.xacro b/marimbabot_description/urdf/marimba.urdf.xacro
index 8985cc27..3db94bf3 100644
--- a/marimbabot_description/urdf/marimba.urdf.xacro
+++ b/marimbabot_description/urdf/marimba.urdf.xacro
@@ -16,11 +16,11 @@ From the point of view of the marimbist:
z: points towards the ceiling, according to right hand rule
Changelog:
+ 2023.01.04 First version created
+ 2023.01.19 Row of keys working
+ 2023.01.26 Complete model v1
2023.03.02 Named and correctly positioned bar frames
Added side boxes
- 2023.01.26 Complete model v1
- 2023.01.19 Row of keys working
- 2023.01.04 First version created
2023.06.14 Made side boxed width and height smaller
-->
@@ -101,7 +101,7 @@ Changelog:
-
+falsetrue
@@ -152,18 +152,36 @@ Changelog:
1
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -377,21 +395,21 @@ Changelog:
-
+
-
+
-
+
-
+
-
+
diff --git a/marimbabot_speech/README.assets/image-20231004163741787.png b/marimbabot_speech/README.assets/image-20231004163741787.png
new file mode 100644
index 00000000..a6e1039f
Binary files /dev/null and b/marimbabot_speech/README.assets/image-20231004163741787.png differ
diff --git a/marimbabot_speech/README.md b/marimbabot_speech/README.md
new file mode 100644
index 00000000..8a136ba0
--- /dev/null
+++ b/marimbabot_speech/README.md
@@ -0,0 +1,132 @@
+# TAMS Master Project 2022/2023 - Command recognition
+
+## 1. Motivation
+
+We need a interface for giving the command to robot, therefore we build a command recognition system to extract the command from human speech, and use speech synthesis for interaction feedback.
+
+## 2. Dependencies
+
+### 2.1 python packages
+
+```
+mycroft-mimic3-tts==0.2.4
+numpy==1.23.3
+pytorch_lightning==1.8.1
+torch==1.13.1
+openai-whisper==20230314
+webrtcvad==2.0.10
+playsound==1.3.0
+```
+
+### 2.2 Ros dependecies
+
+- audio_common_msgs
+- python3-precise-runner-pip
+- sound_play
+
+### 2.3 Precise_engine
+
+You need to download it and extract the binary file to `marimbabot_speech/utils/kws/`
+
+```bash
+#install the precise engine
+wget https://github.com/MycroftAI/mycroft-precise/releases/download/v0.3.0/precise-engine_0.3.0_x86_64.tar.gz
+tar -zxvf precise-*.tar.gz -C src/marimbabot/marimbabot_speech/utils/kws/
+rm precise-*.tar.gz
+```
+
+## 3. Overview
+
+### 3.1 Folder overview
+
+The following folder tree structure show some important files, unimportant files are ignored.
+
+```bash
+├── launch
+│ ├── command_recognition.launch # launch the command recognization
+│ ├── marimbabot.launch
+│ └── test_speech_syntesis.launch # launch the speech synthesis
+├── src
+│ ├── speech2text.py # speech recognization from speech raw data
+│ ├── speech_extraction.py # extract whole sentence of speech signal in format of raw data
+│ ├── text2command.py # extract the command from text
+│ └── text2speech.py # speech sythesis
+└── utils
+ └── kws # the folder for auxiliary file of keyword spotting
+ ├── hi-marimbabot.pb # the model file for keyword spotting
+ └── reminder.wav # a small music file to denote the keyword is spotted
+```
+
+### 3.2 Nodes overview
+
+
+
+- **/speech_node/audio_capture**: capture the speech to raw data.
+- **/speech_node/speech_extraction_node**: extract a whole sentence of speech in format of raw data. There is two primary parts, keyword spotting which is based on [mycroft-precise](https://github.com/MycroftAI/mycroft-precise), voice activity detection is based [webrtcvad](https://github.com/wiseman/py-webrtcvad).
+- **/speech_node/speech2text_node**: Speech transcription which is based on [whisper](https://github.com/openai/whisper).
+- **/speech_node/text2command_node**: extract the command by regular expression.
+- **/speech_tts_node**: speech synthesis.
+
+## 4. Pipeline of command recognition
+
+1. **Audio capture**:
+ The audio raw data will be captured to `/speech_node/audio_stamped`.
+2. **Speech extraction**(`Using some regular expression to extract the command from recognized text.`):
+ 1. Listening the keyword by the keyword spotting tools([mycroft-precise](https://github.com/MycroftAI/mycroft-precise)), once the keyword is spotted, you will hear a tone to remind you.
+ 2. Once the keyword is spotted, gather the data chunks by voice activity detection([webrtcvad](https://github.com/wiseman/py-webrtcvad)), until 1 second successive silence is reached, we can consider those chunks as a whole sentence, then forward it to the next step.
+3. **Speech Transcription**(`src/speech2text.py`)
+ 1. Using [whisper](https://github.com/openai/whisper) to transcript the speech to text by gathered data chunks.
+
+4. **Command extraction**(`src/text2command.py`):
+ 1. Using regular expression for command extraction.
+
+## 5. Command examples
+
+1. For speed setup:
+
+ 1. "play faster by 20 bpm"
+ `{"behavior": "play", "action": "increase speed", "parameter": "20"}`.
+ 2. "play slower by 40 bpm"
+ `{"behavior": "play", "action": "decrease speed", "parameter": "40"}`
+ 3. "play in 20 bpm"
+ `{"behavior": "play", "action": "setup speed", "parameter": "20"}`
+
+2. For volume setup:
+
+ 1. "play louder by 1 step"
+ `{"behavior": "play", "action": "increase volume", "parameter": "1"}`
+ 2. "play softer by 2 steps"
+ `{"behavior": "play", "action": "decrease volume", "parameter": "2"}`
+
+3. "play in the loop":
+ `{"behavior": "play", "action": "loop", "parameter": ""}`.
+
+4. music synthesis:
+ `play` of behavior can be replaced by `preview`, that means, synthesize the music and play it, such as:
+
+ "preview faster by 20 bpm"
+ `{"behavior": "preview", "action": "setup speed", "parameter": "20"}`
+ or "preview in the loop"
+ `{"behavior": "preview", "action": "loop", "parameter": ""}`.
+
+5. other:
+
+ 1. "stop": stop the motion of the robot or the synthesized music.
+ `{"behavior": "stop", "action": "", "parameter": ""}`.
+ 2. "read": read the music note from white board.
+ `{"behavior": "read", "action": "", "parameter": ""}`.
+
+## 6. Topic list
+
+| Topic name | Description | Message type |
+| -------------------------- | ------------------------------------------------ | ------------------------------------------------------------ |
+| /robotsound | publish sound to audio node (sound_play package) | [SoundRequest.msg](http://docs.ros.org/en/jade/api/sound_play/html/msg/SoundRequest.html) |
+| /speech_node/audio | publish audio data | [AudioData.msg](https://github.com/ros-drivers/audio_common/blob/master/audio_common_msgs/msg/AudioData.msg) |
+| /speech_node/audio_info | description data of audio | [AudioInfo.msg](http://docs.ros.org/en/noetic/api/audio_common_msgs/html/msg/AudioInfo.html) |
+| /speech_node/audio_stamped | publish audio data with time stamp | [AudioDataStamped.msg](https://github.com/ros-drivers/audio_common/blob/master/audio_common_msgs/msg/AudioDataStamped.msg) |
+| /speech_node/command | publish command for robot | [Command.msg](../marimbabot_msgs/msg/Command.msg) |
+| /speech_node/speech | publish result of speech transcription | [Speech.msg](../marimbabot_msgs/msg/Speech.msg) |
+
+## 7. Training and using of keyword spotting model
+
+please refer to this branch: [kws_training](https://github.com/UHHRobotics22-23/marimbabot/tree/kws_training).
\ No newline at end of file
diff --git a/marimbabot_speech/launch/command_recognition.launch b/marimbabot_speech/launch/command_recognition.launch
index 62c14c9d..ce532d5b 100644
--- a/marimbabot_speech/launch/command_recognition.launch
+++ b/marimbabot_speech/launch/command_recognition.launch
@@ -5,7 +5,7 @@
-
+
@@ -13,15 +13,17 @@
-
+
-
+
-
+
+
+
diff --git a/marimbabot_speech/src/speech2text.py b/marimbabot_speech/src/speech2text.py
index bb8ab5b7..dcc736de 100755
--- a/marimbabot_speech/src/speech2text.py
+++ b/marimbabot_speech/src/speech2text.py
@@ -60,7 +60,7 @@ def unpack_stream(self, data):
def generate_prompt(self):
prompt = "Marimbabot is a instrument playing robot arm. You are able to give it several common robot's commands," \
- "for example, play in 60 BPM, play louder by 2 step, play in a loop, review faster by 30 bpm ..."
+ "for example, play in 60 BPM, play louder by 2 step, play in a loop, preview in by 30 bpm and so on. So the detected text is:"
return prompt
def run(self):
diff --git a/marimbabot_speech/src/text2command.py b/marimbabot_speech/src/text2command.py
index e9761ed1..244ccbff 100755
--- a/marimbabot_speech/src/text2command.py
+++ b/marimbabot_speech/src/text2command.py
@@ -70,7 +70,7 @@ def __init__(self):
]
self.command_pub = rospy.Publisher('/speech_node/command', CommandMsg, queue_size=100, tcp_nodelay=True)
self.speech_sub = rospy.Subscriber('/speech_node/speech', SpeechMsg, self.speech_callback, queue_size=10, tcp_nodelay=True)
-
+ self.no_speech_prob_threshold = rospy.get_param('~no_speech_prob_threshold', 0.5)
def fill_into_command(self, behavior="", action="", parameters=""):
if behavior != "":
self.template["behavior"] = behavior
@@ -81,8 +81,8 @@ def fill_into_command(self, behavior="", action="", parameters=""):
def speech_callback(self, req):
text = req.text
- if req.no_speech_prob > 0.5:
- rospy.logwarn("even no_speech_prob > 0.5, but is passed the wad, something weird happened, ignore this speech")
+ if req.no_speech_prob > self.no_speech_prob_threshold:
+ rospy.logwarn(f"whisper's no_speech_prob > {self.no_speech_prob_threshold}, even is passed the wad tool, so abandon this text: {text}")
return
command = self.extract(text)
rospy.logdebug(f"command extracted: {command}")
@@ -182,5 +182,4 @@ def extract(self, text: str):
if __name__ == '__main__':
rospy.init_node('command_extractor',log_level=rospy.DEBUG)
command_extractor = CommandExtraction()
- examples_test(command_extractor)
rospy.spin()
diff --git a/marimbabot_vision/scripts/README.md b/marimbabot_vision/scripts/README.md
index daa7fdc0..341d5123 100644
--- a/marimbabot_vision/scripts/README.md
+++ b/marimbabot_vision/scripts/README.md
@@ -42,10 +42,13 @@ This script is used when collection real world data, e.g. from a webcam that is
1. Run the generate_dataset_full.sh
2. Copy some samples into the folder `data_real`
-3. Run this script and use 'whitespace' and 'c' for starting the camera and collecting images.
+3. Run this script and use 'whitespace' and 'c' for starting the camera and collecting images.
### `train.py`
Trains a model on the a set of given `train_data_paths`.
+### `train_tokenizer.py`
+Trains the tokenizer on all text files defined by a glob expression.
+
### `detect.py`
This script is used for live detection of notes. A trained model can be used to initialize. The current model is stored at HuggingFace and its path/name is set by the `MODEL_PATH` parameter inside `config/vision_node.yaml` The detected notes are shown in a window.
\ No newline at end of file
diff --git a/marimbabot_vision/scripts/generate_augmented_data.py b/marimbabot_vision/scripts/generate_augmented_data.py
index a2cf2060..0c45a194 100644
--- a/marimbabot_vision/scripts/generate_augmented_data.py
+++ b/marimbabot_vision/scripts/generate_augmented_data.py
@@ -15,14 +15,14 @@
AUGMENT_OUTPUT_DIR = "data_augmented"
TRANSFORMATIONS = [
- A.Affine(translate_px={"y":10, "x":10}, scale=[0.5, 1.0], rotate=[-3,3], mode=1, always_apply=True),
+ A.Affine(translate_px={"y":7, "x":7}, scale=[0.7, 0.9], rotate=[-3,3], mode=1, always_apply=True),
A.Perspective(always_apply=True),
A.RandomBrightnessContrast(),
A.RandomShadow(shadow_roi=(0, 0, 1, 1), num_shadows_upper=4, shadow_dimension=8),
A.RandomSunFlare(flare_roi=(0, 0, 1, 1), src_radius=100),
A.PixelDropout(),
A.RGBShift(),
- A.MedianBlur(blur_limit=3,),
+ #A.MedianBlur(blur_limit=3,),
A.ZoomBlur(max_factor=1.03)
]
@@ -49,7 +49,8 @@ def augment_sample(i, args):
cv2.imwrite(os.path.join(augmented_path, f"staff_1.png"), new_img)
shutil.copy(orig_txt_path, os.path.join(augmented_path, f"staff_1.txt"))
- shutil.copy(orig_ly_path, os.path.join(augmented_path, f"staff_1.ly"))
+ if os.path.exists(orig_ly_path):
+ shutil.copy(orig_ly_path, os.path.join(augmented_path, f"staff_1.ly"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Augmented data generation of input data.")
@@ -62,4 +63,3 @@ def augment_sample(i, args):
# Call augment_sample on ids with tqdm and multiprocessing
with Pool(args.num_worker) as pool:
list(tqdm.tqdm(pool.imap(partial(augment_sample, args=args), os.listdir(args.input_dir)), total=len(os.listdir(args.input_dir))))
-
\ No newline at end of file
diff --git a/marimbabot_vision/scripts/generate_data.py b/marimbabot_vision/scripts/generate_data.py
index d9aa92af..60affe9a 100644
--- a/marimbabot_vision/scripts/generate_data.py
+++ b/marimbabot_vision/scripts/generate_data.py
@@ -58,13 +58,13 @@ def __init__(self, include_dynamics=True, include_slurs=True, include_articulati
"""
def note_sampler(self, duration, is_dotted = False):
first_note = random.choice(self.music_notes + self.rests)
-
+
octave = choice(["", "'", "''"], p=[0.0, 0.8, 0.2]) if first_note != 'r' else ''
note = first_note + random.choice(self.accidentals) if first_note != 'r' and random.random() < 0.2 else first_note
retNote = note + octave
- if self.chords and random.random() < 0.1 and note != 'r':
+ if self.chords and random.random() < 0.5 and note != 'r':
second_note = random.choice([n for n in self.music_notes if n != first_note])
second_note = second_note + random.choice(self.accidentals) if random.random() < 0.2 else second_note
retNote = "<" + note + octave + " " + second_note + octave + ">"
@@ -77,7 +77,7 @@ def sample_duration(durations, level=0):
"""Randomly subdivides a list of durations into smaller durations"""
prop = 1/3
new_durations = []
- for duration in durations:
+ for duration in durations:
if duration < self.min_duration and random.random() > prop:
new_durations.extend(sample_duration([duration * 2, duration * 2], level + 1))
else:
@@ -98,14 +98,14 @@ def sample_duration(durations, level=0):
new_durations.append((new_duration * 2, False))
else:
new_durations.append((duration, False))
-
+
# shuffle the durations
random.shuffle(new_durations)
return new_durations
return ' '.join([self.note_sampler(str(duration), is_dotted=is_dotted) for duration, is_dotted in sample_duration([1,])])
-
+
"""
creates a random number of articulations
"""
@@ -222,9 +222,12 @@ def generate_piece(self,num_bars=3,):
return string, staff_1
-
+
"""Generate a sample and save it to disk"""
def generate_sample(i, args):
+ # Skip if the sample already exists
+ if os.path.isfile(f"{args.output_dir}/{i}/staff_1.png") and os.path.isfile(f"{args.output_dir}/{i}/staff_1.txt"):
+ return
lilypondGenerator = args.lilypondGenerator
string, staff = lilypondGenerator.generate_piece()
os.makedirs(f"{args.output_dir}/{i}", exist_ok=True)
@@ -240,6 +243,10 @@ def generate_sample(i, args):
)
# save the lilypond file and rotate by 90 degrees
as_png(lilypond_file, f"{args.output_dir}/{i}/staff_1.png", resolution=200)
+ # check if the file exists (it can be case that lilypond renders two pages, we want to regenerate the sample in this case)
+ if not os.path.isfile(f"{args.output_dir}/{i}/staff_1.png"):
+ generate_sample(i, args)
+ return
# turn png by 90 degrees
im = Image.open(f"{args.output_dir}/{i}/staff_1.png")
im = im.rotate(-90, expand=True)
diff --git a/marimbabot_vision/scripts/generate_hw_data.py b/marimbabot_vision/scripts/generate_hw_data.py
index 1feeeb70..247c1b49 100644
--- a/marimbabot_vision/scripts/generate_hw_data.py
+++ b/marimbabot_vision/scripts/generate_hw_data.py
@@ -69,10 +69,10 @@ def check_bar(image, x_pos, y_offset, duration_counter, key, bar_accidentals, ar
def get_note_pose(note, octave):
# head_pos: y-position of the note head, y_pos: y-position of the note image
- head_pos = head_positions[note] - octave*35
+ head_pos = head_positions[note] - octave*35
is_flipped = head_pos < 70
y_pos = head_pos if is_flipped else head_pos-30
- return y_pos, is_flipped
+ return y_pos, is_flipped
def extend_staff(image, x_pos, y_pos, y_offset):
# draw extra lines above staff for notes higher than g''
@@ -148,6 +148,7 @@ def generate_sample_image(key, tempo, args):
return image
def draw_piece(string, sample_name, args):
+ # split string into list of enteties
piece = string.split()
# get key
@@ -170,7 +171,7 @@ def draw_piece(string, sample_name, args):
repeat = True
piece = piece[:repeat_index] + piece[repeat_index + 3:]
- # initialize variables
+ # initialize variables
x_pos = 70 + (key_flats_num[key] if key in key_flats_num.keys() else key_sharps_num[key])*20 + (30 if tempo else 0) + randint(args.min_symbol_dist, args.max_symbol_dist)
key_accidentals = get_key_accidentals(key)
bar_accidentals = {'sharps': [], 'flats': [], 'naturals': []}
@@ -181,7 +182,7 @@ def draw_piece(string, sample_name, args):
# generate sample image
sample_im = generate_sample_image(key, tempo, args)
# piece = [(n[0], n[1], n.count('\''), int((re.findall(r'\d+', n)[0])), n.count('.')) for n in string.split()]
-
+
while index < len(piece):
rule = piece[index]
@@ -279,7 +280,7 @@ def draw_piece(string, sample_name, args):
else:
accidentals[i] = ''
-
+
x_pos += 20 if 'f' in accidentals or 's' in accidentals or 'n' in accidentals else 0
# check for accents
@@ -309,7 +310,7 @@ def draw_piece(string, sample_name, args):
dynamic_y_pos += 15
draw_dynamics(sample_im, piece[index], x_pos-10, dynamic_y_pos, args)
index += 1
-
+
# draw note head(s)
for i in range(len(tones)):
# prevent overlapping note heads if notes are too close
@@ -317,7 +318,7 @@ def draw_piece(string, sample_name, args):
if i == 1 and y_head_poses[0] - y_head_poses[1] <= 5:
x_pos += 15
note_is_flipped = True
-
+
extend_staff(sample_im, x_pos, y_head_poses[i], y_offset)
if duration < 4:
draw_symbol(sample_im, f'{args.hw_symbols_dir}/head/empty', (x_pos, y_head_poses[i]), note_is_flipped)
@@ -349,7 +350,7 @@ def draw_piece(string, sample_name, args):
# handle single note
else:
attachment_point = (x_pos-10, y_head_poses[0]) if is_flipped else (x_pos+5, y_head_poses[0]-30)
-
+
# draw stem with correct amount of flags
if duration <=4:
draw_symbol(sample_im, f'{args.hw_symbols_dir}/stem/4', attachment_point, is_flipped, True)
@@ -363,7 +364,7 @@ def draw_piece(string, sample_name, args):
y_pos = y_head_poses[0] +5 if y_head_poses[0] % 10 else y_head_poses[0]
x_pos += 25 if duration > 4 and not is_flipped else 20
draw_symbol(sample_im, f'{args.hw_symbols_dir}/dot', (x_pos, y_pos))
-
+
if len(tones) > 1:
y_pos = y_head_poses[1] - 5 if y_head_poses[1] % 10 else y_head_poses[1]
x_pos += 5 if duration > 4 and not is_flipped else 0
@@ -385,22 +386,23 @@ def draw_piece(string, sample_name, args):
# check if current bar is full
if index < len(piece):
x_pos, y_offset, duration_counter, bar_accidentals = check_bar(sample_im, x_pos, y_offset, duration_counter, key, bar_accidentals, args)
-
- # draw final bar line
+
+ # draw final bar line
elif repeat == False:
draw_symbol(sample_im, f'{args.hw_symbols_dir}/big-bar', (x_pos,50+y_offset))
-
+
# draw repeat sign
else:
draw_symbol(sample_im, f'{args.hw_symbols_dir}/dot', (x_pos,60+y_offset))
draw_symbol(sample_im, f'{args.hw_symbols_dir}/dot', (x_pos,70+y_offset))
draw_symbol(sample_im, f'{args.hw_symbols_dir}/bar', (x_pos,50+y_offset))
draw_symbol(sample_im, f'{args.hw_symbols_dir}/big-bar', (x_pos+10,50+y_offset))
-
- os.makedirs(f'{args.output_dir}/{sample_name}', exist_ok=True)
+
+ os.makedirs(f'{args.output_dir}/{sample_name}', exist_ok=True)
sample_im.convert('RGB').save(f'{args.output_dir}/{sample_name}/{args.name_prefix}.png','PNG')
shutil.copyfile(f'{args.input_dir}/{sample_name}/{args.name_prefix}.txt', f'{args.output_dir}/{sample_name}/{args.name_prefix}.txt')
- shutil.copyfile(f'{args.input_dir}/{sample_name}/{args.name_prefix}.ly', f'{args.output_dir}/{sample_name}/{args.name_prefix}.ly')
+ if os.path.isfile(f'{args.input_dir}/{sample_name}/{args.name_prefix}.ly'):
+ shutil.copyfile(f'{args.input_dir}/{sample_name}/{args.name_prefix}.ly', f'{args.output_dir}/{sample_name}/{args.name_prefix}.ly')
def render(sample, args):
with open(f'{args.input_dir}/{sample}/staff_1.txt', 'r') as f:
@@ -416,7 +418,7 @@ def render(sample, args):
parser.add_argument("--min_symbol_dist", type=int, required=False, help="Minimum distance between notes.", default=MIN_SYMBOL_DIST)
parser.add_argument("--max_symbol_dist", type=int, required=False, help="Maximum distance between notes.", default=MAX_SYMBOL_DIST)
parser.add_argument("--min_line_dist", type=int, required=False, help="Minimum distance between staff lines.", default=MIN_LINE_DIST)
- parser.add_argument("--max_line_dist", type=int, required=False, help="Maximum distance between staff lines.", default=MAX_LINE_DIST)
+ parser.add_argument("--max_line_dist", type=int, required=False, help="Maximum distance between staff lines.", default=MAX_LINE_DIST)
parser.add_argument("--input_dir", type=str, required=False, help="Folder for the input data.", default=INPUT_DIR)
parser.add_argument("--output_dir", type=str, required=False, help="Folder for the output data.", default=OUTPUT_DIR)
diff --git a/marimbabot_vision/scripts/tokenizer.json b/marimbabot_vision/scripts/tokenizer.json
new file mode 100644
index 00000000..81e9bbc2
--- /dev/null
+++ b/marimbabot_vision/scripts/tokenizer.json
@@ -0,0 +1,97 @@
+{
+ "version": "1.0",
+ "truncation": null,
+ "padding": null,
+ "added_tokens": [
+ {
+ "id": 0,
+ "content": "",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ }
+ ],
+ "normalizer": {
+ "type": "BertNormalizer",
+ "clean_text": true,
+ "handle_chinese_chars": true,
+ "strip_accents": null,
+ "lowercase": false
+ },
+ "pre_tokenizer": {
+ "type": "WhitespaceSplit"
+ },
+ "post_processor": null,
+ "decoder": {
+ "type": "BPEDecoder",
+ "suffix": ""
+ },
+ "model": {
+ "type": "BPE",
+ "dropout": null,
+ "unk_token": null,
+ "continuing_subword_prefix": null,
+ "end_of_word_suffix": "",
+ "fuse_unk": false,
+ "vocab": {
+ "": 0,
+ "'": 1,
+ ".": 2,
+ "0": 3,
+ "1": 4,
+ "2": 5,
+ "4": 6,
+ "6": 7,
+ "8": 8,
+ "9": 9,
+ "<": 10,
+ "=": 11,
+ ">": 12,
+ "\\": 13,
+ "a": 14,
+ "b": 15,
+ "c": 16,
+ "d": 17,
+ "e": 18,
+ "f": 19,
+ "g": 20,
+ "i": 21,
+ "j": 22,
+ "k": 23,
+ "l": 24,
+ "m": 25,
+ "n": 26,
+ "o": 27,
+ "p": 28,
+ "r": 29,
+ "s": 30,
+ "t": 31,
+ "v": 32,
+ "y": 33,
+ ".": 34,
+ "2": 35,
+ "8": 36,
+ "'": 37,
+ "4": 38,
+ "6": 39,
+ "1": 40,
+ "c": 41,
+ "f": 42,
+ "b": 43,
+ "g": 44,
+ "r": 45,
+ "0": 46,
+ "a": 47,
+ "e": 48,
+ "p": 49,
+ "t": 50,
+ "d": 51,
+ "o": 52,
+ "y": 53
+ },
+ "merges": [
+ ]
+ }
+}
\ No newline at end of file
diff --git a/marimbabot_vision/scripts/train.py b/marimbabot_vision/scripts/train.py
index 64030249..34bc5988 100644
--- a/marimbabot_vision/scripts/train.py
+++ b/marimbabot_vision/scripts/train.py
@@ -9,32 +9,34 @@
import torch
from nltk import edit_distance
from PIL import Image
+from pytorch_lightning.callbacks import Callback
+from tokenizers import Tokenizer
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from tqdm import tqdm
from transformers import (DonutProcessor, VisionEncoderDecoderConfig,
- VisionEncoderDecoderModel)
+ VisionEncoderDecoderModel, PreTrainedTokenizerFast)
# Config
config = {
- "max_epochs": 20,
+ "max_epochs": 10,
"check_val_every_n_epoch": 1,
"gradient_clip_val":1.0,
"lr":1e-5,
- "train_batch_sizes": [6],
- "val_batch_sizes": [6],
+ "train_batch_sizes": [12],
+ "val_batch_sizes": [12],
"num_nodes": 1,
- "warmup_steps": 300,
+ "warmup_steps": 1000,
"result_path": "./result",
"verbose": True,
- "train_data_paths": ["data_hw/", "data_augmented/", "data_wb/"],
- "val_data_paths": ["test_data/", ],
- "max_length": 40,
+ "train_data_paths": ["data_accord", "data_keys", "data_lilypond_dynamics", "data_negative", "data_wb_basic", "data_wb_extended", "data_lilypond_augmented", "data_hw_dynamics_augmented"],
+ "val_data_paths": ["test_data_hw_extended", "test_data_extended", "test_data_wb"],
+ "max_length": 60,
"image_size": [583, 409],
"start_token": "",
"num_workers": 24,
+ "tokenizer_path": "./tokenizer.json", # Or None if the tokenizer should not be changed
"base_model": "nielsr/donut-base",
- "output_model": "./model_1",
- "add_note_vocab": True
+ "output_model": "./model_extended_2"
}
# Load base model
@@ -54,27 +56,38 @@
ignore_mismatched_sizes=True,
config=ved_config)
-if config['add_note_vocab']:
- note_vocab = []
- notes = "cdefgab"
- octaves = ["'", "''"]
- durations = [1, 2, 4, 8, 16]
- # Add note tokens
- for note in notes:
- for octave in octaves:
- for duration in durations:
- note_vocab.append(f"{note}{octave}{duration} ")
- # Add rest tokens
- for duration in durations:
- note_vocab.append(f"r{duration} ")
-else:
- note_vocab = []
+# Swap tokenizer with our own character-level tokenizer if a path is provided
+if config['tokenizer_path'] is not None:
+ # Save the special tokens of the original tokenizer
+ eos_token = pre_processor.tokenizer.eos_token
+ pad_token = pre_processor.tokenizer.pad_token
+ # Create a new tokenizers tokenizer based on the provided tokenizer config
+ pre_processor.tokenizer = PreTrainedTokenizerFast(
+ tokenizer_object=Tokenizer.from_file(config['tokenizer_path']))
+ # Add the special tokens back
+ pre_processor.tokenizer.add_special_tokens({
+ "eos_token": eos_token,
+ "pad_token": pad_token,
+ "unk_token": "",
+ "bos_token": config['start_token']
+ })
+ # Update the special token properties
+ pre_processor.tokenizer.eos_token = eos_token
+ pre_processor.tokenizer.pad_token = pad_token
+ # Update the special token ids
+ pre_processor.tokenizer.unk_token_id = pre_processor.tokenizer.convert_tokens_to_ids([''])[0]
+ pre_processor.tokenizer.pad_token_id = pre_processor.tokenizer.convert_tokens_to_ids([''])[0]
+ pre_processor.tokenizer.bos_token_id = pre_processor.tokenizer.convert_tokens_to_ids([config['start_token']])[0]
+ # Resize the model embedding layer
+ model.decoder.resize_token_embeddings(len(pre_processor.tokenizer))
+
+# Set the special tokens of the model
model.config.pad_token_id = pre_processor.tokenizer.pad_token_id
-model.config.decoder_start_token_id = pre_processor.tokenizer.convert_tokens_to_ids([config['start_token']])[0]
+model.config.decoder.pad_token_id = pre_processor.tokenizer.pad_token_id
+model.config.decoder_start_token_id = pre_processor.tokenizer.bos_token_id
# Create dataset
-
class NoteDataset(Dataset):
def __init__(
self,
@@ -102,7 +115,27 @@ def __init__(
ground_truth + " " + pre_processor.tokenizer.eos_token
)
- self.add_tokens(note_vocab + [self.start_token])
+ self.add_tokens([
+ self.start_token,
+ "\\repeat ",
+ "volta ",
+ "\\key ",
+ "\\minor ",
+ "\\major ",
+ "\\tempo ",
+ "4=40 ",
+ "4=60 ",
+ "4=96 ",
+ "4=120 ",
+ "- \\marcato "
+ '\\ppp ',
+ '\\pp ',
+ '\\p ',
+ '\\mp ',
+ '\\mf ',
+ '\\f ',
+ '\\ff ',
+ '\\fff '])
def add_tokens(self, list_of_tokens: List[str]):
newly_added_num = pre_processor.tokenizer.add_tokens(list_of_tokens)
@@ -197,7 +230,7 @@ def validation_step(self, batch, batch_idx, dataset_idx=0):
scores = list()
for pred, answer in zip(predictions, answers):
- pred = pred.replace(self.pre_processor.tokenizer.eos_token, "")[3:]
+ pred = pred.replace(self.pre_processor.tokenizer.eos_token, "").replace(self.pre_processor.tokenizer.bos_token, "")
answer = answer.replace(self.pre_processor.tokenizer.eos_token, "")[:len(pred)]
score = edit_distance(pred, answer)
scores.append(score)
@@ -239,6 +272,13 @@ def val_dataloader(self):
# Instantiate pytorch lightning module
model_module = DonutModelPLModule(config, pre_processor, model, train_dataloader, val_dataloader)
+# Create callback to save the model after each epoch
+class SaveCallback(Callback):
+ def on_train_epoch_end(self, trainer, pl_module):
+ pl_module.model.save_pretrained(pl_module.config['output_model'])
+ pl_module.pre_processor.save_pretrained(pl_module.config['output_model'])
+ ved_config.save_pretrained(pl_module.config['output_model'])
+
# Instantiate pytorch lightning trainer
trainer = pl.Trainer(
accelerator="gpu",
@@ -248,12 +288,8 @@ def val_dataloader(self):
gradient_clip_val=config.get("gradient_clip_val"),
precision=16, # we'll use mixed precision
num_sanity_val_steps=0,
+ callbacks=[SaveCallback()],
)
# Train the model
trainer.fit(model_module)
-
-# Save the model and tokenizer
-model.save_pretrained(config['output_model'])
-pre_processor.save_pretrained(config['output_model'])
-ved_config.save_pretrained(config['output_model'])
diff --git a/marimbabot_vision/scripts/train_tokenizer.py b/marimbabot_vision/scripts/train_tokenizer.py
new file mode 100644
index 00000000..266b6e65
--- /dev/null
+++ b/marimbabot_vision/scripts/train_tokenizer.py
@@ -0,0 +1,19 @@
+import glob
+from tokenizers import CharBPETokenizer
+
+training_file_search_expr = "data_*/*/*.txt"
+vocab_size = 53
+output_path = "./tokenizer.json"
+
+
+# Search for all text / label files that match the glob expression
+training_files = glob.glob(training_file_search_expr)
+
+# Initialize a tokenizer
+tokenizer = CharBPETokenizer(split_on_whitespace_only=True)
+
+# Train it on the training files
+tokenizer.train(training_files, vocab_size=vocab_size, min_frequency=10)
+
+# Save it
+tokenizer.save("./tokenizer.json")
diff --git a/requirements.txt b/requirements.txt
index 5e040173..50b1a3a2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,7 +16,7 @@ rospkg==1.5.0
tensorflow
torch==1.13.1
tqdm==4.64.0
-transformers[sentencepiece]==4.26.0
+transformers[sentencepiece]==4.34.0
catkin_pkg==0.5.2
openai-whisper==20230314
webrtcvad==2.0.10