import { z } from 'zod';

import { ModelType } from '@scale/llm-shared/interfaces/model';

// Type

export const ModelParametersBase = z.object({
  id: z.string(),
  modelId: z.string(),
  modelType: ModelType,
});
export const ModelParametersOpenAi = ModelParametersBase.extend({
  modelType: z.literal('OpenAi'),

  maxTokens: z.optional(z.number().positive()),
  temperature: z.optional(z.number().gte(0).lte(2)),
  stop: z.optional(z.string()),
  suffix: z.optional(z.string()),
  topP: z.optional(z.number()),
  logprobs: z.optional(z.number()),
  logitBias: z.optional(z.record(z.string(), z.number())),
});
export const ModelParametersLaunch = ModelParametersBase.extend({
  modelType: z.literal('ScaleLaunch'),

  maxTokens: z.optional(z.number().positive()),
  temperature: z.optional(z.number().gte(0).lte(1)),
  stop: z.optional(z.string()),
});
export const ModelParametersAi21 = ModelParametersBase.extend({
  modelType: z.literal('Ai21'),

  maxTokens: z.optional(z.number().positive()),
  temperature: z.optional(z.number().gte(0).lte(1)),
  stop: z.optional(z.string()),
});
export const ModelParametersCohere = ModelParametersBase.extend({
  modelType: z.literal('Cohere'),

  maxTokens: z.optional(z.number().positive()),
  temperature: z.optional(z.number().gte(0).lte(1)),
  stop: z.optional(z.string()),
});
export const ModelParametersAnthropic = ModelParametersBase.extend({
  modelType: z.literal('Anthropic'),

  maxTokens: z.optional(z.number().positive()),
  temperature: z.optional(z.number().gte(0).lte(1)),
  stop: z.optional(z.string()),
});
export const ModelParameters = z.discriminatedUnion('modelType', [
  ModelParametersOpenAi,
  ModelParametersLaunch,
  ModelParametersAi21,
  ModelParametersCohere,
  ModelParametersAnthropic,
]);

export type ModelParametersBase = z.infer<typeof ModelParametersBase>;
export type ModelParametersOpenAi = z.infer<typeof ModelParametersOpenAi>;
export type ModelParametersLaunch = z.infer<typeof ModelParametersLaunch>;
export type ModelParametersAi21 = z.infer<typeof ModelParametersAi21>;
export type ModelParametersCohere = z.infer<typeof ModelParametersCohere>;
export type ModelParametersAnthropic = z.infer<typeof ModelParametersAnthropic>;
export type ModelParameters = z.infer<typeof ModelParameters>;

// Create

export const ModelParametersOpenAiCreate = ModelParametersOpenAi.omit({ id: true });
export const ModelParametersLaunchCreate = ModelParametersLaunch.omit({ id: true });
export const ModelParametersAi21Create = ModelParametersAi21.omit({ id: true });
export const ModelParametersCohereCreate = ModelParametersCohere.omit({ id: true });
export const ModelParametersAnthropicCreate = ModelParametersAnthropic.omit({ id: true });
export const ModelParametersCreate = z.discriminatedUnion('modelType', [
  ModelParametersOpenAiCreate,
  ModelParametersLaunchCreate,
  ModelParametersAi21Create,
  ModelParametersCohereCreate,
  ModelParametersAnthropicCreate,
]);

export type ModelParametersOpenAiCreate = z.infer<typeof ModelParametersOpenAiCreate>;
export type ModelParametersLaunchCreate = z.infer<typeof ModelParametersLaunchCreate>;
export type ModelParametersAi21Create = z.infer<typeof ModelParametersAi21Create>;
export type ModelParametersCohereCreate = z.infer<typeof ModelParametersCohereCreate>;
export type ModelParametersAnthropicCreate = z.infer<typeof ModelParametersAnthropicCreate>;
export type ModelParametersCreate = z.infer<typeof ModelParametersCreate>;

// Run

export const ModelParametersOpenAiRun = ModelParametersOpenAiCreate;
export const ModelParametersLaunchRun = ModelParametersLaunchCreate;
export const ModelParametersAi21Run = ModelParametersAi21Create;
export const ModelParametersCohereRun = ModelParametersCohereCreate;
export const ModelParametersAnthropicRun = ModelParametersAnthropicCreate;

export const ModelParametersRun = z.discriminatedUnion('modelType', [
  ModelParametersOpenAiRun,
  ModelParametersLaunchRun,
  ModelParametersAi21Run,
  ModelParametersCohereRun,
  ModelParametersAnthropicRun,
]);

export type ModelParametersOpenAiRun = z.infer<typeof ModelParametersOpenAiRun>;
export type ModelParametersLaunchRun = z.infer<typeof ModelParametersLaunchRun>;
export type ModelParametersAi21Run = z.infer<typeof ModelParametersAi21Run>;
export type ModelParametersCohereRun = z.infer<typeof ModelParametersCohereRun>;
export type ModelParametersAnthropicRun = z.infer<typeof ModelParametersAnthropicRun>;
export type ModelParametersRun = z.infer<typeof ModelParametersRun>;
