<template>
  <div class="text-center">
    <b-form-radio-group
      v-model="metric"
      :options="metricOptions"
      buttons
      size="sm"
      button-variant="primary"
    />
    <base-line-chart
      class="mt-3"
      :data="chartData"
      :options="chartOptions"
      :styles="{ height: '100%' }"
    />
  </div>
</template>

<script>

import { mapState, mapGetters } from 'vuex';
import { Line as BaseLineChart } from 'vue-chartjs';
import {
  Chart as ChartJS,
  Title,
  Legend,
  Tooltip,
  LineElement,
  PointElement,
  LinearScale,
  CategoryScale,
} from 'chart.js';

ChartJS.register(
  Title,
  Legend,
  Tooltip,
  LineElement,
  PointElement,
  LinearScale,
  CategoryScale,
);

export default {
  name: 'Metrics',
  components: {
    BaseLineChart,
  },
  data() {
    return {
      metric: 'Accuracy',
      metricOptions: ['Accuracy', 'Loss'],
    };
  },
  computed: {
    ...mapState('nlu/languageModel', [
      'languageModels',
    ]),
    ...mapGetters('nlu/classifier', [
      'activeVersion',
    ]),
    trainTask() {
      return this.activeVersion.classifierversionbot.train_task;
    },
    trainLog() {
      return this.trainTask ? this.trainTask.train_log
        : this.activeVersion.classifierversionbot.train_log;
    },
    modelType() {
      const languageModelId = this.activeVersion.classifierversionbot.config.language_model;
      if (languageModelId === null) {
        return null;
      }
      return this.languageModels[languageModelId].type;
    },
    xLabel() {
      return this.modelType === 'ulm' ? 'Epochs' : 'Steps';
    },
    chartOptions() {
      return {
        aspectRatio: 1,
        plugins: {
          tooltips: {
            mode: 'index',
            intersect: false,
          },
        },
        hover: {
          mode: 'nearest',
          intersect: true,
        },
        animation: {
          duration: 0,
        },
        scales: {
          x: {
            title: {
              display: true,
              text: this.xLabel,
            },
          },
          y: {
            title: {
              display: true,
              text: this.metric,
            },
          },
        },
      };
    },
    datasets() {
      return {
        loss_train: {
          label: 'Loss (training data)',
          fill: false,
          borderColor: '#8e5ea2',
        },
        loss_val: {
          label: 'Loss (validation data)',
          fill: false,
          borderColor: '#3e95cd',
        },
        accuracy: {
          label: 'Perfect match rate',
          fill: false,
          borderColor: '#3e95cd',
        },
        accuracy_at_p: {
          label: 'Success rate',
          fill: false,
          borderColor: '#0a2c63',
        },
        f1_score: {
          label: 'Robust score (F1)',
          fill: false,
          borderColor: '#8e5ea2',
        },
      };
    },
    chartData() {
      const datasets = [];
      const labels = {
        Loss: ['loss_train', 'loss_val'],
        Accuracy: ['accuracy_at_p', 'accuracy', 'f1_score'],
      }[this.metric];
      for (const label of labels) {
        const dataset = this.datasets[label];
        // Untill training has reached a certain epoch, we will receive an empty trainLog object
        // from backend
        const logKey = this.logKey(label);
        if (this.trainLog[logKey] !== undefined) {
          // Round numbers in array to 2 decimal places
          dataset.data = this.trainLog[logKey].map((x) => Number(x.toFixed(3)));
        } else {
          // This mimicks the behaviour as of commit: 8c11dfb5feaba5339f2bac438b79907c1d0915f9
          dataset.data = null;
        }
        datasets.push(dataset);
      }
      return { datasets, labels: this.trainLog.epoch };
    },
  },
  methods: {
    logKey(label) {
      if (label === 'accuracy' && this.modelType !== 'ulm') {
        return 'accuracy_val';
      }
      return label;
    },
  },
};

</script>

<style scoped>

</style>
