import React, { useEffect, useMemo, useRef } from 'react';
import {
  AxisScale,
  interpolateRgb,
  scaleBand,
  ScaleOrdinal,
  select,
  transition,
} from 'd3';
import { ChartDimensions, DEFAULT_ANIMATION_TIME } from './Util';
import { NoDataChart } from './NoDataChart';

type HeatMapData = {
  x: string | number;
  y: string | number;
  value: number;
};

export type HeatMapChartDimensions = ChartDimensions;

export type CustomDrawingFunction = (
  data: Array<HeatMapData>,
  chartDimensions?: HeatMapChartDimensions,
  scaleX?: AxisScale<string> | AxisScale<number>,
  scaleY?: AxisScale<string> | AxisScale<number>,
  colorScale?: ScaleOrdinal<string, unknown> | ScaleOrdinal<number, unknown>
) => React.ReactFragment;

interface IHeatMapChart {
  data: {
    xKeys: Array<string>;
    yKeys: Array<string>;
    defaultValue: number;
    data?: { [key: string]: { [key: string]: number } };
  };
  chartDimensions?: HeatMapChartDimensions;
}

export const HeatMap: React.FC<IHeatMapChart> = ({
  data,
  chartDimensions = {
    height: 272,
    width: 520,
    padding: {
      top: 20,
      bottom: 0,
      right: 0,
      left: 80,
    },
  },
}) => {
  const { width, height, padding } = chartDimensions;
  const svgRef = useRef<SVGSVGElement>(null);

  const xUniqueValues = data?.xKeys;
  const yUniqueValues = data?.yKeys;

  const xScale = useMemo(() => {
    return scaleBand()
      .range([padding.left, width - padding.left - padding.right])
      .domain(xUniqueValues as Array<string>);
  }, [padding.left, padding.right, width, xUniqueValues]);

  const yScale = useMemo(() => {
    return scaleBand()
      .range([padding.top, height - padding.top - padding.bottom])
      .domain(yUniqueValues as Array<string>);
  }, [height, padding.bottom, padding.top, yUniqueValues]);

  const colorScale = useMemo(() => interpolateRgb('#FDFDFD', '#9373FF'), []);

  const maxValue = useMemo(() => {
    let max = 1;
    xUniqueValues?.forEach((x) => {
      yUniqueValues?.forEach((y) => {
        const v = data?.data?.[y]?.[x] ?? data?.defaultValue;
        if (max < v) {
          max = v;
        }
      });
    });
    return max;
  }, [data?.data, data?.defaultValue, xUniqueValues, yUniqueValues]);

  useEffect(() => {
    if (!svgRef.current) return;

    const svg = select(svgRef.current);

    // Animate rectangles
    const rects = svg
      .selectAll<SVGRectElement, HeatMapData>('.heatmap-rect')
      .data(
        xUniqueValues.flatMap((x) =>
          yUniqueValues.map((y) => ({
            x,
            y,
            value: data?.data?.[y]?.[x] ?? data?.defaultValue,
          }))
        )
      )
      .join(
        (enter) =>
          enter
            .append('rect')
            .attr('class', 'heatmap-rect')
            .attr('x', (d) => xScale(d.x as string) ?? 0)
            .attr('y', (d) => yScale(d.y as string) ?? 0)
            .attr('width', xScale.bandwidth())
            .attr('height', yScale.bandwidth())
            .attr('fill', 'rgba(255, 255, 255, 0)')
            .attr('stroke', 'rgba(100, 116, 139, 0.1)')
            .attr('stroke-width', 1),
        (update) => update,
        (exit) => exit.remove()
      );

    rects
      .transition(transition().duration(DEFAULT_ANIMATION_TIME))
      .attr('fill', (d) => colorScale(d.value / maxValue));

    // Draw X labels
    svg
      .selectAll('.x-label')
      .data(xUniqueValues)
      .join('text')
      .attr('class', 'x-label')
      .attr('x', (d) => (xScale(d as string) ?? 0) + xScale.bandwidth() / 2)
      .attr('y', height - padding.top - padding.bottom + 4)
      .attr('text-anchor', 'middle')
      .attr('dominant-baseline', 'hanging')
      .attr('font-size', 12)
      .attr('fill', 'rgba(107, 114, 128, 1)')
      .text((d) => d);

    // Draw Y labels
    svg
      .selectAll('.y-label')
      .data(yUniqueValues)
      .join('text')
      .attr('class', 'y-label')
      .attr('x', padding.left - 8)
      .attr('y', (d) => (yScale(d as string) ?? 0) + yScale.bandwidth() / 2)
      .attr('text-anchor', 'end')
      .attr('dominant-baseline', 'middle')
      .attr('font-size', 12)
      .attr('fill', 'rgba(107, 114, 128, 1)')
      .text((d) => d);

    // Draw legend
    const legend = svg.select('.legend');
    if (legend.empty()) {
      const newLegend = svg.append('g').attr('class', 'legend');
      newLegend
        .append('rect')
        .attr('x', width - 20)
        .attr('y', padding.top + 8)
        .attr('width', 16)
        .attr('height', height - padding.top - padding.bottom - 32)
        .attr('fill', 'url(#heatmapGradient)');

      newLegend
        .append('text')
        .attr('x', width - 24)
        .attr('y', padding.top + 16)
        .attr('text-anchor', 'end')
        .attr('dominant-baseline', 'hanging')
        .attr('font-size', 12)
        .attr('fill', 'rgba(107, 114, 128, 1)')
        .text('High');

      newLegend
        .append('text')
        .attr('x', width - 24)
        .attr('y', height - padding.top - padding.bottom - 12)
        .attr('text-anchor', 'end')
        .attr('dominant-baseline', 'baseline')
        .attr('font-size', 12)
        .attr('fill', 'rgba(107, 114, 128, 1)')
        .text('Low');
    }
  }, [
    xUniqueValues,
    yUniqueValues,
    xScale,
    yScale,
    colorScale,
    maxValue,
    data,
    height,
    width,
    padding,
  ]);

  return xUniqueValues.length === 0 || yUniqueValues.length === 0 ? (
    <NoDataChart chartDimensions={chartDimensions} />
  ) : (
    <svg ref={svgRef} viewBox={`0 0 ${width} ${height}`}>
      <defs>
        <linearGradient id="heatmapGradient" x1="0%" x2="0%" y2="0%" y1="100%">
          <stop offset="0%" stopColor="rgba(216, 204, 255, 1)" />
          <stop offset="50%" stopColor="rgba(170, 147, 255, 1)" />
          <stop offset="100%" stopColor="rgba(138, 103, 255, 1)" />
        </linearGradient>
      </defs>
    </svg>
  );
};
