From 57ce5fb2dda0c17f54f0ebecc8fc13901973f56a Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:43:39 -0700 Subject: [PATCH 1/9] Add mapper and custom launcher --- .../aind_behavior_vr_foraging/data_mappers.py | 231 ++++++++++++++++++ .../aind_behavior_vr_foraging/launcher.py | 183 ++++++++++++++ tests/test_aind_data_mapper.py | 24 ++ 3 files changed, 438 insertions(+) create mode 100644 src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py create mode 100644 src/DataSchemas/aind_behavior_vr_foraging/launcher.py create mode 100644 tests/test_aind_data_mapper.py diff --git a/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py b/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py new file mode 100644 index 0000000..e0fdd7b --- /dev/null +++ b/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py @@ -0,0 +1,231 @@ +import datetime +import logging +import os +from pathlib import Path +from typing import Dict, Optional, Type, TypeVar, Union + +import aind_behavior_services.rig as AbsRig +import aind_data_schema +import aind_data_schema.components.devices +import aind_data_schema.core.session +import git +import pydantic +from aind_behavior_experiment_launcher.data_mappers import data_mapper_service +from aind_behavior_experiment_launcher.records.subject_info import SubjectInfo +from aind_behavior_services.calibration import Calibration +from aind_behavior_services.session import AindBehaviorSessionModel +from aind_behavior_services.utils import model_from_json_file, utcnow + +from aind_behavior_vr_foraging.rig import AindVrForagingRig +from aind_behavior_vr_foraging.task_logic import AindVrForagingTaskLogic + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + + +class VrForagingToAindDataSchemaDataMapper(data_mapper_service.DataMapperService): + def validate(self, *args, **kwargs): + return True + + @classmethod + def map( + cls, + *args, + schema_root: os.PathLike, + session_model: Type[AindBehaviorSessionModel], + rig_model: Type[AindVrForagingRig], + task_logic_model: Type[AindVrForagingTaskLogic], + repository: Union[os.PathLike, git.Repo], + script_path: os.PathLike, + session_end_time: Optional[datetime.datetime] = None, + output_parameters: Optional[Dict] = None, + subject_info: Optional[SubjectInfo] = None, + session_directory: Optional[os.PathLike] = None, + **kwargs, + ) -> Optional[aind_data_schema.core.session.Session]: + logger.info("Mapping to aind-data-schema Session") + try: + ads_session = cls.map_from_session_root( + schema_root=schema_root, + session_model=session_model, + rig_model=rig_model, + task_logic_model=task_logic_model, + repository=repository, + script_path=script_path, + session_end_time=session_end_time, + output_parameters=output_parameters, + subject_info=subject_info, + **kwargs, + ) + if session_directory is not None: + logger.info("Writing session.json to %s", session_directory) + ads_session.write_standard_file(session_directory) + logger.info("Mapping successful.") + except (pydantic.ValidationError, ValueError, IOError) as e: + logger.error("Failed to map to aind-data-schema Session. %s", e) + raise e + else: + return ads_session + + @classmethod + def map_from_session_root( + cls, + schema_root: os.PathLike, + session_model: Type[AindBehaviorSessionModel], + rig_model: Type[AindVrForagingRig], + task_logic_model: Type[AindVrForagingTaskLogic], + repository: Union[os.PathLike, git.Repo], + script_path: os.PathLike, + session_end_time: Optional[datetime.datetime] = None, + output_parameters: Optional[Dict] = None, + subject_info: Optional[SubjectInfo] = None, + **kwargs, + ) -> aind_data_schema.core.session.Session: + return cls._map( + session_model=model_from_json_file(Path(schema_root) / "session_input.json", session_model), + rig_model=model_from_json_file(Path(schema_root) / "rig_input.json", rig_model), + task_logic_model=model_from_json_file(Path(schema_root) / "tasklogic_input.json", task_logic_model), + repository=repository, + script_path=script_path, + session_end_time=session_end_time if session_end_time else utcnow(), + output_parameters=output_parameters, + subject_info=subject_info, + **kwargs, + ) + + @classmethod + def _map( + cls, + session_model: AindBehaviorSessionModel, + rig_model: AindVrForagingRig, + task_logic_model: AindVrForagingTaskLogic, + repository: Union[os.PathLike, git.Repo], + script_path: os.PathLike, + session_end_time: Optional[datetime.datetime] = None, + output_parameters: Optional[Dict] = None, + subject_info: Optional[SubjectInfo] = None, + **kwargs, + ) -> aind_data_schema.core.session.Session: + # Normalize repository + if isinstance(repository, os.PathLike | str): + repository = git.Repo(Path(repository)) + repository_remote_url = repository.remote().url + repository_sha = repository.head.commit.hexsha + repository_relative_script_path = Path(script_path).resolve().relative_to(repository.working_dir) + + # Populate calibrations: + calibrations = [ + cls._mapper_calibration(_calibration_model[1]) + for _calibration_model in data_mapper_service.get_fields_of_type(rig_model, Calibration) + ] + # Populate cameras + cameras = data_mapper_service.get_cameras(rig_model, exclude_without_video_writer=True) + # populate devices + devices = [ + device[0] for device in data_mapper_service.get_fields_of_type(rig_model, AbsRig.Device) if device[0] + ] + # Populate modalities + modalities: list[aind_data_schema.core.session.Modality] = [ + getattr(aind_data_schema.core.session.Modality, "BEHAVIOR") + ] + if len(cameras) > 0: + modalities.append(getattr(aind_data_schema.core.session.Modality, "BEHAVIOR_VIDEOS")) + modalities = list(set(modalities)) + # Populate stimulus modalities + stimulus_modalities: list[aind_data_schema.core.session.StimulusModality] = [] + + if data_mapper_service.get_fields_of_type(rig_model, AbsRig.Screen): + stimulus_modalities.extend( + [ + aind_data_schema.core.session.StimulusModality.VISUAL, + aind_data_schema.core.session.StimulusModality.VIRTUAL_REALITY, + ] + ) + if data_mapper_service.get_fields_of_type(rig_model, AbsRig.HarpOlfactometer): + stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.OLFACTORY) + if data_mapper_service.get_fields_of_type(rig_model, AbsRig.HarpTreadmill): + stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.WHEEL_FRICTION) + + # Mouse platform + mouse_platform: str + if isinstance(rig_model.harp_treadmill, AbsRig.HarpTreadmill): + mouse_platform = "Treadmill" + active_mouse_platform = True + else: + raise ValueError("Mouse platform is of unexpected type.") + + # Reward delivery + reward_delivery_config = aind_data_schema.core.session.RewardDeliveryConfig( + reward_solution=aind_data_schema.core.session.RewardSolution.WATER, reward_spouts=[] + ) + + # Construct aind-data-schema session + aind_data_schema_session = aind_data_schema.core.session.Session( + animal_weight_post=subject_info.animal_weight_post if subject_info else None, + animal_weight_prior=subject_info.animal_weight_prior if subject_info else None, + reward_consumed_total=subject_info.reward_consumed_total if subject_info else None, + reward_delivery=reward_delivery_config, + experimenter_full_name=session_model.experimenter, + session_start_time=session_model.date, + session_end_time=session_end_time, + session_type=session_model.experiment, + rig_id=rig_model.rig_name, + subject_id=session_model.subject, + notes=session_model.notes, + data_streams=[ + aind_data_schema.core.session.Stream( + daq_names=devices, + stream_modalities=modalities, + stream_start_time=session_model.date, + stream_end_time=session_end_time if session_end_time else session_model.date, + camera_names=list(cameras.keys()), + ), + ], + calibrations=calibrations, + mouse_platform_name=mouse_platform, + active_mouse_platform=active_mouse_platform, + stimulus_epochs=[ + aind_data_schema.core.session.StimulusEpoch( + stimulus_name=session_model.experiment, + stimulus_start_time=session_model.date, + stimulus_end_time=session_end_time if session_end_time else session_model.date, + stimulus_modalities=stimulus_modalities, + software=[ + aind_data_schema.core.session.Software( + name="Bonsai", + version=f"{repository_remote_url}/blob/{repository_sha}/bonsai/Bonsai.config", + url=f"{repository_remote_url}/blob/{repository_sha}/bonsai", + parameters=data_mapper_service.snapshot_bonsai_environment( + config_file=kwargs.get("bonsai_config_path", Path("./bonsai/bonsai.config")) + ), + ), + aind_data_schema.core.session.Software( + name="Python", + version=f"{repository_remote_url}/blob/{repository_sha}/pyproject.toml", + url=f"{repository_remote_url}/blob/{repository_sha}", + parameters=data_mapper_service.snapshot_python_environment(), + ), + ], + script=aind_data_schema.core.session.Software( + name=Path(script_path).stem, + version=session_model.commit_hash if session_model.commit_hash else repository_sha, + url=f"{repository_remote_url}/blob/{repository_sha}/{repository_relative_script_path}", + parameters=task_logic_model.model_dump(), + ), + output_parameters=output_parameters if output_parameters else {}, + ) # type: ignore + ], + ) # type: ignore + return aind_data_schema_session + + @staticmethod + def _mapper_calibration(calibration: Calibration) -> aind_data_schema.components.devices.Calibration: + return aind_data_schema.components.devices.Calibration( + device_name=calibration.device_name, + input=calibration.input.model_dump() if calibration.input else {}, + output=calibration.output.model_dump() if calibration.output else {}, + calibration_date=calibration.date if calibration.date else utcnow(), + description=calibration.description if calibration.description else "", + notes=calibration.notes, + ) diff --git a/src/DataSchemas/aind_behavior_vr_foraging/launcher.py b/src/DataSchemas/aind_behavior_vr_foraging/launcher.py new file mode 100644 index 0000000..8c744f6 --- /dev/null +++ b/src/DataSchemas/aind_behavior_vr_foraging/launcher.py @@ -0,0 +1,183 @@ +import logging +from pathlib import Path +from typing import List, Optional, Self, Type + +from aind_behavior_experiment_launcher.apps import app_service +from aind_behavior_experiment_launcher.data_transfer import robocopy_service, watchdog_service +from aind_behavior_experiment_launcher.launcher import Launcher as BaseLauncher +from aind_behavior_experiment_launcher.resource_monitor import resource_monitor_service +from aind_behavior_experiment_launcher.services import Services +from aind_behavior_services.session import AindBehaviorSessionModel +from aind_behavior_services.utils import utcnow + +from aind_behavior_vr_foraging.rig import AindVrForagingRig +from aind_behavior_vr_foraging.task_logic import AindVrForagingTaskLogic + +from .data_mappers import VrForagingToAindDataSchemaDataMapper + +logger = logging.getLogger(__name__) + + +class VrForagingLauncher(BaseLauncher): + rig_schema_model: Type[AindVrForagingRig] + session_schema_model: Type[AindBehaviorSessionModel] + task_logic_schema_model: Type[AindVrForagingTaskLogic] + + def __init__( + self, + rig_schema_model: Type[AindVrForagingRig], + session_schema_model: Type[AindBehaviorSessionModel], + task_logic_schema_model: Type[AindVrForagingTaskLogic], + data_dir, + config_library_dir, + temp_dir=..., + repository_dir=..., + allow_dirty=..., + skip_hardware_validation=..., + debug_mode=..., + logger=..., + group_by_subject_log=..., + services=..., + validate_init=..., + ): + super().__init__( + rig_schema_model, + session_schema_model, + task_logic_schema_model, + data_dir, + config_library_dir, + temp_dir, + repository_dir, + allow_dirty, + skip_hardware_validation, + debug_mode, + logger, + group_by_subject_log, + services, + validate_init, + ) + + def _prompt_session_input(self, directory=...) -> AindBehaviorSessionModel: + session: AindBehaviorSessionModel = super()._prompt_session_input(directory) + session.experimenter = self._prompt_experimenter() + return session + + @staticmethod + def _prompt_experimenter() -> List[str]: + experimenter: Optional[List[str]] = None + while experimenter is None: + _user_input = input("Experimenter name: ") + if _user_input: + if not _user_input == "": + experimenter = _user_input.replace(",", " ").split() + else: + logger.error("Experimenter name is not valid.") + return experimenter + + @staticmethod + def validate_services(services: Services, *args, **kwargs) -> None: + # Validate services + # Bonsai app is required + if services.app is None: + raise ValueError("Bonsai app not set.") + else: + if not isinstance(services.app, app_service.BonsaiApp): + raise ValueError("Bonsai app is not an instance of BonsaiApp.") + if not services.app.validate(): + raise ValueError("Bonsai app failed to validate.") + else: + logger.info("Bonsai app validated.") + + # data_transfer_service is optional + if services.data_transfer is None: + logger.warning("Data transfer service not set.") + else: + if not services.data_transfer.validate(): + raise ValueError("Data transfer service failed to validate.") + else: + logger.info("Data transfer service validated.") + + # Resource monitor service is optional + if services.resource_monitor is None: + logger.warning("Resource monitor service not set.") + else: + if not services.resource_monitor.validate(): + raise ValueError("Resource monitor service failed to validate.") + else: + logger.info("Resource monitor service validated.") + + # Data mapper service is optional + if services.data_mapper is None: + logger.warning("Data mapper service not set.") + else: + if not isinstance(services.data_mapper, VrForagingToAindDataSchemaDataMapper): + raise ValueError("Data mapper service is not an instance of VrForagingToAindDataSchemaDataMapper.") + if not services.data_mapper.validate(): + raise ValueError("Data mapper service failed to validate.") + else: + logger.info("Data mapper service validated.") + + def _post_run_hook(self, *args, **kwargs) -> Self: + self.logger.info("Post-run hook started.") + if self.services.app is None: + raise ValueError("Bonsai app not set.") + self._subject_info = self.subject_info.prompt_field("animal_weight_post", None) + self._subject_info = self.subject_info.prompt_field("reward_consumed_total", None) + try: + self.logger.info("Subject Info: %s", self.subject_info.model_dump_json(indent=4)) + except Exception as e: + self.logger.error("Failed to log subject info. %s", e) + + mapped = None + if self.services.data_mapper is not None: + if not isinstance(self.services.data_mapper, VrForagingToAindDataSchemaDataMapper): + raise ValueError("Data mapper service is not an instance of VrForagingToAindDataSchemaDataMapper.") + try: + mapped = self.services.data_mapper.map( + schema_root=self.session_directory / "Behavior" / "Logs", + session_model=self.session_schema_model, + rig_model=self.rig_schema_model, + task_logic_model=self.task_logic_schema_model, + repository=self.repository, + script_path=Path(self.services.app.workflow).resolve(), + session_end_time=utcnow(), + subject_info=self.subject_info, + session_directory=self.session_directory, + ) + self.logger.info("Mapping successful.") + except Exception as e: + self.logger.error("Data mapper service has failed: %s", e) + + if self.services.data_transfer is not None: + try: + if isinstance(self.services.data_transfer, robocopy_service.RobocopyService): + self.services.data_transfer.transfer( + source=self.session_directory, + destination=self.services.data_transfer.destination / self.session_schema.session_name, + overwrite=False, + force_dir=True, + ) + elif isinstance(self.services.data_transfer, watchdog_service.WatchdogDataTransferService): + self.services.data_transfer.transfer( + session_schema=self.session_schema, + ads_session=mapped, + session_directory=self.session_directory, + ) + else: + raise ValueError( + "Data transfer service is not an instance of RobocopyService or WatchdogDataTransferService." + ) + except Exception as e: + self.logger.error("Data transfer service has failed: %s", e) + return self + + +def default_services_factory(): + return Services( + app=app_service.BonsaiApp(Path(r"./src/vr-foraging.bonsai")), + data_transfer=robocopy_service.RobocopyService(destination=Path(r"\\allen\aind\scratch\vr-foraging\data")), + resource_monitor=resource_monitor_service.ResourceMonitor( + constrains=[resource_monitor_service.available_storage_constraint_factory(drive=r"C:\\", min_bytes=2e11)] + ), + data_mapper=None, + ) diff --git a/tests/test_aind_data_mapper.py b/tests/test_aind_data_mapper.py new file mode 100644 index 0000000..92bcb1e --- /dev/null +++ b/tests/test_aind_data_mapper.py @@ -0,0 +1,24 @@ +import sys + +sys.path.append(".") +import datetime +import unittest +from pathlib import Path + +from aind_behavior_vr_foraging.data_mappers import VrForagingToAindDataSchemaDataMapper + +from examples.examples import mock_rig, mock_session, mock_task_logic + + +class AindServicesTests(unittest.TestCase): + def test_session_mapper(self): + VrForagingToAindDataSchemaDataMapper.map( + schema_root=None, + session_model=mock_session(), + rig_model=mock_rig(), + task_logic_model=mock_task_logic(), + session_end_time=datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc), + repository=Path("./"), + script_path=Path("./src/unit_test.bonsai"), + bonsai_config_path=Path("./tests/assets/bonsai.config").resolve(), + ) From 4899488ab3cec8a15df1dda8f0cff58f66c88afd Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:43:59 -0700 Subject: [PATCH 2/9] Temporary fix for testing purposes --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4d7ff89..8368514 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ linters = [ 'codespell' ] -launcher = ["aind_behavior_experiment_launcher[aind-services]<0.2.0"] +launcher = ["aind_behavior_experiment_launcher[aind-services]@git+https://github.com/AllenNeuralDynamics/Aind.Behavior.ExperimentLauncher@main"] docs = [ 'Sphinx<7.3', From 645ed27275946a46a5e329b03de35f0806930978 Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 17 Oct 2024 19:33:38 -0700 Subject: [PATCH 3/9] Rely on base launcher experimenter prompt method --- .../aind_behavior_vr_foraging/launcher.py | 53 +++++++------------ 1 file changed, 20 insertions(+), 33 deletions(-) diff --git a/src/DataSchemas/aind_behavior_vr_foraging/launcher.py b/src/DataSchemas/aind_behavior_vr_foraging/launcher.py index 8c744f6..4dd1956 100644 --- a/src/DataSchemas/aind_behavior_vr_foraging/launcher.py +++ b/src/DataSchemas/aind_behavior_vr_foraging/launcher.py @@ -1,4 +1,5 @@ import logging +from os import PathLike from pathlib import Path from typing import List, Optional, Self, Type @@ -25,43 +26,29 @@ class VrForagingLauncher(BaseLauncher): def __init__( self, - rig_schema_model: Type[AindVrForagingRig], - session_schema_model: Type[AindBehaviorSessionModel], - task_logic_schema_model: Type[AindVrForagingTaskLogic], - data_dir, - config_library_dir, - temp_dir=..., - repository_dir=..., - allow_dirty=..., - skip_hardware_validation=..., - debug_mode=..., - logger=..., - group_by_subject_log=..., - services=..., - validate_init=..., + data_dir: PathLike, + config_library_dir: PathLike, + allow_dirty: bool = False, + skip_hardware_validation: bool = False, + debug_mode: bool = False, + services: Services | None = None, + group_by_subject_log: bool = True, + **kwargs, ): super().__init__( - rig_schema_model, - session_schema_model, - task_logic_schema_model, - data_dir, - config_library_dir, - temp_dir, - repository_dir, - allow_dirty, - skip_hardware_validation, - debug_mode, - logger, - group_by_subject_log, - services, - validate_init, + rig_schema_model=AindVrForagingRig, + session_schema_model=AindBehaviorSessionModel, + task_logic_schema_model=AindVrForagingTaskLogic, + data_dir=data_dir, + config_library_dir=config_library_dir, + allow_dirty=allow_dirty, + skip_hardware_validation=skip_hardware_validation, + debug_mode=debug_mode, + services=services, + group_by_subject_log=group_by_subject_log, + **kwargs, ) - def _prompt_session_input(self, directory=...) -> AindBehaviorSessionModel: - session: AindBehaviorSessionModel = super()._prompt_session_input(directory) - session.experimenter = self._prompt_experimenter() - return session - @staticmethod def _prompt_experimenter() -> List[str]: experimenter: Optional[List[str]] = None From 594539b1b7cef80eb16d7fd48822a12691e41ddf Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Sat, 26 Oct 2024 14:15:53 -0700 Subject: [PATCH 4/9] Make regenerate and launcher entry-points of package --- pyproject.toml | 4 + scripts/regenerate.cmd | 6 - scripts/regenerate.ps1 | 4 - .../aind_behavior_vr_foraging/launcher.py | 199 ++++-------------- .../aind_behavior_vr_foraging}/regenerate.py | 5 +- 5 files changed, 47 insertions(+), 171 deletions(-) delete mode 100644 scripts/regenerate.cmd delete mode 100644 scripts/regenerate.ps1 rename {scripts => src/DataSchemas/aind_behavior_vr_foraging}/regenerate.py (99%) diff --git a/pyproject.toml b/pyproject.toml index 8368514..7bc35c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,10 @@ docs = [ 'sphinx-jsonschema' ] +[project.scripts] +clabe = "aind_behavior_vr_foraging.launcher:main" +regenerate = "aind_behavior_vr_foraging.regenerate:main" + [tool.setuptools.packages.find] where = ["src/DataSchemas"] diff --git a/scripts/regenerate.cmd b/scripts/regenerate.cmd deleted file mode 100644 index 7179544..0000000 --- a/scripts/regenerate.cmd +++ /dev/null @@ -1,6 +0,0 @@ -@echo off -setlocal -set "scriptPath=%~dp0" -set "pythonScriptPath=%scriptPath%regenerate.ps1" -powershell -ExecutionPolicy Bypass -File "%pythonScriptPath%" -endlocal diff --git a/scripts/regenerate.ps1 b/scripts/regenerate.ps1 deleted file mode 100644 index 732f62b..0000000 --- a/scripts/regenerate.ps1 +++ /dev/null @@ -1,4 +0,0 @@ -$scriptPath = Split-Path -Parent $MyInvocation.MyCommand.Path -Set-Location -Path (Split-Path -Parent $scriptPath) -.\.venv\Scripts\Activate.ps1 -& python .\scripts\regenerate.py \ No newline at end of file diff --git a/src/DataSchemas/aind_behavior_vr_foraging/launcher.py b/src/DataSchemas/aind_behavior_vr_foraging/launcher.py index 4dd1956..5e247e0 100644 --- a/src/DataSchemas/aind_behavior_vr_foraging/launcher.py +++ b/src/DataSchemas/aind_behavior_vr_foraging/launcher.py @@ -1,170 +1,51 @@ -import logging -from os import PathLike -from pathlib import Path -from typing import List, Optional, Self, Type - -from aind_behavior_experiment_launcher.apps import app_service -from aind_behavior_experiment_launcher.data_transfer import robocopy_service, watchdog_service -from aind_behavior_experiment_launcher.launcher import Launcher as BaseLauncher -from aind_behavior_experiment_launcher.resource_monitor import resource_monitor_service -from aind_behavior_experiment_launcher.services import Services +import aind_behavior_experiment_launcher.launcher.behavior_launcher as behavior_launcher +from aind_behavior_experiment_launcher.apps.app_service import BonsaiApp +from aind_behavior_experiment_launcher.resource_monitor.resource_monitor_service import ( + ResourceMonitor, + available_storage_constraint_factory, + remote_dir_exists_constraint_factory, +) from aind_behavior_services.session import AindBehaviorSessionModel -from aind_behavior_services.utils import utcnow from aind_behavior_vr_foraging.rig import AindVrForagingRig from aind_behavior_vr_foraging.task_logic import AindVrForagingTaskLogic -from .data_mappers import VrForagingToAindDataSchemaDataMapper - -logger = logging.getLogger(__name__) - - -class VrForagingLauncher(BaseLauncher): - rig_schema_model: Type[AindVrForagingRig] - session_schema_model: Type[AindBehaviorSessionModel] - task_logic_schema_model: Type[AindVrForagingTaskLogic] - - def __init__( - self, - data_dir: PathLike, - config_library_dir: PathLike, - allow_dirty: bool = False, - skip_hardware_validation: bool = False, - debug_mode: bool = False, - services: Services | None = None, - group_by_subject_log: bool = True, - **kwargs, - ): - super().__init__( - rig_schema_model=AindVrForagingRig, - session_schema_model=AindBehaviorSessionModel, - task_logic_schema_model=AindVrForagingTaskLogic, - data_dir=data_dir, - config_library_dir=config_library_dir, - allow_dirty=allow_dirty, - skip_hardware_validation=skip_hardware_validation, - debug_mode=debug_mode, - services=services, - group_by_subject_log=group_by_subject_log, - **kwargs, - ) - - @staticmethod - def _prompt_experimenter() -> List[str]: - experimenter: Optional[List[str]] = None - while experimenter is None: - _user_input = input("Experimenter name: ") - if _user_input: - if not _user_input == "": - experimenter = _user_input.replace(",", " ").split() - else: - logger.error("Experimenter name is not valid.") - return experimenter - @staticmethod - def validate_services(services: Services, *args, **kwargs) -> None: - # Validate services - # Bonsai app is required - if services.app is None: - raise ValueError("Bonsai app not set.") - else: - if not isinstance(services.app, app_service.BonsaiApp): - raise ValueError("Bonsai app is not an instance of BonsaiApp.") - if not services.app.validate(): - raise ValueError("Bonsai app failed to validate.") - else: - logger.info("Bonsai app validated.") - - # data_transfer_service is optional - if services.data_transfer is None: - logger.warning("Data transfer service not set.") - else: - if not services.data_transfer.validate(): - raise ValueError("Data transfer service failed to validate.") - else: - logger.info("Data transfer service validated.") - - # Resource monitor service is optional - if services.resource_monitor is None: - logger.warning("Resource monitor service not set.") - else: - if not services.resource_monitor.validate(): - raise ValueError("Resource monitor service failed to validate.") - else: - logger.info("Resource monitor service validated.") - - # Data mapper service is optional - if services.data_mapper is None: - logger.warning("Data mapper service not set.") - else: - if not isinstance(services.data_mapper, VrForagingToAindDataSchemaDataMapper): - raise ValueError("Data mapper service is not an instance of VrForagingToAindDataSchemaDataMapper.") - if not services.data_mapper.validate(): - raise ValueError("Data mapper service failed to validate.") - else: - logger.info("Data mapper service validated.") +def make_launcher() -> behavior_launcher.BehaviorLauncher: + data_dir = r"C:/Data" + remote_dir = r"\\allen\aind\scratch\vr-foraging\data" + srv = behavior_launcher.BehaviorServicesFactoryManager() + srv.bonsai_app = BonsaiApp(r"./src/vr-foraging.bonsai") + srv.data_transfer = behavior_launcher.robocopy_data_transfer_factory(remote_dir) + srv.resource_monitor = ResourceMonitor( + constrains=[ + available_storage_constraint_factory(data_dir, 2e11), + remote_dir_exists_constraint_factory(remote_dir), + ] + ) - def _post_run_hook(self, *args, **kwargs) -> Self: - self.logger.info("Post-run hook started.") - if self.services.app is None: - raise ValueError("Bonsai app not set.") - self._subject_info = self.subject_info.prompt_field("animal_weight_post", None) - self._subject_info = self.subject_info.prompt_field("reward_consumed_total", None) - try: - self.logger.info("Subject Info: %s", self.subject_info.model_dump_json(indent=4)) - except Exception as e: - self.logger.error("Failed to log subject info. %s", e) + return behavior_launcher.BehaviorLauncher( + rig_schema_model=AindVrForagingRig, + session_schema_model=AindBehaviorSessionModel, + task_logic_schema_model=AindVrForagingTaskLogic, + data_dir=data_dir, + config_library_dir=r"\\allen\aind\scratch\AindBehavior.db\AindVrForaging", + temp_dir=r"./local/.temp", + repository_dir=None, + allow_dirty=False, + skip_hardware_validation=False, + debug_mode=False, + group_by_subject_log=True, + services=srv, + validate_init=True, + ) - mapped = None - if self.services.data_mapper is not None: - if not isinstance(self.services.data_mapper, VrForagingToAindDataSchemaDataMapper): - raise ValueError("Data mapper service is not an instance of VrForagingToAindDataSchemaDataMapper.") - try: - mapped = self.services.data_mapper.map( - schema_root=self.session_directory / "Behavior" / "Logs", - session_model=self.session_schema_model, - rig_model=self.rig_schema_model, - task_logic_model=self.task_logic_schema_model, - repository=self.repository, - script_path=Path(self.services.app.workflow).resolve(), - session_end_time=utcnow(), - subject_info=self.subject_info, - session_directory=self.session_directory, - ) - self.logger.info("Mapping successful.") - except Exception as e: - self.logger.error("Data mapper service has failed: %s", e) - if self.services.data_transfer is not None: - try: - if isinstance(self.services.data_transfer, robocopy_service.RobocopyService): - self.services.data_transfer.transfer( - source=self.session_directory, - destination=self.services.data_transfer.destination / self.session_schema.session_name, - overwrite=False, - force_dir=True, - ) - elif isinstance(self.services.data_transfer, watchdog_service.WatchdogDataTransferService): - self.services.data_transfer.transfer( - session_schema=self.session_schema, - ads_session=mapped, - session_directory=self.session_directory, - ) - else: - raise ValueError( - "Data transfer service is not an instance of RobocopyService or WatchdogDataTransferService." - ) - except Exception as e: - self.logger.error("Data transfer service has failed: %s", e) - return self +def main(): + launcher = make_launcher() + launcher.main() + return None -def default_services_factory(): - return Services( - app=app_service.BonsaiApp(Path(r"./src/vr-foraging.bonsai")), - data_transfer=robocopy_service.RobocopyService(destination=Path(r"\\allen\aind\scratch\vr-foraging\data")), - resource_monitor=resource_monitor_service.ResourceMonitor( - constrains=[resource_monitor_service.available_storage_constraint_factory(drive=r"C:\\", min_bytes=2e11)] - ), - data_mapper=None, - ) +if __name__ == "__main__": + main() diff --git a/scripts/regenerate.py b/src/DataSchemas/aind_behavior_vr_foraging/regenerate.py similarity index 99% rename from scripts/regenerate.py rename to src/DataSchemas/aind_behavior_vr_foraging/regenerate.py index b9db9e2..279045f 100644 --- a/scripts/regenerate.py +++ b/src/DataSchemas/aind_behavior_vr_foraging/regenerate.py @@ -1,8 +1,6 @@ import inspect from pathlib import Path -import aind_behavior_vr_foraging.rig -import aind_behavior_vr_foraging.task_logic from aind_behavior_services.session import AindBehaviorSessionModel from aind_behavior_services.utils import ( convert_pydantic_to_bonsai, @@ -10,6 +8,9 @@ snake_to_pascal_case, ) +import aind_behavior_vr_foraging.rig +import aind_behavior_vr_foraging.task_logic + SCHEMA_ROOT = Path("./src/DataSchemas/") EXTENSIONS_ROOT = Path("./src/Extensions/") NAMESPACE_PREFIX = "AindVrForagingDataSchema" From b1d4a29d884b6767d03fe4c6e91ce6ad9277c93d Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Sat, 26 Oct 2024 14:16:00 -0700 Subject: [PATCH 5/9] Regenerate schemas --- src/DataSchemas/aind_behavior_session_model.json | 2 +- src/DataSchemas/aind_vr_foraging_rig.json | 2 +- src/DataSchemas/aind_vr_foraging_task_logic.json | 2 +- src/Extensions/AindBehaviorSessionModel.cs | 2 +- src/Extensions/AindVrForagingRig.cs | 2 +- src/Extensions/AindVrForagingTaskLogic.cs | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/DataSchemas/aind_behavior_session_model.json b/src/DataSchemas/aind_behavior_session_model.json index e900c6a..fd53b1d 100644 --- a/src/DataSchemas/aind_behavior_session_model.json +++ b/src/DataSchemas/aind_behavior_session_model.json @@ -1,7 +1,7 @@ { "properties": { "aind_behavior_services_pkg_version": { - "default": "0.8.1", + "default": "0.8.2", "pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$", "title": "aind_behavior_services package version", "type": "string" diff --git a/src/DataSchemas/aind_vr_foraging_rig.json b/src/DataSchemas/aind_vr_foraging_rig.json index 35c4fb5..829b0f9 100644 --- a/src/DataSchemas/aind_vr_foraging_rig.json +++ b/src/DataSchemas/aind_vr_foraging_rig.json @@ -2107,7 +2107,7 @@ }, "properties": { "aind_behavior_services_pkg_version": { - "default": "0.8.1", + "default": "0.8.2", "pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$", "title": "aind_behavior_services package version", "type": "string" diff --git a/src/DataSchemas/aind_vr_foraging_task_logic.json b/src/DataSchemas/aind_vr_foraging_task_logic.json index 7b6fe21..aa02256 100644 --- a/src/DataSchemas/aind_vr_foraging_task_logic.json +++ b/src/DataSchemas/aind_vr_foraging_task_logic.json @@ -17,7 +17,7 @@ "title": "Rng Seed" }, "aind_behavior_services_pkg_version": { - "default": "0.8.1", + "default": "0.8.2", "pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$", "title": "aind_behavior_services package version", "type": "string" diff --git a/src/Extensions/AindBehaviorSessionModel.cs b/src/Extensions/AindBehaviorSessionModel.cs index d7d5891..3e1c7e6 100644 --- a/src/Extensions/AindBehaviorSessionModel.cs +++ b/src/Extensions/AindBehaviorSessionModel.cs @@ -15,7 +15,7 @@ namespace AindVrForagingDataSchema.Session public partial class AindBehaviorSessionModel { - private string _aindBehaviorServicesPkgVersion = "0.8.1"; + private string _aindBehaviorServicesPkgVersion = "0.8.2"; private string _version = "0.3.0"; diff --git a/src/Extensions/AindVrForagingRig.cs b/src/Extensions/AindVrForagingRig.cs index 5ce1d97..645830c 100644 --- a/src/Extensions/AindVrForagingRig.cs +++ b/src/Extensions/AindVrForagingRig.cs @@ -5385,7 +5385,7 @@ public override string ToString() public partial class AindVrForagingRig { - private string _aindBehaviorServicesPkgVersion = "0.8.1"; + private string _aindBehaviorServicesPkgVersion = "0.8.2"; private string _version = "0.4.0"; diff --git a/src/Extensions/AindVrForagingTaskLogic.cs b/src/Extensions/AindVrForagingTaskLogic.cs index 46e6350..d4e2b4a 100644 --- a/src/Extensions/AindVrForagingTaskLogic.cs +++ b/src/Extensions/AindVrForagingTaskLogic.cs @@ -17,7 +17,7 @@ public partial class AindVrForagingTaskParameters private double? _rngSeed; - private string _aindBehaviorServicesPkgVersion = "0.8.1"; + private string _aindBehaviorServicesPkgVersion = "0.8.2"; private System.Collections.Generic.IDictionary _updaters; From 0d77e6291a8dea5b2928ed0af58d688183ae9a0f Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 31 Oct 2024 22:10:29 -0700 Subject: [PATCH 6/9] Add aind rig and session mapper --- .../aind_behavior_vr_foraging/data_mappers.py | 326 ++++++++++++++---- 1 file changed, 263 insertions(+), 63 deletions(-) diff --git a/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py b/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py index e0fdd7b..47fa6b1 100644 --- a/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py +++ b/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py @@ -2,65 +2,124 @@ import logging import os from pathlib import Path -from typing import Dict, Optional, Type, TypeVar, Union +from typing import Dict, List, Optional, Self, Type, TypeVar, Union import aind_behavior_services.rig as AbsRig import aind_data_schema +import aind_data_schema.base +import aind_data_schema.components.coordinates import aind_data_schema.components.devices +import aind_data_schema.components.stimulus import aind_data_schema.core.session import git import pydantic from aind_behavior_experiment_launcher.data_mappers import data_mapper_service -from aind_behavior_experiment_launcher.records.subject_info import SubjectInfo +from aind_behavior_experiment_launcher.launcher.behavior_launcher import BehaviorLauncher +from aind_behavior_experiment_launcher.records.subject import WaterLogResult from aind_behavior_services.calibration import Calibration +from aind_behavior_services.calibration.olfactometer import OlfactometerChannelType from aind_behavior_services.session import AindBehaviorSessionModel from aind_behavior_services.utils import model_from_json_file, utcnow +from aind_data_schema.core.rig import Rig +from pydantic import BaseModel from aind_behavior_vr_foraging.rig import AindVrForagingRig from aind_behavior_vr_foraging.task_logic import AindVrForagingTaskLogic +TFrom = TypeVar("TFrom", bound=Union[BaseModel, dict]) +TTo = TypeVar("TTo", bound=BaseModel) + T = TypeVar("T") logger = logging.getLogger(__name__) +_DATABASE_DIR = "AindDataSchemaRig" -class VrForagingToAindDataSchemaDataMapper(data_mapper_service.DataMapperService): - def validate(self, *args, **kwargs): - return True - @classmethod - def map( - cls, - *args, - schema_root: os.PathLike, - session_model: Type[AindBehaviorSessionModel], - rig_model: Type[AindVrForagingRig], - task_logic_model: Type[AindVrForagingTaskLogic], +class AindRigDataMapper(data_mapper_service.DataMapperService): + def __init__( + self, + *, + rig_schema_filename: str, + db_root: os.PathLike, + destination_dir: os.PathLike, + db_suffix: Optional[str] = None, + ): + super().__init__() + self.filename = rig_schema_filename + self.db_root = db_root + self.db_dir = db_suffix if db_suffix else f"{_DATABASE_DIR}/{os.environ['COMPUTERNAME']}" + self.target_file = Path(self.db_root) / self.db_dir / self.filename + self.destination_dir = destination_dir + self._mapped: Optional[Rig] = None + + def validate(self): + file_exists = self.target_file.exists() + if not file_exists: + raise FileNotFoundError(f"File {self.target_file} does not exist.") + return file_exists + + def map(self) -> Rig: + self._mapped = model_from_json_file(self.target_file, Rig) + return self.mapped + + @property + def mapped(self) -> Rig: + if self._mapped is None: + raise ValueError("Data has not been mapped yet.") + return self._mapped + + def write_standard_file(self) -> None: + self.mapped.write_standard_file(self.destination_dir) + + +class AindSessionDataMapper(data_mapper_service.DataMapperService): + def __init__( + self, + session_model: AindBehaviorSessionModel, + rig_model: AindVrForagingRig, + task_logic_model: AindVrForagingTaskLogic, repository: Union[os.PathLike, git.Repo], script_path: os.PathLike, session_end_time: Optional[datetime.datetime] = None, output_parameters: Optional[Dict] = None, - subject_info: Optional[SubjectInfo] = None, + subject_info: Optional[WaterLogResult] = None, session_directory: Optional[os.PathLike] = None, - **kwargs, - ) -> Optional[aind_data_schema.core.session.Session]: + ): + self.session_model = session_model + self.rig_model = rig_model + self.task_logic_model = task_logic_model + self.session_directory = session_directory + self.repository = repository + self.script_path = script_path + self.session_end_time = session_end_time + self.output_parameters = output_parameters + self.subject_info = subject_info + self.mapped: Optional[aind_data_schema.core.session.Session] = None + + def validate(self, *args, **kwargs) -> bool: + return True + + def is_mapped(self) -> bool: + return self.mapped is not None + + def map(self) -> Optional[aind_data_schema.core.session.Session]: logger.info("Mapping to aind-data-schema Session") try: - ads_session = cls.map_from_session_root( - schema_root=schema_root, - session_model=session_model, - rig_model=rig_model, - task_logic_model=task_logic_model, - repository=repository, - script_path=script_path, - session_end_time=session_end_time, - output_parameters=output_parameters, - subject_info=subject_info, - **kwargs, + ads_session = self._map( + session_model=self.session_model, + rig_model=self.rig_model, + task_logic_model=self.task_logic_model, + repository=self.repository, + script_path=self.script_path, + session_end_time=self.session_end_time, + output_parameters=self.output_parameters, + subject_info=self.subject_info, ) - if session_directory is not None: - logger.info("Writing session.json to %s", session_directory) - ads_session.write_standard_file(session_directory) + self.mapped = ads_session + if self.session_directory is not None: + logger.info("Writing session.json to %s", self.session_directory) + ads_session.write_standard_file(self.session_directory) logger.info("Mapping successful.") except (pydantic.ValidationError, ValueError, IOError) as e: logger.error("Failed to map to aind-data-schema Session. %s", e) @@ -79,13 +138,42 @@ def map_from_session_root( script_path: os.PathLike, session_end_time: Optional[datetime.datetime] = None, output_parameters: Optional[Dict] = None, - subject_info: Optional[SubjectInfo] = None, - **kwargs, - ) -> aind_data_schema.core.session.Session: - return cls._map( + subject_info: Optional[WaterLogResult] = None, + ) -> Self: + return cls( session_model=model_from_json_file(Path(schema_root) / "session_input.json", session_model), rig_model=model_from_json_file(Path(schema_root) / "rig_input.json", rig_model), task_logic_model=model_from_json_file(Path(schema_root) / "tasklogic_input.json", task_logic_model), + session_directory=schema_root, + repository=repository, + script_path=script_path, + session_end_time=session_end_time if session_end_time else utcnow(), + output_parameters=output_parameters, + subject_info=subject_info, + ) + + @classmethod + def map_from_json_files( + cls, + session_json: os.PathLike, + rig_json: os.PathLike, + task_logic_json: os.PathLike, + session_model: Type[AindBehaviorSessionModel], + rig_model: Type[AindVrForagingRig], + task_logic_model: Type[AindVrForagingTaskLogic], + repository: Union[os.PathLike, git.Repo], + script_path: os.PathLike, + session_end_time: Optional[datetime.datetime], + session_directory: Optional[os.PathLike] = None, + output_parameters: Optional[Dict] = None, + subject_info: Optional[WaterLogResult] = None, + **kwargs, + ) -> Self: + return cls( + session_model=model_from_json_file(session_json, session_model), + rig_model=model_from_json_file(rig_json, rig_model), + task_logic_model=model_from_json_file(task_logic_json, task_logic_model), + session_directory=session_directory, repository=repository, script_path=script_path, session_end_time=session_end_time if session_end_time else utcnow(), @@ -104,7 +192,7 @@ def _map( script_path: os.PathLike, session_end_time: Optional[datetime.datetime] = None, output_parameters: Optional[Dict] = None, - subject_info: Optional[SubjectInfo] = None, + subject_info: Optional[WaterLogResult] = None, **kwargs, ) -> aind_data_schema.core.session.Session: # Normalize repository @@ -115,15 +203,14 @@ def _map( repository_relative_script_path = Path(script_path).resolve().relative_to(repository.working_dir) # Populate calibrations: - calibrations = [ - cls._mapper_calibration(_calibration_model[1]) - for _calibration_model in data_mapper_service.get_fields_of_type(rig_model, Calibration) - ] + calibrations = [cls._mapper_calibration(rig_model.calibration.water_valve)] # Populate cameras cameras = data_mapper_service.get_cameras(rig_model, exclude_without_video_writer=True) # populate devices devices = [ - device[0] for device in data_mapper_service.get_fields_of_type(rig_model, AbsRig.Device) if device[0] + device[0] + for device in data_mapper_service.get_fields_of_type(rig_model, AbsRig.HarpDeviceGeneric) + if device[0] ] # Populate modalities modalities: list[aind_data_schema.core.session.Modality] = [ @@ -134,37 +221,108 @@ def _map( modalities = list(set(modalities)) # Populate stimulus modalities stimulus_modalities: list[aind_data_schema.core.session.StimulusModality] = [] + stimulation_parameters: List[ + aind_data_schema.core.session.AuditoryStimulation + | aind_data_schema.core.session.OlfactoryStimulation + | aind_data_schema.core.session.VisualStimulation + ] = [] + stimulation_devices: List[str] = [] + # Olfactory Stimulation + stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.OLFACTORY) + olfactory_stimulus_channel_config: List[aind_data_schema.components.stimulus.OlfactometerChannelConfig] = [] + for _, channel in rig_model.harp_olfactometer.calibration.input.channel_config.items(): + if channel.channel_type == OlfactometerChannelType.ODOR: + olfactory_stimulus_channel_config.append( + coerce_to_aind_data_schema(channel, aind_data_schema.components.stimulus.OlfactometerChannelConfig) + ) + stimulation_parameters.append( + aind_data_schema.core.session.OlfactoryStimulation( + stimulus_name="Olfactory", channels=olfactory_stimulus_channel_config + ) + ) + + _olfactory_device = data_mapper_service.get_fields_of_type(rig_model, AbsRig.HarpOlfactometer) + if len(_olfactory_device) > 0: + if _olfactory_device[0][0]: + stimulation_devices.append(_olfactory_device[0][0]) + else: + logger.error("Olfactometer device not found in rig model.") + raise ValueError("Olfactometer device not found in rig model.") + + # Auditory Stimulation + stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.AUDITORY) + + stimulation_parameters.append( + aind_data_schema.core.session.AuditoryStimulation(sitmulus_name="Beep", sample_frequency=0) + ) + speaker_config = aind_data_schema.core.session.SpeakerConfig(name="Speaker", volume=60) + stimulation_devices.append("speaker") + # Visual/VR Stimulation + stimulus_modalities.extend( + [ + aind_data_schema.core.session.StimulusModality.VISUAL, + aind_data_schema.core.session.StimulusModality.VIRTUAL_REALITY, + ] + ) - if data_mapper_service.get_fields_of_type(rig_model, AbsRig.Screen): - stimulus_modalities.extend( - [ - aind_data_schema.core.session.StimulusModality.VISUAL, - aind_data_schema.core.session.StimulusModality.VIRTUAL_REALITY, - ] + stimulation_parameters.append( + aind_data_schema.core.session.VisualStimulation( + stimulus_name="VrScreen", + stimulus_parameters={}, ) - if data_mapper_service.get_fields_of_type(rig_model, AbsRig.HarpOlfactometer): - stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.OLFACTORY) - if data_mapper_service.get_fields_of_type(rig_model, AbsRig.HarpTreadmill): - stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.WHEEL_FRICTION) + ) + _screen_device = data_mapper_service.get_fields_of_type(rig_model, AbsRig.Screen) + if len(_screen_device) > 0: + if _screen_device[0][0]: + stimulation_devices.append(_screen_device[0][0]) + else: + logger.error("Screen device not found in rig model.") + raise ValueError("Screen device not found in rig model.") + stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.WHEEL_FRICTION) # Mouse platform - mouse_platform: str - if isinstance(rig_model.harp_treadmill, AbsRig.HarpTreadmill): - mouse_platform = "Treadmill" - active_mouse_platform = True - else: - raise ValueError("Mouse platform is of unexpected type.") + mouse_platform: str = "wheel" # Reward delivery + if rig_model.manipulator.calibration is None: + logger.error("Manipulator calibration is not set.") + raise ValueError("Manipulator calibration is not set.") + initial_position = rig_model.manipulator.calibration.input.initial_position reward_delivery_config = aind_data_schema.core.session.RewardDeliveryConfig( - reward_solution=aind_data_schema.core.session.RewardSolution.WATER, reward_spouts=[] + reward_solution=aind_data_schema.core.session.RewardSolution.WATER, + reward_spouts=[ + aind_data_schema.core.session.RewardSpoutConfig( + side=aind_data_schema.components.devices.SpoutSide.CENTER, + variable_position=True, + starting_position=aind_data_schema.components.devices.RelativePosition( + device_position_transformations=[ + aind_data_schema.components.coordinates.Translation3dTransform( + translation=[initial_position.x, initial_position.y2, initial_position.z] + ) + ], + device_origin="Manipulator home", + device_axes=[ + aind_data_schema.components.coordinates.Axis( + name=aind_data_schema.components.coordinates.AxisName.X, direction="Left" + ), + aind_data_schema.components.coordinates.Axis( + name=aind_data_schema.components.coordinates.AxisName.Y, direction="Front" + ), + aind_data_schema.components.coordinates.Axis( + name=aind_data_schema.components.coordinates.AxisName.Z, direction="Top" + ), + ], + ), + ) + ], ) + end_time = datetime.datetime.now() + # Construct aind-data-schema session aind_data_schema_session = aind_data_schema.core.session.Session( - animal_weight_post=subject_info.animal_weight_post if subject_info else None, - animal_weight_prior=subject_info.animal_weight_prior if subject_info else None, - reward_consumed_total=subject_info.reward_consumed_total if subject_info else None, + animal_weight_post=subject_info.weight_g if subject_info else None, + reward_consumed_total=subject_info.water_earned_ml if subject_info else None, reward_delivery=reward_delivery_config, experimenter_full_name=session_model.experimenter, session_start_time=session_model.date, @@ -178,19 +336,20 @@ def _map( daq_names=devices, stream_modalities=modalities, stream_start_time=session_model.date, - stream_end_time=session_end_time if session_end_time else session_model.date, + stream_end_time=session_end_time if session_end_time else end_time, camera_names=list(cameras.keys()), ), ], calibrations=calibrations, mouse_platform_name=mouse_platform, - active_mouse_platform=active_mouse_platform, + active_mouse_platform=True, stimulus_epochs=[ aind_data_schema.core.session.StimulusEpoch( stimulus_name=session_model.experiment, stimulus_start_time=session_model.date, - stimulus_end_time=session_end_time if session_end_time else session_model.date, + stimulus_end_time=session_end_time if session_end_time else end_time, stimulus_modalities=stimulus_modalities, + stimulus_parameters=stimulation_parameters, software=[ aind_data_schema.core.session.Software( name="Bonsai", @@ -214,6 +373,9 @@ def _map( parameters=task_logic_model.model_dump(), ), output_parameters=output_parameters if output_parameters else {}, + speaker_config=speaker_config, + reward_consumed_during_epoch=subject_info.total_water_ml if subject_info else None, + stimulus_device_names=stimulation_devices, ) # type: ignore ], ) # type: ignore @@ -229,3 +391,41 @@ def _mapper_calibration(calibration: Calibration) -> aind_data_schema.components description=calibration.description if calibration.description else "", notes=calibration.notes, ) + + +def coerce_to_aind_data_schema(value: TFrom, target_type: Type[TTo]) -> TTo: + _normalized_input: dict + if isinstance(value, BaseModel): + _normalized_input = value.model_dump() + elif isinstance(value, dict): + _normalized_input = value + else: + raise ValueError(f"Expected value to be a BaseModel or a dict, got {type(value)}") + target_fields = target_type.model_fields + _normalized_input = {k: v for k, v in _normalized_input.items() if k in target_fields} + return target_type(**_normalized_input) + + +def aind_session_data_mapper_factory(launcher: BehaviorLauncher) -> AindSessionDataMapper: + now = utcnow() + return AindSessionDataMapper( + session_model=launcher.session_schema, + rig_model=launcher.rig_schema, + task_logic_model=launcher.task_logic_schema, + repository=launcher.repository, + script_path=launcher.services_factory_manager.bonsai_app.workflow, + session_directory=launcher.session_directory, + session_end_time=now, + ) + + +def aind_rig_data_mapper_factory( + launcher: BehaviorLauncher[AindVrForagingRig, AindBehaviorSessionModel, AindVrForagingTaskLogic], +) -> AindRigDataMapper: + rig_schema: AindVrForagingRig = launcher.rig_schema + return AindRigDataMapper( + rig_schema_filename=rig_schema.rig_name, + db_suffix=f"{_DATABASE_DIR}/{launcher.computer_name}", + db_root=launcher.config_library_dir, + destination_dir=launcher.session_directory, + ) From 549a57801d64105d087b57bd23c38976130ddcd2 Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 31 Oct 2024 22:11:58 -0700 Subject: [PATCH 7/9] Make olfactometer calibration required --- src/DataSchemas/aind_behavior_vr_foraging/rig.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DataSchemas/aind_behavior_vr_foraging/rig.py b/src/DataSchemas/aind_behavior_vr_foraging/rig.py index 346068f..a392122 100644 --- a/src/DataSchemas/aind_behavior_vr_foraging/rig.py +++ b/src/DataSchemas/aind_behavior_vr_foraging/rig.py @@ -39,7 +39,7 @@ class AindManipulatorDevice(aind_manipulator.AindManipulatorDevice): class HarpOlfactometer(rig.HarpOlfactometer): """Overrides the default settings for the olfactometer calibration""" - calibration: Optional[oc.OlfactometerCalibration] = Field(default=None, description="Olfactometer calibration") + calibration: oc.OlfactometerCalibration = Field(default=None, description="Olfactometer calibration") class RigCalibration(BaseModel): From 832392db4c70d64dd59f8b998c56f53e67a648da Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 31 Oct 2024 22:18:23 -0700 Subject: [PATCH 8/9] Add tests for aind_data_mapper --- tests/__init__.py | 5 ++ tests/test_aind_data_mapper.py | 108 ++++++++++++++++++++++++++++----- tests/test_bonsai.py | 13 ++-- 3 files changed, 105 insertions(+), 21 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 2bc553f..f7760d5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,11 +1,16 @@ import glob import importlib.util +import logging from pathlib import Path from types import ModuleType EXAMPLES_DIR = Path(__file__).parents[1] / "examples" JSON_ROOT = Path("./local").resolve() +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) +logging.disable(logging.CRITICAL) + def build_example(script_path: str) -> ModuleType: module_name = Path(script_path).stem diff --git a/tests/test_aind_data_mapper.py b/tests/test_aind_data_mapper.py index 92bcb1e..300ff6a 100644 --- a/tests/test_aind_data_mapper.py +++ b/tests/test_aind_data_mapper.py @@ -1,24 +1,100 @@ import sys - -sys.path.append(".") -import datetime import unittest +from datetime import datetime from pathlib import Path +from unittest.mock import MagicMock, patch -from aind_behavior_vr_foraging.data_mappers import VrForagingToAindDataSchemaDataMapper +from aind_behavior_vr_foraging.data_mappers import ( + AindBehaviorSessionModel, + AindRigDataMapper, + AindSessionDataMapper, + AindVrForagingRig, + AindVrForagingTaskLogic, +) +from aind_data_schema.core.rig import Rig +from git import Repo + +sys.path.append(".") +from examples.examples import mock_rig, mock_session, mock_task_logic # isort:skip # pylint: disable=wrong-import-position -from examples.examples import mock_rig, mock_session, mock_task_logic +class TestAindSessionDataMapper(unittest.TestCase): + def setUp(self): + self.session_model = mock_session() + self.rig_model = mock_rig() + self.task_logic_model = mock_task_logic() + self.repository = Repo(Path("./")) + self.script_path = Path("./src/vr-foraging.bonsai") + self.session_end_time = datetime.now() + self.session_directory = Path("./") -class AindServicesTests(unittest.TestCase): - def test_session_mapper(self): - VrForagingToAindDataSchemaDataMapper.map( - schema_root=None, - session_model=mock_session(), - rig_model=mock_rig(), - task_logic_model=mock_task_logic(), - session_end_time=datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc), - repository=Path("./"), - script_path=Path("./src/unit_test.bonsai"), - bonsai_config_path=Path("./tests/assets/bonsai.config").resolve(), + self.mapper = AindSessionDataMapper( + session_model=self.session_model, + rig_model=self.rig_model, + task_logic_model=self.task_logic_model, + repository=self.repository, + script_path=self.script_path, + session_end_time=self.session_end_time, + session_directory=self.session_directory, ) + + def test_validate(self): + self.assertTrue(self.mapper.validate()) + + @patch("aind_behavior_vr_foraging.data_mappers.logger") + @patch("aind_behavior_vr_foraging.data_mappers.AindSessionDataMapper._map") + def test_mock_map(self, mock_map, mock_logger): + mock_map.return_value = MagicMock() + result = self.mapper.map() + self.assertIsNotNone(result) + self.assertTrue(self.mapper.is_mapped()) + mock_logger.info.assert_called_with("Mapping successful.") + + def test_map(self): + mapped = self.mapper.map() + self.assertIsNotNone(mapped) + + @patch("aind_behavior_vr_foraging.data_mappers.model_from_json_file") + def test_map_from_json_files(self, mock_model_from_json_file): + mock_model_from_json_file.side_effect = [self.session_model, self.rig_model, self.task_logic_model] + session_json = MagicMock() + rig_json = MagicMock() + task_logic_json = MagicMock() + mapper = AindSessionDataMapper.map_from_json_files( + session_json=session_json, + rig_json=rig_json, + task_logic_json=task_logic_json, + session_model=AindBehaviorSessionModel, + rig_model=AindVrForagingRig, + task_logic_model=AindVrForagingTaskLogic, + repository=self.repository, + script_path=self.script_path, + session_end_time=self.session_end_time, + ) + self.assertIsInstance(mapper, AindSessionDataMapper) + + +class TestAindRigDataMapper(unittest.TestCase): + def setUp(self): + self.rig_schema_filename = "rig_schema.json" + self.db_root = MagicMock() + self.destination_dir = MagicMock() + self.db_suffix = "test_suffix" + self.mapper = AindRigDataMapper( + rig_schema_filename=self.rig_schema_filename, + db_root=self.db_root, + destination_dir=self.destination_dir, + db_suffix=self.db_suffix, + ) + + @patch("aind_behavior_vr_foraging.data_mappers.model_from_json_file") + def test_mock_map(self, mock_model_from_json_file): + mock_model_from_json_file.return_value = MagicMock(spec=Rig) + result = self.mapper.map() + self.assertIsNotNone(result) + self.assertTrue(self.mapper.mapped) + self.assertIsInstance(result, Rig) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bonsai.py b/tests/test_bonsai.py index c67b93e..74d716d 100644 --- a/tests/test_bonsai.py +++ b/tests/test_bonsai.py @@ -1,6 +1,7 @@ import os import sys import unittest +import warnings from pathlib import Path from typing import Generic, List, Optional, TypeVar, Union @@ -43,11 +44,13 @@ def test_deserialization(self): stdout = completed_proc.stdout.decode().split("\n") stdout = [line for line in stdout if (line or line != "")] - for model in models_to_test: - try: - model.try_deserialization(stdout) - except ValueError: - self.fail(f"Could not find a match for {model.input_model.__class__.__name__}.") + with warnings.catch_warnings(): # suppress the warnings relative to the coercion of version across schemas + warnings.simplefilter("ignore") + for model in models_to_test: + try: + model.try_deserialization(stdout) + except ValueError: + self.fail(f"Could not find a match for {model.input_model.__class__.__name__}.") class TestModel(Generic[TModel]): From fba987705577a3e7990367f6953cc027ef3cadca Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 31 Oct 2024 22:23:07 -0700 Subject: [PATCH 9/9] Get launcher from pypi --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7bc35c5..1cb978c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aind-behavior-vr-foraging" -description = "A library that defines AIND data schema for the Aind Behavior VR Foraing experiment." +description = "A library that defines AIND data schema for the Aind Behavior VR Foraging experiment." authors = [ {name = "Bruno Cruz", email = "bruno.cruz@alleninstitute.org"}] license = {text = "MIT"} requires-python = ">=3.11" @@ -15,7 +15,7 @@ readme = "README.md" dynamic = ["version"] dependencies = [ - "aind_behavior_services>=0.8.0", + "aind_behavior_services>=0.8, <0.9", ] [project.optional-dependencies] @@ -25,7 +25,7 @@ linters = [ 'codespell' ] -launcher = ["aind_behavior_experiment_launcher[aind-services]@git+https://github.com/AllenNeuralDynamics/Aind.Behavior.ExperimentLauncher@main"] +launcher = ["aind_behavior_experiment_launcher[aind-services]>=0.2.0rc4"] docs = [ 'Sphinx<7.3',