import { useCallback, useEffect, useMemo, useState } from 'react';

import * as diff from 'diff';
import _ from 'lodash';
import * as R from 'ramda';

import { FontAwesomeIcon } from '@fortawesome/react-fontawesome';
import {
  Accordion,
  AccordionDetails,
  AccordionSummary,
  Alert,
  Box,
  Button,
  MenuItem,
  Paper,
  Select,
  styled,
  Switch,
  Table,
  TableBody,
  TableCell,
  tableCellClasses,
  TableContainer,
  TableHead,
  TableRow,
  Tooltip,
  Typography,
} from '@mui/material';
import {
  CompletedClassificationEvaluation,
  CompletedEvaluation,
  Evaluation,
  EvaluationType,
} from '@scale/llm-shared/interfaces/evaluation';

import AddedText from 'frontend/components/AddedText';
import { Container } from 'frontend/components/Container';
import GrayDataGrid from 'frontend/components/datagrid/GrayDataGrid';
import { MultiInputPreview } from 'frontend/components/evaluation/MultiInputPreview';
import FlexBox from 'frontend/components/FlexBox';
import PageTitle from 'frontend/components/PageTitle';
import { HSpace, VSpace } from 'frontend/components/Spacer';
import DownloadComparisonResultsButton from 'frontend/components/v2/compare/DownloadComparisonResultsButton';
import { DataSelectorAndUploader } from 'frontend/components/v2/data/DataSelectorAndUploader';
import { SingleVariantSelector } from 'frontend/components/v2/variant/SingleVariantSelector';
import VariantName from 'frontend/components/v2/variant/VariantName';
import { useLLMNavigation } from 'frontend/hooks/useLLMNavigation';
import { useComparePageState } from 'frontend/models/v2/useComparePageState';
import LoadingPage from 'frontend/pages/LoadingPage';
import PageContainer from 'frontend/pages/PageContainer';
import { useSettingsStore } from 'frontend/stores/SettingsStore';
import { useEvaluationStore } from 'frontend/storesV2/EvaluationStore';
import { useSelectionStore } from 'frontend/storesV2/SelectionStore';
import { StoredEvaluation, StoredVariant } from 'frontend/storesV2/types';
import { useVariantStore } from 'frontend/storesV2/VariantStore';
import theme, { Colors } from 'frontend/theme';
import { round, toPercent } from 'frontend/utils';
import { padArray } from 'frontend/utils/array';

const DarkTableCell = styled(TableCell)({
  [`&.${tableCellClasses.head}`]: {
    backgroundColor: Colors.CoolGray10,
    color: Colors.CoolGray50,
    fontSize: 12,
    fontWeight: 500,
    textTransform: 'uppercase',
  },
  [`&.${tableCellClasses.body}`]: {
    backgroundColor: Colors.CoolGray05,
    color: Colors.CoolGray50,
  },
});

const EVALUATION_ACCURACY_DESCRIPTION_TEXT: Record<
  Exclude<EvaluationType, 'AIFeedback'>,
  string
> = {
  [EvaluationType.ClassificationEvaluation]: 'Classification',
  [EvaluationType.HumanEvaluation]: 'Human Evaluation',
  [EvaluationType.MauveEvaluation]: 'Mauve Score',
};

interface PromptOneSidedDiffProps {
  text: string;
  other: string;
}

function PromptOneSidedDiff(props: PromptOneSidedDiffProps) {
  const { text, other } = props;
  const segments = diff.diffWords(other, text);

  const height = useMemo(() => {
    const lines = Math.max(text.split(/\n/).length, other.split(/\n/).length);
    return Math.min(225, lines * 32);
  }, [text, other]);

  return (
    <Box sx={{ height, overflowY: 'scroll' }}>
      <Typography sx={{ fontFamily: 'monospace', whiteSpace: 'pre-wrap' }}>
        {segments.map((segment, index) => {
          if (segment.added) {
            return <AddedText key={`added-${segment.value}-${index}`}>{segment.value}</AddedText>;
          } else if (segment.removed) {
            return '';
          } else {
            return segment.value;
          }
        })}
      </Typography>
    </Box>
  );
}

function UndefinedValue() {
  return (
    <Tooltip title="This model parameter is undefined for this variant">
      <FontAwesomeIcon icon="question-circle" color={Colors.CoolGray30} />
    </Tooltip>
  );
}

type TemplateVarDict = {
  [varName: string]: string;
};
function PromptVariableDiffView({
  templateVarsChanged,
  variantLeft,
  variantRight,
}: {
  templateVarsChanged: boolean;
  variantLeft: StoredVariant;
  variantRight: StoredVariant;
}) {
  const left = variantLeft.prompt.exampleVariables as TemplateVarDict;
  const right = variantRight.prompt.exampleVariables as TemplateVarDict;
  const keys = _.uniq(_.concat(_.keys(left), _.keys(right)));
  const rows = keys.map(key => ({
    id: key,
    varName: key,
    valueLeft: left[key],
    valueRight: right[key],
  }));
  if (keys.length === 0) {
    return <Typography>No prompt variables.</Typography>;
  }
  return (
    <Accordion variant="outlined" sx={{ width: '100%' }}>
      <AccordionSummary expandIcon={<FontAwesomeIcon icon="caret-down" />}>
        <Typography>
          {templateVarsChanged
            ? `Example Variables: Differing example values (${rows.length})`
            : 'Example Variables: Identical'}
        </Typography>
      </AccordionSummary>
      <AccordionDetails>
        <GrayDataGrid
          rows={rows}
          columns={[
            {
              field: 'varName',
              headerName: 'Variable Name',
              width: 120,
            },
            {
              field: 'valueLeft',
              headerName: variantLeft.name,
              flex: 1,
            },
            {
              field: 'valueRight',
              headerName: variantRight.name,
              flex: 1,
            },
          ]}
          pageSize={10}
          density="compact"
        />
      </AccordionDetails>
    </Accordion>
  );
}

const DISPLAY_KEYS: { [key: string]: string } = {
  temperature: 'Temperature',
  batchSize: 'Batch Size',
  learningRate: 'Learning Rate',
  maxTokens: 'Max Tokens',
  stop: 'Stop Sequence',
  suffix: 'Suffix',
  topP: 'Top P',
  logprobs: 'Log Probs',
  logitBias: 'Logit Bias',
  taxonomyEnabled: 'Taxonomy Enabled',
};

function EvaluationButton(): JSX.Element {
  const { goToEvaluationPage } = useLLMNavigation();

  const handleClick = useCallback(() => {
    goToEvaluationPage();
  }, [goToEvaluationPage]);

  return (
    <FlexBox sx={{ justifyContent: 'center' }}>
      <Button variant="contained" onClick={handleClick}>
        Run evaluation from the Evaluations page to generate stats
      </Button>
    </FlexBox>
  );
}

// HACK: To handle heterogenous comparisons, just cast everything to strings
type ModelParameterKey = string;
type GenericModelParameters = { [key: ModelParameterKey]: string };

export function ComparePage() {
  const { settings, setSettings } = useSettingsStore();

  const {
    isModalOpen,
    setIsModalOpen,
    showIdentical,
    setShowIdentical,
    dataId,
    setDataId,
    variantLeftId,
    setVariantLeftId,
    variantRightId,
    setVariantRightId,
  } = useComparePageState();

  const { selectedApp, selectedAppVariants } = useSelectionStore(
    R.pick(['selectedApp', 'selectedAppVariants']),
  );
  const { evaluationById } = useEvaluationStore(R.pick(['evaluationById']));
  const { variantById } = useVariantStore(R.pick(['variantById']));

  const [selectedEvaluationType, setSelectedEvaluationType] = useState<EvaluationType>(
    EvaluationType.ClassificationEvaluation,
  );

  const variantLeft = variantLeftId ? variantById[variantLeftId] : undefined;
  const variantRight = variantRightId ? variantById[variantRightId] : undefined;
  const variantsAvailable = !_.isNil(variantLeft) && !_.isNil(variantRight);

  const canCompare = selectedAppVariants && selectedAppVariants.length >= 2;
  useEffect(() => {
    if (canCompare && selectedAppVariants && (!variantLeftId || !variantRightId)) {
      setVariantLeftId(selectedAppVariants[0].id);
      setVariantRightId(selectedAppVariants[1].id);
    }
  }, [canCompare, selectedAppVariants, setVariantLeftId, setVariantRightId]);

  const promptChanged = useMemo(() => {
    return variantsAvailable && variantLeft.prompt.template !== variantRight.prompt.template;
  }, [variantsAvailable, variantLeft, variantRight]);

  const leftSys = variantLeft?.prompt.systemMessage;
  const rightSys = variantRight?.prompt.systemMessage;
  const systemMessageChanged = useMemo(() => {
    return variantsAvailable && leftSys !== rightSys;
  }, [variantsAvailable, leftSys, rightSys]);

  const templateVarsChanged = useMemo(() => {
    return (
      variantsAvailable &&
      !_.isEqual(variantLeft.prompt.exampleVariables, variantRight.prompt.exampleVariables)
    );
  }, [variantsAvailable, variantLeft, variantRight]);

  const hyperparameterPairs = useMemo(() => {
    if (!variantsAvailable) {
      return [];
    }
    const hyperparameterDiffs = _.compact(
      Object.entries(DISPLAY_KEYS).map(([paramKey, displayName]) => {
        const key = paramKey as ModelParameterKey;
        const leftParams = variantLeft.modelParameters as unknown as GenericModelParameters;
        const rightParams = variantRight.modelParameters as unknown as GenericModelParameters;
        const left = leftParams[key];
        const right = rightParams[key];
        if (left === undefined && right === undefined) {
          return undefined;
        }
        const isChanged = left !== right;
        return {
          key: `${key}-${variantLeft.id}-${variantRight.id}`,
          displayName,
          left,
          right,
          isChanged,
        };
      }),
    );
    const modelDiff = {
      key: `Model-${variantLeft.id}-${variantRight.id}`,
      displayName: 'Model',
      left: variantLeft.modelParameters.modelId,
      right: variantRight.modelParameters.modelId,
      isChanged: variantLeft.modelParameters.modelId !== variantRight.modelParameters.modelId,
    };
    return [modelDiff, ...hyperparameterDiffs];
  }, [variantsAvailable, variantLeft, variantRight]);

  const evaluationByDataByTypeByVariant = useMemo(() => {
    const evaluationByVariantId = _.groupBy(Object.values(evaluationById), e => e.variantId);
    return _.mapValues(evaluationByVariantId, evals => {
      const byType = _.groupBy(evals, e => e.type);
      return _.mapValues(byType, evals => {
        return _.keyBy(
          _.sortBy(evals, e => e.createdAt),
          e => e.inputDataId,
        );
      });
    });
  }, [evaluationById]);

  const evaluationLeft = useMemo(() => {
    if (!variantLeftId || !dataId) return;
    return evaluationByDataByTypeByVariant[variantLeftId]?.[selectedEvaluationType]?.[dataId];
  }, [variantLeftId, selectedEvaluationType, dataId]);

  const evaluationRight = useMemo(() => {
    if (!variantRightId || !dataId) return;
    return evaluationByDataByTypeByVariant[variantRightId]?.[selectedEvaluationType]?.[dataId];
  }, [variantRightId, selectedEvaluationType, dataId]);

  const renderTableBody = useCallback(() => {
    if (!variantLeftId || !variantRightId) {
      return <TableRow>You must select 2 variants to compare!</TableRow>;
    }

    switch (selectedEvaluationType) {
      case EvaluationType.ClassificationEvaluation: {
        return statsTable(evaluationLeft, evaluationRight);
      }
    }
  }, [variantLeftId, variantRightId, selectedEvaluationType, evaluationLeft, evaluationRight]);

  const disabledLeftVariants = useMemo(() => {
    return [{ variantId: variantRightId, reason: 'Variant is used on the right' }];
  }, [variantRightId]);
  const disabledRightVariants = useMemo(() => {
    return [{ variantId: variantLeftId, reason: 'Variant is used on the left' }];
  }, [variantLeftId]);

  if (!canCompare) {
    return (
      <PageContainer page="app-compare">
        <PageTitle title="Compare" />
        <Paper>
          <Alert severity="error">
            You will need to create at least two variants for comparison.
          </Alert>
        </Paper>
      </PageContainer>
    );
  }
  if (!variantLeft || !variantRight) {
    // This should only happen in a brief period before the app variants are
    // fully loaded.
    return <LoadingPage />;
  }

  return (
    <PageContainer page="app-compare">
      <PageTitle title="Compare" />
      <Container sx={{ flexDirection: 'row', alignItems: 'center' }}>
        <Typography variant="h2">Comparing</Typography>
        <SingleVariantSelector
          value={variantLeftId}
          onChange={setVariantLeftId}
          disabledVariants={disabledLeftVariants}
        />
        <Typography variant="h2">to</Typography>
        <SingleVariantSelector
          value={variantRightId}
          onChange={setVariantRightId}
          disabledVariants={disabledRightVariants}
        />
      </Container>
      <VSpace s={2} />
      <Container>
        <Box>
          <Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}>
            <Typography variant="h3">Variant Differences</Typography>
            <Box sx={{ display: 'flex', alignItems: 'center', gap: '8px' }}>
              <Box sx={{ display: 'flex', alignItems: 'center' }}>
                <Typography variant="h4">Show Identical Fields</Typography>
                <Switch
                  size="small"
                  checked={showIdentical}
                  onClick={() => setShowIdentical(!showIdentical)}
                  sx={{
                    '& .Mui-checked .MuiSwitch-switchBase': {
                      color: theme.palette.primary.dark,
                    },
                    '& .MuiSwitch-switchBase': {
                      color: theme.palette.primary.light,
                    },
                  }}
                />
              </Box>
            </Box>
          </Box>
          <VSpace s={2} />
          <TableContainer sx={{ borderRadius: 1 }}>
            <Table
              size="small"
              sx={{
                borderSpacing: 0,
                [`& .${tableCellClasses.body}, .${tableCellClasses.head}`]: {
                  border: `1px solid ${Colors.CoolGray20}`,
                },
              }}
            >
              <TableHead>
                <TableRow>
                  <DarkTableCell width="16%">Parameters</DarkTableCell>
                  <DarkTableCell width="42%">
                    <VariantName name={variantLeft.name} />
                  </DarkTableCell>
                  <DarkTableCell width="42%">
                    <VariantName name={variantRight.name} />
                  </DarkTableCell>
                </TableRow>
              </TableHead>
              <TableBody>
                {(showIdentical || promptChanged) && (
                  <TableRow>
                    <DarkTableCell>Prompt</DarkTableCell>
                    <TableCell>
                      <PromptOneSidedDiff
                        text={variantLeft.prompt.template}
                        other={variantRight.prompt.template}
                      />
                    </TableCell>
                    <TableCell>
                      <PromptOneSidedDiff
                        text={variantRight.prompt.template}
                        other={variantLeft.prompt.template}
                      />
                    </TableCell>
                  </TableRow>
                )}

                {(showIdentical || systemMessageChanged) && (
                  <TableRow>
                    <DarkTableCell>System Message</DarkTableCell>
                    <TableCell>
                      {leftSys ? (
                        <PromptOneSidedDiff text={leftSys} other={rightSys ? rightSys : ''} />
                      ) : (
                        'N/A'
                      )}
                    </TableCell>
                    <TableCell>
                      {rightSys ? (
                        <PromptOneSidedDiff text={rightSys} other={leftSys ? leftSys : ''} />
                      ) : (
                        'N/A'
                      )}
                    </TableCell>
                  </TableRow>
                )}

                {hyperparameterPairs.map(pair => {
                  if (!showIdentical && !pair.isChanged) {
                    return null;
                  } else if (pair.isChanged) {
                    return (
                      <TableRow key={pair.key}>
                        <DarkTableCell>{pair.displayName}</DarkTableCell>
                        <TableCell>
                          {pair.left === undefined ? (
                            <UndefinedValue />
                          ) : (
                            <AddedText>{pair.left}</AddedText>
                          )}
                        </TableCell>
                        <TableCell>
                          {pair.right === undefined ? (
                            <UndefinedValue />
                          ) : (
                            <AddedText>{pair.right}</AddedText>
                          )}{' '}
                        </TableCell>
                      </TableRow>
                    );
                  } else {
                    return (
                      <TableRow key={pair.key}>
                        <DarkTableCell>{pair.displayName}</DarkTableCell>
                        <TableCell>{pair.left}</TableCell>
                        <TableCell>{pair.right}</TableCell>
                      </TableRow>
                    );
                  }
                })}
                {(showIdentical || templateVarsChanged) && (
                  <TableRow>
                    <DarkTableCell>Variables</DarkTableCell>
                    <TableCell colSpan={2}>
                      <FlexBox p={1} sx={{ justifyContent: 'center' }}>
                        <PromptVariableDiffView
                          templateVarsChanged={templateVarsChanged}
                          variantLeft={variantLeft}
                          variantRight={variantRight}
                        />
                      </FlexBox>
                    </TableCell>
                  </TableRow>
                )}
              </TableBody>
            </Table>
          </TableContainer>
        </Box>
      </Container>

      <VSpace s={2} />
      <Container>
        <FlexBox>
          <Typography variant="h3">Evaluate</Typography>
          <Select
            size="small"
            sx={{ width: 200 }}
            value={selectedEvaluationType}
            onChange={e => setSelectedEvaluationType(e.target.value as EvaluationType)}
          >
            {Object.values(EvaluationType)
              .filter(
                (evaluationType): evaluationType is Exclude<EvaluationType, 'AIFeedback'> =>
                  evaluationType !== EvaluationType.AIFeedback,
              )
              .map(method => (
                <MenuItem key={method} value={method}>
                  {EVALUATION_ACCURACY_DESCRIPTION_TEXT[method]}
                </MenuItem>
              ))}
          </Select>
          <Typography variant="h3">against</Typography>
          <Box sx={{ maxWidth: 400 }}>
            <DataSelectorAndUploader
              dataId={dataId}
              setDataId={dataId => setDataId(dataId ?? undefined)}
            />
          </Box>
          {evaluationLeft && evaluationRight && (
            <>
              <HSpace />
              <FlexBox sx={{ marginLeft: 'auto' }}>
                <DownloadComparisonResultsButton
                  evaluationIds={[evaluationLeft.id, evaluationRight.id]}
                />
              </FlexBox>
            </>
          )}
        </FlexBox>
        {dataId && (
          <TableContainer sx={{ borderRadius: 1 }}>
            <Table
              size="small"
              sx={{
                borderSpacing: 0,
                [`& .${tableCellClasses.body}, .${tableCellClasses.head}`]: {
                  border: `1px solid ${Colors.CoolGray20}`,
                },
              }}
            >
              <TableHead>
                <TableRow>
                  <DarkTableCell width="160">Evaluation</DarkTableCell>
                  <DarkTableCell>
                    <VariantName name={variantLeft.name} />
                  </DarkTableCell>
                  <DarkTableCell>
                    <VariantName name={variantRight.name} />
                  </DarkTableCell>
                </TableRow>
              </TableHead>
              <TableBody>{renderTableBody()}</TableBody>
            </Table>
          </TableContainer>
        )}
        <Box>
          <MultiInputPreview dataId={dataId} variants={[variantLeft, variantRight]} />
        </Box>
      </Container>
      <Box p={2} />
    </PageContainer>
  );
}

function StatsTableBody({
  stats,
  cellsRows,
}: {
  stats: JSX.Element[];
  cellsRows: (JSX.Element | undefined)[][];
}): JSX.Element {
  return (
    <>
      {stats.map((stat, i) => {
        return (
          <TableRow key={i}>
            <>
              <DarkTableCell>{stat}</DarkTableCell>
              {cellsRows[i]}
            </>
          </TableRow>
        );
      })}
    </>
  );
}

function TooltipText({
  text,
  variant = 'body1',
  tooltip,
}: {
  text: string;
  variant?: string;
  tooltip: string;
}): JSX.Element {
  return (
    <Tooltip title={tooltip} disableInteractive>
      <Typography variant={variant as any}>
        {text} <FontAwesomeIcon icon="circle-info" color={Colors.CoolGray40} />
      </Typography>
    </Tooltip>
  );
}

const STAT_DEF_BY_NAME = {
  Accuracy: {
    tooltip: 'Number of correct predictions divided by number of total predictions.',
    calculate: (evaluation: CompletedClassificationEvaluation) =>
      toPercent(evaluation.stats.microAccuracy, 2),
  },
  'Macro Precision': {
    tooltip:
      'Number of true positive predicted divided by the number of true positive and false positive classes averaged by per-class',
    calculate: (evaluation: CompletedClassificationEvaluation) =>
      toPercent(valuesMean(evaluation.stats.precisionByClass), 2),
  },
  'Macro Recall': {
    tooltip:
      'Number of true positive predicted divided by the number of true positive and false negative classes averaged by per-class',
    calculate: (evaluation: CompletedClassificationEvaluation) =>
      toPercent(valuesMean(evaluation.stats.recallByClass), 2),
  },
  'Macro F1': {
    tooltip:
      'The average of each per-class F1 score. The per-class F1 score is the harmonic mean of precision and recall for that class.',
    calculate: (evaluation: CompletedClassificationEvaluation) =>
      toPercent(valuesMean(evaluation.stats.f1ScoreByClass), 2),
  },
};

const NUM_STAT_ROWS = Object.keys(STAT_DEF_BY_NAME).length;

function statsTable(evaluationLeft?: StoredEvaluation, evaluationRight?: StoredEvaluation) {
  const stats = Object.entries(STAT_DEF_BY_NAME).map(([key, statDef]) => {
    return <TooltipText key={key} text={key} tooltip={statDef.tooltip} />;
  });

  if (!evaluationLeft && !evaluationRight) {
    return (
      <StatsTableBody
        stats={stats}
        cellsRows={[
          [
            <TableCell key="eval-button" colSpan={2} rowSpan={NUM_STAT_ROWS}>
              <EvaluationButton />
            </TableCell>,
          ],
        ]}
      />
    );
  }
  const leftEvalCol = statsColumn(evaluationLeft);
  const rightEvalCol = statsColumn(evaluationRight);
  return <StatsTableBody stats={stats} cellsRows={_.zip(leftEvalCol, rightEvalCol)} />;
}

function statsColumn(evaluation?: StoredEvaluation) {
  const statEntries = Object.entries(STAT_DEF_BY_NAME);
  // TODO handle evaluation error
  if (!evaluation || evaluation.status === 'Errored') {
    return padArray(
      [
        <TableCell key={`eval-button-${evaluation?.id}`} colSpan={1} rowSpan={NUM_STAT_ROWS}>
          <EvaluationButton />
        </TableCell>,
      ],
      undefined,
      statEntries.length,
    );
  }
  if (isCompletedEvaluation(evaluation)) {
    if (isClassificationEvaluation(evaluation)) {
      return statEntries.map(([key, stat]) => {
        return <TableCell key={`${key}-${evaluation.id}`}>{stat.calculate(evaluation)}</TableCell>;
      });
    }
    return padArray([], undefined, statEntries.length);
  }
  return padArray(
    [
      <TableCell key={`eval-processing-${evaluation.id}`} colSpan={1} rowSpan={NUM_STAT_ROWS}>
        <Typography variant="body2">
          Evaluation is still processing.
          <br />
          Check the Evaluation page for more details.
        </Typography>
      </TableCell>,
    ],
    undefined,
    statEntries.length,
  );
}

function valuesMean(object: Record<string, number>): number {
  return round(_.mean(Object.values(object)), 4);
}

function isCompletedEvaluation(evaluation: Evaluation): evaluation is CompletedEvaluation {
  return evaluation.status === 'Completed';
}

function isClassificationEvaluation(
  evaluation: Evaluation,
): evaluation is CompletedClassificationEvaluation {
  return evaluation.type === 'ClassificationEvaluation';
}
