import _ from 'lodash';
import * as R from 'ramda';
import create, { StoreApi } from 'zustand';
import { subscribeWithSelector } from 'zustand/middleware';

import { DataRow } from '@scale/llm-shared/interfaces/data';
import { ModelParameters, ModelParametersRun } from '@scale/llm-shared/interfaces/modelParameters';
import { Prompt, PromptRun } from '@scale/llm-shared/interfaces/prompt';
import { Variant, VariantCreate, VariantRun } from '@scale/llm-shared/interfaces/variant';
import { OpenAiMultiTurnBaseModelInternalId } from '@scale/llm-shared/modelProviders/openAi';
import { TokenProbs } from '@scale/llm-shared/types';

import { client } from 'frontend/api/trpc';
import { useDataStore } from 'frontend/storesV2/DataStore';
import {
  StoredData,
  StoredModelParameters,
  StoredPrompt,
  StoredVariant,
} from 'frontend/storesV2/types';
import { PairedRowStatus } from 'frontend/types/types';
import { track } from 'frontend/utils/analytics';
import { singleIndexById } from 'frontend/utils/object';

class VariantStore {
  variantById: Record<string, StoredVariant> = {};
  variantPromiseById: Record<string, Promise<StoredVariant>> = {};

  constructor(
    public set: StoreApi<VariantStore>['setState'],
    public get: StoreApi<VariantStore>['getState'],
  ) {}

  setVariantById = (byId: Record<string, StoredVariant>) => this.set(R.set(indexLens, byId));
  setVariantPromiseById = (byId: Record<string, Promise<StoredVariant>>) =>
    this.set(R.set(promiseIndexLens, byId));
  appendVariantById = (byId: Record<string, StoredVariant>) =>
    this.set(R.over(indexLens, R.mergeLeft(byId)));
  appendVariantPromiseById = (byId: Record<string, Promise<StoredVariant>>) =>
    this.set(R.over(promiseIndexLens, R.mergeLeft(byId)));

  createVariant = (variantCreate: VariantCreate) =>
    client
      // Create Variant
      .mutation('v2.variant.create', variantCreate)
      // Get Prompt and ModelParameters
      .then(VariantStore.hydrateVariant)
      // Construct StoredVariant
      .then(VariantStore.constructStoredVariant)
      .then(
        R.tap(
          R.pipe(
            // Prepare to be stored
            singleIndexById,
            // Store in variantById
            this.appendVariantById,
          ),
        ),
      )
      .then(() => {
        track('Prompt Variant Created', { name: variantCreate.name });
      });

  getVariant = (id: string) => {
    const { variantById, variantPromiseById } = this.get();
    // Get variant from cache
    const variant = variantById[id];
    if (variant) return Promise.resolve(variant);
    // Get variant from loading
    const variantPromise = variantPromiseById[id];
    if (variantPromise) return variantPromise;
    return (
      client
        .query('v2.variant.findUnique', { id })
        // Get Prompt and ModelParameters
        .then(VariantStore.hydrateVariant)
        // Construct StoredVariant
        .then(VariantStore.constructStoredVariant)
        // Store StoredVariant
        .then(
          R.tap(
            R.pipe(
              // Prepare to be stored
              singleIndexById,
              // Store in variantById
              this.appendVariantById,
            ),
          ),
        )
    );
  };

  getVariantsByAppId = (appId: string) => {
    const { variantById } = this.get();
    return R.filter(R.propEq('appId', appId), R.values(variantById));
  };

  findAllVariants = () => {
    return (
      client
        .query('v2.variant.findAll', undefined, { context: { skipBatch: true } })
        // Get Prompt and ModelParameters
        // TODO make curried promise map
        .then(variants => Promise.all(R.map(VariantStore.hydrateVariant, variants)))
        // Construct StoredVariant
        .then(R.map(VariantStore.constructStoredVariant))
        // Store StoredVariant
        .then(
          R.tap(
            R.pipe(
              // Prepare to be stored
              R.indexBy(R.prop('id')),
              // Store in variantById
              this.appendVariantById,
            ),
          ),
        )
    );
  };

  runVariant = (id: string, inputs: DataRow[]) => {
    return client.mutation('v2.variant.runVariant', { id, inputs });
  };

  run = async (
    variant: VariantRun,
    prompt: PromptRun,
    modelParameters: ModelParametersRun,
    inputs: DataRow[],
  ): Promise<
    {
      output: string;
      tokenProbs: TokenProbs | null;
      finishReason: string | null;
      errorReason: string | null;
      status: PairedRowStatus;
    }[]
  > => {
    if (modelParameters.modelId === OpenAiMultiTurnBaseModelInternalId.GPT4) {
      return await Promise.all(
        inputs.map(async (input, i) => {
          try {
            const runResult = await client.mutation(
              'v2.variant.run',
              { variant, prompt, modelParameters, inputs: [input] },
              { context: { skipBatch: true } },
            );
            return {
              output: runResult.outputs[0],
              tokenProbs: runResult.tokenProbs[0],
              finishReason: runResult.finishReasons[0],
              errorReason: runResult.errorReasons[0],
              status: PairedRowStatus.COMPLETE,
            };
          } catch (err) {
            return {
              output: '',
              tokenProbs: null,
              finishReason: null,
              errorReason: null,
              status: PairedRowStatus.FAILED,
            };
          }
        }),
      );
    } else {
      const runResults = await client.mutation('v2.variant.run', {
        variant,
        prompt,
        modelParameters,
        inputs,
      });
      return runResults.outputs.map((_o, i) => {
        return {
          output: runResults.outputs[i],
          tokenProbs: runResults.tokenProbs[i],
          finishReason: runResults.finishReasons[i],
          errorReason: runResults.errorReasons[i],
          status:
            runResults.errorReasons[i] === null ? PairedRowStatus.COMPLETE : PairedRowStatus.FAILED,
        };
      });
    }
  };

  static hydrateVariant = (variant: Variant) => {
    return Promise.all([
      variant,
      // Get Prompt
      client.query('v2.variant.findPrompt', { variantId: variant.id }),
      // Get ModelParameters
      client.query('v2.variant.findModelParameters', { variantId: variant.id }),
    ]).then(([variant, prompt, modelParameters]) =>
      // Get Data if Prompt uses one
      Promise.all([
        variant,
        prompt,
        modelParameters,
        prompt.variablesSourceDataId
          ? useDataStore.getState().getData(prompt.variablesSourceDataId)
          : undefined,
      ]),
    );
  };

  static constructStoredVariant = ([variant, prompt, modelParameters, data]: [
    Variant,
    Prompt,
    ModelParameters,
    StoredData | undefined,
  ]) => {
    return {
      ...variant,
      prompt: {
        ...prompt,
        variablesSourceData: data,
      },
      modelParameters,
    };
  };
}

const indexKey = 'variantById';
const indexLens = R.lensProp<VariantStore, typeof indexKey>(indexKey);

const promiseIndexKey = 'variantPromiseById';
const promiseIndexLens = R.lensProp<VariantStore, typeof promiseIndexKey>(promiseIndexKey);

export const useVariantStore = create<VariantStore>()(
  subscribeWithSelector((set, get) => new VariantStore(set, get)),
);

(window as any).useVariantStoreV2 = useVariantStore;
