import _ from 'lodash';

// The default prompts for all the PromptPage versions
// Prefer this to be here as opposed to shared/ so that we don't have to rebuild
// the shared folder when we edit the prompts.
import { ModelHost, TaskType } from '@prisma/client';
import { ModelParametersCreate } from '@scale/llm-shared/interfaces/modelParameters';
import { Prompt, PromptCreate } from '@scale/llm-shared/interfaces/prompt';
import { VariantCreate } from '@scale/llm-shared/interfaces/variant';
import { DEFAULT_MODEL_HOST, DEFAULT_MODEL_ID } from 'frontend/consts/defaults';
import { neverfail } from '@scale/llm-shared/utils/type';

import { insertBefore } from 'frontend/utils/string';

export const CLASSIFICATION_INPUT = 'Input';
export const CLASSIFICATION_OUTPUT = 'Class';
export const CLASSIFICATION_INPUT_COL = 'input';
export const DEFAULT_CLASSIFICATION_PROMPT = `
Classify each input.

Use this format:

${CLASSIFICATION_INPUT}: <input>
${CLASSIFICATION_OUTPUT}: <class label>

Begin:

${CLASSIFICATION_INPUT}: {{ ${CLASSIFICATION_INPUT.toLowerCase()} }}
${CLASSIFICATION_OUTPUT}:
`.trim();

export const AUTOCOMPLETION_INPUT = 'Input';
export const AUTOCOMPLETION_OUTPUT = 'Completion';
export const DEFAULT_AUTOCOMPLETION_PROMPT = `
Output a continuation of the given input text without repeating it.

Use this format:

${AUTOCOMPLETION_INPUT}: <input>
${AUTOCOMPLETION_OUTPUT}: <continuation of input>

Begin:

${AUTOCOMPLETION_INPUT}: {{ ${AUTOCOMPLETION_INPUT.toLowerCase()} }}
${AUTOCOMPLETION_OUTPUT}:
`.trim();

export const SUMMARIZATION_INPUT = 'Text';
export const SUMMARIZATION_OUTPUT = 'Summary';
export const DEFAULT_SUMMARIZATION_PROMPT = `
Summarize the input text.

Use this format:

${SUMMARIZATION_INPUT}: <text>
${SUMMARIZATION_OUTPUT}: <one-sentence summary>

Begin:

${SUMMARIZATION_INPUT}: {{ ${SUMMARIZATION_INPUT.toLowerCase()} }}
${SUMMARIZATION_OUTPUT}:
`.trim();

export const GENERATION_INPUT = 'Input';
export const GENERATION_OUTPUT = 'Output';
export const DEFAULT_GENERATION_PROMPT = `
Generate text for each input.

Use this format:

${GENERATION_INPUT}: <input text>
${GENERATION_OUTPUT}: <output text>

Begin:

${GENERATION_INPUT}: {{ ${GENERATION_INPUT.toLowerCase()} }}
${GENERATION_OUTPUT}:
`.trim();

export const EXTRACTION_INPUT = 'Text';
export const EXTRACTION_OUTPUT = 'Entities';
export const DEFAULT_EXTRACTION_PROMPT = `
Extract all named entities from the text.

Use this format:

${EXTRACTION_INPUT}: <text>
${EXTRACTION_OUTPUT}: <JSON list of quoted strings>

Begin:

${EXTRACTION_INPUT}: {{ ${EXTRACTION_INPUT.toLowerCase()} }}
${EXTRACTION_OUTPUT}:
`.trim();

// TODO: when possible, use column names from dataset
export function generateFewShotStr(
  numShots: number,
  inputColNames: Maybe<string[]>,
  outputColName: Maybe<string>,
): string {
  return (
    _.range(numShots)
      .map(i => {
        const inputLines = inputColNames?.length
          ? inputColNames.map(ic => `${_.capitalize(ic)}: {{ examples[${i}].${ic} }}`).join('\n')
          : '';
        const outputLine = outputColName
          ? `${_.capitalize(outputColName)}: {{ examples[${i}].${outputColName} }}`
          : '';
        return _.compact([inputLines, outputLine]).join('\n');
      })
      .join('\n\n') + '\n'
  );
}

export function createDefaultModelParameters(taskType?: TaskType): Partial<ModelParametersCreate> {
  if (taskType === TaskType.Summarization || taskType === TaskType.Generation) {
    return {
      modelId: DEFAULT_MODEL_ID,
      modelType: DEFAULT_MODEL_HOST,
      temperature: 1,
      maxTokens: 60,
    };
  } else {
    return {
      modelId: DEFAULT_MODEL_ID,
      modelType: DEFAULT_MODEL_HOST,
      temperature: 0,
      maxTokens: 10,
    };
  }
}

export const TEMPLATE_BY_TASK_TYPE: Record<TaskType, string> = {
  [TaskType.Autocompletion]: DEFAULT_AUTOCOMPLETION_PROMPT,
  [TaskType.Summarization]: DEFAULT_SUMMARIZATION_PROMPT,
  [TaskType.Classification]: DEFAULT_CLASSIFICATION_PROMPT,
  [TaskType.Generation]: DEFAULT_GENERATION_PROMPT,
  [TaskType.Extraction]: DEFAULT_EXTRACTION_PROMPT,
} as const;

export function createDefaultPrompt({
  variablesSourceDataId,
  taskType,
}: {
  variablesSourceDataId?: string;
  taskType?: TaskType;
} = {}): Partial<PromptCreate> {
  return {
    template: taskType ? TEMPLATE_BY_TASK_TYPE[taskType] : DEFAULT_CLASSIFICATION_PROMPT,
    exampleVariables: {},
    variablesSourceDataId,
  };
}

export function createDefaultVariant({
  name,
}: {
  name?: string;
} = {}): Partial<VariantCreate> {
  return {
    name: name || 'New Variant',
  };
}

/**
 * Inserts few shots (for use with dataset) into a zero shot prompt.
 *
 * @param template The prompt template object
 * @param taskType The task type of the app
 */
export function insertTemplateVars(template: Maybe<string>, taskType: TaskType): string {
  switch (taskType) {
    case TaskType.Autocompletion: {
      const beforeStr = `${AUTOCOMPLETION_INPUT}: {{ ${AUTOCOMPLETION_INPUT.toLowerCase()} }}`;
      const insertStr = generateFewShotStr(3, [AUTOCOMPLETION_INPUT], AUTOCOMPLETION_OUTPUT);
      return insertBefore(template, beforeStr, insertStr);
    }
    case TaskType.Generation: {
      const beforeStr = `${GENERATION_INPUT}: {{ ${GENERATION_INPUT.toLowerCase()} }}`;
      const insertStr = generateFewShotStr(3, [GENERATION_INPUT], GENERATION_OUTPUT);
      return insertBefore(template, beforeStr, insertStr);
    }
    case TaskType.Summarization: {
      const beforeStr = `${SUMMARIZATION_INPUT}: {{ ${SUMMARIZATION_INPUT.toLowerCase()} }}`;
      const insertStr = generateFewShotStr(3, [SUMMARIZATION_INPUT], SUMMARIZATION_OUTPUT);
      return insertBefore(template, beforeStr, insertStr);
    }
    case TaskType.Extraction: {
      const beforeStr = `${EXTRACTION_INPUT}: {{ ${EXTRACTION_INPUT.toLowerCase()} }}`;
      const insertStr = generateFewShotStr(3, [EXTRACTION_INPUT], EXTRACTION_OUTPUT);
      return insertBefore(template, beforeStr, insertStr);
    }
    case TaskType.Classification: {
      const beforeStr = `${CLASSIFICATION_INPUT}: {{ ${CLASSIFICATION_INPUT.toLowerCase()} }}`;
      const insertStr = generateFewShotStr(3, [CLASSIFICATION_INPUT], CLASSIFICATION_OUTPUT);
      return insertBefore(template, beforeStr, insertStr);
    }
    default:
      neverfail(taskType, `Unknown taskType ${taskType}`);
  }
}
