import React, { useMemo, useRef } from 'react'

import { getBlueGradientColors, getRainbowColors } from 'src/utils/colors'

import { path } from 'd3'

type Coordinate = {
  x: number
  y: number
}

export type Node = {
  content: React.ReactNode
  selected?: boolean
  invisible?: boolean
}

export type TreeSVGProps = {
  nodes: Node[][]
  colorful?: boolean
  props?: React.SVGProps<SVGSVGElement>
}

const _drawPath = (source: Coordinate, target: Coordinate) => {
  const midX = (source.x + target.x) / 2

  const arcRadius = midX - source.x

  const _path = path()

  if (source.y === target.y) {
    _path.moveTo(source.x, source.y)
    _path.lineTo(target.x, target.y)
    return _path.toString()
  }

  const isDown = source.y < target.y
  const direction = isDown ? 1 : -1
  const arcDirection = direction * arcRadius

  _path.moveTo(source.x, source.y)
  _path.arc(source.x, source.y + arcDirection, Math.abs(arcRadius), (-Math.PI / 2) * direction, 0, !isDown)
  _path.lineTo(midX, target.y - arcDirection)
  _path.arc(target.x, target.y - arcDirection, Math.abs(arcRadius), Math.PI * direction, (Math.PI / 2) * direction, isDown)

  return _path.toString()
}

const TreeSVG: React.FC<TreeSVGProps> = ({ nodes, props, colorful }) => {
  const ref = useRef<SVGSVGElement>(null)

  const colors = useMemo(() => {
    const getColors = colorful ? getRainbowColors : getBlueGradientColors
    return getColors(8)
  }, [colorful])

  const _nodes = useMemo(() => {
    return nodes.map((columnNodes, columnIndex) =>
      columnNodes.map((node, rowIndex) => ({ ...node, position: { x: columnIndex * 256, y: columnIndex === 0 ? 200 : rowIndex * 100 } }))
    )
  }, [nodes])

  const _paths = useMemo(() => {
    const firstColumnWithSelected = _nodes.findLastIndex((columnNodes) => columnNodes.some((node) => node.selected))

    const pathPoints = _nodes.flatMap((columnNodes, columnIndex) => {
      const selectedNode = _nodes[columnIndex - 1]?.find((node) => node.selected)
      const noSelected = columnNodes.some((node) => node.selected === undefined)
      if (!selectedNode || noSelected) return []

      const selectedNodePathPosition = { x: selectedNode.position.x + 160, y: selectedNode.position.y + 12 }

      const paths = columnNodes
        .filter((n) => !n.invisible)
        .flatMap((node) => {
          const nodePathPosition = { x: node.position.x, y: node.position.y + 12 }
          return { selected: node.selected, path: _drawPath(selectedNodePathPosition, nodePathPosition), columnIndex: columnIndex }
        })
      return paths
    })

    const sortedPathPoints = pathPoints.sort((a) => (a.selected ? 1 : -1))

    return sortedPathPoints.map((point, index) => {
      return (
        <path
          key={index}
          d={point.path}
          stroke={point.selected || point.columnIndex === firstColumnWithSelected + 1 ? colors[colors.length - 1] : 'lightgray'}
          fill="transparent"
        />
      )
    })
  }, [_nodes, colors])

  return (
    <svg {...props} ref={ref}>
      {_paths}
      {_nodes.map((column) =>
        column.map((node, index) => (
          <foreignObject key={index} x={node.position.x} y={node.position.y} width="160" height="96">
            {node.content}
          </foreignObject>
        ))
      )}
    </svg>
  )
}

export default TreeSVG
