import _ from 'lodash';
import * as R from 'ramda';
import create, { StoreApi } from 'zustand';
import { subscribeWithSelector } from 'zustand/middleware';

import { JobStatus } from '@prisma/client';
import {
  InferenceBatch,
  InferenceBatchCreateDataUrl,
  InferenceBatchCreateVariant,
} from '@scale/llm-shared/interfaces/inferenceBatch';

import { client } from 'frontend/api/trpc';
import { StoredInferenceBatch } from 'frontend/storesV2/types';
import { singleIndexById } from 'frontend/utils/object';

class InferenceBatchStore {
  inferenceBatchById: Record<string, StoredInferenceBatch> = {};
  inferenceBatchPromiseById: Record<string, Promise<StoredInferenceBatch>> = {};

  constructor(
    public set: StoreApi<InferenceBatchStore>['setState'],
    public get: StoreApi<InferenceBatchStore>['getState'],
  ) {}

  setInferenceBatchById = (byId: Record<string, StoredInferenceBatch>) =>
    this.set(R.set(indexLens, byId));
  setInferenceBatchPromiseById = (byId: Record<string, Promise<StoredInferenceBatch>>) =>
    this.set(R.set(promiseIndexLens, byId));
  appendInferenceBatchById = (byId: Record<string, StoredInferenceBatch>) =>
    this.set(R.over(indexLens, R.mergeLeft(byId)));
  appendInferenceBatchPromiseById = (byId: Record<string, Promise<StoredInferenceBatch>>) =>
    this.set(R.over(promiseIndexLens, R.mergeLeft(byId)));

  // Create inference batch from server
  createInferenceBatch = (inferencebatchCreate: InferenceBatchCreateVariant) => {
    return (
      client
        // Create InferenceBatch
        .mutation('v2.inference.batch.create', inferencebatchCreate)
        // Construct StoredInferenceBatch
        .then(InferenceBatchStore.constructStoredInferenceBatch)
        // Store StoredInferenceBatch
        .then(R.tap(this.storeInferenceBatch))
    );
  };

  // Add inferencebatch to store
  storeInferenceBatch: (inferencebatch: StoredInferenceBatch) => void = R.pipe(
    // Prepare to be stored
    singleIndexById,
    // Store in inferenceBatchById
    this.appendInferenceBatchById,
  );

  loadInferenceBatch = (id: string): Promise<StoredInferenceBatch> => {
    return (
      client
        // Load inference batch from server
        .query('v2.inference.batch.findById', { inferenceBatchId: id })
        // Construct StoredInferenceBatch
        .then(InferenceBatchStore.constructStoredInferenceBatch)
        // Store StoredInferenceBatch
        .then(R.tap(this.storeInferenceBatch))
    );
  };

  findInferenceBatchById = (id: string): Promise<StoredInferenceBatch> => {
    const { inferenceBatchById, inferenceBatchPromiseById } = this.get();
    // Get inferencebatch from cache
    const inferencebatch = inferenceBatchById[id];
    if (inferencebatch) return Promise.resolve(inferencebatch);
    // Get inferencebatch from loading
    const inferencebatchPromise = inferenceBatchPromiseById[id];
    if (inferencebatchPromise) return inferencebatchPromise;
    return this.loadInferenceBatch(id);
  };

  // Gets all inference batches for user
  getInferenceBatches = (): Promise<StoredInferenceBatch[]> => {
    return (
      client
        // Load inference batches from server
        .query('v2.inference.batch.findMany')
        // Extract from response
        .then(R.prop('data'))
        // Construct StoredInferenceBatches
        .then(R.map(InferenceBatchStore.constructStoredInferenceBatch))
        // Store StoredInferenceBatches
        .then(
          R.tap<StoredInferenceBatch[]>(
            R.pipe(
              // Prepare to be stored
              R.indexBy(R.prop('id')),
              // Store in deploymentById
              this.appendInferenceBatchById,
            ),
          ),
        )
    );
  };

  // Gets processing (running or pending) inference batches for user
  getProcessingInferenceBatches = (): Promise<StoredInferenceBatch[]> => {
    // Load inference batches from server
    const inferenceBatchesPromise = Promise.all([
      client.query('v2.inference.batch.findMany', { status: 'Running' }).then(R.prop('data')),
      client.query('v2.inference.batch.findMany', { status: 'Pending' }).then(R.prop('data')),
    ]).then(_.flatten);
    return (
      inferenceBatchesPromise
        // Construct StoredInferenceBatches
        .then(R.map(InferenceBatchStore.constructStoredInferenceBatch))
        // Store StoredInferenceBatches
        .then(
          R.tap<StoredInferenceBatch[]>(
            R.pipe(
              // Prepare to be stored
              R.indexBy(R.prop('id')),
              // Store in deploymentById
              this.appendInferenceBatchById,
            ),
          ),
        )
    );
  };

  getInferenceBatchResults = (inferenceBatchId: string): Promise<string> => {
    return client.query('v2.inference.batch.getResults', { inferenceBatchId });
  };

  static constructStoredInferenceBatch = (inferencebatch: InferenceBatch): StoredInferenceBatch => {
    return inferencebatch;
  };
}
const indexKey = 'inferenceBatchById';
const indexLens = R.lensProp<InferenceBatchStore, typeof indexKey>(indexKey);

const promiseIndexKey = 'inferenceBatchPromiseById';
const promiseIndexLens = R.lensProp<InferenceBatchStore, typeof promiseIndexKey>(promiseIndexKey);

export const useInferenceBatchStore = create<InferenceBatchStore>()(
  subscribeWithSelector((set, get) => new InferenceBatchStore(set, get)),
);

export const what = 'what';

(window as any).useInferenceBatchStoreV2 = useInferenceBatchStore;
