diff --git a/src/router/ProtectedRoute.tsx b/src/router/ProtectedRoute.tsx index 799184a43..974604683 100644 --- a/src/router/ProtectedRoute.tsx +++ b/src/router/ProtectedRoute.tsx @@ -2,6 +2,7 @@ import { useContext, useEffect } from 'react' import { Outlet, useLocation } from 'react-router-dom' import { AuthContext } from 'react-oauth2-code-pkce' import { AppLoader } from '../sections/shared/layout/app-loader/AppLoader' +import { encodeReturnToPathInStateQueryParam } from '@/sections/auth-callback/AuthCallback' /** * This component is responsible for protecting routes that require authentication. @@ -10,16 +11,18 @@ import { AppLoader } from '../sections/shared/layout/app-loader/AppLoader' */ export const ProtectedRoute = () => { - const { pathname } = useLocation() + const { pathname, search } = useLocation() const { token, loginInProgress, logIn: oidcLogin } = useContext(AuthContext) useEffect(() => { if (loginInProgress) return if (!token) { - oidcLogin(encodeURIComponent(pathname)) + const state = encodeReturnToPathInStateQueryParam(`${pathname}${search}`) + + oidcLogin(state) } - }, [token, oidcLogin, pathname, loginInProgress]) + }, [token, oidcLogin, pathname, loginInProgress, search]) if (loginInProgress) { return diff --git a/src/sections/auth-callback/AuthCallback.tsx b/src/sections/auth-callback/AuthCallback.tsx index 66f03cb7e..5f3b0862a 100644 --- a/src/sections/auth-callback/AuthCallback.tsx +++ b/src/sections/auth-callback/AuthCallback.tsx @@ -4,6 +4,8 @@ import { AuthContext } from 'react-oauth2-code-pkce' import { QueryParamKey } from '../Route.enum' import { AppLoader } from '../shared/layout/app-loader/AppLoader' +export type AuthStateQueryParamValue = { returnTo: string } + /** * This component will we rendered as redirectUri page after the OIDC login is complete. * It will redirect the user to the intended page before the OIDC login was initiated. @@ -25,8 +27,33 @@ export const AuthCallback = () => { return } - navigate(decodeURIComponent(stateQueryParam), { replace: true }) + const returnToPath = decodeReturnToPathFromStateQueryParam(stateQueryParam) + + navigate(returnToPath, { replace: true }) }, [stateQueryParam, navigate, loginInProgress]) return } + +export const encodeReturnToPathInStateQueryParam = (returnToPath: string): string => { + const returnToObject: AuthStateQueryParamValue = { returnTo: returnToPath } + + return encodeURIComponent(JSON.stringify(returnToObject)) +} + +export const decodeReturnToPathFromStateQueryParam = (stateQueryParam: string): string => { + const decodedStateQueryParam = decodeURIComponent(stateQueryParam) + const parsedStateQueryParam = JSON.parse(decodedStateQueryParam) as unknown + + if (isReturnToObject(parsedStateQueryParam)) { + return parsedStateQueryParam.returnTo + } + + return '/' +} + +function isReturnToObject(obj: unknown): obj is AuthStateQueryParamValue { + return ( + obj !== null && typeof obj === 'object' && 'returnTo' in obj && typeof obj.returnTo === 'string' + ) +} diff --git a/src/sections/layout/header/Header.tsx b/src/sections/layout/header/Header.tsx index 7538a7b60..e39a50a12 100644 --- a/src/sections/layout/header/Header.tsx +++ b/src/sections/layout/header/Header.tsx @@ -8,6 +8,7 @@ import { Route } from '@/sections/Route.enum' import { useSession } from '@/sections/session/SessionContext' import { LoggedInHeaderActions } from './LoggedInHeaderActions' import { CollectionJSDataverseRepository } from '@/collection/infrastructure/repositories/CollectionJSDataverseRepository' +import { encodeReturnToPathInStateQueryParam } from '@/sections/auth-callback/AuthCallback' import styles from './Header.module.scss' const collectionRepository = new CollectionJSDataverseRepository() @@ -15,12 +16,14 @@ const collectionRepository = new CollectionJSDataverseRepository() export function Header() { const { t } = useTranslation('header') const { user } = useSession() - const { pathname } = useLocation() + const { pathname, search } = useLocation() const { logIn: oidcLogin } = useContext(AuthContext) const handleOidcLogIn = () => { - oidcLogin(encodeURIComponent(pathname)) + const state = encodeReturnToPathInStateQueryParam(`${pathname}${search}`) + + oidcLogin(state) } return (