Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion DashAI/back/dataloaders/classes/audio_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def load_data(
prepared_path = self.prepare_files(filepath_or_buffer, temp_path)
if prepared_path[1] == "dir":
dataset = load_dataset(
"audiofolder", data_dir=prepared_path[0], streaming=bool(n_sample)
"audiofolder",
data_dir=prepared_path[0],
streaming=bool(n_sample),
cache_dir=temp_path,
)
if n_sample:
if type(dataset) is IterableDatasetDict:
Expand Down
2 changes: 2 additions & 0 deletions DashAI/back/dataloaders/classes/csv_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,15 @@ def load_data(
data_files=prepared_path[0],
**clean_params,
streaming=bool(n_sample),
cache_dir=temp_path,
)
else:
dataset = load_dataset(
"csv",
data_dir=prepared_path[0],
**clean_params,
streaming=bool(n_sample),
cache_dir=temp_path,
)
shutil.rmtree(prepared_path[0])
if n_sample:
Expand Down
7 changes: 6 additions & 1 deletion DashAI/back/dataloaders/classes/json_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,15 @@ def load_data(
data_files=prepared_path[0],
field=field,
streaming=bool(n_sample),
cache_dir=temp_path,
)
else:
dataset = load_dataset(
"json", data_dir=prepared_path[0], field=field, streaming=bool(n_sample)
"json",
data_dir=prepared_path[0],
field=field,
streaming=bool(n_sample),
cache_dir=temp_path,
)
shutil.rmtree(prepared_path[0])
if n_sample:
Expand Down
2 changes: 2 additions & 0 deletions DashAI/front/src/components/models/ModelCenterContent.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export default function ModelsCenterContent() {
setSessions,
selectedTask,
tasks,
loadingTasks,
datasets,
selectedDatasetId,
step,
Expand Down Expand Up @@ -93,6 +94,7 @@ export default function ModelsCenterContent() {
/>
) : step === 0 ? (
<SelectOptionMenu
loading={loadingTasks}
title={
selectedDatasetId
? t("models:label.selectTaskForSession")
Expand Down
2 changes: 2 additions & 0 deletions DashAI/front/src/components/models/ModelsContext.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export function ModelsProvider({ children }) {

const {
tasks,
loadingTasks,
selectedTask,
selectedSessionId,
selectedSession,
Expand Down Expand Up @@ -116,6 +117,7 @@ export function ModelsProvider({ children }) {
replaceDatasets,
startDatasetPolling,
tasks,
loadingTasks,
selectedTask,
selectedSessionId,
selectedSession,
Expand Down
59 changes: 37 additions & 22 deletions DashAI/front/src/components/threeSectionLayout/SelectOptionMenu.jsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { useState } from "react";
import { Box, Grid, Button, Alert, AlertTitle } from "@mui/material";
import { Box, Grid, Button, Alert, AlertTitle, Skeleton } from "@mui/material";
import SearchBar from "./SearchBar";
import CustomLayout from "../custom/CustomLayout";
import OptionBox from "./OptionBox";
Expand All @@ -10,6 +10,7 @@ export default function SelectOptionMenu({
title,
subtitle,
options,
loading = false,
searchBar = false,
showNoDatasetAlert = false,
onGoToDatasets = null,
Expand Down Expand Up @@ -68,29 +69,43 @@ export default function SelectOptionMenu({
spacing={1}
sx={{ mt: 2, mx: 0, maxWidth: "100%" }}
>
{filteredOptions.map((option, index) => {
const { name, display_name, description, Icon, ...otherProps } =
option;
{loading
? Array.from({ length: 6 }).map((_, index) => (
<Grid
size={{ xl: 6, lg: 6, md: 6, sm: 12, xs: 12 }}
key={index}
>
<Skeleton variant="rounded" height={100} />
</Grid>
))
: null}
{!loading &&
filteredOptions.map((option, index) => {
const { name, display_name, description, Icon, ...otherProps } =
option;

return (
<Grid size={{ xl: 6, lg: 6, md: 6, sm: 12, xs: 12 }} key={index}>
<OptionBox
optionName={display_name}
description={description}
onClick={() => goToNextStep(option.name)}
Icon={Icon}
dataTour={
dataTour && dataTourTarget && name === dataTourTarget
? dataTour
: dataTour && !dataTourTarget
return (
<Grid
size={{ xl: 6, lg: 6, md: 6, sm: 12, xs: 12 }}
key={index}
>
<OptionBox
optionName={display_name}
description={description}
onClick={() => goToNextStep(option.name)}
Icon={Icon}
dataTour={
dataTour && dataTourTarget && name === dataTourTarget
? dataTour
: undefined
}
{...otherProps}
/>
</Grid>
);
})}
: dataTour && !dataTourTarget
? dataTour
: undefined
}
{...otherProps}
/>
</Grid>
);
})}
</Grid>
</Box>

Expand Down
5 changes: 5 additions & 0 deletions DashAI/front/src/hooks/models/useSessions.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { startJobPolling } from "../../utils/jobPoller";
export function useSessions({ t }) {
const { enqueueSnackbar } = useSnackbar();
const [tasks, setTasks] = useState([]);
const [loadingTasks, setLoadingTasks] = useState(true);
const [selectedTask, setSelectedTask] = useState(null);
const [selectedSessionId, setSelectedSessionId] = useState(null);
const [selectedSession, setSelectedSession] = useState(null);
Expand All @@ -45,6 +46,7 @@ export function useSessions({ t }) {
}, [enqueueSnackbar, t]);

const fetchTasks = useCallback(async () => {
setLoadingTasks(true);
try {
const data = await getComponents({
selectTypes: ["Task"],
Expand All @@ -56,6 +58,8 @@ export function useSessions({ t }) {
variant: "error",
});
console.error("Failed to fetch tasks:", error);
} finally {
setLoadingTasks(false);
}
}, [enqueueSnackbar, t]);

Expand Down Expand Up @@ -274,6 +278,7 @@ export function useSessions({ t }) {

return {
tasks,
loadingTasks,
setTasks,
selectedTask,
setSelectedTask,
Expand Down
30 changes: 1 addition & 29 deletions DashAI/front/src/pages/models/ModelsContent.jsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { useEffect, useState } from "react";
import { Box, CircularProgress } from "@mui/material";
import { useEffect } from "react";
import { useLocation } from "react-router-dom";
import { useTranslation } from "react-i18next";
import { TourProvider } from "../../components/tour/TourProvider";
Expand All @@ -20,14 +19,10 @@ import { useModels } from "../../components/models/ModelsContext";
export default function ModelsContent() {
const location = useLocation();
const threePanelLayout = useThreePanelLayout();
const [isInitialLoading, setIsInitialLoading] = useState(true);
const { t } = useTranslation(["models"]);

const {
fetchDatasets,
sessions,
fetchSessions,
fetchTasks,
step,
setStep,
selectedSessionId,
Expand All @@ -37,14 +32,6 @@ export default function ModelsContent() {
fetchRuns,
} = useModels();

useEffect(() => {
const loadInitialData = async () => {
await Promise.all([fetchDatasets(), fetchSessions(), fetchTasks()]);
setIsInitialLoading(false);
};
loadInitialData();
}, []);

useEffect(() => {
if (location.state?.openSessionId && sessions.length > 0) {
const sessionToOpen = sessions.find(
Expand Down Expand Up @@ -77,21 +64,6 @@ export default function ModelsContent() {
}
}, [selectedSessionId, sessions]);

if (isInitialLoading) {
return (
<Box
sx={{
display: "flex",
justifyContent: "center",
alignItems: "center",
height: "100vh",
}}
>
<CircularProgress />
</Box>
);
}

return (
<ThreePanelLayoutContext.Provider value={threePanelLayout}>
<ModuleContainer>
Expand Down
Loading