import { BareIdentityGraphNodeType, BareNodesColumnsType } from '../identityGraphTypes.ts'
import { keyBy, pullAt } from 'lodash'
import { Edge } from '@xyflow/react'

const compactGroupingThreshold = 12

type NodeTypes = Exclude<BareIdentityGraphNodeType['type'], undefined>

type NodeGroupMapping<TSingleNode extends NodeTypes, TGroupNode extends NodeTypes> = {
	singleNode: TSingleNode
	groupNode: TGroupNode
	groupFunction: (
		data: Array<Extract<BareIdentityGraphNodeType, { type?: TSingleNode }>['data']>,
	) => Extract<BareIdentityGraphNodeType, { type?: TGroupNode }>['data']
}

function createNodeGroupMapping<T extends NodeTypes, U extends NodeTypes>(
	mapping: NodeGroupMapping<T, U>,
): NodeGroupMapping<T, U> {
	return mapping
}

const nodeGroupMappers = keyBy(
	[
		createNodeGroupMapping({
			singleNode: 'enrichedAzureRole',
			groupNode: 'enrichedAzureRoles',
			groupFunction: (data) => ({ roles: data.map((datum) => datum.role) }),
		}),
		createNodeGroupMapping({
			singleNode: 'awsIamRole',
			groupNode: 'awsIamRoles',
			groupFunction: (data) => ({ awsIamRoleXcs: data.map((datum) => datum.awsIamRoleXc) }),
		}),
		createNodeGroupMapping({
			singleNode: 'gcpProject',
			groupNode: 'gcpProjects',
			groupFunction: (data) => ({ principalInfoXcs: data.map((datum) => datum.principalInfoXc) }),
		}),
		createNodeGroupMapping({
			singleNode: 'githubRepository',
			groupNode: 'githubRepositories',
			groupFunction: (data) => ({ repositories: data.map((datum) => datum.repository) }),
		}),
		createNodeGroupMapping({
			singleNode: 'detailedEntraIdRole',
			groupNode: 'detailedEntraIdRoles',
			groupFunction: (data) => ({ roles: data.map((datum) => datum.role) }),
		}),
		createNodeGroupMapping({
			singleNode: 'azureSubscription',
			groupNode: 'azureSubscriptions',
			groupFunction: (data) => ({ subscriptions: data.map((datum) => datum.subscription) }),
		}),
		createNodeGroupMapping({
			singleNode: 'azureManagementGroup',
			groupNode: 'azureManagementGroups',
			groupFunction: (data) => ({ managementGroups: data.map((datum) => datum.managementGroup) }),
		}),
		createNodeGroupMapping({
			singleNode: 'entraIDUser',
			groupNode: 'entraIDUsers',
			groupFunction: (data) => ({ users: data.map((datum) => datum.user) }),
		}),
		createNodeGroupMapping({
			singleNode: 'entraIDServicePrincipal',
			groupNode: 'entraIDServicePrincipals',
			groupFunction: (data) => ({ servicePrincipals: data.map((datum) => datum.servicePrincipal) }),
		}),
		createNodeGroupMapping({
			singleNode: 'snowflakeRole',
			groupNode: 'snowflakeRoles',
			groupFunction: (data) => ({ roles: data.map((datum) => datum.role) }),
		}),
		createNodeGroupMapping({
			singleNode: 'awsPolicy',
			groupNode: 'awsPolicies',
			groupFunction: (data) => ({ policies: data.map((datum) => datum.policy) }),
		}),
		createNodeGroupMapping({
			singleNode: 'gcpRole',
			groupNode: 'gcpRoles',
			groupFunction: (data) => ({ roles: data.map((datum) => datum.role) }),
		}),
		createNodeGroupMapping({
			singleNode: 'salesforcePermissionSet',
			groupNode: 'salesforcePermissionSets',
			groupFunction: (data) => ({ permissionSets: data.map((datum) => datum.permissionSet) }),
		}),
	],
	'singleNode',
)

type MappedSingleNodeTypes = (typeof nodeGroupMappers)[string]['singleNode']

function nodeTypeInMap(nodeType?: string): nodeType is MappedSingleNodeTypes {
	return !!nodeType && nodeType in nodeGroupMappers
}

/**
 * Groups together nodes (replaces many 'single' nodes with one 'group' node) in the graph if:
 *  1. The nodes are in the same column.
 *  2. The nodes are of the same type.
 *  3. That type has a group mapping defined in `nodeGroupMappers`.
 *  4. The amount of nodes from the same type in the same column is over `compactGroupingThreshold`.
 */
export function groupNodes(nodeColumns: BareNodesColumnsType[], edges: Edge[]): [BareNodesColumnsType[], Edge[]] {
	// Keeping a mapping of old single node ID -> new group node ID for nodes that were grouped so we could later fix
	// the graph edges to point to and from the new grouped nodes.
	const singleNodeIdToGroupNodeId = new Map<string, string>()
	nodeColumns.forEach((nodeColumn) => {
		// First, iterate over every column and keep a mapping of node types that have a group mapping and in which
		// indices the nodes appear in the column. This will later be used to determine if the nodes need to be grouped
		// and to actually replace the nodes in the recorded indices.
		const typeToIndices = new Map<MappedSingleNodeTypes, number[]>()
		nodeColumn.nodes.forEach((node, index) => {
			if (!nodeTypeInMap(node.type)) {
				return
			}

			if (typeToIndices.has(node.type)) {
				typeToIndices.get(node.type)!.push(index)
			} else {
				typeToIndices.set(node.type, [index])
			}
		})

		// Since we're going to iteratively replace nodes in the graph, one type at a time, we'll keep a count of how
		// many nodes we removed from the graph so we can compensate index calculation.
		let removedNodesAmount = 0
		// Iterate over the mapping of nodes that can be grouped by type to their indices in the column, check if they
		// need to be grouped, and group them by removing the nodes and inserting a new grouped node with an aggregation
		// of the nodes' data.
		for (const [type, indices] of typeToIndices.entries()) {
			if (indices.length < compactGroupingThreshold) {
				continue
			}

			// We need to compensate for previous node types that have been grouped together that "shifted" the indices.
			const shiftedIndices = indices.map((index) => index - removedNodesAmount)
			const nodeGroupMapper = nodeGroupMappers[type]
			// The new grouped node ID will be based on the node ID of the first 'single' node in the group to be
			// replaced, since it will take its position in the graph.
			const newGroupedNodeId = `${nodeColumn.nodes[shiftedIndices[0]].id.split('-')[0]}-${shiftedIndices[0]}`
			const removedSingleNodes = pullAt(nodeColumn.nodes, shiftedIndices)
			removedSingleNodes.forEach((removedNode) => {
				singleNodeIdToGroupNodeId.set(removedNode.id, newGroupedNodeId)
			})

			// Create the new grouped node using the node group mapper that is defined in `nodeGroupMappers`.
			const newGroupedNode = {
				type: nodeGroupMapper.groupNode,
				id: newGroupedNodeId,
				// @ts-expect-error The compiler can't narrow down the type of the argument for `groupFunction`
				data: nodeGroupMapper.groupFunction(removedSingleNodes.map((removedNode) => removedNode.data)),
			} as BareIdentityGraphNodeType

			nodeColumn.nodes.splice(shiftedIndices[0], 0, newGroupedNode)
			removedNodesAmount += removedSingleNodes.length - 1
		}
	})

	const revisedEdges = edges.map((edge) => {
		const source = singleNodeIdToGroupNodeId.get(edge.source) || edge.source
		const target = singleNodeIdToGroupNodeId.get(edge.target) || edge.target
		return { ...edge, source, target }
	})

	return [nodeColumns, revisedEdges]
}
