Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui): add TEXT_GENERATION model create #211

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions ui/src/components/modals/add-new-model/index.jsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import {
RbitModal, SectionTitle, Steps,
} from '@radicalbit/radicalbit-design-system';
import useModals from '@Src/hooks/use-modals';
import { RbitModal, SectionTitle, Steps } from '@radicalbit/radicalbit-design-system';
import { ModelTypeEnum } from '@Src/store/state/models/constants';
import ModalContextProvider, { useModalContext } from './modal-context-provider';
import ActionsStepFourth from './step-four/actions';
import BodyStepFour from './step-four/new-body';
Expand All @@ -9,6 +12,10 @@ import ActionsStepThree from './step-three/actions';
import BodyStepThree from './step-three/body';
import ActionsStepTwo from './step-two/actions';
import BodyStepTwo from './step-two/body';
import {
TextGenerationActionButton, TextGenerationBody,
TextGenerationHeader,
} from './text-generation';

const { Step } = Steps;

Expand Down Expand Up @@ -59,7 +66,9 @@ function Header() {
step,
} = useModalContext();

const { isFormInvalid: isFormInvalidStepOne } = useFormbitStepOne;
const { isFormInvalid: isFormInvalidStepOne, form } = useFormbitStepOne;
const modelType = form?.modelType;

const { isFormInvalid: isFormInvalidStepTwo } = useFormbitStepTwo;
const { isFormInvalid: isFormInvalidStepThree } = useFormbitStepThree;
const { isFormInvalid: isFormInvalidStepFour } = useFormbitStepFour;
Expand All @@ -69,21 +78,27 @@ function Header() {
|| (step === 2 && isFormInvalidStepThree())
|| (step === 3 && isFormInvalidStepFour()) ? 'error' : undefined;

return (
<div className="flex flex-col gap-4">
<SectionTitle title="New Model" />
switch (modelType) {
case ModelTypeEnum.TEXT_GENERATION:
return (<TextGenerationHeader />);

default:
return (
<div className="flex flex-col gap-4">
<SectionTitle title="New Model" />

<Steps className="w-3/4 self-center" current={step} direction="horizontal" status={stepStatus}>
<Step title="Registry" />
<Steps className="w-3/4 self-center" current={step} direction="horizontal" status={stepStatus}>
<Step title="Registry" />

<Step title="Schema" />
<Step title="Schema" />

<Step title="Fields" />
<Step title="Fields" />

<Step title="Target" />
</Steps>
</div>
);
<Step title="Target" />
</Steps>
</div>
);
}
}

function Subtitles() {
Expand Down Expand Up @@ -146,6 +161,21 @@ function Subtitles() {
}

function Body() {
const { useFormbitStepOne } = useModalContext();

const { form } = useFormbitStepOne;
const modelType = form?.modelType;

switch (modelType) {
case ModelTypeEnum.TEXT_GENERATION:
return <TextGenerationBody />;

default:
return <BodyInner />;
}
}

function BodyInner() {
const { step } = useModalContext();

switch (step) {
Expand All @@ -165,6 +195,21 @@ function Body() {
}

function Actions() {
const { useFormbitStepOne } = useModalContext();

const { form } = useFormbitStepOne;
const modelType = form?.modelType;

switch (modelType) {
case ModelTypeEnum.TEXT_GENERATION:
return (<TextGenerationActionButton />);

default:
return <ActionsInner />;
}
}

function ActionsInner() {
const { step } = useModalContext();

switch (step) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ function ModelType() {
ModelTypeEnum.BINARY_CLASSIFICATION,
ModelTypeEnum.MULTI_CLASSIFICATION,
ModelTypeEnum.REGRESSION,
ModelTypeEnum.TEXT_GENERATION,
];

return (
Expand Down
148 changes: 148 additions & 0 deletions ui/src/components/modals/add-new-model/text-generation/form-fields.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import useAutoFocus from '@Src/hooks/use-auto-focus';
import {
DataTypeEnum, DataTypeEnumLabel, GranularityEnum, GranularityEnumLabel,
ModelTypeEnum,
ModelTypeEnumLabel,
} from '@State/models/constants';
import {
FormField,
Input,
Select,
} from '@radicalbit/radicalbit-design-system';
import { useRef } from 'react';
import { useModalContext } from '../modal-context-provider';
import useHandleOnSubmit from './use-handle-on-submit';

function Name() {
const ref = useRef(null);

const { handleOnSubmit } = useHandleOnSubmit();
const { useFormbit } = useModalContext();
const { form, error, write } = useFormbit;

const handleOnChange = ({ target: { value } }) => {
write('name', value);
};

useAutoFocus(ref);

return (
<FormField label="Name" message={error('name')} modifier="w-full" required>
<Input
onChange={handleOnChange}
onPressEnter={handleOnSubmit}
ref={ref}
value={form.name}
/>
</FormField>
);
}

function ModelType() {
const { useFormbit } = useModalContext();
const { form, write } = useFormbit;

const handleOnChange = (value) => {
write('modelType', value);
};

const modelTypeSelections = [
ModelTypeEnum.BINARY_CLASSIFICATION,
ModelTypeEnum.MULTI_CLASSIFICATION,
ModelTypeEnum.REGRESSION,
ModelTypeEnum.TEXT_GENERATION,
];

return (
<FormField label="Model type" modifier="w-full" required>
<Select onChange={handleOnChange} value={form.modelType}>
{Object.values(modelTypeSelections).map((value) => (
<Select.Option key={value}>
{ModelTypeEnumLabel[value]}
</Select.Option>
))}
</Select>
</FormField>
);
}

function DataType() {
const { useFormbit } = useModalContext();
const { error } = useFormbit;

return (
<FormField label="Data type" message={error('dataType')} modifier="w-full" required>
{DataTypeEnumLabel[DataTypeEnum.TEXT]}
</FormField>
);
}

function Granularity() {
const { useFormbit } = useModalContext();
const { form, error, write } = useFormbit;

const handleOnChange = (value) => {
write('granularity', value);
};

return (
<FormField label="Granularity" message={error('granularity')} modifier="w-full" required>
<Select onChange={handleOnChange} value={form.granularity}>
{Object.values(GranularityEnum).map((value) => (
<Select.Option key={value}>
{GranularityEnumLabel[value]}
</Select.Option>
))}
</Select>
</FormField>
);
}

function Framework() {
const { handleOnSubmit } = useHandleOnSubmit();
const { useFormbit } = useModalContext();
const { form, error, write } = useFormbit;

const handleOnChange = ({ target: { value } }) => {
write('frameworks', value);
};

return (
<FormField label="Framework" message={error('frameworks')} modifier="w-full">
<Input
onChange={handleOnChange}
onPressEnter={handleOnSubmit}
value={form.frameworks}
/>
</FormField>
);
}

function Algorithm() {
const { handleOnSubmit } = useHandleOnSubmit();
const { useFormbit } = useModalContext();
const { form, error, write } = useFormbit;

const handleOnChange = ({ target: { value } }) => {
write('algorithm', value);
};

return (
<FormField label="Algorithm" message={error('algorithm')} modifier="w-full">
<Input
onChange={handleOnChange}
onPressEnter={handleOnSubmit}
value={form.algorithm}
/>
</FormField>
);
}

export {
Algorithm,
DataType,
Framework,
Granularity,
ModelType,
Name,
};
59 changes: 59 additions & 0 deletions ui/src/components/modals/add-new-model/text-generation/index.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { Button, FormField, SectionTitle } from '@radicalbit/radicalbit-design-system';
import {
Algorithm, DataType, Framework, Granularity, ModelType, Name,
} from './form-fields';
import { useModalContext } from '../modal-context-provider';
import useHandleOnSubmit from './use-handle-on-submit';

function TextGenerationHeader() {
return <SectionTitle title="New Model" />;
}

function TextGenerationBody() {
const { useFormbit } = useModalContext();
const { error } = useFormbit;

return (
<div className="flex flex-row justify-center">
<div className="flex flex-col gap-4 w-full max-w-[400px] items-center">
<Name />

<div className="flex flex-row gap-4 w-full">
<ModelType />

<DataType />
</div>

<Granularity />

<Framework />

<Algorithm />

<FormField message={error('silent.backend')} />
</div>
</div>

);
}

function TextGenerationActionButton() {
const { handleOnSubmit, args, isSubmitDisabled } = useHandleOnSubmit();

return (
<>
<div />

<Button
disabled={isSubmitDisabled}
loading={args.isLoading}
onClick={handleOnSubmit}
type="primary"
>
Save Model
</Button>
</>
);
}

export { TextGenerationHeader, TextGenerationBody, TextGenerationActionButton };
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { globalConfigSliceActions } from '@State/global-configuration/slice';
import { modelsApiSlice } from '@State/models/api';
import { useDispatch } from 'react-redux';
import { useNavigate } from 'react-router';
import { useSearchParams } from 'react-router-dom';
import { DataTypeEnum } from '@Src/store/state/models/constants';
import { useModalContext } from '../modal-context-provider';

const { useAddNewModelMutation } = modelsApiSlice;

export default () => {
const navigate = useNavigate();
const [searchParams] = useSearchParams();
const dispatch = useDispatch();

const { useFormbit, useFormbitStepOne } = useModalContext();

const { isDirty, isFormInvalid, submitForm } = useFormbit;
const { form: formStepOne } = useFormbitStepOne;

const [triggerAddNewModel, args] = useAddNewModelMutation();

const isSubmitDisabled = !isDirty || isFormInvalid();

const handleOnSubmit = () => {
if (isSubmitDisabled || args.isLoading) {
return;
}

submitForm(async (_, setError) => {
const {
name, algorithm, frameworks, modelType, granularity,
} = formStepOne;

const response = await triggerAddNewModel({
name,
dataType: DataTypeEnum.TEXT,
modelType,
granularity,
algorithm,
frameworks,
});

if (response.error) {
console.error(response.error);
setError('silent.backed', response.error);
return;
}

const newModelUUID = response.data.uuid;

searchParams.delete('modal');
dispatch(globalConfigSliceActions.addModelToShowConfettiList(newModelUUID));

navigate({ pathname: `models/${response.data.uuid}`, search: searchParams.toString() });
});
};

return { handleOnSubmit, args, isSubmitDisabled };
};
Loading