import { yupResolver } from "@hookform/resolvers/yup";
import { Alert, AlertTitle } from "@mui/material";
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { useEffect, useMemo } from "react";
import { Controller, useForm } from "react-hook-form";
import { useNavigate } from "react-router-dom";
import Button from "src/components/Button";
import FileInput from "src/components/FileInput";
import LabelTitle from "src/components/LabelTitle";
import Paper from "src/components/Paper";
import PhraseTitle from "src/components/PhraseTitle";
import Select from "src/components/Select";
import Spinner from "src/components/Spinner";
import TextArea from "src/components/TextArea";
import TextInput from "src/components/TextInput";
import { postModel } from "src/library/apis";
import { ReactComponent as BatchChannelDimensionExample } from "../../../assets/svgs/batch_channel_dimension.svg";
import { ReactComponent as BatchDimensionChannelExample } from "../../../assets/svgs/batch_dimension_channel.svg";

import {
  getFileFromEvent,
  getFileNameExtension,
  getMetricUnitsFromTask,
} from "src/library/utils";
import * as yup from "yup";
import { ReactComponent as OpenVINOModel } from "../../../assets/svgs/OpenVINOModel.svg";
import { ReactComponent as TensorFlowKerasModel } from "../../../assets/svgs/TensorFlowKerasModel.svg";

const frameworkToExtensions = {
  onnx: ["onnx"],
  pytorch: ["pt"],
  tensorflow_keras: ["h5", "zip"],
  tensorflow_lite: ["tflite"],
  tensorrt: ["trt"],
  openvino: ["zip"],
};

const modelUploadInputSchema = yup
  .object({
    modelName: yup
      .string()
      .matches(
        /^[0-9a-zA-Z_]+$/,
        "Only alphabets, numbers, and underscores can be entered in the model name."
      )
      .required("Please enter a model name"),
    description: yup.string(),
    task: yup
      .string()
      .oneOf([
        "object_detection",
        "image_classification",
        "image_segmentation",
        "semantic_segmentation",
        "instance_segmentation",
        "panoptic_segmentation",
        "other",
      ])
      .required("Please select a task"),
    framework: yup
      .string()
      .oneOf([
        "tensorflow_keras",
        "onnx",
        "pytorch",
        "keras_saved_model",
        "openvino",
        "tensorflow_lite",
        "tensorrt",
      ])
      .required(),
    file: yup
      .mixed()
      .test(
        "file is needed",
        "Please select a file to upload",
        (value, context) => {
          if (!value) {
            return false;
          }
          return true;
        }
      )
      .test(
        "file size",
        "The maximum size of the model file is 600MB",
        (value, context) => {
          if (value.size > 600 * 1024 * 1024) {
            return false;
          }
          return true;
        }
      ),
    metricUnit: yup.string(),
    metricUnitText: yup.string(),
    metricValue: yup
      .string()
      .test(
        "metric value should be numeric",
        "Metric value should be numeric",
        (value, context) => {
          if (isNaN(value)) {
            return false;
          } else {
            return true;
          }
        }
      ),
    inputShapeBatch: yup
      .string()
      .test(
        "The batch of input shape must be an integer 1 ~ 100",
        "The batch of input shape must be an integer 1 ~ 100",
        (value, context) =>
          !(value !== "" && !/^[1-9][0-9]?$|^100$/.test(value))
      ),
    inputShapeChannel: yup
      .string()
      .test(
        "The channel of input shape must be an integer 1 ~ 100",
        "The channel of input shape must be an integer 1 ~ 100",
        (value, context) =>
          !(value !== "" && !/^[1-9][0-9]?$|^100$/.test(value))
      ),
    inputShapeDimension: yup
      .string()
      .test(
        "The dimension of the input shape must be an integer of 1 or more, and if it is more than one dimension, it must be separated by a comma without space.",
        "The dimension of the input shape must be an integer of 1 or more, and if it is more than one dimension, it must be separated by a comma without space.",
        (value, context) =>
          !(
            value !== "" &&
            !/^\s*[1-9]([0-9]*)(\s*,\s*[1-9]([0-9])*)*\s*$/.test(value)
          )
      ),
  })
  .test(
    "If the framework is PyTorch GraphModule, Input shape must be entered.",
    (context) =>
      "If the framework is PyTorch GraphModule, Input shape must be entered.",
    (context) => {
      const framework = context.framework;
      const inputShapeValues = [
        context.inputShapeBatch,
        context.inputShapeChannel,
        context.inputShapeDimension,
      ];
      const isFull = inputShapeValues.every((value) => value !== "");
      return framework === "pytorch" ? isFull : true;
    }
  )
  .test(
    "InputShape must be entered in all or not in all.",
    (context) => "InputShape must be entered in all or not in all.",
    (context) => {
      const inputShapeValues = [
        context.inputShapeBatch,
        context.inputShapeChannel,
        context.inputShapeDimension,
      ];
      const isEmpty = inputShapeValues.every((value) => value === "");
      const isFull = inputShapeValues.every((value) => value !== "");
      return isEmpty || isFull;
    }
  )
  .test(
    "This file is not available to upload",
    (context) => "This file is not available to upload",
    (context) => {
      if (
        context.file &&
        frameworkToExtensions[context.framework].includes(
          getFileNameExtension(context.file.name)
        )
      ) {
        return true;
      }
      return false;
    }
  )
  .test(
    "metric unit is needed",
    (context) => "Please enter a metric unit",
    (context) => {
      if (context.metricUnit === "Other" && context.metricUnitText === "") {
        return false;
      }
      return true;
    }
  )
  .test(
    "metric value is needed",
    (context) =>
      `Please enter a value of ${
        context.value.metricUnit === "Other"
          ? context.value.metricUnitText
          : context.value.metricUnit
      }`,
    (context) => {
      if (context.metricUnit !== "" && context.metricValue === null) {
        return false;
      }
      return true;
    }
  );

export default function ModelUploadPage() {
  const navigate = useNavigate();

  const queryClient = useQueryClient();

  const { register, setValue, watch, handleSubmit, reset, control } = useForm({
    resolver: yupResolver(modelUploadInputSchema),
    defaultValues: {
      modelName: "",
      description: "",
      task: "image_classification",
      framework: "onnx",
      file: "",
      metricUnit: "",
      metricUnitText: "",
      metricValue: "",
      inputShapeBatch: "1",
      inputShapeChannel: "3",
      inputShapeDimension: "",
    },
  });

  const [file, task, metricUnit, framework] = [
    watch("file"),
    watch("task"),
    watch("metricUnit"),
    watch("framework"),
  ];

  const PostModel = useMutation((model) => postModel(model), {
    onSuccess: ({ data }) => {
      queryClient.resetQueries(["parentsModels"]);
      alert("Model is uploaded successfully");
      navigate(-1);
    },
    onError: ({ data }) => {
      console.log(data);
      PostModel.reset();
    },
  });

  const metricUnits = useMemo(() => getMetricUnitsFromTask(task), [task]);

  useEffect(() => {
    setValue("metricUnit", "");
    if (task === "other") {
      setValue("metricUnit", "Other");
    }
  }, [setValue, task]);

  useEffect(() => {
    setValue("inputShapeBatch", "1");
    setValue("inputShapeChannel", "3");
    setValue("inputShapeDimension", "");
  }, [framework, setValue]);

  const handleChangeFile = (e) => {
    const file = getFileFromEvent(e);
    if (file) {
      setValue("file", file);
    }
  };

  const handleValidationError = (error) => {
    Object.entries(error).every(([name, { message, type, ref }]) => {
      alert(message);
      return false;
    });
  };

  const handleClickUpload = (data) => {
    PostModel.mutate({
      modelName: data.modelName,
      description: data.description,
      file: data.file,
      framework: data.framework,
      task: data.task,
      metricUnit:
        data.metricUnit === "Other" ? data.metricUnitText : data.metricUnit,
      metricValue: data.metricValue,
      inputShapeBatch: Number(data.inputShapeBatch),
      inputShapeChannel: Number(data.inputShapeChannel),
      inputShapeDimension: data.inputShapeDimension
        .split(",")
        .map((e) => Number(e.trim())),
    });
  };

  const handleClickCancel = () => {
    reset();
    PostModel.reset();
    navigate(-1);
  };

  const isDisabled = useMemo(
    () => PostModel.isLoading || PostModel.isSuccess,
    [PostModel.isLoading, PostModel.isSuccess]
  );

  return (
    <Paper>
      <form onSubmit={handleSubmit(handleClickUpload, handleValidationError)}>
        <div className="p-4 flex flex-col gap-8">
          <div className="flex flex-col gap-4">
            <PhraseTitle>Model info</PhraseTitle>
            <div className="flex flex-col gap-1">
              <LabelTitle>Model name *</LabelTitle>
              <TextInput
                {...register("modelName")}
                disabled={isDisabled}
                data-cy="upload-model-name-input"
              />
            </div>
            <div className="flex flex-col gap-1">
              <LabelTitle>Memo</LabelTitle>
              <TextArea
                {...register("description")}
                disabled={isDisabled}
                data-cy="upload-description-textarea"
              />
            </div>
            <div className="flex flex-col gap-1">
              <LabelTitle>Task *</LabelTitle>
              <div>
                <Select
                  {...register("task")}
                  disabled={isDisabled}
                  data-cy="upload-task-select"
                >
                  <option value="image_classification">
                    Image Classification
                  </option>
                  <option value="object_detection">Object Detection</option>
                  <option value="semantic_segmentation">
                    Semantic Segmentation
                  </option>
                  <option value="instance_segmentation">
                    Instance Segmentation
                  </option>
                  <option value="instance_segmentation">
                    Panoptic Segmentation
                  </option>
                  <option value="other">Other</option>
                </Select>
              </div>
            </div>

            <div className="flex flex-col gap-1">
              <LabelTitle>Evaluation metric</LabelTitle>
              <div className="flex items-center gap-1">
                <div>
                  <Select
                    {...register("metricUnit")}
                    disabled={isDisabled}
                    data-cy="upload-metric-unit-select"
                  >
                    <option value="">Select metric</option>
                    {metricUnits.map((metricUnit) => (
                      <option key={metricUnit} value={metricUnit}>
                        {metricUnit}
                      </option>
                    ))}
                    <option value="Other">Other</option>
                  </Select>
                </div>
                {metricUnit === "Other" && (
                  <div>
                    <TextInput
                      {...register("metricUnitText")}
                      disabled={isDisabled}
                    />
                  </div>
                )}
              </div>
            </div>
            <div className="flex flex-col gap-1">
              <LabelTitle>Evaluation value</LabelTitle>
              <div className="flex items-center gap-1">
                <div>
                  <TextInput
                    {...register("metricValue")}
                    disabled={isDisabled}
                    data-cy="upload-metric-value-input"
                  />
                </div>
              </div>
            </div>
          </div>
          <div className="flex flex-col gap-4">
            <PhraseTitle>Model file</PhraseTitle>

            <div className="flex flex-col gap-1">
              <LabelTitle>Framework *</LabelTitle>
              <div className="flex flex-col items-start gap-1">
                <label className="flex items-center">
                  <input
                    type="radio"
                    {...register("framework")}
                    value="onnx"
                    className="mr-2 peer"
                    disabled={isDisabled}
                    data-cy="upload-framework-input"
                  />
                  <span className="peer-disabled:opacity-75">ONNX (.onnx)</span>
                </label>
                <label className="flex items-center">
                  <input
                    type="radio"
                    {...register("framework")}
                    value="pytorch"
                    className="mr-2 peer"
                    disabled={isDisabled}
                    data-cy="upload-framework-input"
                  />
                  <span className="peer-disabled:opacity-75">
                    PyTorch GraphModule (.pt)
                  </span>
                </label>
                <label className="flex items-center">
                  <input
                    type="radio"
                    {...register("framework")}
                    value="tensorflow_keras"
                    className="mr-2 peer"
                    disabled={isDisabled}
                    data-cy="upload-framework-input"
                  />
                  <span className="peer-disabled:opacity-75">
                    TensorFlow-Keras (.h5, .zip)
                  </span>
                </label>
                <label className="flex items-center">
                  <input
                    type="radio"
                    name="task"
                    {...register("framework")}
                    value="tensorflow_lite"
                    className="mr-2 peer"
                    disabled={isDisabled | true}
                    data-cy="upload-framework-input"
                  />
                  <span className="peer-disabled:opacity-75">
                    TensorFlow Lite (.tflite)
                  </span>
                </label>
                <label className="flex items-center">
                  <input
                    type="radio"
                    name="task"
                    {...register("framework")}
                    value="tensorrt"
                    className="mr-2 peer"
                    disabled={isDisabled | true}
                    data-cy="upload-framework-input"
                  />
                  <span className="peer-disabled:opacity-75">
                    TensorRT (.trt)
                  </span>
                </label>
                <label className="flex items-center">
                  <input
                    type="radio"
                    name="task"
                    {...register("framework")}
                    value="openvino"
                    className="mr-2 peer"
                    disabled={isDisabled | true}
                    data-cy="upload-framework-input"
                  />
                  <span className="peer-disabled:opacity-75">
                    OpenVINO (.zip)
                  </span>
                </label>
              </div>

              {framework === "onnx" && (
                <Alert color="info">
                  <AlertTitle>ONNX model</AlertTitle>
                  <ul className="list-disc list-inside">
                    <b>Model Compressor</b>
                    <li>
                      {"Supported version: PyTorch, ONNX version >= 1.10.x."}
                    </li>
                    <li>
                      If a model is defined in PyTorch, it should be converted
                      into the ONNX format before being uploaded.
                    </li>
                    <li>
                      How-to-guide for ONNX conversion is at the{" "}
                      <a
                        href="https://github.com/Nota-NetsPresso/NetsPresso-CompressionToolkit-ModelZoo/tree/main/models/torch#conversion-of-pytorch-into-onnx"
                        target="_blank"
                        rel="noreferrer"
                      >
                        ModelZoo-torch
                      </a>
                      .
                    </li>
                  </ul>
                  <br />
                  <ul className="list-disc list-inside">
                    <b>Model Launcher</b>
                    <li>
                      ONNX to TensorRT, TensorFlow Lite, OpenVINO converting
                      will be available soon.
                    </li>
                  </ul>
                </Alert>
              )}

              {framework === "pytorch" && (
                <Alert color="info">
                  <AlertTitle>PyTorch GraphModule model</AlertTitle>
                  <ul className="list-disc list-inside">
                    <b>Model Compressor</b>
                    <li>{"Supported version: PyTorch version >= 1.10.x."}</li>
                    <li>
                      If a model is defined in PyTorch, it should be converted
                      into the GraphModule before being uploaded.
                    </li>
                    <li>
                      The model must contain not only the status dictionary but
                      also the structure of the model (do not use state_dict).
                    </li>
                    <li>
                      How-to-guide for the conversion is at the{" "}
                      <a
                        href="https://github.com/Nota-NetsPresso/NetsPresso-CompressionToolkit-ModelZoo/tree/main/models/torch#conversion-of-pytorch-into-onnx"
                        target="_blank"
                        rel="noreferrer"
                      >
                        ModelZoo-torch
                      </a>
                      .
                    </li>
                  </ul>
                  <br />
                  <ul className="list-disc list-inside">
                    <b>Model Launcher</b>
                    <li>
                      PyTorch GraphModule to TensorRT, TensorFlow Lite, OpenVINO
                      converting will be available soon.
                    </li>
                  </ul>
                </Alert>
              )}

              {framework === "tensorflow_keras" && (
                <Alert color="info">
                  <AlertTitle>TensorFlow-Keras model</AlertTitle>
                  <ul className="list-disc list-inside">
                    <b>Model Compressor</b>
                    <li>Supported version: TensorFlow 2.2.x ~ 2.5.x.</li>
                    <li>
                      Custom layer must not be included in Keras H5 (.h5)
                      format.
                    </li>
                    <li>
                      The model must contain not only weights but also the
                      structure of the model (do not use save_weights).
                    </li>
                    <li>
                      If there is a custom layer in the model, please upload
                      TensorFlow SavedModel format (.zip).
                    </li>
                  </ul>
                  <br />
                  <ul className="list-disc list-inside">
                    <b>Model Launcher</b>
                    <li>
                      TensorFlow-Keras to TensorFlow Lite converting will be
                      available soon.
                    </li>
                  </ul>
                  <br />
                  <TensorFlowKerasModel />
                </Alert>
              )}

              {framework === "tensorflow_lite" && (
                <Alert color="info">
                  <AlertTitle>TensorFlow Lite model</AlertTitle>
                  <ul className="list-disc list-inside">
                    <b>Model Launcher</b>
                    <li>
                      Model Launcher TensorFlow Lite models can be packaged.
                    </li>
                  </ul>
                </Alert>
              )}

              {framework === "tensorrt" && (
                <Alert color="info">
                  <AlertTitle>TensorRT model</AlertTitle>
                  <ul className="list-disc list-inside">
                    <b>Model Launcher</b>
                    <li>TensorRT models can be packaged.</li>
                    <li>
                      Packaged TensorRT models will only work properly on the
                      machine on which converted TensorRT.
                    </li>
                  </ul>
                </Alert>
              )}

              {framework === "openvino" && (
                <Alert color="info">
                  <AlertTitle>OpenVINO model</AlertTitle>
                  <ul className="list-disc list-inside">
                    <b>Model Launcher</b>
                    <li>OpenVINO models can be packaged.</li>
                  </ul>
                  <br />
                  <OpenVINOModel />
                </Alert>
              )}
            </div>
            <div className="flex flex-col gap-1">
              <LabelTitle>Model File *</LabelTitle>
              <div>
                You can upload one of the following three file extention formats
              </div>
              {/* <FileExtensionsTab /> */}
              <Controller
                control={control}
                name="file"
                render={() => (
                  <FileInput
                    drag={true}
                    buttonText={"Select File"}
                    inputText="Please Select a Model File"
                    dragText={"(Less than 600MB is available)"}
                    onChange={handleChangeFile}
                    file={file}
                    disabled={isDisabled}
                    data-cy="upload-file-input"
                  />
                )}
              />
            </div>
            <div className="flex flex-col gap-1">
              <LabelTitle>
                Input shape {framework === "pytorch" && "*"}
              </LabelTitle>
              <div className="flex gap-2">
                <div className="w-24">
                  <LabelTitle>Batch</LabelTitle>
                  <TextInput
                    {...register("inputShapeBatch")}
                    disabled={isDisabled}
                    data-cy="upload-input-shape-batch-input"
                  />
                </div>
                {(framework === "onnx" || framework === "pytorch") && (
                  <>
                    <div className="w-24">
                      <LabelTitle>Channel</LabelTitle>
                      <TextInput
                        {...register("inputShapeChannel")}
                        disabled={isDisabled}
                        data-cy="upload-input-shape-channel-input"
                      />
                    </div>
                    <div className="w-24">
                      <LabelTitle>Dimension</LabelTitle>
                      <TextInput
                        {...register("inputShapeDimension")}
                        disabled={isDisabled}
                        placeholder="ex)128,128"
                        data-cy="upload-input-shape-dimension-input"
                      />
                    </div>
                  </>
                )}
                {framework === "tensorflow_keras" && (
                  <>
                    <div className="w-24">
                      <LabelTitle>Dimension</LabelTitle>
                      <TextInput
                        {...register("inputShapeDimension")}
                        disabled={isDisabled}
                        placeholder="ex)128,128"
                        data-cy="upload-input-shape-dimension-input"
                      />
                    </div>
                    <div className="w-24">
                      <LabelTitle>Channel</LabelTitle>
                      <TextInput
                        {...register("inputShapeChannel")}
                        disabled={isDisabled}
                        data-cy="upload-input-shape-channel-input"
                      />
                    </div>
                  </>
                )}
              </div>
              {["onnx", "pytorch"].includes(framework) && (
                <Alert color="info">
                  <ul className="list-disc list-inside shrink">
                    For input shape, use the same values that you used to train
                    the model.
                    <li>
                      Only channels first format is supported (batch, channel,
                      dimension).
                    </li>
                    <li>Channel: 3 for RGB images, 1 for gray images.</li>
                    <li>
                      For example, width=1024, height=768 RGB images and the
                      batch size is 1.
                    </li>
                  </ul>
                  <br />
                  <BatchChannelDimensionExample />
                  <br />
                  Currently, only single input models are supported.
                </Alert>
              )}
              {framework === "tensorflow_keras" && (
                <Alert color="info">
                  <ul className="list-disc list-inside shrink">
                    Use the same values for the input shape you used to train
                    the model.
                    <li>
                      Only channels first format is supported (batch, dimension,
                      channel).
                    </li>
                    <li>Channel: 3 for RGB images, 1 for gray images.</li>
                    <li>
                      For example, width=1024, height=768 RGB images and the
                      batch size is 1.
                    </li>
                  </ul>
                  <br />
                  <BatchDimensionChannelExample />
                  <br />
                  Currently, only single input models are supported.
                </Alert>
              )}
            </div>
          </div>
          <div className="flex justify-end items-center gap-2">
            {PostModel.isLoading && <Spinner />}
            <Button
              color="red"
              onClick={handleClickCancel}
              disabled={PostModel.isLoading}
            >
              Cancel
            </Button>
            <Button
              type="submit"
              disabled={isDisabled}
              data-cy="upload-upload-button"
            >
              Upload
            </Button>
          </div>
        </div>
      </form>
    </Paper>
  );
}
