import { memo, useRef, useState } from 'react';
import { ChevronRight, FileDownload, FilterList, Refresh } from '@mui/icons-material';
import SearchIcon from '@mui/icons-material/Search';
import {
  alpha,
  Badge,
  Box,
  Checkbox,
  FormControlLabel,
  IconButton,
  LinearProgress,
  styled,
  Switch,
  SxProps,
  Table,
  TableBody,
  TableCell,
  TableContainer,
  TableFooter,
  TableHead,
  TablePagination,
  TableRow,
  Tooltip,
  Typography,
  useMediaQuery,
} from '@mui/material';
import {
  Column,
  flexRender,
  GroupingState,
  Row,
  Table as TableConfig,
} from '@tanstack/react-table';
import isEmpty from 'lodash/isEmpty';
import meanBy from 'lodash/meanBy';
import partition from 'lodash/partition';
import pickBy from 'lodash/pickBy';
import { FieldValues } from 'react-hook-form';
import { Field } from '@/classes';
import ColumnHeader from '@/components/DataTable/ColumnHeader';
import ColumnsDialog from '@/components/DataTable/ColumnsDialog';
import FilterDrawer from '@/components/DataTable/FilterDrawer';
import TableBodyLoadingOverlay from '@/components/DataTable/TableBodyLoadingOverlay';
import { NonFormField } from '@/components/Form/FormField';
import DebouncedTextField from '@/components/Shared/DebouncedTextField';
import getValueFromEvent from '@/utils/getValueFromEvent';

const Toolbar = styled('div', {
  shouldForwardProp: (prop) => prop !== 'highlight',
})<{ highlight?: boolean }>(({ theme, highlight }) => ({
  minHeight: theme.spacing(10),
  display: 'flex',
  alignItems: 'center',
  gap: theme.spacing(4),
  paddingRight: theme.spacing(2),
  paddingLeft: theme.spacing(2),
  paddingTop: theme.spacing(1),
  paddingBottom: theme.spacing(1),
  backgroundColor: highlight
    ? alpha(theme.palette.primary.main, 0.08)
    : theme.palette.background.paper,
}));

const getCommonPinningStyles = (column: Column<any>, somePinned: boolean): SxProps => {
  if (column.id === 'checkbox') {
    return {};
  }

  if (!somePinned) {
    return {
      width: column.getSize(),
      minWidth: column.getSize(),
    };
  }

  const isPinned = column.getIsPinned();
  const isLastLeftPinnedColumn = isPinned === 'left' && column.getIsLastColumn('left');
  const isFirstRightPinnedColumn = isPinned === 'right' && column.getIsFirstColumn('right');

  return {
    boxShadow: isLastLeftPinnedColumn
      ? '-4px 0 4px -4px inset rgba(0,0,0,0.3)'
      : isFirstRightPinnedColumn
        ? '4px 0 4px -4px inset rgba(0,0,0,0.3)'
        : undefined,
    left: isPinned === 'left' ? `${column.getStart('left')}px` : undefined,
    right: isPinned === 'right' ? `${column.getAfter('right')}px` : undefined,
    opacity: isPinned ? 0.95 : 1,
    position: isPinned ? 'sticky' : 'relative',
    zIndex: isPinned ? 2 : 0,
    bgcolor: 'background.paper',
    width: column.getSize(),
    minWidth: column.getSize(),
  };
};

function ReactTableBody<T>({
  table,
  sectionRef,
  onRowClick,
  someCanExpand,
  someCanSelect,
  someColumnsPinned,
  isFetching,
}: {
  table: TableConfig<T>;
  sectionRef?: React.Ref<HTMLTableSectionElement>;
  onRowClick?: (row: T) => void;
  someCanExpand: boolean;
  someCanSelect: boolean;
  someColumnsPinned: boolean;
  isFetching?: boolean;
}) {
  return (
    <TableBody ref={sectionRef}>
      {table.getRowModel().rows.map((row) => {
        const bgcolor = row.getIsGrouped() ? 'action.hover' : 'inherit';
        return (
          <TableRow key={row.id} hover={!!onRowClick} onClick={() => onRowClick?.(row.original)}>
            {someCanExpand && (
              <TableCell padding="checkbox" sx={{ bgcolor }}>
                {row.getCanExpand() && (
                  <IconButton
                    onClick={(e) => {
                      e.stopPropagation();
                      row.toggleExpanded();
                    }}
                    sx={{
                      transform: `rotate(${row.getIsExpanded() ? '90deg' : '0'})`,
                    }}
                  >
                    <ChevronRight />
                  </IconButton>
                )}
              </TableCell>
            )}
            {someCanSelect && (
              <TableCell padding="checkbox" sx={{ bgcolor }}>
                <Checkbox
                  checked={row.getIsSelected()}
                  onClick={(e) => {
                    e.stopPropagation();
                    row.toggleSelected();
                  }}
                />
              </TableCell>
            )}
            {row.getVisibleCells().map((cell) => (
              <TableCell
                key={cell.id}
                padding={cell.column.id === 'checkbox' ? 'checkbox' : undefined}
                sx={{
                  bgcolor,
                  ...getCommonPinningStyles(cell.column, someColumnsPinned),
                }}
              >
                {cell.getIsAggregated()
                  ? flexRender(
                      cell.column.columnDef.aggregatedCell ?? cell.column.columnDef.cell,
                      cell.getContext(),
                    )
                  : cell.getIsPlaceholder()
                    ? null
                    : flexRender(cell.column.columnDef.cell, cell.getContext())}
              </TableCell>
            ))}
          </TableRow>
        );
      })}
      {table.getRowCount() === 0 && (
        <TableRow>
          <TableCell colSpan={table.getVisibleLeafColumns().length}>
            <Typography color="textSecondary">
              {isFetching ? 'Loading...' : 'No results were found.'}
            </Typography>
          </TableCell>
        </TableRow>
      )}
    </TableBody>
  );
}

const MemoizedTableBody = memo(
  ReactTableBody,
  (prev, next) => prev.table.options.data === next.table.options.data,
) as typeof ReactTableBody;

export default function ReactTable<T>({
  table,
  size,
  isFetching,
  refetch,
  onDownload,
  onRowClick,
  getBulkActions,
  slots,
  definedGrouping = [],
  filterable = [],
  enableToolbar = true,
}: {
  table: TableConfig<T>;
  size?: 'small' | 'medium';
  isFetching?: boolean;
  refetch?: () => void;
  onDownload?: () => void;
  onRowClick?: (row: T) => void;
  getBulkActions?: (rows: Row<T>[]) => React.ReactNode;
  slots?: Partial<Record<'leftActions' | 'rightActions', React.ReactNode>>;
  definedGrouping?: GroupingState;
  filterable?: Field[];
  enableToolbar?: boolean;
}) {
  const [isSearching, setIsSearching] = useState(false);
  const [isFiltering, setIsFiltering] = useState(false);
  const isMobile = useMediaQuery('(max-width: 900px)');
  const tableBodyRef = useRef<HTMLTableSectionElement | null>(null);

  const {
    pagination,
    globalFilter,
    grouping,
    columnFilters,
    avgs = {},
    sums = {},
  } = table.getState();

  const { rows } = table.getRowModel();
  const hasFooter = table
    .getFooterGroups()
    .some((fg) => fg.headers.some((h) => h.column.columnDef.footer));
  const someCanExpand = rows.some((r) => r.getCanExpand());
  const someCanSelect = rows.some((r) => r.getCanSelect());
  const someColumnsPinned = table.getVisibleLeafColumns().some((c) => c.getIsPinned());
  const canGlobalFilter = table.options.enableGlobalFilter;
  const canGroup = definedGrouping.length > 0 || grouping.length > 0;
  const showBulkActions = table.getSelectedRowModel().rows.length > 0;
  const [quickFilters, formFilters] = partition(filterable, (i) => i.shouldQuickFilter(isMobile));
  const quickFilterIds = quickFilters.map((f) => f.name);
  const [quickColumnFilters, formColumnFilters] = partition(columnFilters, (cf) =>
    quickFilterIds.includes(cf.id),
  );
  const someAggregations = Object.keys(pickBy({ ...sums, ...avgs })).length > 0;

  const getSumForColumn = (columnId: string) => {
    if (table.options.meta?.getSum) {
      return table.options.meta.getSum(columnId);
    }
    return table
      .getFilteredRowModel()
      .rows.reduce((agg, r) => Number(r.getValue(columnId)) + agg, 0);
  };

  const getAvgForColumn = (columnId: string) => {
    if (table.options.meta?.getAvg) {
      return table.options.meta.getAvg(columnId);
    }
    return meanBy(table.getFilteredRowModel().rows, (r) => r.getValue(columnId));
  };

  return (
    <>
      {enableToolbar && (
        <Toolbar highlight={showBulkActions}>
          {showBulkActions ? (
            <>
              <Typography color="inherit">
                {table.getSelectedRowModel().rows.length} selected
              </Typography>

              <Box ml="auto">{getBulkActions?.(table.getSelectedRowModel().rows)}</Box>
            </>
          ) : (
            <>
              <Box display="flex" alignItems="center">
                {refetch && (
                  <Tooltip title="Reload" placement="top">
                    <IconButton onClick={refetch} size="large" disabled={isFetching}>
                      <Refresh />
                    </IconButton>
                  </Tooltip>
                )}

                <ColumnsDialog table={table} />

                {onDownload && (
                  <Tooltip title="Export to Spreadsheet" placement="top">
                    <IconButton onClick={onDownload} size="large">
                      <FileDownload />
                    </IconButton>
                  </Tooltip>
                )}

                {slots?.leftActions}
              </Box>

              <Box display="flex" alignItems="center" gap={1} flexGrow={1} justifyContent="end">
                {quickFilters.map((f) => (
                  <Box key={f.name} mr={2} minWidth={225}>
                    <NonFormField
                      field={f.getFilterField().with({ margin: 'none' })}
                      value={columnFilters.find((cf) => cf.id === f.name)?.value}
                      onChange={(e) =>
                        table.setColumnFilters((prev) => {
                          const index = prev.findIndex((cf) => cf.id === f.name);
                          const value = getValueFromEvent(e);
                          if (index === -1) {
                            return [...prev, { id: f.name, value }];
                          }
                          return prev.map((cf, i) => (i === index ? { id: f.name, value } : cf));
                        })
                      }
                    />
                  </Box>
                ))}

                {canGroup && (
                  <FormControlLabel
                    control={
                      <Switch
                        checked={grouping.length > 0}
                        onChange={(e) =>
                          table.setGrouping(e.currentTarget.checked ? definedGrouping : [])
                        }
                      />
                    }
                    label="Grouped"
                  />
                )}

                {canGlobalFilter &&
                  (isSearching || globalFilter ? (
                    <DebouncedTextField
                      initialValue={globalFilter ? String(globalFilter) : ''}
                      onChange={(value: string) => table.setGlobalFilter(value)}
                      size="small"
                      type="search"
                      label="Search"
                      fullWidth
                      sx={{ maxWidth: 250 }}
                      autoFocus
                      onBlur={(e) => {
                        table.setGlobalFilter(e.target.value);
                        setIsSearching(false);
                      }}
                    />
                  ) : (
                    <IconButton onClick={() => setIsSearching(true)} size="large">
                      <SearchIcon />
                    </IconButton>
                  ))}

                {formFilters.length > 0 && (
                  <div>
                    <Tooltip title="Filters" placement="top">
                      <IconButton onClick={() => setIsFiltering(true)} size="large">
                        <Badge
                          badgeContent={formColumnFilters.length}
                          color="secondary"
                          invisible={formColumnFilters.length === 0}
                        >
                          <FilterList />
                        </Badge>
                      </IconButton>
                    </Tooltip>
                  </div>
                )}

                {slots?.rightActions}
              </Box>
            </>
          )}
        </Toolbar>
      )}

      <FilterDrawer
        open={isFiltering}
        onClose={() => setIsFiltering(false)}
        filterableFields={formFilters}
        onSuccess={(f) =>
          table.setColumnFilters(
            Object.entries(f).reduce((agg, [k, v]) => {
              if (!isEmpty(v)) {
                agg.push({ id: k, value: v });
              }
              return agg;
            }, quickColumnFilters),
          )
        }
        initialValues={formColumnFilters.reduce((acc, f) => {
          acc[f.id] = f.value;
          return acc;
        }, {} as FieldValues)}
      />

      <TableContainer sx={{ position: 'relative' }}>
        <Table size={size}>
          <TableHead>
            {table.getHeaderGroups().map((headerGroup) => (
              <TableRow key={headerGroup.id}>
                {someCanExpand && (
                  <TableCell padding="checkbox">
                    <IconButton
                      onClick={() => table.toggleAllRowsExpanded()}
                      sx={{
                        transform: `rotate(${table.getIsAllRowsExpanded() ? '90deg' : '0'})`,
                      }}
                    >
                      <ChevronRight />
                    </IconButton>
                  </TableCell>
                )}
                {someCanSelect && (
                  <TableCell padding="checkbox">
                    <Checkbox
                      checked={table.getIsAllRowsSelected()}
                      indeterminate={table.getIsSomeRowsSelected()}
                      onClick={() => table.toggleAllRowsSelected()}
                    />
                  </TableCell>
                )}
                {headerGroup.headers.map((header) => (
                  <TableCell
                    key={header.id}
                    sortDirection={header.column.getIsSorted()}
                    sx={{
                      whiteSpace: 'nowrap',
                      ...getCommonPinningStyles(header.column, someColumnsPinned),
                    }}
                    padding={header.id === 'checkbox' ? 'checkbox' : 'normal'}
                  >
                    <ColumnHeader header={header} />
                  </TableCell>
                ))}
              </TableRow>
            ))}
            <TableRow>
              <TableCell
                colSpan={table.getVisibleLeafColumns().length}
                sx={{
                  p: 0,
                  position: 'relative',
                  height: 1,
                  border: 0,
                }}
              >
                {isFetching && (
                  <Box position="absolute" top={-4} right={0} left={0}>
                    <LinearProgress />
                  </Box>
                )}
              </TableCell>
            </TableRow>
          </TableHead>

          {table.getState().columnSizingInfo.isResizingColumn ? (
            <MemoizedTableBody
              table={table}
              sectionRef={tableBodyRef}
              onRowClick={onRowClick}
              someCanExpand={someCanExpand}
              someCanSelect={someCanSelect}
              someColumnsPinned={someColumnsPinned}
              isFetching={isFetching}
            />
          ) : (
            <ReactTableBody
              table={table}
              sectionRef={tableBodyRef}
              onRowClick={onRowClick}
              someCanExpand={someCanExpand}
              someCanSelect={someCanSelect}
              someColumnsPinned={someColumnsPinned}
              isFetching={isFetching}
            />
          )}

          {someAggregations && (
            <TableFooter>
              <TableRow>
                {someCanExpand && <TableCell padding="checkbox" />}
                {someCanSelect && <TableCell padding="checkbox" />}
                {table.getVisibleLeafColumns().map((column) => {
                  const renderValue = (value: number) => {
                    const cell = column.columnDef.aggregatedCell ?? column.columnDef.cell;

                    return cell
                      ? // @ts-expect-error this is a hack to get the cell to render
                        cell({
                          renderValue: () => value,
                          getValue: () => value,
                          row: { original: {}, subRows: [] },
                          column: column,
                        })
                      : value.toLocaleString();
                  };

                  return (
                    <TableCell key={column.id}>
                      {sums[column.id] && <div>Sum: {renderValue(getSumForColumn(column.id))}</div>}
                      {avgs[column.id] && <div>Avg: {renderValue(getAvgForColumn(column.id))}</div>}
                    </TableCell>
                  );
                })}
              </TableRow>
            </TableFooter>
          )}
          {hasFooter && (
            <TableFooter>
              {table.getFooterGroups().map((footerGroup) => (
                <TableRow key={footerGroup.id}>
                  {someCanExpand && <TableCell padding="checkbox" />}
                  {someCanSelect && <TableCell padding="checkbox" />}
                  {footerGroup.headers.map((header) => (
                    <TableCell key={header.id} colSpan={header.colSpan}>
                      {header.isPlaceholder
                        ? null
                        : flexRender(header.column.columnDef.footer, header.getContext())}
                    </TableCell>
                  ))}
                </TableRow>
              ))}
            </TableFooter>
          )}
          <TableBodyLoadingOverlay
            el={tableBodyRef.current}
            show={Boolean(isFetching) && table.getRowCount() > 0}
          />
        </Table>
      </TableContainer>

      {table.getRowCount() > 10 && (
        <TablePagination
          component="div"
          count={table.getRowCount()}
          page={pagination.pageIndex}
          onPageChange={(e, p) => table.setPageIndex(p)}
          onRowsPerPageChange={(e) => table.setPageSize(Number(e.target.value))}
          rowsPerPage={pagination.pageSize}
          rowsPerPageOptions={[10, 25, 50, 100, 250, 1000, 2500]}
        />
      )}
    </>
  );
}
