Skip to content

Commit

Permalink
Fix AI tests using PASSES_RANDOMLY (#5486)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pawkkie authored Oct 8, 2024
1 parent f8f4fc9 commit efad9a3
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 68 deletions.
14 changes: 11 additions & 3 deletions include/test/battle.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,15 @@ struct AILogLine
s16 score;
};

// Data which is updated by the test runner during a battle and needs to
// be reset between trials.
struct BattleTrialData
{
u8 lastActionTurn;
u8 queuedEvent;
u8 aiActionsPlayed[MAX_BATTLERS_COUNT];
};

struct BattleTestData
{
u8 stack[BATTLE_TEST_STACK_SIZE];
Expand Down Expand Up @@ -676,20 +685,19 @@ struct BattleTestData
u8 battleRecordSourceLineOffsets[MAX_BATTLERS_COUNT][BATTLER_RECORD_SIZE];
u16 recordIndexes[MAX_BATTLERS_COUNT];
struct BattlerTurn battleRecordTurns[MAX_TURNS][MAX_BATTLERS_COUNT];
u8 lastActionTurn;

u8 queuedEventsCount;
u8 queueGroupType;
u8 queueGroupStart;
u8 queuedEvent;
struct QueuedEvent queuedEvents[MAX_QUEUED_EVENTS];
u8 expectedAiActionIndex[MAX_BATTLERS_COUNT];
u8 aiActionsPlayed[MAX_BATTLERS_COUNT];
struct ExpectedAIAction expectedAiActions[MAX_BATTLERS_COUNT][MAX_EXPECTED_ACTIONS];
struct ExpectedAiScore expectedAiScores[MAX_BATTLERS_COUNT][MAX_TURNS][MAX_AI_SCORE_COMPARISION_PER_TURN]; // Max 4 comparisions per turn
struct AILogLine aiLogLines[MAX_BATTLERS_COUNT][MAX_MON_MOVES][MAX_AI_LOG_LINES];
u8 aiLogPrintedForMove[MAX_BATTLERS_COUNT]; // Marks ai score log as printed for move, so the same log isn't displayed multiple times.
u16 flagId;

struct BattleTrialData trial;
};

struct BattleTestRunnerState
Expand Down
14 changes: 2 additions & 12 deletions test/battle/ai/ai_switching.c
Original file line number Diff line number Diff line change
Expand Up @@ -378,28 +378,18 @@ AI_SINGLE_BATTLE_TEST("AI_FLAG_SMART_SWITCHING: AI will switch out if mon would

AI_SINGLE_BATTLE_TEST("Switch AI: AI will switch out if it can't deal damage to a mon with Wonder Guard 66% of the time")
{
u32 aiOmniscientFlag = 0;
PARAMETRIZE { aiOmniscientFlag = 0; }
PARAMETRIZE { aiOmniscientFlag = AI_FLAG_OMNISCIENT; }
PASSES_RANDOMLY(66, 100, RNG_AI_SWITCH_WONDER_GUARD);
GIVEN {
ASSUME(gSpeciesInfo[SPECIES_SHEDINJA].types[0] == TYPE_BUG);
ASSUME(gSpeciesInfo[SPECIES_SHEDINJA].types[1] == TYPE_GHOST);
ASSUME(gMovesInfo[MOVE_TACKLE].type == TYPE_NORMAL);
ASSUME(gMovesInfo[MOVE_SHADOW_BALL].type == TYPE_GHOST);
AI_FLAGS(AI_FLAG_CHECK_BAD_MOVE | AI_FLAG_CHECK_VIABILITY | AI_FLAG_TRY_TO_FAINT | aiOmniscientFlag);
AI_FLAGS(AI_FLAG_CHECK_BAD_MOVE | AI_FLAG_CHECK_VIABILITY | AI_FLAG_TRY_TO_FAINT);
PLAYER(SPECIES_SHEDINJA) { Moves(MOVE_TACKLE); }
OPPONENT(SPECIES_ZIGZAGOON) { Moves(MOVE_TACKLE); }
OPPONENT(SPECIES_ZIGZAGOON) { Moves(MOVE_SHADOW_BALL); }
} WHEN {
if(aiOmniscientFlag == 0) {
TURN { MOVE(player, MOVE_TACKLE) ; EXPECT_MOVE(opponent, MOVE_TACKLE); }
TURN { MOVE(player, MOVE_TACKLE) ; EXPECT_SWITCH(opponent, 1); }
}
else {
TURN { MOVE(player, MOVE_TACKLE) ; EXPECT_SWITCH(opponent, 1); }
}

TURN { MOVE(player, MOVE_TACKLE) ; EXPECT_SWITCH(opponent, 1); }
}
}

Expand Down
105 changes: 52 additions & 53 deletions test/test_runner_battle.c
Original file line number Diff line number Diff line change
Expand Up @@ -567,19 +567,19 @@ void TestRunner_Battle_RecordAbilityPopUp(u32 battlerId, u32 ability)
s32 match;
struct QueuedEvent *event;

if (DATA.queuedEvent == DATA.queuedEventsCount)
if (DATA.trial.queuedEvent == DATA.queuedEventsCount)
return;

event = &DATA.queuedEvents[DATA.queuedEvent];
event = &DATA.queuedEvents[DATA.trial.queuedEvent];
switch (event->groupType)
{
case QUEUE_GROUP_NONE:
case QUEUE_GROUP_ONE_OF:
if (TryAbilityPopUp(DATA.queuedEvent, event->groupSize, battlerId, ability) != -1)
DATA.queuedEvent += event->groupSize;
if (TryAbilityPopUp(DATA.trial.queuedEvent, event->groupSize, battlerId, ability) != -1)
DATA.trial.queuedEvent += event->groupSize;
break;
case QUEUE_GROUP_NONE_OF:
queuedEvent = DATA.queuedEvent;
queuedEvent = DATA.trial.queuedEvent;
do
{
if ((match = TryAbilityPopUp(queuedEvent, event->groupSize, battlerId, ability)) != -1)
Expand All @@ -598,7 +598,7 @@ void TestRunner_Battle_RecordAbilityPopUp(u32 battlerId, u32 ability)
continue;

if (TryAbilityPopUp(queuedEvent, event->groupSize, battlerId, ability) != -1)
DATA.queuedEvent = queuedEvent + event->groupSize;
DATA.trial.queuedEvent = queuedEvent + event->groupSize;
} while (FALSE);
break;
}
Expand Down Expand Up @@ -630,19 +630,19 @@ void TestRunner_Battle_RecordAnimation(u32 animType, u32 animId)
s32 match;
struct QueuedEvent *event;

if (DATA.queuedEvent == DATA.queuedEventsCount)
if (DATA.trial.queuedEvent == DATA.queuedEventsCount)
return;

event = &DATA.queuedEvents[DATA.queuedEvent];
event = &DATA.queuedEvents[DATA.trial.queuedEvent];
switch (event->groupType)
{
case QUEUE_GROUP_NONE:
case QUEUE_GROUP_ONE_OF:
if (TryAnimation(DATA.queuedEvent, event->groupSize, animType, animId) != -1)
DATA.queuedEvent += event->groupSize;
if (TryAnimation(DATA.trial.queuedEvent, event->groupSize, animType, animId) != -1)
DATA.trial.queuedEvent += event->groupSize;
break;
case QUEUE_GROUP_NONE_OF:
queuedEvent = DATA.queuedEvent;
queuedEvent = DATA.trial.queuedEvent;
do
{
if ((match = TryAnimation(queuedEvent, event->groupSize, animType, animId)) != -1)
Expand All @@ -661,7 +661,7 @@ void TestRunner_Battle_RecordAnimation(u32 animType, u32 animId)
continue;

if (TryAnimation(queuedEvent, event->groupSize, animType, animId) != -1)
DATA.queuedEvent = queuedEvent + event->groupSize;
DATA.trial.queuedEvent = queuedEvent + event->groupSize;
} while (FALSE);
break;
}
Expand Down Expand Up @@ -720,19 +720,19 @@ void TestRunner_Battle_RecordHP(u32 battlerId, u32 oldHP, u32 newHP)
s32 match;
struct QueuedEvent *event;

if (DATA.queuedEvent == DATA.queuedEventsCount)
if (DATA.trial.queuedEvent == DATA.queuedEventsCount)
return;

event = &DATA.queuedEvents[DATA.queuedEvent];
event = &DATA.queuedEvents[DATA.trial.queuedEvent];
switch (event->groupType)
{
case QUEUE_GROUP_NONE:
case QUEUE_GROUP_ONE_OF:
if (TryHP(DATA.queuedEvent, event->groupSize, battlerId, oldHP, newHP) != -1)
DATA.queuedEvent += event->groupSize;
if (TryHP(DATA.trial.queuedEvent, event->groupSize, battlerId, oldHP, newHP) != -1)
DATA.trial.queuedEvent += event->groupSize;
break;
case QUEUE_GROUP_NONE_OF:
queuedEvent = DATA.queuedEvent;
queuedEvent = DATA.trial.queuedEvent;
do
{
if ((match = TryHP(queuedEvent, event->groupSize, battlerId, oldHP, newHP)) != -1)
Expand All @@ -751,7 +751,7 @@ void TestRunner_Battle_RecordHP(u32 battlerId, u32 oldHP, u32 newHP)
continue;

if (TryHP(queuedEvent, event->groupSize, battlerId, oldHP, newHP) != -1)
DATA.queuedEvent = queuedEvent + event->groupSize;
DATA.trial.queuedEvent = queuedEvent + event->groupSize;
} while (FALSE);
break;
}
Expand Down Expand Up @@ -782,7 +782,7 @@ static u32 CountAiExpectMoves(struct ExpectedAIAction *expectedAction, u32 battl
void TestRunner_Battle_CheckChosenMove(u32 battlerId, u32 moveId, u32 target)
{
const char *filename = gTestRunnerState.test->filename;
u32 id = DATA.aiActionsPlayed[battlerId];
u32 id = DATA.trial.aiActionsPlayed[battlerId];
struct ExpectedAIAction *expectedAction = &DATA.expectedAiActions[battlerId][id];

if (!expectedAction->actionSet)
Expand Down Expand Up @@ -845,13 +845,13 @@ void TestRunner_Battle_CheckChosenMove(u32 battlerId, u32 moveId, u32 target)
}
// Turn passed, clear logs from the turn
ClearAiLog(battlerId);
DATA.aiActionsPlayed[battlerId]++;
DATA.trial.aiActionsPlayed[battlerId]++;
}

void TestRunner_Battle_CheckSwitch(u32 battlerId, u32 partyIndex)
{
const char *filename = gTestRunnerState.test->filename;
u32 id = DATA.aiActionsPlayed[battlerId];
u32 id = DATA.trial.aiActionsPlayed[battlerId];
struct ExpectedAIAction *expectedAction = &DATA.expectedAiActions[battlerId][id];

if (!expectedAction->actionSet)
Expand All @@ -865,7 +865,7 @@ void TestRunner_Battle_CheckSwitch(u32 battlerId, u32 partyIndex)
if (expectedAction->target != partyIndex)
Test_ExitWithResult(TEST_RESULT_FAIL, SourceLine(0), ":L%s:%d: Expected partyIndex %d, got %d", filename, expectedAction->sourceLine, expectedAction->target, partyIndex);
}
DATA.aiActionsPlayed[battlerId]++;
DATA.trial.aiActionsPlayed[battlerId]++;
}

void TestRunner_Battle_InvalidNoHPMon(u32 battlerId, u32 partyIndex)
Expand Down Expand Up @@ -1029,7 +1029,7 @@ void TestRunner_Battle_CheckAiMoveScores(u32 battlerId)
}

// We need to make sure that the expected move has the best score. We have to rule out a situation where the expected move is used, but it has the same number of points as some other moves.
aiAction = &DATA.expectedAiActions[battlerId][DATA.aiActionsPlayed[battlerId]];
aiAction = &DATA.expectedAiActions[battlerId][DATA.trial.aiActionsPlayed[battlerId]];
if (aiAction->actionSet && !aiAction->pass)
{
s32 target = aiAction->target;
Expand Down Expand Up @@ -1102,19 +1102,19 @@ void TestRunner_Battle_RecordExp(u32 battlerId, u32 oldExp, u32 newExp)
s32 match;
struct QueuedEvent *event;

if (DATA.queuedEvent == DATA.queuedEventsCount)
if (DATA.trial.queuedEvent == DATA.queuedEventsCount)
return;

event = &DATA.queuedEvents[DATA.queuedEvent];
event = &DATA.queuedEvents[DATA.trial.queuedEvent];
switch (event->groupType)
{
case QUEUE_GROUP_NONE:
case QUEUE_GROUP_ONE_OF:
if (TryExp(DATA.queuedEvent, event->groupSize, battlerId, oldExp, newExp) != -1)
DATA.queuedEvent += event->groupSize;
if (TryExp(DATA.trial.queuedEvent, event->groupSize, battlerId, oldExp, newExp) != -1)
DATA.trial.queuedEvent += event->groupSize;
break;
case QUEUE_GROUP_NONE_OF:
queuedEvent = DATA.queuedEvent;
queuedEvent = DATA.trial.queuedEvent;
do
{
if ((match = TryExp(queuedEvent, event->groupSize, battlerId, oldExp, newExp)) != -1)
Expand All @@ -1133,7 +1133,7 @@ void TestRunner_Battle_RecordExp(u32 battlerId, u32 oldExp, u32 newExp)
continue;

if (TryExp(queuedEvent, event->groupSize, battlerId, oldExp, newExp) != -1)
DATA.queuedEvent = queuedEvent + event->groupSize;
DATA.trial.queuedEvent = queuedEvent + event->groupSize;
} while (FALSE);
break;
}
Expand Down Expand Up @@ -1191,19 +1191,19 @@ void TestRunner_Battle_RecordMessage(const u8 *string)
s32 match;
struct QueuedEvent *event;

if (DATA.queuedEvent == DATA.queuedEventsCount)
if (DATA.trial.queuedEvent == DATA.queuedEventsCount)
return;

event = &DATA.queuedEvents[DATA.queuedEvent];
event = &DATA.queuedEvents[DATA.trial.queuedEvent];
switch (event->groupType)
{
case QUEUE_GROUP_NONE:
case QUEUE_GROUP_ONE_OF:
if (TryMessage(DATA.queuedEvent, event->groupSize, string) != -1)
DATA.queuedEvent += event->groupSize;
if (TryMessage(DATA.trial.queuedEvent, event->groupSize, string) != -1)
DATA.trial.queuedEvent += event->groupSize;
break;
case QUEUE_GROUP_NONE_OF:
queuedEvent = DATA.queuedEvent;
queuedEvent = DATA.trial.queuedEvent;
do
{
if ((match = TryMessage(queuedEvent, event->groupSize, string)) != -1)
Expand All @@ -1222,7 +1222,7 @@ void TestRunner_Battle_RecordMessage(const u8 *string)
continue;

if (TryMessage(queuedEvent, event->groupSize, string) != -1)
DATA.queuedEvent = queuedEvent + event->groupSize;
DATA.trial.queuedEvent = queuedEvent + event->groupSize;
} while (FALSE);
break;
}
Expand Down Expand Up @@ -1256,19 +1256,19 @@ void TestRunner_Battle_RecordStatus1(u32 battlerId, u32 status1)
s32 match;
struct QueuedEvent *event;

if (DATA.queuedEvent == DATA.queuedEventsCount)
if (DATA.trial.queuedEvent == DATA.queuedEventsCount)
return;

event = &DATA.queuedEvents[DATA.queuedEvent];
event = &DATA.queuedEvents[DATA.trial.queuedEvent];
switch (event->groupType)
{
case QUEUE_GROUP_NONE:
case QUEUE_GROUP_ONE_OF:
if (TryStatus(DATA.queuedEvent, event->groupSize, battlerId, status1) != -1)
DATA.queuedEvent += event->groupSize;
if (TryStatus(DATA.trial.queuedEvent, event->groupSize, battlerId, status1) != -1)
DATA.trial.queuedEvent += event->groupSize;
break;
case QUEUE_GROUP_NONE_OF:
queuedEvent = DATA.queuedEvent;
queuedEvent = DATA.trial.queuedEvent;
do
{
if ((match = TryStatus(queuedEvent, event->groupSize, battlerId, status1)) != -1)
Expand All @@ -1287,7 +1287,7 @@ void TestRunner_Battle_RecordStatus1(u32 battlerId, u32 status1)
continue;

if (TryStatus(queuedEvent, event->groupSize, battlerId, status1) != -1)
DATA.queuedEvent = queuedEvent + event->groupSize;
DATA.trial.queuedEvent = queuedEvent + event->groupSize;
} while (FALSE);
break;
}
Expand All @@ -1307,22 +1307,22 @@ void TestRunner_Battle_AfterLastTurn(void)
{
const struct BattleTest *test = GetBattleTest();

if (DATA.turns - 1 != DATA.lastActionTurn)
if (DATA.turns - 1 != DATA.trial.lastActionTurn)
{
const char *filename = gTestRunnerState.test->filename;
Test_ExitWithResult(TEST_RESULT_FAIL, SourceLine(0), ":L%s:%d: %d TURNs specified, but %d ran", filename, SourceLine(0), DATA.turns, DATA.lastActionTurn + 1);
Test_ExitWithResult(TEST_RESULT_FAIL, SourceLine(0), ":L%s:%d: %d TURNs specified, but %d ran", filename, SourceLine(0), DATA.turns, DATA.trial.lastActionTurn + 1);
}

while (DATA.queuedEvent < DATA.queuedEventsCount
&& DATA.queuedEvents[DATA.queuedEvent].groupType == QUEUE_GROUP_NONE_OF)
while (DATA.trial.queuedEvent < DATA.queuedEventsCount
&& DATA.queuedEvents[DATA.trial.queuedEvent].groupType == QUEUE_GROUP_NONE_OF)
{
DATA.queuedEvent += DATA.queuedEvents[DATA.queuedEvent].groupSize;
DATA.trial.queuedEvent += DATA.queuedEvents[DATA.trial.queuedEvent].groupSize;
}
if (DATA.queuedEvent != DATA.queuedEventsCount)
if (DATA.trial.queuedEvent != DATA.queuedEventsCount)
{
const char *filename = gTestRunnerState.test->filename;
u32 line = SourceLine(DATA.queuedEvents[DATA.queuedEvent].sourceLineOffset);
const char *macro = sEventTypeMacros[DATA.queuedEvents[DATA.queuedEvent].type];
u32 line = SourceLine(DATA.queuedEvents[DATA.trial.queuedEvent].sourceLineOffset);
const char *macro = sEventTypeMacros[DATA.queuedEvents[DATA.trial.queuedEvent].type];
Test_ExitWithResult(TEST_RESULT_FAIL, line, ":L%s:%d: Unmatched %s", filename, line, macro);
}

Expand Down Expand Up @@ -1395,8 +1395,7 @@ static void CB2_BattleTest_NextTrial(void)
PrintTestName();
gTestRunnerState.result = TEST_RESULT_PASS;
DATA.recordedBattle.rngSeed = MakeRngValue(STATE->runTrial);
DATA.queuedEvent = 0;
DATA.lastActionTurn = 0;
memset(&DATA.trial, 0, sizeof(DATA.trial));
SetVariablesForRecordedBattle(&DATA.recordedBattle);
SetMainCallback2(CB2_InitBattle);
}
Expand Down Expand Up @@ -1892,7 +1891,7 @@ void TestRunner_Battle_CheckBattleRecordActionType(u32 battlerId, u32 recordInde

if (DATA.battleRecordTypes[battlerId][recordIndex] != RECORDED_BYTE)
{
DATA.lastActionTurn = gBattleResults.battleTurnCounter;
DATA.trial.lastActionTurn = gBattleResults.battleTurnCounter;

if (actionType != DATA.battleRecordTypes[battlerId][recordIndex])
{
Expand Down Expand Up @@ -1926,7 +1925,7 @@ void TestRunner_Battle_CheckBattleRecordActionType(u32 battlerId, u32 recordInde
}
else
{
if (DATA.lastActionTurn == gBattleResults.battleTurnCounter)
if (DATA.trial.lastActionTurn == gBattleResults.battleTurnCounter)
{
const char *filename = gTestRunnerState.test->filename;
Test_ExitWithResult(TEST_RESULT_FAIL, SourceLine(0), ":L%s:%d: TURN %d incomplete", filename, SourceLine(0), gBattleResults.battleTurnCounter + 1);
Expand Down

0 comments on commit efad9a3

Please sign in to comment.