import { useCallback, useEffect, useMemo, useState } from 'react';

import { capitalize, head } from 'lodash';
import * as R from 'ramda';

import { FontAwesomeIcon } from '@fortawesome/react-fontawesome';
import { Box, Button, MenuItem, Select, Tooltip, Typography } from '@mui/material';
import { Model } from '@prisma/client';
import { FrontendFeatureFlag } from '@scale/llm-shared/featureFlags';
import { BaseModelInternalId } from '@scale/llm-shared/modelProviders/types';
import { isBaseModelId } from '@scale/llm-shared/modelProviders/utils';
import { OpenAiFineTunedModel } from '@scale/llm-shared/types/defaults';

import { FEATURE_FLAG } from 'frontend/consts/featureFlags';
import { useUnleashFlags } from 'frontend/hooks/useUnleashFlags';
import { useModelStore } from 'frontend/stores/ModelStore';
import { Colors } from 'frontend/theme';

// Helper function to capitalize name of fine-tuned models
// eg: turns "davinici" --> "GPT-3 Davinci"
function formatFineTuneModelName(model: Model): string {
  // A hacky way to get the right fine-tunable model name, this format could change for future models
  const model_name_str: string = model.name as string;
  const model_type: string = head(model_name_str.trim().split(' ')) as string;

  if (model_type == 'GPT-3') {
    return model_type + ' ' + capitalize(model.fineTuneEndpoint as string);
  }
  return capitalize(model.fineTuneEndpoint as string);
}

export default function ModelSelector({
  selectedModelId,
  baseModelsOnly,
  onChange,
  // Select props
  label,
  labelId,
  hiddenModelIds,
  fullWidth = true,
  displayEmpty = true,
  size = 'small',
}: {
  selectedModelId?: string;
  // baseModelsOnly?: boolean;
  // onChange: (value: T) => void;
  // Select props
  label?: string;
  labelId?: string;
  hiddenModelIds?: string[];
  fullWidth?: boolean;
  displayEmpty?: boolean;
  size?: 'small' | 'medium';
} & (
  | { baseModelsOnly?: false; onChange: (value: BaseModelInternalId) => void }
  | { baseModelsOnly: true; onChange: (value: BaseModelInternalId) => void }
)) {
  const { modelById } = useModelStore(R.pick(['modelById']));
  const { checkUnleashFlag } = useUnleashFlags();
  const showScaleOpenAiFineTunedModels = !checkUnleashFlag(
    FrontendFeatureFlag.AllowCustomOpenAiKey,
  );

  const filterByOpenAiFineTunedModels = useCallback(
    (model: Model) => {
      if (!model.baseModelId) {
        return true;
      }
      if (!showScaleOpenAiFineTunedModels) {
        const isPrebakedFineTunedModel = model.baseModelId in OpenAiFineTunedModel;
        return !isPrebakedFineTunedModel;
      }
      return true;
    },
    [showScaleOpenAiFineTunedModels],
  );

  const allowScaleModelFinetuning = checkUnleashFlag(FrontendFeatureFlag.AllowScaleModelFinetuning);
  if (allowScaleModelFinetuning) {
    FEATURE_FLAG.filterFineTuningModels.filteredIds.add('Flan_XXL');
  } else {
    FEATURE_FLAG.filterFineTuningModels.filteredIds.delete('Flan_XXL');
  }

  const manuallyFilterModels = useCallback(
    (model: Model) => !hiddenModelIds || !hiddenModelIds.includes(model.id),
    [hiddenModelIds],
  );

  // Featured models first, fine-tuned second, all others last
  function modelOrder(m: Model) {
    if (m.order) {
      return m.order;
    }
    if (m.id !== m.baseModelId) {
      return Number.MAX_SAFE_INTEGER - 1;
    }
    return Number.MAX_SAFE_INTEGER;
  }

  const models = useMemo(() => {
    const modelByIdReducer = baseModelsOnly
      ? R.pipe(
          R.pickBy<Model>(
            m => isBaseModelId(m.id) && filterByOpenAiFineTunedModels(m) && manuallyFilterModels(m),
          ),
          FEATURE_FLAG.filterFineTuningModels.omit,
          R.values,
          R.sortBy<Model>(R.prop('createdAt')),
          R.reverse<Model>,
          R.sortBy(modelOrder),
        )
      : R.pipe(
          R.pickBy<Model>(m => filterByOpenAiFineTunedModels(m) && manuallyFilterModels(m)),
          FEATURE_FLAG.hideModels.omit,
          R.values,
          R.sortBy<Model>(R.prop('createdAt')),
          R.reverse<Model>,
          R.sortBy(modelOrder),
        );
    return modelByIdReducer(modelById);
  }, [modelById, hiddenModelIds]);

  const featuredModels = useMemo(
    () => models.filter(m => m.featured || m.id === selectedModelId || m.id !== m.baseModelId),
    [models, selectedModelId],
  );

  useEffect(() => {
    if (models.find(m => m.id === selectedModelId)) {
      return;
    }
    const firstModelId = models[0]?.id as BaseModelInternalId | undefined;
    if (!firstModelId) {
      return;
    }
    console.warn(`Model ${selectedModelId} not among selection. Selecting ${firstModelId}`);
    onChange(firstModelId);
  }, [selectedModelId]);

  const [showAll, setShowAll] = useState<boolean>(false);

  return (
    <Select
      label={label}
      labelId={labelId}
      fullWidth={fullWidth}
      displayEmpty={displayEmpty}
      size={size}
      value={selectedModelId}
      onChange={e => onChange(e.target.value as any)}
    >
      {(showAll ? models : featuredModels).map(model => (
        <MenuItem key={model.id} value={model.id}>
          {model.id !== model.baseModelId && (
            <Tooltip title="This model is a fine-tuned model.">
              <FontAwesomeIcon
                icon="gears"
                color={Colors.CoolGray50}
                style={{ marginRight: '1ex' }}
              />
            </Tooltip>
          )}
          {/* If fine-tunable model, display finetune endpoint name rather than model name */}
          {model.fineTuneEndpoint ? formatFineTuneModelName(model) : model.name}
        </MenuItem>
      ))}
      {!showAll && (
        <Box px={1.5}>
          <Button size="small" onClick={() => setShowAll(true)}>
            Show all models
          </Button>
        </Box>
      )}
    </Select>
  );
}
