import * as R from 'ramda';

import { BaseModelInternalId } from '@scale/llm-shared/modelProviders/types';
import { GenericFineTuningPromptV2 } from '@scale/llm-shared/templating/genericPrompts';
import { ScaleFinetuningMethod } from '@scale/llm-shared/types/requests';

import { useSelectionStore } from 'frontend/storesV2/SelectionStore';
import { StoredDataColumn } from 'frontend/storesV2/types';
import { createFromStates, state } from 'frontend/utils/store';

export const Tab = {
  Create: 'Create',
  View: 'View Fine-Tuned Models',
} as const;
export type Tab = typeof Tab[keyof typeof Tab];
export const TABS = Object.entries(Tab).map(([, value]) => ({ text: value, value }));

const states = [
  state('selectedTab', Tab.Create as Tab),
  state('includePrompt', false as boolean),

  state('numEpochs', 5 as number),
  state('modelName', '' as string),
  state('model', BaseModelInternalId.GPT3TextDavinci001 as BaseModelInternalId),
  state('prompt', GenericFineTuningPromptV2),
  state('learningRateModifier', 1 as number),
  state('trainDatasetId', undefined as string | undefined),
  state('validationDatasetId', undefined as string | undefined),

  state('outputColumn', undefined as StoredDataColumn | undefined),
  state('stopSequence', '' as string),

  // Exclusively for finetuning our own models
  state('scaleFinetuningMethod', ScaleFinetuningMethod.IA3 as ScaleFinetuningMethod),
];

type Store = UnionToIntersection<ReturnType<typeof states[number]>>;

export const useFineTuningPageState = createFromStates<Store>(states);

function resetFineTuningPageState(state: Store) {
  const {
    setSelectedTab,
    setIncludePrompt,
    setNumEpochs,
    setModel,
    setModelName,
    setPrompt,
    setScaleFinetuningMethod,
    setLearningRateModifier,
    setTrainDatasetId,
    setValidationDatasetId,
    setOutputColumn,
    setStopSequence,
  } = state;
  setSelectedTab(Tab.Create);
  setIncludePrompt(false);
  setNumEpochs(5);
  setModel(BaseModelInternalId.GPT3TextDavinci001);
  setModelName('');
  setPrompt(GenericFineTuningPromptV2);
  setScaleFinetuningMethod(ScaleFinetuningMethod.IA3);
  setLearningRateModifier(1);
  setTrainDatasetId(undefined);
  setValidationDatasetId(undefined);
  setOutputColumn(undefined);
  setStopSequence('');
}

/**
 * When selectedApp changes, reset page state
 */
useSelectionStore.subscribe(R.prop('selectedApp'), () => {
  resetFineTuningPageState(useFineTuningPageState.getState());
});
