<template>
  <b-card
    class="r-75"
    body-class="p-3"
  >
    <b-row
      v-if="isFetching"
      class="my-5 text-center"
    >
      <b-col>
        <b-spinner />
      </b-col>
    </b-row>
    <div v-else>
      <b-row
        no-gutters
        class="mb-2"
      >
        <b-col class="my-auto">
          <h4 class="mb-0">
            Version details
          </h4>
        </b-col>
        <b-col cols="auto">
          <b-button
            v-if="isTraining"
            class="mr-2"
            variant="warning"
            @click="stopTrain"
          >
            {{ isStopping ? 'Stopping' : 'Stop Training' }}
            <b-spinner small />
          </b-button>
          <b-link :to="{ name: 'nlu-versions-overview' }">
            <b-button
              variant="primary"
            >
              Back
            </b-button>
          </b-link>
        </b-col>
      </b-row>
      <b-container
        v-if="activeVersion"
        class="p-0"
        fluid
      >
        <b-row v-if="!isSWML">
          <b-col>
            <p>
              Here you can see the details of a classifier training. The graph on
              the right shows the evolution of the
              <b-link v-b-modal.metrics-modal>
                validation metrics
              </b-link>
              over time. At the bottom of this page, you can see
              the statistics of each label and the performance of the classifier on each label.
            </p>
            <b-table
              class="mt-4"
              :items="items"
              :fields="['name', 'value']"
            >
              <template #cell(value)="data">
                {{ data.value }}
                <b-spinner
                  v-if="data.item.name === 'Training state' && data.value === 'Training'"
                  class="ml-1"
                  small
                />
              </template>
            </b-table>
            <small
              v-if="isTraining"
              class="text-muted form-text"
            >
              The training will stop automatically when the validation metrics
              have not improved for 5 epochs or 200 steps. Note that training computations block
              each other, so this training might be idle until other trainings finish.
            </small>
          </b-col>
          <b-col cols="6">
            <metrics />
          </b-col>
        </b-row>
        <p v-else>
          This version is uploaded and only information about labels is available.
        </p>
        <b-row class="mt-3">
          <b-col>
            <h4>Labels</h4>
            <b-table
              class="mt-2"
              :tbody-tr-attr="{ style: 'cursor:pointer' }"
              :items="labelItems"
              :fields="labelFields"
              hover
              @row-clicked="labelClicked"
            />
            <small class="text-muted form-text">
              Click a row to see which other labels any label gets confused with.
            </small>
            <b-modal
              id="confusion-modal"
              :title="`Prediction distribution for label '${selectedLabelName}'`"
              ok-only
            >
              <template v-if="labelPredictions !== null">
                <p class="mb-3">
                  Here you can check which other labels that the label '{{ selectedLabelName }}'
                  get confused with when predicting.
                </p>
                <b-table
                  :items="labelPredictions"
                  :fields="confusionModalFields"
                  sort-by="count"
                  sort-desc
                />
              </template>
              <h5 v-else>
                No prediction distribution is available. This is expected for older classifiers and
                during training.
              </h5>
            </b-modal>
          </b-col>
        </b-row>
      </b-container>
      <b-modal
        id="metrics-modal"
        title="Explanation of metrics"
      >
        <p>
          <strong>Perfect match rate / accuracy:</strong> This metric measures how often the top 1
          prediction is correct.
        </p>
        <p class="mt-2">
          <strong>Success rate:</strong> This metric measures the fraction of data points for which
          the correct label has at least 10% accuracy. The success rate therefore corresponds to the
          fraction of times that the correct label would be shown by a smart node with a "display
          option threshold" of 10%.
        </p>
        <p class="mt-2">
          <strong>Robust score (F1):</strong> This metric is also a measure of how accurate
          the model
          is. Rather than only looking at whether a prediction is correct, it takes into account the
          the proportion of false positives/false negatives of each label. In the case of a dataset
          having a mix of labels with few and many datapoints, this can make the score more robust.
        </p>
        <p class="mt-2">
          <strong>Loss:</strong> This is the metric that the training pipeline is trying
          to minimize. Mostly, you should just ignore it, but it can tell you if the training
          is still improving even when the accuracy seems stable.
        </p>
      </b-modal>
    </div>
  </b-card>
</template>

<script>
import axios from 'axios';
import { mapGetters, mapMutations, mapActions } from 'vuex';
import { percentageFormatter } from 'supwiz/util/formatters';
import endpoints from '@/js/urls';
import Metrics from '@/pages/NLU/Versions/Metrics.vue';

const trainData = [
  { name: 'Language model type', key: 'lm_type' },
  { name: 'Language model description', key: 'lm_description' },
  { name: 'Type of training', key: 'train_type' },
  { name: 'Related bot', key: 'bot_id' },
  { name: 'Autolabels age limit', key: 'autolabels_age' },
  { name: 'Training datapoints', key: 'n_data_train' },
  { name: 'Validation datapoints', key: 'n_data_val' },
];

const ulmMetrics = [
  { name: 'Loss', key: 'loss_val' },
  { name: 'Perfect match', key: 'accuracy' },
  { name: 'Success rate', key: 'accuracy_at_p' },
  { name: 'Robust score (F1)', key: 'f1_score' },
];

const transformerMetrics = [
  { name: 'Loss', key: 'loss_val' },
  { name: 'Perfect match', key: 'accuracy_val' },
  { name: 'Success rate', key: 'accuracy_at_p' },
  { name: 'Robust score (F1)', key: 'f1_score' },
];

function prettifyMetric(num, key) {
  if (key.indexOf('accuracy') !== -1 || key.indexOf('f1') !== -1) {
    return `${String((num * 100).toFixed(1))}%`;
  }
  return num.toFixed(4);
}

export default {
  name: 'ModelVersionDetails',
  components: {
    Metrics,
  },
  beforeRouteUpdate(to, from, next) {
    this.setActiveVersionId(to.params.versionId);
    this.fetchVersionDetails();
    next();
  },
  beforeRouteLeave(to, from, next) {
    this.setActiveVersionId(null);
    this.setActiveVersion(null);
    next();
  },
  data() {
    return {
      setIntervalId: null,
      lastUpdated: null,
      nodes: null,
      selectedLabelObj: null,
      isFetching: true,
      confusionModalFields: [
        { key: 'label' },
        { key: 'count' },
        { key: 'fraction', formatter: percentageFormatter },
      ],
    };
  },
  computed: {
    ...mapGetters('nlu/classifier', [
      'getModelFromId',
      'labels',
      'activeModel',
      'activeVersion',
    ]),
    ...mapGetters('botManipulation', [
      'botsList',
    ]),
    isSWML() {
      return this.activeModel.type === 'swml';
    },
    config() {
      return this.activeVersion.classifierversionbot.config;
    },
    trainTask() {
      if (this.activeVersion && this.activeVersion.classifierversionbot) {
        return this.activeVersion.classifierversionbot.train_task;
      }
      return null;
    },
    isTraining() {
      return this.trainTask && !this.trainTask.finished;
    },
    isStopping() {
      return this.isTraining && this.trainTask.aborted;
    },
    trainLog() {
      return this.trainTask ? this.trainTask.train_log
        : this.activeVersion.classifierversionbot.train_log;
    },
    items() {
      if (this.activeModel.type === 'swml' || !this.trainLog) {
        return null;
      }
      const items = [
        { name: 'Training state', value: this.trainTask ? this.trainTask.status : 'Unknown' },
      ];
      const log = this.trainLog;
      // training data
      for (const { name, key } of trainData) {
        if (key in log) {
          let showVal = log[key];
          if (key === 'bot_id' && showVal) {
            try {
              showVal = this.botsList.filter((x) => x.id === showVal)[0].name;
            } catch {
              console.log('Bot used for extraction not found in list.');
            }
          }
          if (showVal === null) {
            showVal = 'None';
          }
          items.push({ name, value: showVal });
        }
      }
      // metrics
      const isULM = this.activeVersion.classifierversionbot.type === 'ulm';
      const metricList = isULM ? ulmMetrics : transformerMetrics;
      if ('epoch_saved' in log && log.epoch_saved !== null) {
        const idx = log.epoch.indexOf(log.epoch_saved);
        for (const { name, key } of metricList) {
          if (key in log) {
            const num = prettifyMetric(log[key][idx], key);
            items.push({ name, value: num });
          }
        }
      }
      return items;
    },
    labelFields() {
      if (this.isSWML) {
        return [{ key: 'name', label: 'Name', sortable: true }];
      }
      return [
        {
          key: 'name',
          formatter: 'labelToName',
          sortable: true,
          sortByFormatted: true,
        },
        { key: 'count_train', label: 'Training rows', sortable: true },
        { key: 'count_val', label: 'Validation rows', sortable: true },
        {
          key: 'accuracy',
          label: 'Validation accuracy',
          formatter: percentageFormatter,
          sortable: true,
          sortByFormatted: (value) => (Number.isNaN(value) ? -1 : value),
        },
      ];
    },
    labelItems() {
      if (!this.isSWML && this.trainLog.labels) {
        const items = Object.values(this.trainLog.labels).map((x) => ({ ...x }));
        for (const x of items) {
          if ('preds' in x && 'count_val' in x) {
            x.accuracy = x.preds[x.name] / x.count_val;
          }
        }
        return items;
      }
      if (this.activeVersion.meta.labels) {
        return this.activeVersion.meta.labels.map((x) => ({ name: x }));
      }
      return [];
    },
    labelPredictions() {
      if (this.selectedLabelObj === null || !('preds' in this.selectedLabelObj)) {
        return null;
      }
      return Object.entries(this.selectedLabelObj.preds).map((x) => ({
        label: this.labelToName(x[0]),
        count: x[1],
        _rowVariant: (x[0] === this.selectedLabelObj.name) ? 'success' : '',
        fraction: x[1] / this.selectedLabelObj.count_val,
      }));
    },
    selectedLabelName() {
      if (this.selectedLabelObj === null) {
        return '';
      }
      return this.labelToName(this.selectedLabelObj.name);
    },
  },
  async created() {
    this.setActiveVersionId(this.$route.params.versionId);
    this.isFetching = true;
    await this.fetchVersionDetails();
    this.isFetching = false;
    this.intervalId = setInterval(() => {
      this.lastUpdated = new Date();
      this.fetchVersionDetails();
    }, 5000);
  },
  beforeDestroy() {
    clearInterval(this.intervalId);
  },
  async mounted() {
    // get nodes for label names
    if (this.activeModel && this.activeModel.nodeId) {
      const resp = await axios.get(endpoints.botsBase + this.activeModel.botId, {
        headers: { Authorization: `JWT ${this.$store.state.auth.jwt}` },
      });
      this.nodes = resp.data.nodes;
    }
  },
  methods: {
    ...mapMutations('nlu/classifier', ['setActiveVersionId', 'setIsAborted', 'setActiveVersion']),
    ...mapActions('nlu/classifier', ['fetchVersionDetails']),
    labelToName(label) {
      if (!this.nodes || !(label in this.nodes)) {
        return label;
      }
      return this.nodes[label].name;
    },
    goBack() {
      this.$router.push({ name: 'nlu-label-single-overview' });
    },
    async stopTrain() {
      if (this.isStopping) {
        if (await this.$bvModal.msgBoxConfirm('Are you sure you want to force stop the training?')) {
          this.executeStop();
        }
        return;
      }
      this.executeStop();
    },
    executeStop() {
      this.setIsAborted(true);
      axios.delete(endpoints.trainTask, {
        params: { clfId: this.activeModel.id },
        headers: { Authorization: `JWT ${this.$store.state.auth.jwt}` },
      });
    },
    labelClicked(item) {
      this.selectedLabelObj = item;
      this.$bvModal.show('confusion-modal');
    },
  },
};
</script>

<style scoped>

</style>
