import {
  Alert,
  Badge,
  Center,
  Flex,
  Group,
  Loader,
  Progress,
  SegmentedControl,
  Stack,
  Table,
  Text,
  Title,
} from '@mantine/core';
import { useMemo, useState } from 'react';
import { match } from 'ts-pattern';
import { EChart } from '../echarts/BareEChart';
import { mixture } from '../util/mixture';

import { SankeyChart } from 'echarts/charts';
import { TooltipComponent } from 'echarts/components';
import * as echarts from 'echarts/core';
import { BinaryConfusionMatrixStats } from '../RecoveryGoal/BinaryConfusionMatrixStats';
import NetWeight from '../Weights/NetWeight';
import { LBSFORMAT } from '../Weights/formatting';
import { useFeedFlowGroupClassificationPerformance } from '../api/feedFlowGroup';
import { SuccessfulFeedFlowGroupClassificationPerformanceAnalysisResultDTO } from '../rest-client';
import cssClasses from './FeedFlowGroupClassificationPerformance.module.css';

echarts.use([SankeyChart, TooltipComponent]);

export function FeedFlowGroupClassificationPerformance(props: {
  feedFlowGroupId: string;
}) {
  const { feedFlowGroupId } = props;
  const classificationQuery = useFeedFlowGroupClassificationPerformance({
    feedFlowGroupId,
  });

  if (classificationQuery.isLoading) {
    return (
      <Center>
        <Flex direction='column' justify='center' align='center'>
          <Loader variant='bars' />{' '}
          <Text size='sm' color='dimmed'>
            Loading classification performance...
          </Text>
        </Flex>
      </Center>
    );
  }

  if (classificationQuery.data) {
    const perf = classificationQuery.data;
    return match(perf)
      .with({ status: 'Successful' }, (result) => (
        <SuccessfulFeedFlowGroupClassificationPerformance result={result} />
      ))
      .otherwise(({ explanation }) => (
        <Alert
          title='Classification Performance Analysis Not Possible'
          color='lime'
        >
          A classification performance analysis could not be conducted because{' '}
          <Text span weight={'bold'}>
            {explanation}
          </Text>
        </Alert>
      ));
  }

  throw classificationQuery.error;
}

function SuccessfulFeedFlowGroupClassificationPerformance(props: {
  result: SuccessfulFeedFlowGroupClassificationPerformanceAnalysisResultDTO;
}) {
  const { result } = props;

  const outputPortNames = new Map(Object.entries(result.outputPortNames));
  const outputPortIds = [...outputPortNames.keys()];
  return (
    <div>
      {/* TODO(2270): Explain compositional inferences here */}
      <section className={cssClasses.classificationPerformanceSection}>
        <Title order={5}>Classification Performance</Title>
        <ClassificationPerformanceOutputPortConfusionMatrix
          result={result}
          outputPortIds={outputPortIds}
        />
      </section>
      <section className={cssClasses.classificationPerformanceSection}>
        <Title order={5}>Output Port Statistics</Title>
        <PortBinaryConfusionMatrixStats
          result={result}
          outputPortIds={outputPortIds}
        />
      </section>

      <section
        className={`${cssClasses.classificationPerformanceSection} ${cssClasses.materialClassPerformanceSection}`}
      >
        <Title order={5}>Material Class Performance</Title>
        <MaterialClassPerformanceTable
          result={result}
          outputPortIdOrder={outputPortIds}
        />
      </section>
      <section
        className={`${cssClasses.classificationPerformanceSection} ${cssClasses.materialClassSankeySection}`}
      >
        <Title order={5}>Material Class Flow</Title>
        <ClassificationPerformanceClassToPortSankey
          result={result}
          outputPortIds={outputPortIds}
        />
      </section>
    </div>
  );
}

function PortBinaryConfusionMatrixStats(props: {
  result: SuccessfulFeedFlowGroupClassificationPerformanceAnalysisResultDTO;
  outputPortIds: string[];
}) {
  const { result, outputPortIds } = props;

  const [selectedOutputPort, setSelectedOutputPort] = useState(
    outputPortIds[0],
  );

  const matrix = result.oneVsRestBinaryConfusionMatrices[selectedOutputPort];

  return (
    <Stack w='100%' className={cssClasses.portBinaryStatsWrapper}>
      <SegmentedControl
        value={selectedOutputPort}
        onChange={setSelectedOutputPort}
        data={outputPortIds.map((outputPortId) => ({
          label: `${selectedOutputPort === outputPortId ? '+' : '-'} ${
            result.outputPortNames[outputPortId]
          }`,
          value: outputPortId,
        }))}
      />
      <Group className={cssClasses.portBinaryStats}>
        <BinaryConfusionMatrixStats matrix={matrix} withCardWrapper />
      </Group>
    </Stack>
  );
}

function MaterialClassPerformanceTable(props: {
  result: SuccessfulFeedFlowGroupClassificationPerformanceAnalysisResultDTO;
  outputPortIdOrder: string[];
}) {
  const { result, outputPortIdOrder } = props;
  const materialClassNames = new Map(Object.entries(result.materialClassNames));

  const outputPortColorSeq = ['orange', 'grape', 'indigo', 'violet'];
  const outputPortColors = Object.fromEntries(
    outputPortIdOrder.map((outputPortId, i) => [
      outputPortId,
      outputPortColorSeq[i % outputPortColorSeq.length],
    ]),
  );

  return (
    <Table withBorder>
      <thead>
        <tr>
          <th>Class</th>
          <th style={{ textAlign: 'center' }}>Goal</th>
          <th style={{ textAlign: 'right', whiteSpace: 'nowrap' }}>
            Feed Mass
          </th>
          <th colSpan={2} style={{ textAlign: 'center', whiteSpace: 'nowrap' }}>
            Correct
          </th>
          <th colSpan={2} style={{ textAlign: 'center', whiteSpace: 'nowrap' }}>
            Error
          </th>

          <th />
        </tr>
      </thead>
      <tbody>
        {[...materialClassNames.entries()].map(
          ([materialClassId, materialClassName]) => {
            const classPerf = result.materialClassPerformances[materialClassId];
            const noMaterial = classPerf.feedstockMassLbs === 0;
            const goalPort =
              result.materialClassTargetOutputPorts[materialClassId];
            return (
              <tr key={materialClassId}>
                <th>{materialClassName}</th>
                <td>
                  <Badge color={outputPortColors[goalPort]}>
                    {result.outputPortNames[goalPort]}
                  </Badge>
                </td>
                <td style={{ textAlign: 'right' }}>
                  <NetWeight
                    weight={
                      result.estimatedFeedstockMaterialClassGenericMassMixture[
                        materialClassId
                      ]
                    }
                    sourceIconMode='icon-tooltip'
                  />
                </td>
                <td style={{ textAlign: 'right' }}>
                  {noMaterial ? (
                    '-'
                  ) : (
                    <Text fw={500 + classPerf.percentCorrectlySortedByMass}>
                      {`${classPerf.percentCorrectlySortedByMass.toFixed(1)}%`}
                    </Text>
                  )}
                </td>
                <td style={{ textAlign: 'right' }}>
                  {noMaterial ? (
                    '-'
                  ) : (
                    <NetWeight
                      weight={classPerf.correctlySortedMass}
                      sourceIconMode='icon-tooltip'
                    />
                  )}
                </td>

                <td
                  style={{
                    textAlign: 'right',
                  }}
                >
                  {noMaterial ? (
                    '-'
                  ) : (
                    <Text fw={800 - 3 * classPerf.percentCorrectlySortedByMass}>
                      {(100 - classPerf.percentCorrectlySortedByMass).toFixed(
                        1,
                      )}
                      %
                    </Text>
                  )}
                </td>
                <td
                  style={{
                    textAlign: 'right',
                  }}
                >
                  {' '}
                  {noMaterial ? (
                    '-'
                  ) : (
                    <NetWeight
                      weight={classPerf.incorrectlySortedMass}
                      sourceIconMode='icon-tooltip'
                    />
                  )}
                </td>
                <td
                  style={{
                    textAlign: 'center',
                  }}
                >
                  {noMaterial ? (
                    '-'
                  ) : (
                    <Progress
                      radius='xs'
                      w='6ch'
                      size='1.2em'
                      sections={[
                        {
                          value: classPerf.percentCorrectlySortedByMass,
                          color: 'teal',
                        },
                        {
                          value: 100 - classPerf.percentCorrectlySortedByMass,
                          color: 'red',
                        },
                      ]}
                    />
                  )}
                </td>
              </tr>
            );
          },
        )}
      </tbody>
    </Table>
  );
}

function ClassificationPerformanceClassToPortSankey(props: {
  result: SuccessfulFeedFlowGroupClassificationPerformanceAnalysisResultDTO;
  outputPortIds: string[];
}) {
  const { result, outputPortIds } = props;

  const option = useMemo(() => {
    const materialClassIds = result.materialClassSet.materialClasses.map(
      (mc) => mc.id,
    );
    const materialClassNames = new Map(
      Object.entries(result.materialClassNames),
    );
    const outputPortNames = new Map(Object.entries(result.outputPortNames));
    const targetPorts = new Map(
      Object.entries(result.materialClassTargetOutputPorts),
    );

    const feedstockMixture = mixture(
      Object.entries(result.estimatedFeedstockMaterialClassMixtureLbs),
    );

    const nonZeroMaterialClasses = new Set(
      [...feedstockMixture.entries()]
        .filter(([, lbs]) => lbs > 0)
        .map(([k]) => k),
    );

    const outputPortMixtures = Object.fromEntries(
      Object.entries(
        result.estimatedOutputPortMaterialClassMassMixturesLbs,
      ).map(([outputPortId, classMixture]) => [
        outputPortId,
        mixture(classMixture),
      ]),
    );

    return {
      tooltip: {
        valueFormatter: (v: number | string) => `${Number(v).toFixed(1)} lbs`,
      },
      series: {
        type: 'sankey',
        layoutIterations: 128, // Should we make this order match material class set order by setting to 0?
        layout: 'none',
        emphasis: {
          focus: 'adjacency',
        },
        data: [
          ...materialClassIds
            .filter((mcId) => nonZeroMaterialClasses.has(mcId))
            .map((mcId) => ({
              name: materialClassNames.get(mcId),
              label: {
                show: true,
                position: 'right',
              },
            })),
          ...[...outputPortNames.values()].map((portName) => ({
            name: portName,
          })),
        ],
        links: outputPortIds.flatMap((outputPortId) =>
          [...outputPortMixtures[outputPortId].entries()]
            .filter(([mcId]) => nonZeroMaterialClasses.has(mcId))
            .filter(([, lbs]) => lbs > 0)
            .map(([materialClassId, lbs]) => {
              const toCorrectPort =
                targetPorts.get(materialClassId) === outputPortId;
              return {
                source: materialClassNames.get(materialClassId),
                target: outputPortNames.get(outputPortId),
                value: lbs.toFixed(1),
                lineStyle: {
                  color: toCorrectPort ? 'target' : 'red',
                  opacity: toCorrectPort ? 0.2 : 0.4,
                  shadowBlur: toCorrectPort ? undefined : 6,
                  shadowColor: toCorrectPort ? undefined : 'rgba(0, 0, 0, 0.5)',
                  curveness: 0.6,
                },
              };
            }),
        ),
      },
    };
  }, [outputPortIds, result]);

  return <EChart h={500} w='100%' option={option} />;
}

function ClassificationPerformanceOutputPortConfusionMatrix(props: {
  result: SuccessfulFeedFlowGroupClassificationPerformanceAnalysisResultDTO;
  outputPortIds: string[];
}) {
  const { result, outputPortIds } = props;

  const outputPortNames = new Map(Object.entries(result.outputPortNames));
  const sortedOutputPortTotalLbs = mixture(
    Object.entries(result.sortedOutputPortMassMixtureLbs),
  );
  const totalLbs = sortedOutputPortTotalLbs.total();

  type OutputPortStats = {
    correct: number;
    absent: number;
    wronglyIncluded: number;
  };
  const statsByOutputPort = new Map<string, OutputPortStats>(
    outputPortIds.map((outputPortId) => [
      outputPortId,
      {
        correct: 0,
        absent: 0,
        wronglyIncluded: 0,
      },
    ]),
  );
  for (const goalPortId of outputPortIds) {
    const goalPortStats = statsByOutputPort.get(goalPortId) as OutputPortStats;
    for (const presenceOutputPortId of outputPortIds) {
      const valToAdd =
        result.outputPortConfusionMatrixLbs[goalPortId][presenceOutputPortId];
      if (goalPortId === presenceOutputPortId) {
        goalPortStats.correct += valToAdd;
      } else {
        goalPortStats.absent += valToAdd;
        const presentPortStats = statsByOutputPort.get(
          presenceOutputPortId,
        ) as OutputPortStats;
        presentPortStats.wronglyIncluded += valToAdd;
      }
    }
  }

  const tableRows = Array.from(
    statsByOutputPort,
    ([outputPortId, { wronglyIncluded, absent, correct }]) => {
      const sortedLbs = result.sortedOutputPortMassMixtureLbs[outputPortId];
      const outputPortName = outputPortNames.get(outputPortId)?.slice(0, 3);

      return (
        <tr key={outputPortId}>
          <td>{outputPortName}</td>
          <td>{LBSFORMAT.format(sortedLbs)} lbs.</td>
          <td>{LBSFORMAT.format(wronglyIncluded)} lbs.</td>
          <td>{LBSFORMAT.format(absent)} lbs.</td>
          <td>{LBSFORMAT.format(correct)} lbs.</td>
        </tr>
      );
    },
  );

  return (
    <div>
      <Table withBorder withColumnBorders>
        <thead>
          <tr>
            <th>Process Output</th>
            <th>Output Total</th>
            <th>Wrongly Included</th>
            <th>Absent</th>
            <th>Correctly Sorted</th>
          </tr>
        </thead>
        <tbody>
          {tableRows}
          {/* TODO(2289): this should be in a <tfoot> but mantine doesn't style
            the content of <tfoot> */}
          <tr className={cssClasses.totalRow}>
            <td>Total</td>
            <td colSpan={4}> {LBSFORMAT.format(totalLbs)} lbs.</td>
          </tr>
        </tbody>
      </Table>
    </div>
  );
}
