<script setup lang="ts">
import {
  useGroundingStore,
  type Claim,
  type ClaimAndSource,
} from '@/modules/Project/useGroundingStore'
import { invariant } from '@/shared/utils/typeAssertions'
import { getBackgroundRainbowColor } from '@/uiKit/RainbowColor'
import { useEventListener } from '@vueuse/core'
import type { Node } from 'prosemirror-model'
import { computed, ref, useId, useTemplateRef, watch } from 'vue'
import ProseMirror, { type ProseMirrorProps } from '../ProseMirror/ProseMirror.vue'
import { parser } from './parser'
import ProseMirrorClaim from './ProseMirrorClaim.vue'
import { schema } from './schema'
import { serializer } from './serializer'

type Props = Pick<ProseMirrorProps, 'placeholder' | 'value'>
const props = defineProps<Props & { claims: Claim[]; propertyId?: string }>()
const emit = defineEmits<{
  'show-source': [payload: ClaimAndSource]
}>()

// Generate unique IDs for aria labeling
const editorId = useId()

/**
 * The grounded text, with claims inserted at the correct positions. Claims are
 * serialized in a way that will be picked up and parsed by our custom
 * markdown-it plugin.
 */
const textWithClaims = computed(() => {
  type ClaimStartOrEnd = { position: number; type: 'start' | 'end' } & Claim
  const claimMarks = props.claims.reduce<Array<ClaimStartOrEnd>>((acc, claim) => {
    acc.push({ position: claim.start, type: 'start', ...claim })
    acc.push({ position: claim.end, type: 'end', ...claim })
    return acc
  }, [])

  const getSerializedClaims = (claim: ClaimStartOrEnd) => {
    if (claim.type === 'start') {
      return `<CLAIM_START:${claim.id}>`
    }

    if (claim.type === 'end') {
      // For each claim end, we create a CLAIM node for each source
      return claim.sources
        .map(
          (sourceId) =>
            `<CLAIM:${claim.id}:${sourceId}:${claim.start}:${claim.end}:${props.propertyId || ''}>`,
        )
        .join('')
    }
  }

  const toolValueWithClaims = claimMarks
    // Process the claims in reverse order to avoid changing the positions of
    // text that is yet to be processed
    .toSorted((a, b) => b.position - a.position)
    .reduce((text, claim) => {
      const insertClaimAt = claim.position
      return `${text.slice(0, insertClaimAt)}${getSerializedClaims(claim)}${text.slice(insertClaimAt)}`
    }, props.value)

  return toolValueWithClaims
})

const editor = useTemplateRef('editor')

/**
 * Apply grounded marks to all claims in the document
 */
const applyGroundedMarks = () => {
  const view = editor.value?.view
  invariant(view)

  const tr = view.state.tr
  // Keep track of claim start positions and end positions to apply marks after traversal
  const claimRanges: Array<{ start: number; end: number; claimId: string; sourceId: string }> = []
  let currentStart: { pos: number; claimId: string } | null = null

  // First pass: collect all ranges
  view.state.doc.nodesBetween(0, view.state.doc.content.size, (node: Node, pos: number) => {
    if (node.type.name === 'claim_start') {
      currentStart = { pos: pos + node.nodeSize, claimId: node.attrs.claimId }
    } else if (
      node.type.name === 'claim' &&
      currentStart &&
      currentStart.claimId === node.attrs.claimId
    ) {
      claimRanges.push({
        start: currentStart.pos,
        end: pos,
        claimId: node.attrs.claimId,
        sourceId: node.attrs.sourceId,
      })
      currentStart = null
    }
    return true
  })

  // Second pass: apply marks to all ranges
  claimRanges.forEach((range) => {
    const colorIndex = Number(range.sourceId) % 16
    const groundedMark = view.state.schema.marks.grounded.create({
      'data-focused': 'false',
      'data-color-index': String(colorIndex),
      'data-grounded': 'true',
    })
    tr.addMark(range.start, range.end, groundedMark)
  })

  view.dispatch(tr)
}

// Function to update the focused state of claims and sources
const updateFocusedClaim = (claimId: string | null, sourceId: string | null) => {
  const view = editor.value?.view
  invariant(view)

  const tr = view.state.tr

  // First pass: collect all ranges and update their focused state
  const claimRanges: Array<{ start: number; end: number; claimId: string; sourceId: string }> = []
  let currentStart: { pos: number; claimId: string } | null = null

  view.state.doc.nodesBetween(0, view.state.doc.content.size, (node: Node, pos: number) => {
    if (node.type.name === 'claim_start') {
      currentStart = { pos: pos + node.nodeSize, claimId: node.attrs.claimId }
    } else if (
      node.type.name === 'claim' &&
      currentStart &&
      currentStart.claimId === node.attrs.claimId
    ) {
      claimRanges.push({
        start: currentStart.pos,
        end: pos,
        claimId: node.attrs.claimId,
        sourceId: node.attrs.sourceId,
      })
      currentStart = null
    }
    return true
  })

  // Second pass: apply the focused state
  claimRanges.forEach((range) => {
    const isFocused = range.claimId === claimId
    const isSourceFocused = isFocused && range.sourceId === sourceId
    const colorIndex = Number(range.sourceId) % 16
    tr.addMark(
      range.start,
      range.end,
      view.state.schema.marks.grounded.create({
        'data-focused': isFocused ? 'true' : 'false',
        'data-source-focused': isSourceFocused ? 'true' : 'false',
        'data-source-id': range.sourceId,
        'data-color-index': String(colorIndex),
        'data-grounded': 'true',
      }),
    )
  })

  view.dispatch(tr)
}

// Apply grounded marks when the editor is mounted or when content changes
watch([() => editor.value?.view, textWithClaims], ([newView]) => {
  if (newView) {
    applyGroundedMarks()
  }
})

const groundingStore = useGroundingStore()
watch(
  () => groundingStore.field?.propertyId,
  (newPropertyId) => {
    if (newPropertyId !== props.propertyId) {
      updateFocusedClaim(null, null)
    }
  },
)

// Keep track of the currently focused claim and its cycle position
const focusedClaimInfo = ref<{ claimId: string; sourceId: string; cycleIndex: number } | null>(null)

// Function to get all claims and their sources at a given text position
const getClaimsAtPosition = (pos: number) => {
  const view = editor.value?.view
  invariant(view)

  // First collect all claim ranges to find which claims contain our position
  const claimRanges: Array<{ start: number; end: number; claimId: string }> = []
  let currentStart: { pos: number; claimId: string } | null = null

  view.state.doc.nodesBetween(0, view.state.doc.content.size, (node: Node, nodePos: number) => {
    if (node.type.name === 'claim_start') {
      currentStart = { pos: nodePos + node.nodeSize, claimId: node.attrs.claimId }
    } else if (
      node.type.name === 'claim' &&
      currentStart &&
      currentStart.claimId === node.attrs.claimId
    ) {
      claimRanges.push({
        start: currentStart.pos,
        end: nodePos,
        claimId: node.attrs.claimId,
      })
      currentStart = null
    }
    return true
  })

  // Find claims that contain our position and get their full data from props
  const claimsAtPosition = claimRanges
    .filter((range) => pos >= range.start && pos <= range.end)
    .map((range) => {
      // Find the original claim data which contains all sources
      const originalClaim = props.claims.find((c) => String(c.id) === range.claimId)
      invariant(originalClaim, `Could not find original claim data for claim ${range.claimId}`)

      return {
        claimId: String(originalClaim.id),
        sources: originalClaim.sources.map(String),
      }
    })

  return claimsAtPosition
}

// Function to cycle to the next claim or source
const cycleToNextClaim = (
  claims: Array<{ claimId: string; sources: string[] }>,
  currentFocus: { claimId: string; sourceId: string } | null,
) => {
  if (claims.length === 0) return null
  if (claims.length === 1 && claims[0].sources.length === 1) {
    return {
      claimId: claims[0].claimId,
      sourceId: claims[0].sources[0],
      cycleIndex: 0,
    }
  }

  // If no current selection, start with first claim and its first source
  if (!currentFocus) {
    return {
      claimId: claims[0].claimId,
      sourceId: claims[0].sources[0],
      cycleIndex: 0,
    }
  }

  // Find current claim and source indices
  const currentClaimIndex = claims.findIndex((c) => c.claimId === currentFocus.claimId)
  if (currentClaimIndex === -1) {
    // If current claim not found, start over
    return {
      claimId: claims[0].claimId,
      sourceId: claims[0].sources[0],
      cycleIndex: 0,
    }
  }

  const currentClaim = claims[currentClaimIndex]
  const currentSourceIndex = currentClaim.sources.indexOf(currentFocus.sourceId)

  // If there are more sources for the current claim, move to next source
  if (currentSourceIndex < currentClaim.sources.length - 1) {
    return {
      claimId: currentClaim.claimId,
      sourceId: currentClaim.sources[currentSourceIndex + 1],
      cycleIndex: currentSourceIndex + 1,
    }
  }

  // Otherwise, move to the next claim's first source
  const nextClaimIndex = (currentClaimIndex + 1) % claims.length
  return {
    claimId: claims[nextClaimIndex].claimId,
    sourceId: claims[nextClaimIndex].sources[0],
    cycleIndex: 0,
  }
}

// Update the click handler to handle cycling
useEventListener('click', (event: Event) => {
  if (!(event.target instanceof HTMLElement)) {
    return
  }

  // Check if the click target is within this component's DOM tree
  const view = editor.value?.view
  if (!view?.dom.contains(event.target)) {
    const sourceIndicator = event.target.closest('[data-source]')
    // If not clicking a source, clear the selected grounding state
    if (!sourceIndicator) {
      updateFocusedClaim(null, null)
      groundingStore.selectedClaimsAndSource = null
    }
    return
  }

  // Find either a claim pill or grounded text that was clicked
  const claimIndicator = event.target.closest('[data-claim-pill-claim]')
  const groundedText = event.target.closest('[data-grounded]')
  const sourceIndicator = event.target.closest('[data-source]')

  // If neither a pill nor grounded text was clicked
  if (!claimIndicator && !groundedText) {
    if (!sourceIndicator) {
      updateFocusedClaim(null, null)
      focusedClaimInfo.value = null
    }
    return
  }

  // If grounded text was clicked
  if (groundedText) {
    const pos = view.posAtDOM(groundedText, 0)
    const claims = getClaimsAtPosition(pos)

    // If no claims found at this position, try to find claims at the parent position
    if (claims.length === 0) {
      const parent = groundedText.parentElement
      if (parent) {
        const parentPos = view.posAtDOM(parent, 0)
        const parentClaims = getClaimsAtPosition(parentPos)
        if (parentClaims.length > 0) {
          const nextClaim = cycleToNextClaim(parentClaims, focusedClaimInfo.value)
          if (nextClaim) {
            focusedClaimInfo.value = nextClaim
            emit('show-source', {
              claimId: Number(nextClaim.claimId),
              sourceId: Number(nextClaim.sourceId),
            })
            updateFocusedClaim(nextClaim.claimId, nextClaim.sourceId)
          }
          return
        }
      }
    }

    // Get the next claim in the cycle
    const nextClaim = cycleToNextClaim(claims, focusedClaimInfo.value)
    if (nextClaim) {
      focusedClaimInfo.value = nextClaim
      emit('show-source', {
        claimId: Number(nextClaim.claimId),
        sourceId: Number(nextClaim.sourceId),
      })
      updateFocusedClaim(nextClaim.claimId, nextClaim.sourceId)
    }
    return
  }

  // Handle regular claim pill click
  const claimId = claimIndicator?.getAttribute('data-claim-pill-claim')
  const sourceId = claimIndicator?.getAttribute('data-claim-pill-source')

  if (!claimId || !sourceId) {
    return
  }

  // Reset cycle when clicking a pill directly
  focusedClaimInfo.value = { claimId, sourceId, cycleIndex: 0 }
  emit('show-source', {
    claimId: Number(claimId),
    sourceId: Number(sourceId),
  })
  updateFocusedClaim(claimId, sourceId)
})

// Create CSS variables for the colors
const colorStyles = computed(() => {
  const styles = Array.from({ length: 16 }, (_, i) => {
    const color = getBackgroundRainbowColor(`${props.propertyId}-${(i + 1) * 8}`)
    return [`--claim-color-${i}: linear-gradient(${color}, ${color})`]
  }).flat()
  return styles.join(';')
})

watch(
  () => groundingStore.selectedClaimsAndSource,
  (claimsAndSources) => {
    if (!claimsAndSources) {
      updateFocusedClaim(null, null)
      return
    }

    updateFocusedClaim(String(claimsAndSources.claimIds[0]), String(claimsAndSources.sourceId))
  },
)
</script>

<template>
  <ProseMirror
    :id="editorId"
    ref="editor"
    :style="colorStyles"
    :mode="{
      parser,
      serializer,
      schema,
      nodeViewOptions: [
        {
          name: 'claim',
          component: ProseMirrorClaim,
        },
      ],
    }"
    :value="textWithClaims"
    :readonly="claims.length > 0"
  />
</template>

<style lang="scss">
// List markers are rendered outside of the nodes to which we attach the `faint` mark,
// so we have this workaround fade them out.
li:has([data-faint])::marker {
  color: var(--color-text-disabled);
}

// Make the underline span across whitespace by using a background decoration instead
[data-grounded='true'] {
  text-decoration: none;
  background-position: 0 100%;
  background-repeat: repeat-x;
  background-size: 100% 100%;
  border-radius: 4px;
  margin-left: -2px;
  margin-right: -2px;
  padding-left: 2px;
  padding-right: 2px;
  color: inherit;
  cursor: pointer;

  @for $i from 0 through 16 {
    &[data-color-index='#{$i}']:hover {
      background-image: var(--claim-color-#{$i});
    }
    &[data-color-index='#{$i}'][data-focused='true'] {
      background-image: var(--claim-color-#{$i});
    }
  }
}
</style>
