Skip to content

Commit

Permalink
feat: update wizard (#54)
Browse files Browse the repository at this point in the history
* feat: enable all modelType

* fix: remove unused import

* fix: valid types

* refactor: add empty array for readability

---------

Co-authored-by: Luca Tagliabue <[email protected]>
  • Loading branch information
lucataglia and Luca Tagliabue authored Jul 1, 2024
1 parent ccb6b79 commit 379ee62
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { DataTypeEnum, ModelTypeEnum } from '@State/models/constants';
import { DataTypeEnum } from '@State/models/constants';
import useFormbit from '@radicalbit/formbit';
import {
createContext,
Expand All @@ -18,7 +18,7 @@ function ModalContextProvider({ children }) {
const [isMaximize, setIsMaximize] = useState(false);

const useFormbitStepOne = useFormbit({
initialValues: { modelType: ModelTypeEnum.BINARY_CLASSIFICATION, dataType: DataTypeEnum.TABULAR },
initialValues: { dataType: DataTypeEnum.TABULAR },
yup: schemaStepOne,
});

Expand Down
25 changes: 18 additions & 7 deletions ui/src/components/modals/add-new-model/step-four/form-fields.jsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { ModelTypeEnum } from '@Src/store/state/models/constants';
import {
FormField,
Select,
Tooltip,
} from '@radicalbit/radicalbit-design-system';
import { ModelTypeEnum } from '@Src/store/state/models/constants';
import { useModalContext } from '../modal-context-provider';

function Target() {
Expand Down Expand Up @@ -196,15 +196,18 @@ function Prediction() {
}

function Probability() {
const { useFormbit } = useModalContext();
const { useFormbit, useFormbitStepOne } = useModalContext();

const {
form, error, write, remove,
} = useFormbit;

const probabilities = useGetProbabilities();
const predictionName = form?.prediction?.name;
const value = form?.predictionProba?.name;

const { form: formStepOne } = useFormbitStepOne;
const { modelType } = formStepOne;

const handleOnChange = (val) => {
if (val === undefined) {
remove('predictionProba');
Expand All @@ -218,6 +221,14 @@ function Probability() {
}
};

if (modelType === ModelTypeEnum.REGRESSION) {
return (
<FormField label="Probability" modifier="w-full">
<Select disabled readOnly value="Not available for Regression" />
</FormField>
);
}

return (
<FormField
label="Probability"
Expand Down Expand Up @@ -267,7 +278,7 @@ function Probability() {

const targetValidTypes = {
[ModelTypeEnum.BINARY_CLASSIFICATION]: ['int', 'float', 'double'],
[ModelTypeEnum.MULTI_CLASSIFICATION]: ['int', 'float', 'double'],
[ModelTypeEnum.MULTI_CLASSIFICATION]: ['int', 'float', 'double', 'string'],
[ModelTypeEnum.REGRESSION]: ['int', 'float', 'double'],
};
const useGetTargets = () => {
Expand Down Expand Up @@ -295,10 +306,10 @@ const useGetPredictions = () => {
return form.outputs.filter(({ type }) => predictionValidTypes[modelType].includes(type));
};

const binaryClassificationProbabilityValidTypes = {
const probabilityValidTypes = {
[ModelTypeEnum.BINARY_CLASSIFICATION]: ['float', 'double'],
[ModelTypeEnum.MULTI_CLASSIFICATION]: ['float', 'double'],
[ModelTypeEnum.REGRESSION]: ['float', 'double'],
[ModelTypeEnum.REGRESSION]: [],
};
const useGetProbabilities = () => {
const { useFormbitStepOne, useFormbit } = useModalContext();
Expand All @@ -307,7 +318,7 @@ const useGetProbabilities = () => {
const { form: formStepOne } = useFormbitStepOne;
const { modelType } = formStepOne;

return form.outputs.filter(({ type }) => binaryClassificationProbabilityValidTypes[modelType].includes(type));
return form.outputs.filter(({ type }) => probabilityValidTypes[modelType].includes(type));
};

const timestampValidTypes = {
Expand Down
15 changes: 14 additions & 1 deletion ui/src/components/modals/add-new-model/step-one/form-fields.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,22 @@ function Name() {
}

function ModelType() {
const { useFormbit } = useModalContext();
const { form, write } = useFormbit;

const handleOnChange = (value) => {
write('modelType', value);
};

return (
<FormField label="Model type" modifier="w-full" required>
{ModelTypeEnumLabel[ModelTypeEnum.BINARY_CLASSIFICATION]}
<Select onChange={handleOnChange} value={form.modelType}>
{Object.values(ModelTypeEnum).map((value) => (
<Select.Option key={value}>
{ModelTypeEnumLabel[value]}
</Select.Option>
))}
</Select>
</FormField>
);
}
Expand Down

0 comments on commit 379ee62

Please sign in to comment.