Skip to content

Commit

Permalink
adds return data to completions + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
p-ferreira committed Sep 14, 2023
1 parent 1806039 commit a02b864
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
4 changes: 4 additions & 0 deletions openvalidators/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
class EventSchema:
completions: List[str] # List of completions received for a given prompt
completion_times: List[float] # List of completion times for a given prompt
completion_return_messages: List[str] # List of completion return messages for a given prompt
completion_return_codes: List[str] # List of completion return codes for a given prompt
name: str # Prompt type, e.g. 'followup', 'answer'
block: float # Current block at given step
gating_loss: float # Gating model loss for given step
Expand Down Expand Up @@ -95,6 +97,8 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> 'EventSchema':
return EventSchema(
completions=event_dict['completions'],
completion_times=event_dict['completion_times'],
completion_return_messages=event_dict['completion_return_messages'],
completion_return_codes=event_dict['completion_return_codes'],
name=event_dict['name'],
block=event_dict['block'],
gating_loss=event_dict['gating_loss'],
Expand Down
4 changes: 4 additions & 0 deletions openvalidators/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ async def run_step(self, prompt: str, k: int, timeout: float, name: str, exclude

# Get completion times
completion_times: List[float] = [comp.elapsed_time for comp in responses]
completion_return_messages: List[str] = [comp.return_message for comp in responses]
completion_return_codes: List[str] = [comp.return_code for comp in responses]

# Compute forward pass rewards, assumes followup_uids and answer_uids are mutually exclusive.
# shape: [ metagraph.n ]
Expand All @@ -133,6 +135,8 @@ async def run_step(self, prompt: str, k: int, timeout: float, name: str, exclude
"uids": uids.tolist(),
"completions": completions,
"completion_times": completion_times,
"completion_return_messages": completion_return_messages,
"completion_return_codes": completion_return_codes,
"rewards": rewards.tolist(),
"gating_loss": gating_loss.item(),
"best": best,
Expand Down
2 changes: 2 additions & 0 deletions openvalidators/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def reward(

class MockDendriteResponse:
completion = ""
return_message = "Success"
return_code = "1"
elapsed_time = 0
is_success = True
firewall_prompt = FirewallPrompt()
Expand Down
10 changes: 8 additions & 2 deletions tests/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def test_event_from_dict_all_forward_columns_match(self):
event_dict = {
'completions': ['test'],
'completion_times': [0.123],
'completion_return_messages': ['Success'],
'completion_return_codes': ['1'],
'name': 'test-name',
'block': 1.0,
'gating_loss': 1.0,
Expand Down Expand Up @@ -85,7 +87,9 @@ def test_event_from_dict_forward_no_reward_logging(self):
# Assert: create a dictionary with all non-related reward columns
event_dict = {
'completions': ['test'],
'completion_times': [0.123],
'completion_times': [0.123],
'completion_return_messages': ['Success'],
'completion_return_codes': ['1'],
'name': 'test-name',
'block': 1.0,
'gating_loss': 1.0,
Expand Down Expand Up @@ -134,7 +138,9 @@ def test_event_from_dict_forward_reward_logging_mismatch(self):
# Assert: create a dictionary with all non-related reward columns
event_dict = {
'completions': ['test'],
'completion_times': [0.123],
'completion_times': [0.123],
'completion_return_messages': ['Success'],
'completion_return_codes': ['1'],
'name': 'test-name',
'block': 1.0,
'gating_loss': 1.0,
Expand Down

0 comments on commit a02b864

Please sign in to comment.