Skip to content

Commit

Permalink
[Question Answering] Cache support for Model Overview Metrics (#2166)
Browse files Browse the repository at this point in the history
  • Loading branch information
Advitya17 authored Jul 7, 2023
1 parent 4506b0d commit d882424
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 94 deletions.
6 changes: 5 additions & 1 deletion apps/widget/src/app/ModelAssessment.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
};
callBack.requestQuestionAnsweringMetrics = async (
selectionIndexes: number[][],
questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
>,
abortSignal: AbortSignal
): Promise<any[]> => {
return callFlaskService(
this.props.config,
[selectionIndexes],
[selectionIndexes, questionAnsweringCache],
"/get_question_answering_metrics",
abortSignal
);
Expand Down
4 changes: 4 additions & 0 deletions libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ export interface IModelAssessmentContext {
requestQuestionAnsweringMetrics?:
| ((
selectionIndexes: number[][],
questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
>,
abortSignal: AbortSignal
) => Promise<any[]>)
| undefined;
Expand Down
70 changes: 70 additions & 0 deletions libs/core-ui/src/lib/util/ImageStatisticsUtils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { localization } from "@responsible-ai/localization";

import {
ILabeledStatistic,
TotalCohortSamples
} from "../Interfaces/IStatistic";

export enum ImageClassificationMetrics {
Accuracy = "accuracy",
MacroF1 = "f1",
Expand Down Expand Up @@ -56,3 +63,66 @@ export const generateMicroMacroMetrics = (
microScore
};
};

export const generateImageStats: (
trueYs: number[],
predYs: number[]
) => ILabeledStatistic[] = (
trueYs: number[],
predYs: number[]
): ILabeledStatistic[] => {
const correctCount = predYs.filter(
(pred, index) => pred === trueYs[index]
).length;
const accuracy = correctCount / predYs.length;
const precision = generateMicroMacroMetrics(predYs, trueYs);
const microP = precision.microScore;
const macroP = precision.macroScore;
const recall = generateMicroMacroMetrics(trueYs, predYs);
const microR = recall.microScore;
const macroR = recall.macroScore;
const microF1 = 2 * ((microP * microR) / (microP + microR)) || 0;
const macroF1 = 2 * ((macroP * macroR) / (macroP + macroR)) || 0;
return [
{
key: TotalCohortSamples,
label: localization.Interpret.Statistics.samples,
stat: predYs.length
},
{
key: ImageClassificationMetrics.Accuracy,
label: localization.Interpret.Statistics.accuracy,
stat: accuracy
},
{
key: ImageClassificationMetrics.MicroPrecision,
label: localization.Interpret.Statistics.precision,
stat: microP
},
{
key: ImageClassificationMetrics.MicroRecall,
label: localization.Interpret.Statistics.recall,
stat: microR
},
{
key: ImageClassificationMetrics.MicroF1,
label: localization.Interpret.Statistics.f1Score,
stat: microF1
},
{
key: ImageClassificationMetrics.MacroPrecision,
label: localization.Interpret.Statistics.precision,
stat: macroP
},
{
key: ImageClassificationMetrics.MacroRecall,
label: localization.Interpret.Statistics.recall,
stat: macroR
},
{
key: ImageClassificationMetrics.MacroF1,
label: localization.Interpret.Statistics.f1Score,
stat: macroF1
}
];
};
36 changes: 28 additions & 8 deletions libs/core-ui/src/lib/util/QuestionAnsweringStatisticsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,33 @@ export enum QuestionAnsweringMetrics {
}

export const generateQuestionAnsweringStats: (
selectionIndexes: number[][]
selectionIndexes: number[][],
questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
>
) => ILabeledStatistic[][] = (
selectionIndexes: number[][]
selectionIndexes: number[][],
questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
>
): ILabeledStatistic[][] => {
return selectionIndexes.map((selectionArray) => {
const count = selectionArray.length;

const value = questionAnsweringCache.get(selectionArray.toString());
const stat = value
? value
: [
Number.NaN,
Number.NaN,
Number.NaN,
Number.NaN,
Number.NaN,
Number.NaN
];

return [
{
key: TotalCohortSamples,
Expand All @@ -34,32 +54,32 @@ export const generateQuestionAnsweringStats: (
{
key: QuestionAnsweringMetrics.ExactMatchRatio,
label: localization.Interpret.Statistics.exactMatchRatio,
stat: Number.NaN
stat: stat[0]
},
{
key: QuestionAnsweringMetrics.F1Score,
label: localization.Interpret.Statistics.f1Score,
stat: Number.NaN
stat: stat[1]
},
{
key: QuestionAnsweringMetrics.MeteorScore,
label: localization.Interpret.Statistics.meteorScore,
stat: Number.NaN
stat: stat[2]
},
{
key: QuestionAnsweringMetrics.BleuScore,
label: localization.Interpret.Statistics.bleuScore,
stat: Number.NaN
stat: stat[3]
},
{
key: QuestionAnsweringMetrics.BertScore,
label: localization.Interpret.Statistics.bertScore,
stat: Number.NaN
stat: stat[4]
},
{
key: QuestionAnsweringMetrics.RougeScore,
label: localization.Interpret.Statistics.rougeScore,
stat: Number.NaN
stat: stat[5]
}
];
});
Expand Down
87 changes: 15 additions & 72 deletions libs/core-ui/src/lib/util/StatisticsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ import {
} from "../Interfaces/IStatistic";
import { IsBinary } from "../util/ExplanationUtils";

import {
generateMicroMacroMetrics,
ImageClassificationMetrics
} from "./ImageStatisticsUtils";
import { generateImageStats } from "./ImageStatisticsUtils";
import { JointDataset } from "./JointDataset";
import {
ClassificationEnum,
Expand All @@ -28,6 +25,11 @@ import {
RegressionMetrics
} from "./StatisticsUtilsEnums";

type QuestionAnsweringCacheType = Map<
string,
[number, number, number, number, number, number]
>;

const generateBinaryStats: (outcomes: number[]) => ILabeledStatistic[] = (
outcomes: number[]
): ILabeledStatistic[] => {
Expand Down Expand Up @@ -166,91 +168,32 @@ const generateMulticlassStats: (outcomes: number[]) => ILabeledStatistic[] = (
];
};

const generateImageStats: (
trueYs: number[],
predYs: number[]
) => ILabeledStatistic[] = (
trueYs: number[],
predYs: number[]
): ILabeledStatistic[] => {
const correctCount = predYs.filter(
(pred, index) => pred === trueYs[index]
).length;
const accuracy = correctCount / predYs.length;
const precision = generateMicroMacroMetrics(predYs, trueYs);
const microP = precision.microScore;
const macroP = precision.macroScore;
const recall = generateMicroMacroMetrics(trueYs, predYs);
const microR = recall.microScore;
const macroR = recall.macroScore;
const microF1 = 2 * ((microP * microR) / (microP + microR)) || 0;
const macroF1 = 2 * ((macroP * macroR) / (macroP + macroR)) || 0;

return [
{
key: TotalCohortSamples,
label: localization.Interpret.Statistics.samples,
stat: predYs.length
},
{
key: ImageClassificationMetrics.Accuracy,
label: localization.Interpret.Statistics.accuracy,
stat: accuracy
},
{
key: ImageClassificationMetrics.MicroPrecision,
label: localization.Interpret.Statistics.precision,
stat: microP
},
{
key: ImageClassificationMetrics.MicroRecall,
label: localization.Interpret.Statistics.recall,
stat: microR
},
{
key: ImageClassificationMetrics.MicroF1,
label: localization.Interpret.Statistics.f1Score,
stat: microF1
},
{
key: ImageClassificationMetrics.MacroPrecision,
label: localization.Interpret.Statistics.precision,
stat: macroP
},
{
key: ImageClassificationMetrics.MacroRecall,
label: localization.Interpret.Statistics.recall,
stat: macroR
},
{
key: ImageClassificationMetrics.MacroF1,
label: localization.Interpret.Statistics.f1Score,
stat: macroF1
}
];
};

export const generateMetrics: (
jointDataset: JointDataset,
selectionIndexes: number[][],
modelType: ModelTypes,
objectDetectionCache?: Map<string, [number, number, number]>,
objectDetectionInputs?: [string, string, number]
objectDetectionInputs?: [string, string, number],
questionAnsweringCache?: QuestionAnsweringCacheType
) => ILabeledStatistic[][] = (
jointDataset: JointDataset,
selectionIndexes: number[][],
modelType: ModelTypes,
objectDetectionCache?: Map<string, [number, number, number]>,
objectDetectionInputs?: [string, string, number]
objectDetectionInputs?: [string, string, number],
questionAnsweringCache?: QuestionAnsweringCacheType
): ILabeledStatistic[][] => {
if (
modelType === ModelTypes.ImageMultilabel ||
modelType === ModelTypes.TextMultilabel
) {
return generateMultilabelStats(jointDataset, selectionIndexes);
}
if (modelType === ModelTypes.QuestionAnswering) {
return generateQuestionAnsweringStats(selectionIndexes);
if (modelType === ModelTypes.QuestionAnswering && questionAnsweringCache) {
return generateQuestionAnsweringStats(
selectionIndexes,
questionAnsweringCache
);
}
const trueYs = jointDataset.unwrap(JointDataset.TrueYLabel);
const predYs = jointDataset.unwrap(JointDataset.PredictedYLabel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ interface IModelOverviewProps {
objectDetectionCache: Map<string, [number, number, number]>
) => Promise<any[]>;
requestQuestionAnsweringMetrics?: (
selectionIndexes: number[][]
selectionIndexes: number[][],
questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
>
) => Promise<any[]>;
}

Expand Down Expand Up @@ -95,6 +99,10 @@ export class ModelOverview extends React.Component<
IModelOverviewState
> {
public static contextType = ModelAssessmentContext;
public questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
> = new Map();
public objectDetectionCache: Map<string, [number, number, number]> =
new Map();
public context: React.ContextType<typeof ModelAssessmentContext> =
Expand Down Expand Up @@ -610,7 +618,8 @@ export class ModelOverview extends React.Component<
this.state.aggregateMethod,
this.state.className,
this.state.iouThreshold
]
],
this.questionAnsweringCache
);

this.setState({
Expand Down Expand Up @@ -715,6 +724,7 @@ export class ModelOverview extends React.Component<
this.context
.requestQuestionAnsweringMetrics(
selectionIndexes,
this.questionAnsweringCache,
new AbortController().signal
)
.then((result) => {
Expand All @@ -734,6 +744,24 @@ export class ModelOverview extends React.Component<
] of result.entries()) {
const count = selectionIndexes[cohortIndex].length;

if (
!this.questionAnsweringCache.has(
selectionIndexes[cohortIndex].toString()
)
) {
this.questionAnsweringCache.set(
selectionIndexes[cohortIndex].toString(),
[
exactMatchRatio,
f1Score,
meteorScore,
bleuScore,
bertScore,
rougeScore
]
);
}

const updatedCohortMetricStats = [
{
key: TotalCohortSamples,
Expand Down Expand Up @@ -813,7 +841,8 @@ export class ModelOverview extends React.Component<
this.state.aggregateMethod,
this.state.className,
this.state.iouThreshold
]
],
this.questionAnsweringCache
);

this.setState({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ export interface ITabsViewProps {
) => Promise<any[]>;
requestQuestionAnsweringMetrics?: (
selectionIndexes: number[][],
questionAnsweringCache: Map<
string,
[number, number, number, number, number, number]
>,
abortSignal: AbortSignal
) => Promise<any[]>;
requestDebugML?: (
Expand Down
Loading

0 comments on commit d882424

Please sign in to comment.