diff --git a/package.json b/package.json index 88e8f3a9..4e1ef815 100644 --- a/package.json +++ b/package.json @@ -76,6 +76,7 @@ "react-virtualized-auto-sizer": "^1.0.7", "react-window": "^1.8.8", "recharts": "^2.4.3", + "rouge": "^1.0.3", "seedrandom": "^3.0.5", "shallowequal": "^1.1.0", "short-uuid": "^4.2.2", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9847ebd2..c0407c55 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,5 +1,9 @@ lockfileVersion: '6.0' +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + dependencies: '@floating-ui/react': specifier: ^0.19.2 @@ -148,6 +152,9 @@ dependencies: recharts: specifier: ^2.4.3 version: 2.4.3(prop-types@15.8.1)(react-dom@18.2.0)(react@18.2.0) + rouge: + specifier: ^1.0.3 + version: 1.0.3 seedrandom: specifier: ^3.0.5 version: 3.0.5 @@ -5376,6 +5383,11 @@ packages: p-locate: 5.0.0 dev: true + /lodash-node@2.4.1: + resolution: {integrity: sha512-egEt8eNQp2kZWRmngahiqMoDCDCENv3uM188S7Ed5t4k3v6RrLELXC+FqLNMUnhCo7gvQX3G1V8opK/Lcslahg==} + deprecated: This package is discontinued. Use lodash@^4.0.0. + dev: false + /lodash.castarray@4.4.0: resolution: {integrity: sha512-aVx8ztPv7/2ULbArGJ2Y42bG1mEQ5mGjpdvrbJcJFU3TbYybe+QlLS4pst9zV52ymy2in1KpFPiZnAOATxD4+Q==} dev: false @@ -6565,6 +6577,12 @@ packages: fsevents: 2.3.3 dev: true + /rouge@1.0.3: + resolution: {integrity: sha512-YCt74Dxsi99E8/uh943FTa80EmGboaOu1ij4q8WD4EAGyvyWYaH7MRHorrDbGgLY7iFUwDwyW/g9KJZx7D5fUQ==} + dependencies: + lodash-node: 2.4.1 + dev: false + /run-parallel@1.2.0: resolution: {integrity: sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==} dependencies: @@ -7653,7 +7671,3 @@ packages: react: 18.2.0 use-sync-external-store: 1.2.0(react@18.2.0) dev: false - -settings: - autoInstallPeers: true - excludeLinksFromLockfile: false diff --git a/src/components/AppBar.tsx b/src/components/AppBar.tsx index 05780480..2b2799fe 100644 --- a/src/components/AppBar.tsx +++ b/src/components/AppBar.tsx @@ -19,7 +19,6 @@ import MainWalkthrough, { Handle as MainWalkthroughRef, } from './walkthrough/MainWalkthrough'; import { useColors } from '../stores/colors'; -import type { ColorsState } from '../stores/colors'; import ColorPaletteSelect from './ui/ColorPaletteSelect'; import { categoricalPalettes, continuousPalettes } from '../palettes'; @@ -137,16 +136,8 @@ const HelpMenu = (): JSX.Element => { ); }; -const useRobustColorScalesSelector = (c: ColorsState) => ({ - useRobustColorScales: c.useRobustColorScales, - setUseRobustColorScales: c.setUseRobustColorScales, -}); - const ColorMenu = () => { const colors = useColors(); - const { useRobustColorScales, setUseRobustColorScales } = useColors( - useRobustColorScalesSelector - ); const content = (
@@ -167,13 +158,27 @@ const ColorMenu = () => { onChangeColorPalette={colors.setCategoricalPalette} /> - Robust Coloring + + Continuous Ints Enable + + Continuous Categories + + Enable + + + Robust Coloring + + Enable +
); diff --git a/src/components/ScalarValue.tsx b/src/components/ScalarValue.tsx index 416d037e..ff5e7351 100644 --- a/src/components/ScalarValue.tsx +++ b/src/components/ScalarValue.tsx @@ -41,7 +41,7 @@ const ScalarValue: FunctionComponent = ({ // eslint-disable-next-line react-hooks/exhaustive-deps const colorTransferFunctionSelector = useCallback( (d: Dataset) => - d.colorTransferFunctions[column.key]?.[filtered ? 'filtered' : 'full'][0], + d.colorTransferFunctions[column.key]?.[filtered ? 'filtered' : 'full'], [column.key, filtered] ); diff --git a/src/hooks/useColorTransferFunction.ts b/src/hooks/useColorTransferFunction.ts index 2c992799..1939342c 100644 --- a/src/hooks/useColorTransferFunction.ts +++ b/src/hooks/useColorTransferFunction.ts @@ -109,36 +109,40 @@ export const createConstantTransferFunction = ( return tf as ConstantTransferFunction; }; -const createColorTransferFunction = ( +export const createColorTransferFunction = ( data: ColumnData | undefined, dType: DataType | undefined, + robust = false, + continuousInts = false, + continuousCategories = false, classBreaks?: number[] ): TransferFunction => { - const robustColoring = useColors.getState().useRobustColorScales; - if (dType === undefined) return createConstantTransferFunction(unknownDataType); if (data === undefined) return createConstantTransferFunction(dType); - if (['int', 'bool', 'Category', 'str'].includes(dType.kind)) { - const uniqValues = _.uniq(data); + if (['bool', 'str'].includes(dType.kind)) { + return createCategoricalTransferFunction(_.uniq(data), dType); + } + if (dType.kind === 'int' && !continuousInts) { + const uniqValues = _.uniq(data); const tooManyInts = dType.kind === 'int' && uniqValues.length > MAX_VALUES_FOR_INT_CATEGORY; if (!tooManyInts) { - const transferFunction = createCategoricalTransferFunction( - uniqValues, - dType - ); - return transferFunction; + return createCategoricalTransferFunction(uniqValues, dType); } } - if (['int', 'float'].includes(dType.kind)) { + if (dType.kind === 'Category' && !continuousCategories) { + return createCategoricalTransferFunction(_.uniq(data), dType); + } + + if (['int', 'float', 'Category'].includes(dType.kind)) { const stats = makeStats(dType, data); return createContinuousTransferFunction( - (robustColoring ? stats?.p5 : stats?.min) ?? 0, - (robustColoring ? stats?.p95 : stats?.max) ?? 1, + (robust ? stats?.p5 : stats?.min) ?? 0, + (robust ? stats?.p95 : stats?.max) ?? 1, dType, classBreaks ); @@ -148,7 +152,19 @@ const createColorTransferFunction = ( }; // eslint-disable-next-line @typescript-eslint/no-explicit-any -export const useColorTransferFunction = (data: any[], dtype: DataType) => - useMemo(() => createColorTransferFunction(data, dtype), [dtype, data]); +export const useColorTransferFunction = (data: any[], dtype: DataType) => { + const colors = useColors(); + return useMemo( + () => + createColorTransferFunction( + data, + dtype, + colors.robust, + colors.continuousInts, + colors.continuousCategories + ), + [dtype, data, colors.robust, colors.continuousInts, colors.continuousCategories] + ); +}; export default useColorTransferFunction; diff --git a/src/lenses/RougeScoreLens.tsx b/src/lenses/RougeScoreLens.tsx new file mode 100644 index 00000000..9ee86dbf --- /dev/null +++ b/src/lenses/RougeScoreLens.tsx @@ -0,0 +1,39 @@ +import { Lens } from '../types'; +import 'twin.macro'; +import rouge from 'rouge'; +import { formatNumber } from '../dataformat'; + +const RougeScoreLens: Lens = ({ values }) => { + const rouge1 = rouge.n(values[0], values[1], 1); + const rouge2 = rouge.n(values[0], values[1], 2); + return ( +
+
+ Rouge 1: {formatNumber(rouge1)} +
+
+ Rouge 2: {formatNumber(rouge2)} +
+
+ ); +}; + +RougeScoreLens.key = 'RougeScoreView'; +RougeScoreLens.dataTypes = ['str']; +RougeScoreLens.defaultHeight = 50; +RougeScoreLens.minHeight = 50; +RougeScoreLens.maxHeight = 100; +RougeScoreLens.multi = true; +RougeScoreLens.displayName = 'ROUGE Score'; +RougeScoreLens.filterAllowedColumns = (allColumns, selectedColumns) => { + if (selectedColumns.length === 2) return []; + const selectedKeys = selectedColumns.map((selectedCol) => selectedCol.key); + return allColumns.filter(({ type, key }) => { + return type.kind === 'str' && !selectedKeys.includes(key); + }); +}; +RougeScoreLens.isSatisfied = (columns) => { + if (columns.length === 2) return true; + return false; +}; +export default RougeScoreLens; diff --git a/src/lenses/index.ts b/src/lenses/index.ts index 0a170cb5..35051f2d 100644 --- a/src/lenses/index.ts +++ b/src/lenses/index.ts @@ -10,6 +10,7 @@ import SequenceLens from './SequenceLens'; import SpectrogramLens from './SpectrogramLens'; import TextLens from './TextLens'; import VideoLens from './VideoLens'; +import RougeScoreLens from './RougeScoreLens'; export const ALL_LENSES = [ ArrayLens, @@ -24,4 +25,5 @@ export const ALL_LENSES = [ HtmlLens, MarkdownLens, ScalarLens, + RougeScoreLens, ]; diff --git a/src/rougeScore.d.ts b/src/rougeScore.d.ts new file mode 100644 index 00000000..a294f4d3 --- /dev/null +++ b/src/rougeScore.d.ts @@ -0,0 +1 @@ +declare module 'rouge'; diff --git a/src/stores/colors.ts b/src/stores/colors.ts index cf576ad5..d850f649 100644 --- a/src/stores/colors.ts +++ b/src/stores/colors.ts @@ -17,11 +17,15 @@ export interface ColorsState { constantPalette: ConstantPalette; categoricalPalette: CategoricalPalette; continuousPalette: ContinuousPalette; - useRobustColorScales: boolean; + robust: boolean; + continuousInts: boolean; + continuousCategories: boolean; setConstantPalette: (palette?: ConstantPalette) => void; setCategoricalPalette: (palette?: CategoricalPalette) => void; setContinuousPalette: (palette?: ContinuousPalette) => void; - setUseRobustColorScales: (useRobust: boolean) => void; + setRobust: (robust: boolean) => void; + setContinuousInts: (continuous: boolean) => void; + setContinuousCategories: (continuous: boolean) => void; } export const useColors = create()( @@ -30,35 +34,26 @@ export const useColors = create()( constantPalette: defaultConstantPalette, categoricalPalette: defaultCategoricalPalette, continuousPalette: defaultContinuousPalette, - useRobustColorScales: false, + robust: false, + continuousInts: false, + continuousCategories: false, setConstantPalette: (palette) => { - set((state) => { - return { - ...state, - constantPalette: palette ?? defaultConstantPalette, - }; - }); + set({ constantPalette: palette ?? defaultConstantPalette }); }, setCategoricalPalette: (palette) => { - set((state) => { - return { - ...state, - categoricalPalette: palette ?? defaultCategoricalPalette, - }; - }); + set({ categoricalPalette: palette ?? defaultCategoricalPalette }); }, setContinuousPalette: (palette) => { - set((state) => { - return { - ...state, - continuousPalette: palette ?? defaultContinuousPalette, - }; - }); + set({ continuousPalette: palette ?? defaultContinuousPalette }); + }, + setRobust: (robust: boolean) => { + set({ robust }); + }, + setContinuousInts: (continuousInts: boolean) => { + set({ continuousInts }); }, - setUseRobustColorScales: (useRobustColorScales: boolean) => { - set((state) => { - return { ...state, useRobustColorScales }; - }); + setContinuousCategories: (continuousCategories: boolean) => { + set({ continuousCategories }); }, }), { diff --git a/src/stores/dataset/colorTransferFunctionFactory.tsx b/src/stores/dataset/colorTransferFunctionFactory.tsx index f918ec43..a868bfa2 100644 --- a/src/stores/dataset/colorTransferFunctionFactory.tsx +++ b/src/stores/dataset/colorTransferFunctionFactory.tsx @@ -1,83 +1,39 @@ import { - DataType, - isCategorical, - isFloat, - isNumerical, - isScalar, -} from '../../datatypes'; -import { - createCategoricalTransferFunction, - createConstantTransferFunction, - createContinuousTransferFunction, + createColorTransferFunction, TransferFunction, } from '../../hooks/useColorTransferFunction'; -import _ from 'lodash'; -import { useColors } from '../../stores/colors'; -import { - ColumnData, - DataColumn, - DataStatistics, - isCategoricalColumn, - isScalarColumn, - TableData, -} from '../../types'; -import { Dataset } from './dataset'; - -export const makeApplicableColorTransferFunctions = ( - type: DataType, - data: ColumnData, - stats?: DataStatistics -): TransferFunction[] => { - const transferFunctions: TransferFunction[] = []; - - if ((isCategorical(type) || isScalar(type)) && !isFloat(type)) { - const uniqueValues = _.uniq(data); - const transFn = createCategoricalTransferFunction(uniqueValues, type); - transferFunctions.push(transFn); - } - - if (isNumerical(type)) { - const useRobustColoring = useColors.getState().useRobustColorScales; - - const min = useRobustColoring ? stats?.p5 : stats?.min; - const max = useRobustColoring ? stats?.p95 : stats?.max; - - transferFunctions.push( - createContinuousTransferFunction(min || 0, max || 1, type) - ); - } - - transferFunctions.push(createConstantTransferFunction(type)); - - return transferFunctions; -}; +import { DataColumn, TableData } from '../../types'; +import { useColors } from '../colors'; type ColumnsTransferFunctions = Record< string, - { full: TransferFunction[]; filtered: TransferFunction[] } + { full: TransferFunction; filtered: TransferFunction } >; export const makeColumnsColorTransferFunctions = ( columns: DataColumn[], data: TableData, - stats: Dataset['columnStats'], - filteredMask: boolean[] + filteredIndices: Int32Array ): ColumnsTransferFunctions => { - return columns - .filter((column) => isScalarColumn(column) || isCategoricalColumn(column)) - .reduce((a, column) => { - a[column.key] = { - full: makeApplicableColorTransferFunctions( - column.type, - data[column.key], - stats.full[column.key] - ), - filtered: makeApplicableColorTransferFunctions( - column.type, - data[column.key].filter((_, i) => filteredMask[i]), - stats.filtered[column.key] - ), - }; - return a; - }, {} as ColumnsTransferFunctions); + const colors = useColors.getState(); + + return columns.reduce((a, column) => { + a[column.key] = { + full: createColorTransferFunction( + data[column.key], + column.type, + colors.robust, + colors.continuousInts, + colors.continuousCategories + ), + filtered: createColorTransferFunction( + filteredIndices.map((i) => data[column.key][i]), + column.type, + colors.robust, + colors.continuousInts, + colors.continuousCategories + ), + }; + return a; + }, {} as ColumnsTransferFunctions); }; diff --git a/src/stores/dataset/dataset.ts b/src/stores/dataset/dataset.ts index 4916c015..123ec75d 100644 --- a/src/stores/dataset/dataset.ts +++ b/src/stores/dataset/dataset.ts @@ -44,8 +44,8 @@ export interface Dataset { colorTransferFunctions: Record< string, { - full: TransferFunction[]; - filtered: TransferFunction[]; + full: TransferFunction; + filtered: TransferFunction; } >; recomputeColorTransferFunctions: () => void; @@ -430,8 +430,7 @@ export const useDataset = create( const newTransferFunctions = makeColumnsColorTransferFunctions( get().columns.filter(({ key }) => columnsToCompute.includes(key)), get().columnData, - get().columnStats, - get().isIndexFiltered + get().filteredIndices ); set({ @@ -535,6 +534,8 @@ useDataset.subscribe( } ); +useColors.subscribe(useDataset.getState().recomputeColorTransferFunctions); + useDataset.subscribe( (state) => state.selectedIndices, useDataset.getState().recomputeColumnRelevance diff --git a/src/stores/dataset/statisticsFactory.tsx b/src/stores/dataset/statisticsFactory.tsx index 20265276..d559821f 100644 --- a/src/stores/dataset/statisticsFactory.tsx +++ b/src/stores/dataset/statisticsFactory.tsx @@ -1,4 +1,4 @@ -import { DataType, isNumerical } from '../../datatypes'; +import { DataType, isCategorical, isNumerical } from '../../datatypes'; import { max, mean, min, quantile, standardDeviation } from 'simple-statistics'; import { ColumnData, @@ -13,7 +13,7 @@ export const makeStats = ( data: ColumnData, mask?: boolean[] ): DataStatistics | undefined => { - if (!isNumerical(type)) { + if (!isNumerical(type) && !isCategorical(type)) { return; } diff --git a/src/widgets/DataGrid/Cell/CategoricalCell.tsx b/src/widgets/DataGrid/Cell/CategoricalCell.tsx index f7cb40d2..187cfcb5 100644 --- a/src/widgets/DataGrid/Cell/CategoricalCell.tsx +++ b/src/widgets/DataGrid/Cell/CategoricalCell.tsx @@ -22,7 +22,7 @@ const CategoricalCell: FunctionComponent = ({ value, column }) => { (d: Dataset) => d.colorTransferFunctions[column.key]?.[ tableView !== 'full' ? 'filtered' : 'full' - ][0], + ], [column.key, tableView] ); diff --git a/src/widgets/Histogram/Histogram.tsx b/src/widgets/Histogram/Histogram.tsx index 99feb64e..1188520a 100644 --- a/src/widgets/Histogram/Histogram.tsx +++ b/src/widgets/Histogram/Histogram.tsx @@ -81,7 +81,7 @@ const Histogram: Widget = () => { stackByColumnKey ? d.colorTransferFunctions[stackByColumnKey]?.[ filter ? 'filtered' : 'full' - ][0] + ] : createConstantTransferFunction(), [filter, stackByColumnKey] ); diff --git a/src/widgets/ScatterplotView/ScatterplotView.tsx b/src/widgets/ScatterplotView/ScatterplotView.tsx index 3e9dd600..b4560728 100644 --- a/src/widgets/ScatterplotView/ScatterplotView.tsx +++ b/src/widgets/ScatterplotView/ScatterplotView.tsx @@ -139,9 +139,7 @@ const ScatterplotView: Widget = () => { const transferFunctionSelector = useCallback( (d: Dataset) => colorByKey !== undefined && colorByKey.length > 0 - ? d.colorTransferFunctions[colorByKey]?.[ - filter ? 'filtered' : 'full' - ][0] + ? d.colorTransferFunctions[colorByKey]?.[filter ? 'filtered' : 'full'] : createConstantTransferFunction(), [colorByKey, filter] ); diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index 023c4dd1..9bd9c55f 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -243,9 +243,7 @@ const SimilarityMap: Widget = () => { const transferFunctionSelector = useCallback( (d: Dataset) => colorByKey !== undefined && colorByKey.length > 0 - ? d.colorTransferFunctions[colorByKey]?.[ - filter ? 'filtered' : 'full' - ][0] + ? d.colorTransferFunctions[colorByKey]?.[filter ? 'filtered' : 'full'] : createConstantTransferFunction(colorBy?.type ?? unknownDataType), [colorByKey, filter, colorBy?.type] );