import React from "react";
import { useRef, useEffect } from "react";
import * as d3 from "d3";
import * as d3Sankey from "d3-sankey";
import { useDarkModeActive } from "theme";

function SankeyChart(
  { links, svgRef },
  {
    nodes,
    format = ",",
    align = "justify",
    nodeId = d => d,
    nodeGroup,
    nodeGroups,
    nodeLabel,
    nodeTitle = d => null,
    nodeAlign = align,
    nodeWidth = 25,
    nodePadding = 38,
    nodeLabelPadding = 6,
    nodeStroke = "currentColor",
    nodeStrokeWidth,
    nodeStrokeOpacity,
    nodeStrokeLinejoin,
    linkSource = ({ source }) => source,
    linkTarget = ({ target }) => target,
    linkValue = ({ value }) => value,
    linkPath = d3Sankey.sankeyLinkHorizontal(),
    linkTitle = d => null,
    linkColor = "source-target",
    linkStrokeOpacity = 0.5,
    linkMixBlendMode = "normal",
    colors = d3.schemeTableau10,
    width = 640,
    height = 400,
    marginTop = 50,
    marginRight = 300,
    marginBottom = 50,
    marginLeft = 300,
    labelBackgroundPadding = 5,
    labelBackgroundColor = "rgb(255, 255, 255, 0.8)",
    fontSize = 25
  } = {}
) {
  if (typeof nodeAlign !== "function")
    nodeAlign =
      {
        left: d3Sankey.sankeyLeft,
        right: d3Sankey.sankeyRight,
        center: d3Sankey.sankeyCenter
      }[nodeAlign] || d3Sankey.sankeyJustify;

  const LS = d3.map(links, linkSource).map(intern);
  const LT = d3.map(links, linkTarget).map(intern);
  const LV = d3.map(links, linkValue);

  if (nodes === undefined) {
    nodes = [];

    for (const item of LS) {
      if (nodes.find(temp => temp.id === item.id) === undefined) {
        nodes.push(item);
      }
    }
    for (const item of LT) {
      if (nodes.find(temp => temp.id === item.id) === undefined) {
        nodes.push(item);
      }
    }
  }

  const N = d3.map(nodes, nodeId).map(intern);
  const G = nodeGroup == null ? null : d3.map(nodes, nodeGroup).map(intern);

  nodes = d3.map(nodes, (_, i) => {
    return {
      id: N[i].id,
      label: N[i].label,
      position: N[i].position,
      description: N[i].description,
      percent: N[i].percent,
      linkColor: N[i].linkColor,
      isOthersLabel: N[i].isOthersLabel
    };
  });
  links = d3.map(links, (_, i) => ({
    source: LS[i].id,
    target: LT[i].id,
    value: LV[i]
  }));

  if (!G && ["source", "target", "source-target"].includes(linkColor)) linkColor = "currentColor";

  if (G && nodeGroups === undefined) nodeGroups = G;

  const color = nodeGroup == null ? null : d3.scaleOrdinal(nodeGroups, colors);
  d3Sankey
    .sankey()
    .nodeId(({ index: i }) => N[i].id)
    .nodeAlign(nodeAlign)
    .nodeWidth(nodeWidth)
    .nodePadding(nodePadding)
    .extent([[marginLeft, marginTop], [width - marginRight, height - marginBottom]])({ nodes, links });

  if (typeof format !== "function") format = d3.format(format);
  const Tl = nodeLabel === undefined ? N : nodeLabel == null ? null : d3.map(nodes, nodeLabel);
  const Tt = nodeTitle == null ? null : d3.map(nodes, nodeTitle);
  const Lt = linkTitle == null ? null : d3.map(links, linkTitle);

  const uid = `O-${Math.random()
    .toString(16)
    .slice(2)}`;

  const svg = d3.select(svgRef.current);

  svg.selectAll("*").remove();

  svg
    .attr("width", width)
    .attr("height", height)
    .attr("viewBox", [0, 0, width, height])
    .attr("style", "max-width: 100%; height: auto; height: intrinsic; overflow: visible;");

  const totalLayers = nodes.reduce((prev, current) =>
    prev && prev.layer > current.layer ? prev.layer : current.layer
  );

  const node = svg
    .append("g")
    .attr("stroke", nodeStroke)
    .attr("stroke-width", nodeStrokeWidth)
    .attr("stroke-opacity", nodeStrokeOpacity)
    .attr("stroke-linejoin", nodeStrokeLinejoin)
    .selectAll("rect")
    .data(nodes)
    .join("rect")
    .attr("x", d => d.x0)
    .attr("y", d => d.y0)
    .attr("height", d => d.y1 - d.y0)
    .attr("width", d => d.x1 - d.x0);

  if (G) node.attr("fill", ({ index: i }) => color(G[i]));
  if (Tt) node.append("title").text(({ index: i }) => Tt[i]);

  const link = svg
    .append("g")
    .attr("fill", "none")
    .attr("stroke-opacity", linkStrokeOpacity)
    .selectAll("g")
    .data(links)
    .join("g")
    .style("mix-blend-mode", linkMixBlendMode);

  if (linkColor === "source-target")
    link
      .append("linearGradient")
      .attr("id", d => `${uid}-link-${d.index}`)
      .attr("gradientUnits", "userSpaceOnUse")
      .attr("x1", d => d.source.x1)
      .attr("x2", d => d.target.x0)
      .call(gradient =>
        gradient
          .append("stop")
          .attr("offset", "100%")
          .attr("stop-color", ({ source: { index: i } }) => color(G[i]))
      );

  link
    .append("path")
    .attr("d", linkPath)
    .attr(
      "stroke",
      linkColor === "source-target"
        ? d => {
            if (d.target.linkColor === "source") {
              return color(G[d.source.index]);
            } else if (d.target.linkColor === "target") {
              return color(G[d.target.index]);
            }
            return `url(#${uid}-link-${d.index})`;
          }
        : linkColor === "source"
        ? ({ source: { index: i } }) => color(G[i])
        : linkColor === "target"
        ? ({ target: { index: i } }) => color(G[i])
        : linkColor
    )
    .attr("stroke-width", ({ width }) => Math.max(1, width))
    .call(Lt ? path => path.append("title").text(({ index: i }) => Lt[i]) : () => {});

  if (Tl) {
    const labelParent = svg
      .append("g")
      .attr("font-family", "Inter")
      .attr("font-size", fontSize);

    const labelBackground = labelParent
      .selectAll("rect")
      .data(nodes)
      .join("rect")
      .attr("fill", labelBackgroundColor)
      .style("opacity", 1);

    let chartHoverTipDiv = d3.select("#tooltip-donut");
    if (!d3.select("#tooltip-donut").node()) {
      chartHoverTipDiv = d3
        .select("body")
        .append("div")
        .attr("id", "tooltip-donut")
        .style("opacity", 0)
        .style("display", "none");
    }

    const labels = labelParent
      .selectAll("text")
      .data(nodes)
      .join("text")
      .attr("x", d => {
        if (d.position === "middle") {
          return d.x0;
        }
        if (d.depth === 0) {
          return d.x1 - nodeWidth - 10;
        } else if (d.depth === totalLayers) {
          return d.x0 + nodeWidth + 10;
        }
        return d.x0;
      })
      .attr("y", d => (d.y1 + d.y0) / 2)
      .attr("dy", "0.35em")
      .attr("text-anchor", d => {
        if (!d.position === false) {
          return d.position;
        }
        if (d.depth === 0) {
          return "end";
        } else if (d.depth === totalLayers) {
          return "start";
        }
        return "middle";
      })
      .attr("cursor", "default")
      .on("mouseover", function(event, d) {
        chartHoverTipDiv
          .transition()
          .duration(50)
          .style("opacity", 1)
          .style("display", "block");

        chartHoverTipDiv
          .html(`${d.description || d.label} ${d.percent ? `(${d.percent}%)` : ""}`)
          .style("z-index", 10000)
          .style("left", event.target.getBoundingClientRect().left + "px")
          .style("top", event.target.getBoundingClientRect().top + 25 + "px");
      })
      .on("mouseout", function(d, i) {
        chartHoverTipDiv
          .transition()
          .duration(50)
          .style("opacity", 0)
          .style("display", "none");
      })
      .text(({ index: i }) => Tl[i]);

    labelParent
      .selectAll("text")
      .join(nodes)
      .each(function(d) {
        d.bbox = this.getBBox();
      });

    labels.text(d => {
      return nodeLabel(d);
    });

    labelParent
      .selectAll("text")
      .join(nodes)
      .each(function(d) {
        d.bbox = this.getBBox();
      });

    labelBackground
      .attr("x", d => d.bbox.x - labelBackgroundPadding)
      .attr("y", d => d.bbox.y - labelBackgroundPadding)
      .attr("rx", d => 5)
      .attr("ry", d => 5)
      .attr("height", d => d.bbox.height + 2 * labelBackgroundPadding)
      .attr("width", d => d.bbox.width + 2 * labelBackgroundPadding);
  }

  function intern(value) {
    return value !== null && typeof value === "object" ? value.valueOf() : value;
  }
  Object.assign(svg.node(), { scales: { color } });
}

function SankeyChartComponent({
  className,
  data,
  nodeAlign,
  formatValue,
  maxLayerDepth,
  nodePadding = 38,
  fontSize = 25,
  margin = { left: 400, right: 250, top: 50, bottom: 50 },
  respectDarkMode = true
}) {
  var svgRef = useRef(null);
  const isDarkModeActive = useDarkModeActive();

  useEffect(() => {
    SankeyChart(
      { links: data, svgRef },
      {
        nodePadding: nodePadding,
        height: nodePadding * maxLayerDepth + 500,
        width: 2000,
        nodeGroup: d => d.id.split(/\W/)[0],
        nodeLabel: d => {
          if (!d.bbox || !d.bbox.width || d.bbox.x >= 0 || !d.isOthersLabel === false) {
            return `${d.label}: ${formatValue(d.value)}`;
          }
          var nameComponents = d.label.split(" ");
          var nameLabel = nameComponents.slice(0, 2).join(" ");
          if (nameComponents.length > 2) {
            nameLabel += " ...";
          }
          return `${nameLabel}: ${formatValue(d.value)}`;
        },
        nodeAlign: nodeAlign,
        nodeStroke: "transparent",
        format: formatValue,
        fontSize: fontSize,
        marginLeft: margin.left,
        marginRight: margin.right,
        marginTop: margin.top,
        marginBottom: margin.bottom,
        labelBackgroundColor: respectDarkMode && isDarkModeActive ? "rgb(0, 0, 0, 0.6)" : "rgb(255, 255, 255, 0.8)"
      }
    );
  }, [data, nodeAlign, formatValue, maxLayerDepth, nodePadding, fontSize, margin, isDarkModeActive, respectDarkMode]);

  return (
    <div id="sankey-chart" className={className}>
      <svg ref={svgRef} />
    </div>
  );
}

export default SankeyChartComponent;
