Skip to content

Commit

Permalink
Apply formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Nov 7, 2024
1 parent e346f14 commit 75d8398
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions ddlitlab2024/dataset/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@


class DDLITLab2024Dataset(Dataset):
def __init__(self,
data_base_path: str,
sample_rate_imu: int = 100,
num_samples_imu: int = 100,
sample_rate_joint_states: int = 100,
num_samples_joint_states: int = 100,
sample_rate_joint_trajectory: int = 100,
num_samples_joint_trajectory: int = 100,
num_samples_joint_trajectory_future: int = 10,
max_fps_video: int = 10,
num_frames_video: int = 50,
trajectory_stride: int = 10,
):
def __init__(
self,
data_base_path: str,
sample_rate_imu: int = 100,
num_samples_imu: int = 100,
sample_rate_joint_states: int = 100,
num_samples_joint_states: int = 100,
sample_rate_joint_trajectory: int = 100,
num_samples_joint_trajectory: int = 100,
num_samples_joint_trajectory_future: int = 10,
max_fps_video: int = 10,
num_frames_video: int = 50,
trajectory_stride: int = 10,
):
# Store the parameters
self.sample_rate_imu = sample_rate_imu
self.num_samples_imu = num_samples_imu
Expand All @@ -32,7 +33,7 @@ def __init__(self,
self.max_fps_video = max_fps_video
self.num_frames_video = num_frames_video
self.trajectory_stride = trajectory_stride

# The Data exists in a sqlite database
self.data_base_path = data_base_path

Expand Down Expand Up @@ -66,29 +67,28 @@ def __init__(self,
self.num_samples += int((recording_length - self.sample_length_s) / self.stride_s)
# Store the boundaries of the samples for later retrieval
self.sample_boundaries.append((total_samples_before, self.num_samples, recording_id, start_timestamp))


def __len__(self):
return self.num_samples

def __getitem__(self, idx):
# Find the recording that contains the sample
boundary = None
boundary = None
for start_sample, end_sample, recording_id, st in self.sample_boundaries:
if idx >= start_sample and idx < end_sample:
boundary = (recording_id, st)
break
assert boundary is not None, "Could not find the recording that contains the sample"
recording_id, start_timestamp = boundary

# Calculate the timestamp of the sample
sample_timestamp = start_timestamp + (idx - boundary[0]) * self.stride_s
sample_timestamp_future = sample_timestamp + self.num_samples_joint_trajectory_future / self.sample_rate_joint_trajectory
sample_timestamp_future = (
sample_timestamp + self.num_samples_joint_trajectory_future / self.sample_rate_joint_trajectory
)

# Get the joint command
raw_joint_command = pd.read_sql_query(f"SELECT * FROM joint_command WHERE recording_id = {recording_id} AND stamp >= {sample_timestamp} AND stamp < {sample_timestamp_future}", self.db_connection)





raw_joint_command = pd.read_sql_query(
f"SELECT * FROM joint_command WHERE recording_id = {recording_id} AND stamp >= {sample_timestamp} AND stamp < {sample_timestamp_future}",
self.db_connection,
)

0 comments on commit 75d8398

Please sign in to comment.