import { D3Attr, DrawShape, IDrawEdge, IDrawNode, ITreeChartNode, TreeChart, drawEdgePath, drawNodeRect } from '@patsnap/synapse_tree_chart'
import { Selection } from 'd3'
declare module '@patsnap/synapse_tree_chart' {
  interface TreeChart<Datum> {
    highlight: (id?: string) => void
    highlightDrawNode: IDrawNode<Datum>
    highlightUpdateNode: IDrawNode<Datum>
    highlightDrawEdge: IDrawEdge<Datum>
    highlightUpdateEdge: IDrawEdge<Datum>
    _highlightIds: Set<string>
    _highlightId: string | undefined
  }
}

export interface IHighlightOptions {
  attr?: {
    [k: string]: D3Attr
  }
  defaultHighlightId?: string
  autoRegister?: {
    drawNode?: boolean
    updateNode?: boolean
    drawEdge?: boolean
    updateEdge?: boolean
  }
}
function highlight<Datum>(treeChart: typeof TreeChart, entity: TreeChart<Datum>, options?: IHighlightOptions) {
  const proto = treeChart.prototype

  function genHighlightIds(this: TreeChart<Datum>, datum?: ITreeChartNode<Datum>) {
    this._highlightIds.clear()
    this._highlightId = undefined
    if (!datum) return
    this._highlightId = this._options.dataKeyAccessor(datum.data)
    datum.ancestors().forEach((i) => {
      const identify = this._options.dataKeyAccessor(i.data)
      this._highlightIds.add(identify)
      if (this._highlightId !== identify) {
        i._collapsed = false
      }
    })
    datum.descendants().forEach((i) => this._highlightIds.add(this._options.dataKeyAccessor(i.data)))
  }
  proto.highlight = function (id) {
    const nodeGroup = this._nodeGroup
    const edgeGroup = this._edgeGroup
    // node 节点样式重置
    const nodes = nodeGroup.selectAll<SVGGElement, ITreeChartNode<Datum>>('g').filter((d) => {
      return this._highlightIds.has(this._options.dataKeyAccessor(d.data))
    })
    const rect = nodes.select<SVGGElement>('rect')
    drawNodeRect(rect)
    // 边样式重置
    const edges = edgeGroup
      .selectAll<SVGGElement, ITreeChartNode<Datum>>(`.${treeChart.EDGE_CLASS}`)
      .filter((d) => this._highlightIds.has(this._options.dataKeyAccessor(d.data)))
    const path = edges.select('path')
    drawEdgePath(path)
    const node = nodeGroup.select(`#${treeChart.ID_PREFIX}${id}`)

    if (!id || !node.node()) {
      this.render()
      return
    }
    const datum = node.datum()
    genHighlightIds.call(this, datum)

    this.render()
  }
  const drawHighlightNode: DrawShape<Datum, IDrawNode<Datum>> = function (this: TreeChart<Datum>, g, rect) {
    rect
      .filter((d) => this._highlightIds.has(this._options.dataKeyAccessor(d.data)))
      .attr('stroke', options?.attr?.stroke || '#1976D2')
      .filter((d) => this._options.dataKeyAccessor(d.data) === this._highlightId)
      .attr('fill', options?.attr?.fill || '#EDF4FC')
    return g
  }

  const raiseEdge = function (this: TreeChart<Datum>, g: Selection<SVGGElement, ITreeChartNode<Datum>, SVGGElement, undefined>) {
    g.filter((d) => this._highlightIds.has(this._options.dataKeyAccessor(d.data))).raise()
  }
  const updateHighlightEdge: DrawShape<Datum, IDrawEdge<Datum>> = function (this: TreeChart<Datum>, g, path) {
    path.filter((d) => this._highlightIds.has(this._options.dataKeyAccessor(d.data))).attr('stroke', options?.attr?.stroke || '#1976D2')
    return g
  }
  const drawHighlightEdge: DrawShape<Datum, IDrawEdge<Datum>> = function (this: TreeChart<Datum>, g, path) {
    updateHighlightEdge.call(this, g, path)
    return g
  }
  const afterDraw = function (this: TreeChart<Datum>) {
    const path = this._edgeGroup.selectAll<SVGGElement, ITreeChartNode<Datum>>(`.${TreeChart.EDGE_CLASS}`)
    raiseEdge.call(this, path)
  }
  const { drawNode = true, updateNode = true, drawEdge = true, updateEdge = true } = options?.autoRegister || {}
  proto.highlightDrawNode = drawHighlightNode
  proto.highlightUpdateNode = drawHighlightNode
  proto.highlightDrawEdge = drawHighlightEdge
  proto.highlightUpdateEdge = updateHighlightEdge
  const boundDrawHighlightNode = drawHighlightNode.bind(entity)
  const boundUpdateHighlightNode = drawHighlightNode.bind(entity)
  const boundDrawHighlightEdge = drawHighlightEdge.bind(entity)
  const boundUpdateHighlightEdge = updateHighlightEdge.bind(entity)
  updateNode && entity.addDrawNode(boundDrawHighlightNode, 'post')
  drawNode && entity.addUpdateNode(boundUpdateHighlightNode, 'post')
  updateEdge && entity.addDrawEdge(boundDrawHighlightEdge, 'post')
  drawEdge && entity.addUpdateEdge(boundUpdateHighlightEdge, 'post')
  const init = () => {
    entity._highlightIds = new Set()
    genHighlightIds.call(
      entity,
      entity._hierarchyNode.find((d) => entity._options.dataKeyAccessor(d.data) === options?.defaultHighlightId)
    )
  }
  entity._emitter.on('hierarchyReady', init)
  const boundRaiseEdge = raiseEdge.bind(entity)
  entity._emitter.on('edgeRendered', boundRaiseEdge)
  const boundAfterDraw = afterDraw.bind(entity)
  entity._emitter.on('mounted', boundAfterDraw)
}

export default {
  install: highlight,
}
