import GPT3Tokenizer from 'gpt3-tokenizer';
import _ from 'lodash';
import * as R from 'ramda';

import { Model } from '@prisma/client';
import { isOpenAiBaseModelInternalId } from '@scale/llm-shared/modelProviders/openAi';
import { isScaleLaunchBaseModelInternalId } from '@scale/llm-shared/modelProviders/scaleLaunch';
import { BaseModelInternalId } from '@scale/llm-shared/modelProviders/types';
import {
  FINANCIAL_EVENT_PRICE,
  SpellbookFinancialEventSubType,
} from '@scale/llm-shared/services/billing/billingTypes';
import { parseQueryVariables } from '@scale/llm-shared/templating';
import { render } from '@scale/llm-shared/templating/parser';
import { renderPromptWithQuery } from '@scale/llm-shared/templating/promptsV2';
import { DEFAULT_TOKENIZER_TYPE } from '@scale/llm-shared/types/defaults';

import { StoredDataset } from 'frontend/stores/types';
import { StoredData, StoredDataColumn, StoredVariant } from 'frontend/storesV2/types';
import { Cryb128Generator } from 'frontend/utils/random';

export const tokenizer = new GPT3Tokenizer({ type: DEFAULT_TOKENIZER_TYPE });

export function finetuneRenderPromptWithQuery(prompt: string, query: Record<string, string>) {
  const exampleVars = parseQueryVariables(prompt);
  const exampleVarsWithQuery: Record<string, string> = {};
  for (const exampleVar of exampleVars) {
    if (exampleVar in query) {
      exampleVarsWithQuery[exampleVar] = query[exampleVar];
    }
  }
  const fullPrompt = render(prompt, exampleVarsWithQuery);
  return fullPrompt;
}

export function fineTuningDollarsCostEstimateV2(
  dataset: StoredData | undefined,
  model: BaseModelInternalId,
  prompt: string,
  outputColumn: StoredDataColumn | undefined,
): number | undefined {
  // Deterministic calculations
  const pseudoRandom = new Cryb128Generator(1001);
  if (!dataset || !outputColumn) {
    return;
  }
  const dollarsPerMillionTokens =
    FINANCIAL_EVENT_PRICE.FineTuningModelTraining?.[model as SpellbookFinancialEventSubType];
  if (!dollarsPerMillionTokens) {
    return;
  }

  // Open AI models
  if (isOpenAiBaseModelInternalId(model)) {
    const rows = Object.values(dataset.rowByIndex);
    const sampleRows = pseudoRandom.sampleSize(rows, 100);

    const sampleDollarsCost = sampleRows.reduce((count, row) => {
      return (
        tokenizer.encode(finetuneRenderPromptWithQuery(prompt, row.valueByName)).text.length +
        // Output column should only be a list of length 1
        (row.valueByName[outputColumn.name[0]]
          ? tokenizer.encode(row.valueByName[outputColumn.name[0]]).text.length
          : 0) +
        count
      );
    }, 0);
    return (
      (sampleDollarsCost * dollarsPerMillionTokens * (rows.length / sampleRows.length)) / 1_000_000
    );
  }

  // Scale models
  if (isScaleLaunchBaseModelInternalId(model)) {
    // TODO use T5 to tokenize
    const rows = Object.values(dataset.rowByIndex);
    const sampleRows = pseudoRandom.sampleSize(rows, 100);

    const sampleDollarsCost = sampleRows.reduce((count, row) => {
      return (
        tokenizer.encode(finetuneRenderPromptWithQuery(prompt, row.valueByName)).text.length +
        // Output column should only be a list of length 1
        (row.valueByName[outputColumn.name[0]]
          ? tokenizer.encode(row.valueByName[outputColumn.name[0]]).text.length
          : 0) +
        count
      );
    }, 0);
    return (
      (sampleDollarsCost * dollarsPerMillionTokens * (rows.length / sampleRows.length)) / 1_000_000
    );
  }
}

export function evaluationDollarsCostEstimate(
  data: StoredData,
  outputColName: string,
  variants: StoredVariant[],
  modelById: Record<string, Model>,
): number | undefined {
  // Deterministic calculations
  const pseudoRandom = new Cryb128Generator(1001);
  const rows = Object.values(data.rowByIndex);
  const sampleRows = pseudoRandom.sampleSize(rows, 100);
  const inverseSampleRatio = rows.length / sampleRows.length;

  const models = variants.map(variant => {
    const modelId = variant.modelParameters.modelId;
    const model: Model | undefined = modelById[modelId];
    return model;
  });
  if (models.some(m => m == null)) {
    // Model not loaded, will eventually load
    return;
  }
  const dollars = R.zip(variants, models).map(([variant, model]) => {
    const baseModelId = model.baseModelId as BaseModelInternalId;
    switch (baseModelId) {
      case 'GPT4':
      case 'GPT4_32K': {
        const promptDollarsPerMTokens =
          FINANCIAL_EVENT_PRICE.BaseModelQuerying?.[
            (baseModelId + '_Prompt') as SpellbookFinancialEventSubType
          ];
        const completionDollarsPerMTokens =
          FINANCIAL_EVENT_PRICE.BaseModelQuerying?.[
            (baseModelId + '_Completion') as SpellbookFinancialEventSubType
          ];
        // No set price, assume 0. Shouldn't happen
        if (!promptDollarsPerMTokens || !completionDollarsPerMTokens) {
          return 0;
        }

        const sampleDollarsCost = sampleRows.reduce((dollars, row) => {
          const promptDollarsCost =
            (countTokens(renderPromptWithQuery(variant.prompt, row.valueByName)) *
              promptDollarsPerMTokens) /
            1_000_000;
          const completionDollarsCost =
            (countTokens(row.valueByName[outputColName]) * completionDollarsPerMTokens) / 1_000_000;
          return promptDollarsCost + completionDollarsCost + dollars;
        }, 0);

        const totalDollarsCost = sampleDollarsCost * inverseSampleRatio;
        return totalDollarsCost;
      }
      default: {
        const dollarsPerMTokens =
          FINANCIAL_EVENT_PRICE.BaseModelQuerying?.[baseModelId as SpellbookFinancialEventSubType];
        // No set price, assume 0. Shouldn't happen
        if (!dollarsPerMTokens) {
          return 0;
        }

        const sampleTokenCount = sampleRows.reduce((count, row) => {
          const promptTokens = countTokens(renderPromptWithQuery(variant.prompt, row.valueByName));
          return promptTokens + countTokens(row.valueByName[outputColName]) + count;
        }, 0);
        const sampleDollarsCost = (sampleTokenCount * dollarsPerMTokens) / 1_000_000;

        const totalDollarsCost = sampleDollarsCost * inverseSampleRatio;
        return totalDollarsCost;
      }
    }
  });
  return _.sum(dollars);
}

function countTokens(str?: string) {
  return str ? tokenizer.encode(str).text.length : 0;
}
