import { create } from "zustand";
import { persist, devtools } from "zustand/middleware";
import {
  GridSortModel,
  GridFilterModel,
  GridColumnVisibilityModel,
  GridColDef,
  GridColumnResizeParams,
  MuiEvent,
  GridCallbackDetails,
} from "@mui/x-data-grid";

// Zustand store definition
interface DataGridState {
  sortModel: GridSortModel;
  filterModel: GridFilterModel;
  columnVisibilityModel: GridColumnVisibilityModel;
  columnWidths: Record<string, number>;
}

interface DataGridStore {
  grids: Record<string, DataGridState>;
  setGridState: (key: string, state: Partial<DataGridState>) => void;
  getGridState: (
    key: string,
    initialState?: Partial<DataGridState>
  ) => DataGridState;
}

export const useDataGridStorage = create<DataGridStore>()(
  devtools(
    persist(
      (set, get) => ({
        grids: {},
        setGridState: (key, state) =>
          set((prev) => ({
            grids: {
              ...prev.grids,
              [key]: {
                ...(prev.grids[key] ?? {}), // Ensure `prev.grids[key]` exists before spreading
                ...state, // Merge new values
              },
            },
          })),
        getGridState: (key, initialState = {}) => ({
          columnWidths: get().grids[key]?.columnWidths ?? {},
          sortModel:
            get().grids[key]?.sortModel ?? initialState.sortModel ?? [],
          filterModel: get().grids[key]?.filterModel ??
            initialState.filterModel ?? { items: [] },
          columnVisibilityModel:
            get().grids[key]?.columnVisibilityModel ??
            initialState.columnVisibilityModel ??
            {},
        }),
      }),
      {
        name: "datagrid-storage",
      }
    )
  )
);

// Hook to integrate with DataGrid
export const usePersistentDataGrid = (
  storageKey: string,
  columns: GridColDef[],
  initialState?: Partial<DataGridState>
) => {
  const { getGridState, setGridState } = useDataGridStorage();
  const storedState = getGridState(storageKey, initialState); // Use initial state if no stored state exists

  return {
    sortModel: storedState.sortModel,
    onSortModelChange: (model: GridSortModel) =>
      setGridState(storageKey, { sortModel: model }),

    filterModel: storedState.filterModel,
    onFilterModelChange: (model: GridFilterModel) =>
      setGridState(storageKey, { filterModel: model }),

    columnVisibilityModel: storedState.columnVisibilityModel,
    onColumnVisibilityModelChange: (model: GridColumnVisibilityModel) =>
      setGridState(storageKey, { columnVisibilityModel: model }),
    columns: columns.map((x) => {
      if (x.field in storedState.columnWidths) {
        const { flex, ...rest } = x; // Remove flex without mutating the original object
        return { ...rest, width: storedState.columnWidths[x.field] };
      } else {
        return x;
      }
    }),
    onColumnWidthChange: (
      params: GridColumnResizeParams,
      _event: MuiEvent,
      _details: GridCallbackDetails
    ) => {
      const columnId = params.colDef.field;
      const width = params.width;
      setGridState(storageKey, {
        columnWidths: { ...storedState.columnWidths, [columnId]: width },
      });
    },
  };
};
