import { Alert, Flex, Stack } from '@mantine/core';
import { forwardRef } from 'react';
import tinycolor from 'tinycolor2';
import { match } from 'ts-pattern';
import {
  CommoditySpreadMaterialNodeDTO,
  CommoditySpreadRecoveryGoalNodeDTO,
  MaterialClassDTO,
} from '../rest-client';
import { useRecoveryStrategySimulationCtx } from './RecoveryStrategySimulationContext';

import { useCategoricalColors } from '../lib/colors';
import { MaterialClassLegendItem } from './MaterialClassLegendItem';
import classes from './RecoveryStrategySimulationDiagram.module.css';

export function RecoveryStrategySimulationDiagram() {
  const { simulation, selectedMaterialClassId, setSelectedMaterialClassId } =
    useRecoveryStrategySimulationCtx();

  const colors = useCategoricalColors();

  if (simulation.rootMaterialNode === null) {
    return <Alert color='yellow'>No Simulation Results</Alert>;
  }

  const { materialClasses } = simulation.materialClassSet;

  const negativeSubTreeBlocks: { sepBlock: SepBlock; offset: number }[] = [];
  const rootMaterialNode: CommoditySpreadMaterialNodeDTO =
    simulation.rootMaterialNode;

  let inputCursor = 0;
  const inputCenters: Record<string, number> = {};
  for (const materialClass of materialClasses) {
    const inputProb =
      rootMaterialNode.materialClassProbabilities[materialClass.id] *
      SCALE_FACTOR;
    const center = inputCursor + inputProb / 2;
    inputCenters[materialClass.id] = center;
    inputCursor += inputProb;
  }

  let recoveryNode: CommoditySpreadRecoveryGoalNodeDTO | null =
    rootMaterialNode.consumingRecoveryGoalNode;
  let offset = 0;
  while (recoveryNode !== null) {
    const sepBlock = toSepBlock(
      recoveryNode,
      simulation.materialClassSet.materialClasses,
    );
    negativeSubTreeBlocks.push({ sepBlock, offset });
    offset += sepBlock.width;
    recoveryNode =
      recoveryNode.negativeMaterialNode.consumingRecoveryGoalNode ?? null;
  }

  const posEnd = negativeSubTreeBlocks[0].sepBlock.height + 50;

  return (
    <Flex>
      <Stack spacing='xs' mt={30}>
        {materialClasses.map((materialClass, i) => (
          <MaterialClassLegendItem
            key={materialClass.id}
            color={colors[i]}
            materialClass={materialClass}
          />
        ))}
      </Stack>
      <svg viewBox={`-70 -30 ${offset + 100} ${posEnd + 120}`} width={'100%'}>
        <defs>
          {materialClasses.map((materialClass, i) => {
            const color = colors[i % colors.length];
            return (
              <pattern
                key={materialClass.id}
                id={`incorrect-${materialClass.id}`}
                width='100px'
                height='5px'
                patternUnits='userSpaceOnUse'
                patternContentUnits='userSpaceOnUse'
                viewBox='0 0 100 100'
                preserveAspectRatio='none'
                patternTransform='rotate(-30)'
              >
                <rect
                  fill={tinycolor(color).lighten().desaturate().toHexString()}
                  width={100}
                  height={100}
                />
                <line
                  x1='0'
                  y1='50'
                  x2='100'
                  y2='50'
                  strokeWidth='50'
                  stroke={color}
                />
              </pattern>
            );
          })}
          <filter
            id='coloredShadow'
            x='-50%'
            y='-50%'
            width='200%'
            height='200%'
          >
            <feGaussianBlur stdDeviation='2' result='coloredBlur' />
            <feMerge>
              <feMergeNode in='coloredBlur' />
              <feMergeNode in='SourceGraphic' />
            </feMerge>
          </filter>
          <filter id='glow' x='-50%' y='-50%' width='200%' height='200%'>
            <feGaussianBlur stdDeviation='4' result='coloredBlur' />
            <feMerge>
              <feMergeNode in='coloredBlur' />
              <feMergeNode in='SourceGraphic' />
            </feMerge>
          </filter>
        </defs>
        <g>
          {/* Input lines */}
          {materialClasses.map((materialClass, i) => (
            <path
              className={classes.materialPath}
              data-selected={selectedMaterialClassId === materialClass.id}
              key={materialClass.id}
              d={`m 0 ${inputCenters[materialClass.id]} h -50`}
              fill='none'
              stroke={colors[i]}
              strokeWidth={
                (simulation.rootMaterialNode?.materialClassProbabilities[
                  materialClass.id
                ] ?? Number.NaN) * SCALE_FACTOR
              }
              onClick={() => setSelectedMaterialClassId(materialClass.id)}
            />
          ))}
        </g>
        <g>
          {negativeSubTreeBlocks.map(({ sepBlock, offset }, i) => (
            <SepBlock
              key={i}
              sepBlock={sepBlock}
              extend={posEnd - sepBlock.height}
              transform={`translate(${offset}, 0)`}
            />
          ))}
        </g>
      </svg>
    </Flex>
  );
}

interface MaterialClassPath {
  path: string;
  thickness: number;
  correct: boolean;
}

interface SepBlockPath {
  materialClassId: string;
  negative: MaterialClassPath;
  positive: MaterialClassPath;
}

export interface SepBlock {
  recoveryNode: CommoditySpreadRecoveryGoalNodeDTO;

  width: number;
  height: number;

  inputThickness: number;

  posThickness: number;
  posBlockPoint: [number, number];

  negThickness: number;
  negBlockPoint: [number, number];
  negTerminal: boolean;

  paths: SepBlockPath[];
}

const SCALE_FACTOR = 800;
const POS_BEND_RADIUS = 25;
const MIN_NEG_ARC_RADIUS = 25;
const PADDING = 20;

export function toSepBlock(
  recoveryNode: CommoditySpreadRecoveryGoalNodeDTO,
  materialClasses: MaterialClassDTO[],
): SepBlock {
  const recoveryGoalClassIds = new Set(
    recoveryNode.recoveryGoal.materialClasses.map((mc) => mc.id),
  );

  // go from the start of the input block to the end, finding the offsets for each line
  let inputCursor = 0;
  let posOutputCursor = 0;
  let negOutputCursor = 0;

  const centers: {
    materialClassId: string;
    positive: {
      /**
       * Distance from the start of the input block to the center of the path
       */
      input: number;
      /**
       * Distance from the start of the output block to the center of the path
       */
      output: number;

      thickness: number;

      correct: boolean;
    };
    negative: {
      /**
       * Distance from the start of the input block to the center of the path
       */
      input: number;
      /**
       * Distance from the start of the output block to the center of the path
       */
      output: number;

      thickness: number;

      correct: boolean;
    };
  }[] = [];

  const posDirection: 'left' | 'right' = 'right' as 'left' | 'right';

  let maxNegDelta = 0;
  let maxNegThick = 0;
  for (const materialClass of materialClasses) {
    const posProb =
      recoveryNode.positiveMaterialNode.materialClassProbabilities[
        materialClass.id
      ];
    const negProb =
      recoveryNode.negativeMaterialNode.materialClassProbabilities[
        materialClass.id
      ];

    const posThick = posProb * SCALE_FACTOR;
    const negThick = negProb * SCALE_FACTOR;
    maxNegThick = Math.max(maxNegThick, negThick);

    let posInputCenter: number;
    let negInputCenter: number;

    // we place the positive line before the negative line on the input axis
    if (posDirection === 'left') {
      posInputCenter = inputCursor += posThick / 2;
      inputCursor += posThick / 2;

      negInputCenter = inputCursor += negThick / 2;
      inputCursor += negThick / 2;
    } else {
      negInputCenter = inputCursor += negThick / 2;
      inputCursor += negThick / 2;

      posInputCenter = inputCursor += posThick / 2;
      inputCursor += posThick / 2;
    }

    // biome-ignore lint/suspicious/noAssignInExpressions: <explanation>
    const posOutputCenter = (posOutputCursor += posThick / 2);
    posOutputCursor += posThick / 2;

    // biome-ignore lint/suspicious/noAssignInExpressions: <explanation>
    const negOutputCenter = (negOutputCursor += negThick / 2);
    negOutputCursor += negThick / 2;

    const isInRg = recoveryGoalClassIds.has(materialClass.id);

    maxNegDelta = Math.max(maxNegDelta, negInputCenter - negOutputCenter);

    centers.push({
      materialClassId: materialClass.id,
      positive: {
        input: posInputCenter,
        output: posOutputCenter,
        correct: isInRg,
        thickness: posThick,
      },
      negative: {
        input: negInputCenter,
        output: negOutputCenter,
        correct: !isInRg,
        thickness: negThick,
      },
    });
  }

  const posOutputBlockOffset = POS_BEND_RADIUS;
  const posOutputBlockEnd = posOutputBlockOffset + posOutputCursor;

  const negJoinArcCenter =
    negOutputCursor + Math.max(maxNegDelta / 2, MIN_NEG_ARC_RADIUS);

  // find max required total negative arc width
  let negArcSize = 0;
  for (const center of centers) {
    const negDelta = center.negative.input - center.negative.output;
    const joinArcRadius = negJoinArcCenter - center.negative.output;
    negArcSize = Math.max(
      negArcSize,
      Math.sqrt(2 * negDelta * joinArcRadius - negDelta * negDelta),
    );
  }
  negArcSize += MIN_NEG_ARC_RADIUS;

  const paths: SepBlock['paths'] = [];
  for (const center of centers) {
    const negDelta = center.negative.input - center.negative.output;
    const joinArcRadius = negJoinArcCenter - center.negative.output;
    const alignArcRadius =
      (0.5 * negArcSize ** 2) / negDelta - joinArcRadius + negDelta / 2;

    const R = joinArcRadius + alignArcRadius;

    const alignArcWidth = (alignArcRadius / R) * negArcSize;
    const alignArcHeight = (alignArcRadius / R) * negDelta;

    const joinArcWidth = (joinArcRadius / R) * negArcSize;
    const joinArcHeight = (joinArcRadius / R) * negDelta;

    // The negative path is a straight horizontal line across until we can start converging
    // the convergence can start right after the last positive output line
    const convergencePath =
      negDelta === 0
        ? `h ${negArcSize}`
        : `a ${alignArcRadius} ${alignArcRadius} 0 0 0 ${alignArcWidth} ${-alignArcHeight}
           a ${joinArcRadius} ${joinArcRadius} 0 0 1 ${joinArcWidth} ${-joinArcHeight}`;

    // The radius of the neg arc is given by the total of neg below it (width - center)
    // TODO(2288): make this do direction properly
    const negArcRadius =
      negOutputCursor - center.negative.output + POS_BEND_RADIUS;
    const negArc =
      recoveryNode.negativeMaterialNode.consumingRecoveryGoalNode === null
        ? `a ${negArcRadius} ${negArcRadius} 0 0 1 ${negArcRadius} ${negArcRadius} v ${
            inputCursor - negOutputCursor
          }`
        : '';

    const negPath = `
    m 0, ${center.negative.input}
    h ${PADDING}
    h ${posOutputBlockEnd}
    ${convergencePath}
    h ${PADDING}
    ${negArc}
    `;

    // The positive path is a straight horizontal line across until we can start the quarter circle
    // the distance before the quarter circle is the output center distance minus the bend radius
    // after the quarter circle, we draw a vertical line to clear all the negative paths
    // this distance is given by its input distance since that's the orthogonal axis
    const posInputLineLength =
      match(posDirection)
        .with('left', () => posOutputBlockOffset + center.positive.output)
        .with('right', () => posOutputBlockEnd - center.positive.output)
        .exhaustive() - POS_BEND_RADIUS;
    const posPath = `
    m 0, ${center.positive.input}
    h ${PADDING}
    h ${posInputLineLength}
    a ${POS_BEND_RADIUS} ${POS_BEND_RADIUS} 0 0 1 ${POS_BEND_RADIUS} ${
      posDirection === 'left' ? -POS_BEND_RADIUS : POS_BEND_RADIUS
    }
    v ${
      posDirection === 'left'
        ? 0 - center.positive.input
        : inputCursor - center.positive.input // TODO(2288): This this dangle something?
    }
    `;

    paths.push({
      materialClassId: center.materialClassId,
      positive: {
        path: posPath,
        thickness: center.positive.thickness,
        correct: center.positive.correct,
      },
      negative: {
        path: negPath,
        thickness: center.negative.thickness,
        correct: center.negative.correct,
      },
    });
  }

  // the total height is given by the total thickness of all inputs
  const height = inputCursor + POS_BEND_RADIUS;

  const width =
    posOutputBlockEnd +
    negArcSize +
    2 * PADDING +
    (recoveryNode.negativeMaterialNode.consumingRecoveryGoalNode === null
      ? negOutputCursor + POS_BEND_RADIUS
      : 0);

  return {
    recoveryNode: recoveryNode,

    inputThickness: inputCursor,

    width,
    height,
    // TODO(2288): Why is + 5? needed here?
    posBlockPoint: [PADDING + POS_BEND_RADIUS, inputCursor + PADDING + 5],
    posThickness: posOutputCursor,
    negTerminal:
      recoveryNode.negativeMaterialNode.consumingRecoveryGoalNode === null,

    negThickness: negOutputCursor,
    negBlockPoint: [
      width - negOutputCursor,
      recoveryNode.negativeMaterialNode.consumingRecoveryGoalNode === null
        ? height
        : negOutputCursor,
    ],
    paths,
  };
}

type SepBlockProps = React.SVGProps<SVGGElement> & {
  sepBlock: SepBlock;
  extend: number;
};

export const SepBlock = forwardRef<SVGGElement, SepBlockProps>(
  function SepBlock(props: SepBlockProps, ref) {
    const { sepBlock, extend, ...rest } = props;

    const colors = useCategoricalColors();

    const {
      selectedRecoveryGoalNode,
      setSelectedRecoveryGoaNode,
      selectedMaterialClassId,
    } = useRecoveryStrategySimulationCtx();

    const { recoveryGoal } = sepBlock.recoveryNode;
    const rgSelected =
      recoveryGoal.id === selectedRecoveryGoalNode?.recoveryGoal.id;

    const selectedPathIdx = sepBlock.paths.findIndex(
      (p) => p.materialClassId === selectedMaterialClassId,
    );

    return (
      <g ref={ref} {...rest}>
        {/* Negative paths first */}
        <g
          style={{
            filter: 'url(#coloredShadow)',
          }}
        >
          {sepBlock.paths.map((path, i) => {
            if (i === selectedPathIdx) return null;
            return (
              <MaterialPath
                key={`${path.materialClassId}-neg`}
                path={path}
                materialNode={sepBlock.recoveryNode.negativeMaterialNode}
                color={colors[i]}
                direction='negative'
                extend={sepBlock.negTerminal ? extend : undefined}
              />
            );
          })}
        </g>

        {/* Then positive paths */}
        <g
          style={{
            filter: 'drop-shadow(0px 0px 8px rgba(0,0,0,0.45))',
          }}
        >
          {sepBlock.paths.map((path, i) => {
            if (i === selectedPathIdx) return null;
            return (
              <MaterialPath
                key={`${path.materialClassId}-pos`}
                path={path}
                materialNode={sepBlock.recoveryNode.positiveMaterialNode}
                color={colors[i]}
                direction='positive'
                extend={extend}
              />
            );
          })}
        </g>

        {selectedPathIdx !== -1 && (
          <g className={classes.selectedMaterialPathGroup}>
            <MaterialPath
              key={`${sepBlock.paths[selectedPathIdx].materialClassId}-pos`}
              path={sepBlock.paths[selectedPathIdx]}
              materialNode={sepBlock.recoveryNode.positiveMaterialNode}
              direction='positive'
              color={colors[selectedPathIdx]}
              extend={extend}
            />
            <MaterialPath
              key={`${sepBlock.paths[selectedPathIdx].materialClassId}-neg`}
              path={sepBlock.paths[selectedPathIdx]}
              materialNode={sepBlock.recoveryNode.negativeMaterialNode}
              direction='negative'
              color={colors[selectedPathIdx]}
              extend={sepBlock.negTerminal ? extend : undefined}
            />
          </g>
        )}

        <TerminalBlock
          x={sepBlock.posBlockPoint[0]}
          y={sepBlock.posBlockPoint[1] + extend}
          materialNode={sepBlock.recoveryNode.positiveMaterialNode}
        />

        {sepBlock.negTerminal ? (
          <TerminalBlock
            x={sepBlock.negBlockPoint[0]}
            y={sepBlock.negBlockPoint[1] + extend}
            materialNode={sepBlock.recoveryNode.negativeMaterialNode}
          />
        ) : null}

        <g onClick={() => setSelectedRecoveryGoaNode(sepBlock.recoveryNode)}>
          <rect
            className={classes.rgBox}
            data-selected={rgSelected}
            rx={3}
            y={-10}
            height={sepBlock.inputThickness + 20}
            width={'1.6em'}
            x={'-.8em'}
          />
          <text
            style={{ userSelect: 'none' }}
            transform={`translate(0, ${sepBlock.height / 2}) rotate(90)`}
            dominantBaseline='middle'
            textAnchor='middle'
            fontWeight={500}
          >
            {sepBlock.recoveryNode.recoveryGoal.name}
          </text>
        </g>
      </g>
    );
  },
);

function TerminalBlock(props: {
  x: number;
  y: number;
  materialNode: CommoditySpreadMaterialNodeDTO;
}) {
  const { x, y, materialNode } = props;

  const { selectedMaterialNode, setSelectedMaterialNode, feedTotal } =
    useRecoveryStrategySimulationCtx();

  const isSelected = selectedMaterialNode?.index === materialNode.index;

  const totalProb = Object.values(
    materialNode.materialClassProbabilities,
  ).reduce((a, b) => a + b, 0);
  const width = totalProb * SCALE_FACTOR;
  return (
    <g>
      <rect
        data-selected={isSelected}
        onClick={() => {
          setSelectedMaterialNode(materialNode);
        }}
        className={classes.terminalBlock}
        x={x}
        y={y}
        height={'1.5em'}
        fill={'black'}
        width={width}
      />

      <text
        x={x + width / 2}
        y={y + 12}
        style={{ userSelect: 'none', pointerEvents: 'none' }}
        dominantBaseline='middle'
        textAnchor='middle'
        fontWeight={500}
      >
        {feedTotal
          ? (feedTotal * totalProb).toFixed(1)
          : `${(totalProb * 100).toFixed(1)}%`}
      </text>
      <text
        x={x + width / 2}
        y={y + 40}
        style={{ userSelect: 'none', pointerEvents: 'none' }}
        dominantBaseline='middle'
        textAnchor='middle'
        fontWeight={500}
      >
        {materialNode.commodity?.name ?? 'Unspecified Commodity'}
      </text>
    </g>
  );
}

function MaterialPath(props: {
  path: SepBlockPath;
  materialNode: CommoditySpreadMaterialNodeDTO;
  color: string;
  direction: 'positive' | 'negative';
  extend: number | undefined;
}) {
  const { path, materialNode, color, direction, extend } = props;
  const {
    selectedMaterialClassId,
    setSelectedMaterialClassId,
    setSelectedMaterialNode,
  } = useRecoveryStrategySimulationCtx();

  const selected = path.materialClassId === selectedMaterialClassId;
  return (
    <path
      className={classes.materialPath}
      data-selected={selected}
      onClick={() => {
        setSelectedMaterialClassId(path.materialClassId);
        setSelectedMaterialNode(materialNode);
      }}
      stroke={
        path[direction].correct
          ? color
          : `url(#incorrect-${path.materialClassId})`
      }
      strokeWidth={path[direction].thickness}
      d={path[direction].path + (extend ? `v ${extend}` : '')}
    />
  );
}
