






































































































































import loadImage from "blueimp-load-image";
import { runModelUtils } from "../../utils";

import modelStatus from "./ModelStatus.vue";
import { InferenceSession, Tensor } from "onnxruntime-web";
import { Vue, Component, Prop, Watch } from "vue-property-decorator";

@Component({
  components: {
    modelStatus,
  },
})
export default class ImageModelUI extends Vue {
  @Prop({ type: String, required: true }) modelFilepath!: string;
  @Prop({ type: Number, required: true }) imageSize!: number;
  @Prop({ type: Array, required: true }) imageUrls!: Array<{
    text: string;
    value: string;
  }>;
  @Prop({ type: Function, required: true }) preprocess!: (
    ctx: CanvasRenderingContext2D
  ) => Tensor;
  @Prop({ type: Function, required: true }) getPredictedClass!: (
    output: Float32Array
  ) => {};

  sessionBackend: string;
  backendSelectList: Array<{ text: string; value: string }>;
  modelLoading: boolean;
  modelInitializing: boolean;
  modelLoadingError: boolean;
  sessionRunning: boolean;
  session: InferenceSession | undefined;
  gpuSession: InferenceSession | undefined;
  cpuSession: InferenceSession | undefined;

  inferenceTime: number;
  imageURLInput: string;
  imageURLSelect: null;
  imageURLSelectList: Array<{ text: string; value: string }>;
  imageLoading: boolean;
  imageLoadingError: boolean;
  output: Tensor.DataType;
  modelFile: ArrayBuffer;

  constructor() {
    super();
    this.sessionBackend = "wasm";
    this.backendSelectList = [
      { text: "GPU-WebGL", value: "webgl" },
      { text: "CPU-WebAssembly", value: "wasm" },
    ];
    this.modelLoading = true;
    this.modelInitializing = true;
    this.modelLoadingError = false;
    this.sessionRunning = false;
    this.inferenceTime = 0;
    this.imageURLInput = "";
    this.imageURLSelect = null;
    this.imageURLSelectList = this.imageUrls;
    this.imageLoading = false;
    this.imageLoadingError = false;
    this.output = [];
    this.modelFile = new ArrayBuffer(0);
  }

  async created() {
    // fetch the model file to be used later
    const response = await fetch(this.modelFilepath);
    this.modelFile = await response.arrayBuffer();
    try {
      await this.initSession();
    } catch (e) {
      this.sessionBackend = "wasm";
    }
  }

  async initSession() {
    this.sessionRunning = false;
    this.modelLoadingError = false;
    if (this.sessionBackend === "webgl") {
      if (this.gpuSession) {
        this.session = this.gpuSession;
        return;
      }
      this.modelLoading = true;
      this.modelInitializing = true;
    }
    if (this.sessionBackend === "wasm") {
      if (this.cpuSession) {
        this.session = this.cpuSession;
        return;
      }
      this.modelLoading = true;
      this.modelInitializing = true;
    }

    try {
      if (this.sessionBackend === "webgl") {
        this.gpuSession = await runModelUtils.createModelGpu(this.modelFile);
        this.session = this.gpuSession;
      } else if (this.sessionBackend === "wasm") {
        this.cpuSession = await runModelUtils.createModelCpu(this.modelFile);
        this.session = this.cpuSession;
      }
    } catch (e) {
      console.log(e);
      this.modelLoading = false;
      this.modelInitializing = false;
      if (this.sessionBackend === "webgl") {
        this.gpuSession = undefined;
      } else {
        this.cpuSession = undefined;
      }
      throw new Error("Error: Backend not supported. ");
    }
    this.modelLoading = false;
    // warm up session with a sample tensor. Use setTimeout(..., 0) to make it an async execution so
    // that UI update can be done.
    if (this.sessionBackend === "webgl") {
      setTimeout(() => {
        runModelUtils.warmupModel(this.session!, [
          1,
          3,
          this.imageSize,
          this.imageSize,
        ]);
        this.modelInitializing = false;
      }, 0);
    } else {
      await runModelUtils.warmupModel(this.session!, [
        1,
        3,
        this.imageSize,
        this.imageSize,
      ]);
      this.modelInitializing = false;
    }
  }

  @Watch("sessionBackend")
  async onSessionBackendChange(newVal: string) {
    this.sessionBackend = newVal;
    this.clearAll();
    try {
      await this.initSession();
    } catch (e) {
      this.modelLoadingError = true;
    }
    return newVal;
  }

  @Watch("imageURLSelect")
  onImageURLSelectChange(newVal: string) {
    this.imageURLInput = newVal;
    this.loadImageToCanvas(newVal);
  }

  beforeDestroy() {
    this.session = undefined;
    this.gpuSession = undefined;
    this.cpuSession = undefined;
  }

  get outputClasses() : any {
    return this.getPredictedClass(
      Array.prototype.slice.call(this.output)
    );
  }

  onImageURLInputEnter(e: any) {
    this.imageURLSelect = null;
    this.loadImageToCanvas(e.target.value);
  }

  handleFileChange(e: any) {
    this.$emit("input", e.target.files[0]);
    this.loadImageToCanvas(e.target.files[0]);
  }

  loadImageToCanvas(url: string) {
    if (!url) {
      this.clearAll();
      return;
    }
    this.imageLoading = true;
    loadImage(
      url,
      (img) => {
        if ((img as Event).type === "error") {
          this.imageLoadingError = true;
          this.imageLoading = false;
        } else {
          // load image data onto input canvas
          const element = document.getElementById(
            "input-canvas"
          ) as HTMLCanvasElement;
          if (element) {
            const ctx = element.getContext("2d");
            if (ctx) {
              ctx.drawImage(img as HTMLImageElement, 0, 0);
              this.imageLoadingError = false;
              this.imageLoading = false;
              this.sessionRunning = true;
              this.output = [];
              this.inferenceTime = 0;
              // session predict
              this.$nextTick(function () {
                setTimeout(() => {
                  this.runModel();
                }, 10);
              });
            }
          }
        }
      },
      {
        maxWidth: this.imageSize,
        maxHeight: this.imageSize,
        cover: true,
        crop: true,
        canvas: true,
        crossOrigin: "Anonymous",
      }
    );
  }

  async runModel() {
    const element = document.getElementById(
      "input-canvas"
    ) as HTMLCanvasElement;

    const ctx = element.getContext("2d") as CanvasRenderingContext2D;
    const preprocessedData = this.preprocess(ctx);
    let tensorOutput = null;
    [tensorOutput, this.inferenceTime] = await runModelUtils.runModel(
      this.session!,
      preprocessedData
    );
    this.output = tensorOutput.data;
    this.sessionRunning = false;
  }

  clearAll() {
    this.sessionRunning = false;
    this.inferenceTime = 0;
    this.imageURLInput = "";
    this.imageURLSelect = null;
    this.imageLoading = false;
    this.imageLoadingError = false;
    this.output = [];

    const element = document.getElementById(
      "input-canvas"
    ) as HTMLCanvasElement;
    if (element) {
      const ctx = element.getContext("2d");
      if (ctx) {
        ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);
      }
    }

    const file = document.getElementById("input-upload-image") as HTMLInputElement;
    if (file) {
      file.value = '';
    }
  }
}
