import React, { useCallback, useMemo } from 'react';
import { AxisScale, interpolateRgb, scaleBand, ScaleOrdinal } from 'd3';
import { ChartDimensions } from './Util';
import { NoDataChart } from './NoDataChart';

type HeatMapData = {
  x: string | number;
  y: string | number;
  value: number;
};

export type HeatMapChartDimensions = ChartDimensions;

export type CustomDrawingFunction = (
  data: Array<HeatMapData>,
  chartDimensions?: HeatMapChartDimensions,
  scaleX?: AxisScale<string> | AxisScale<number>,
  scaleY?: AxisScale<string> | AxisScale<number>,
  colorScale?: ScaleOrdinal<string, unknown> | ScaleOrdinal<number, unknown>
) => React.ReactFragment;

interface IHeatMapChart {
  data: {
    xKeys: Array<string>;
    yKeys: Array<string>;
    defaultValue: number;
    data?: { [key: string]: { [key: string]: number } };
  };
  chartDimensions?: HeatMapChartDimensions;
}

export const HeatMap: React.FC<IHeatMapChart> = ({
  data,
  chartDimensions = {
    height: 272,
    width: 520,
    padding: {
      top: 20,
      bottom: 0,
      right: 0,
      left: 80,
    },
  },
}) => {
  const { width, height, padding } = chartDimensions;

  const xUniqueValues = data?.xKeys;
  const yUniqueValues = data?.yKeys;

  const xScale = useMemo(() => {
    return scaleBand()
      .range([padding.left, width - padding.left - padding.right])
      .domain(xUniqueValues as Array<string>);
  }, [padding.left, padding.right, width, xUniqueValues]);

  const yScale = useMemo(() => {
    return scaleBand()
      .range([padding.top, height - padding.top - padding.bottom])
      .domain(yUniqueValues as Array<string>);
  }, [height, padding.bottom, padding.top, yUniqueValues]);

  const colorScale = useMemo(() => interpolateRgb('#FDFDFD', '#9373FF'), []);

  const maxValue = useMemo(() => {
    let max = 1;
    xUniqueValues?.forEach((x) => {
      yUniqueValues?.forEach((y) => {
        const v = data?.data?.[y]?.[x] ?? data?.defaultValue;
        if (max < v) {
          max = v;
        }
      });
    });
    return max;
  }, [data?.data, data?.defaultValue, xUniqueValues, yUniqueValues]);

  const drawRects = useCallback(
    () =>
      xUniqueValues
        .map((x) =>
          yUniqueValues.map((y) => {
            const v = data?.data?.[y]?.[x] ?? data?.defaultValue;
            return (
              <rect
                key={`d-${x}-${y}-${v}`}
                x={xScale(x as string)}
                y={yScale(y as string)}
                width={xScale.bandwidth()}
                height={yScale.bandwidth()}
                fill={colorScale(v / maxValue)}
                stroke="rgba(100, 116, 139, 0.1)"
                strokeWidth={1}
              />
            );
          })
        )
        .flat(),
    [
      colorScale,
      data?.data,
      data?.defaultValue,
      maxValue,
      xScale,
      xUniqueValues,
      yScale,
      yUniqueValues,
    ]
  );

  const drawXLabels = useCallback(
    () =>
      xUniqueValues.map((label) => {
        const x = xScale(label as string) ?? 0;

        return (
          <text
            key={label}
            x={x + xScale.bandwidth() / 2}
            y={height - padding.top - padding.bottom + 4}
            textAnchor="middle"
            dominantBaseline="hanging"
            fontSize={12}
            fill="rgba(107, 114, 128, 1)"
          >
            {label}
          </text>
        );
      }),
    [height, padding.bottom, padding.top, xScale, xUniqueValues]
  );

  const drawYLabels = useCallback(
    () =>
      yUniqueValues.map((label) => {
        const y = yScale(label as string) ?? 0;

        return (
          <text
            key={label}
            x={padding.left - 8}
            y={y + yScale.bandwidth() / 2}
            textAnchor="end"
            dominantBaseline="middle"
            fontSize={12}
            fill="rgba(107, 114, 128, 1)"
          >
            {label}
          </text>
        );
      }),
    [padding.left, yScale, yUniqueValues]
  );

  const drawLegend = useCallback(
    () => (
      <g>
        <rect
          x={width - 20}
          y={padding.top + 8}
          width={16}
          height={height - padding.top - padding.bottom - 32}
          fill="url(#heatmapGradient)"
        />
        <text
          x={width - 24}
          y={padding.top + 16}
          textAnchor="end"
          dominantBaseline="hanging"
          fontSize={12}
          fill="rgba(107, 114, 128, 1)"
        >
          High
        </text>
        <text
          x={width - 24}
          y={height - padding.top - padding.bottom - 12}
          textAnchor="end"
          dominantBaseline="baseline"
          fontSize={12}
          fill="rgba(107, 114, 128, 1)"
        >
          Low
        </text>
      </g>
    ),
    [height, padding.bottom, padding.top, width]
  );

  return xUniqueValues.length === 0 || yUniqueValues.length === 0 ? (
    <NoDataChart chartDimensions={chartDimensions} />
  ) : (
    <svg viewBox={`0 0 ${width} ${height}`}>
      <defs>
        <linearGradient id="heatmapGradient" x1="0%" x2="0%" y2="0%" y1="100%">
          <stop offset="0%" stopColor="rgba(216, 204, 255, 1)" />
          <stop offset="50%" stopColor="rgba(170, 147, 255, 1)" />
          <stop offset="100%" stopColor="rgba(138, 103, 255, 1)" />
        </linearGradient>
      </defs>
      {drawRects()}
      {drawXLabels()}
      {drawYLabels()}
      {drawLegend()}
    </svg>
  );
};
