import difference from 'lodash/difference'
import {
  create,
  parseDependencies,
  addDependencies,
  subtractDependencies,
  divideDependencies,
  multiplyDependencies,
  EvalFunction,
  MathNode, SymbolNode, ParenthesisNode, FunctionNode, ArrayNode, OperatorNode
} from 'mathjs'

import { CompositeDimension, Dimension } from '../dimensions'

import { NumberReducer, Reducer, ReducerConstructor, SerializedReducer } from '.'

const { parse } = create({ parseDependencies, addDependencies, subtractDependencies, divideDependencies, multiplyDependencies })

type Condition = {
  dimensions: string[],
  separator: string
}

/**
 * Special cases are specific metrics that need to be handled differently from regular one.
 * They have a hardcoded behavior and are identified by a metric name in the `SPECIAL_CASES` constant.
 */
interface SpecialCase {
  add(p: Record<string, any>, key: string, v: Record<string, any>): void
  remove(p: Record<string, any>, key: string, v: Record<string, any>): void
  post(scope: Record<string, any>, p: Record<string, any>, key: string): void
}

type SpecialCaseConstructor = new (dimension?: Dimension) => SpecialCase

function findSymbols (node: MathNode, array: string[]): void {
  const genericNode = node as any
  if (genericNode.isOperatorNode) {
    (node as OperatorNode).args.forEach(a => findSymbols(a, array))
  } else if (genericNode.isArrayNode) {
    (node as ArrayNode).items.forEach(i => findSymbols(i, array))
  } else if (genericNode.isFunctionNode) {
    (node as FunctionNode).args.forEach(a => findSymbols(a, array))
  } else if (genericNode.isParenthesisNode) {
    findSymbols((node as ParenthesisNode).content, array)
  } else if (genericNode.isSymbolNode) {
    const name = (node as SymbolNode).name
    if (!name.startsWith('$')) {
      throw new Error(`FormulaReducer: invalid formula: found symbol "${name}" that doesn't start with "$"`)
    }
    if (!array.includes(name)) {
      array.push(name)
    }
  }
}

/**
 * Versatile reducer taking a mathematical formula (string) as an input.
 * The formula is parsed and executed by the MathJS library (https://mathjs.org/) and supports
 * variables, identified by a dollar symbol (`$`).
 * e.g: `$impressions`, `$impressions * 1000`, `$impressions / $adcalls`, ...
 * The variable value is the sum of that value for the crossfilter group.
 *
 * `$adcalls` is an exception: it is always computed as a global sum.
 * If the reduced data is for a composite dimension containing the following pairs
 * of dimensions, a condition is added:
 * - adresponseBidder and whitelistedBidders
 * - adresponseMediatype and inventoryMediatypes
 * - size and sizes
 * The dimensions will be removed from the adcalls global sum dimensions
 * (e.g.: orgaID, adresponseBidder, whitelistedBidders -> adcalls aggregated by orgaID only).
 * The sum condition works as follows: sum of adcalls for for all lines in the group
 * (global sum dimensions like explained above) having the first dimension in the pair contained
 * in the second dimension in the pair.
 * e.g.: sum(adcalls) where orgaID = 1004 and whitelistedBidders contains '33across'
 *
 * To achieve this, the reducer keeps in memory the sum in a map. The key is the value of the
 * global sum dimensions concatenated with all the possible combinations of the values of the
 * conditions. (e.g.: '1004~33across~ban', '1004~33across~nat', '1004~33across~vidout', ...).
 * At the end of reducing (in the `post()` function), the correct sum of adcalls is retrieved
 * by reconstructing the corresponding key based on the dimensions included in the result record.
 *
 * The formula is only executed in the `post()` stage. `add()` and `reduce()` only measure the sums of
 * all the necessary metrics.
 */
export class FormulaReducer extends NumberReducer implements Reducer {
  private compiledFormula: EvalFunction
  private metrics: string[]
  private sums: Map<Record<string, any>, Record<string, any>>

  private specialCases: Record<string, SpecialCase>

  constructor (formula: string, dimension?: Dimension) {
    super()
    const node = parse(formula)
    this.compiledFormula = node.compile()
    this.metrics = []
    this.sums = new Map()
    findSymbols(node, this.metrics)
    this.specialCases = {}
    this.metrics.forEach(m => {
      const Case = SPECIAL_CASES[m]
      if (Case !== undefined) {
        this.specialCases[m] = new Case(dimension)
      }
    })
  }

  add (p: Record<string, any>, key: string, v: Record<string, any>): void {
    let sumObj = this.sums.get(p)
    if (sumObj === undefined) {
      sumObj = {}
      this.sums.set(p, sumObj)
    }
    this.metrics.forEach(metric => {
      const specialCase = this.specialCases[metric]
      if (specialCase !== undefined) {
        specialCase.add(p, key, v)
      } else {
        if (sumObj![metric] === undefined) {
          sumObj![metric] = 0
        }
        sumObj![metric] += v[metric.slice(1)]
      }
    })
  }

  remove (p: Record<string, any>, key: string, v: Record<string, any>): void {
    const sumObj = this.sums.get(p)!
    this.metrics.forEach(metric => {
      const specialCase = this.specialCases[metric]
      if (specialCase !== undefined) {
        specialCase.remove(p, key, v)
      } else {
        sumObj![metric] -= v[metric.slice(1)]
      }
    })
  }

  post (p: Record<string, any>, key: string): void {
    const scope: Record<string, any> = {}
    this.metrics.forEach(metric => {
      const specialCase = this.specialCases[metric]
      if (specialCase !== undefined) {
        specialCase.post(scope, p, key)
      } else {
        const sumObj = this.sums.get(p)
        if (sumObj) {
          scope[metric] = sumObj[metric]
        }
      }
    })
    try {
      p[key] = this.compiledFormula.evaluate(scope)
    } catch (e: any) {
      p[key] = NaN
    }
  }

  static deserialize (m: SerializedReducer): ReducerConstructor {
    if (m.params === undefined || m.params.formula === undefined) {
      throw new Error('Invalid SerializedReducer (FormulaReducer): missing parameters')
    }
    return FormulaReducerConstructor(m.params.formula)
  }
}

/**
 * @param formula the math formula used by the reducer
 * @returns FormilaReducer constructor using this metric
 */
export function FormulaReducerConstructor (formula: string): ReducerConstructor {
  let node: MathNode
  try {
    node = parse(formula)
  } catch (e: any) {
    throw new Error('Invalid SerializedReducer (FormulaReducer): invalid formula: ' + e)
  }
  findSymbols(node, []) // Make sure the formula is valid

  return class extends FormulaReducer {
    constructor (dimension?: Dimension) {
      super(formula, dimension)
    }

    static serialize (): SerializedReducer {
      return {
        constructor: 'FormulaReducer',
        params: {
          formula
        }
      }
    }
  }
}

// -------------------------------
// SPECIAL CASES

function cartesian (...arrays: string[][]): string[] {
  const remainder = arrays.length > 1 ? cartesian(...arrays.slice(1)) : ['']
  return arrays[0].flatMap(a => remainder.map(b => b ? `${a}~${b}` : a))
}

abstract class GlobalSumSpecialCase {
  protected sumDimensions: string[]
  protected globalSum: Record<string, number>

  constructor () {
    this.sumDimensions = []
    this.globalSum = {}
  }

  getSumKey (p: Record<string, any>): string {
    return this.sumDimensions.map(d => p[d]).join('~')
  }

  addToGlobalSum (key: string, value: number): void {
    if (this.globalSum[key] === undefined) {
      this.globalSum[key] = 0
    }
    this.globalSum[key] += value
  }
}

class AdCallsSpecialCase extends GlobalSumSpecialCase implements SpecialCase {
  private static SPECIAL_DIMENSIONS: Condition[] = [
    { dimensions: ['adresponseBidder', 'whitelistedBidders'], separator: ',' },
    { dimensions: ['adresponseMediatype', 'inventoryMediatypes'], separator: '|' },
    { dimensions: ['size', 'sizes'], separator: ',' },
    { dimensions: ['seatID', 'seatType'], separator: ',' }
  ]

  private conditions: Condition[] // Adcalls conditions (e.g.: where bidder is in whitelistedBidders)

  constructor (dimension?: Dimension) {
    super()
    this.conditions = []
    if (dimension) {
      if (dimension instanceof CompositeDimension) {
        this.sumDimensions = dimension.dimensions.map(d => d.name)
        this.conditions = AdCallsSpecialCase.SPECIAL_DIMENSIONS.filter(sd => difference(sd.dimensions, this.sumDimensions).length === 0)
        this.conditions.forEach(sd => {
          sd.dimensions.forEach(d => {
            const i = this.sumDimensions.indexOf(d)
            if (i !== -1) {
              this.sumDimensions.splice(i, 1)
            }
          })
        })
        // Always remove all special dimensions
        AdCallsSpecialCase.SPECIAL_DIMENSIONS.forEach(sd => {
          const i = this.sumDimensions.indexOf(sd.dimensions[0])
          if (i !== -1) {
            this.sumDimensions.splice(i, 1)
          }
        })
      } else {
        this.sumDimensions = !AdCallsSpecialCase.SPECIAL_DIMENSIONS.some(sd => sd.dimensions[0] === dimension.name) ? [dimension.name] : []
      }
    }
  }

  add (p: Record<string, any>, key: string, v: Record<string, any>): void {
    const sumKey = this.getSumKey(p)
    if (this.conditions.length > 0) {
      const conditionValues = this.conditions.map(c => (v[c.dimensions[1]] as string).split(c.separator))
      cartesian(...conditionValues).forEach(c => {
        this.addToGlobalSum(sumKey + '~' + c, v.adcalls)
      })
    } else {
      this.addToGlobalSum(sumKey, v.adcalls)
    }
  }

  remove (p: Record<string, any>, key: string, v: Record<string, any>): void {
    const sumKey = this.getSumKey(p)
    if (this.conditions.length > 0) {
      const conditionValues = this.conditions.map(c => (v[c.dimensions[1]] as string).split(c.separator))
      cartesian(...conditionValues).forEach(c => {
        this.addToGlobalSum(sumKey + '~' + c, -v.adcalls)
      })
    } else {
      this.addToGlobalSum(sumKey, -v.adcalls)
    }
  }

  post (scope: Record<string, any>, p: Record<string, any>): void {
    let sumKey = this.getSumKey(p)
    this.conditions.forEach(c => {
      sumKey += '~' + p[c.dimensions[0]]
    })
    const globalSum = this.globalSum[sumKey] || 0
    scope.$adcalls = globalSum
  }
}

class SumIgnoreSpecialCase extends GlobalSumSpecialCase implements SpecialCase {
  private name: string
  private metric: string

  // If the condition returns false, then 0 is used instead of the real value
  private condition: (v: Record<string, any>) => boolean

  constructor (name: string, metric: string, dimension?: Dimension, ignoreDimensions: string[] = [], condition: (v: Record<string, any>) => boolean = () => true) {
    super()
    this.name = name
    this.metric = metric
    this.condition = condition
    this.sumDimensions = dimension
      ? difference(dimension instanceof CompositeDimension ? dimension.dimensions.map(d => d.name) : [dimension.name], ignoreDimensions)
      : []
  }

  add (p: Record<string, any>, key: string, v: Record<string, any>): void {
    this.addToGlobalSum(this.getSumKey(p), this.condition(v) ? v[this.metric] : 0)
  }

  remove (p: Record<string, any>, key: string, v: Record<string, any>): void {
    this.addToGlobalSum(this.getSumKey(p), this.condition(v) ? -v[this.metric] : 0)
  }

  post (scope: Record<string, any>, p: Record<string, any>): void {
    scope[this.name] = this.globalSum[this.getSumKey(p)] || 0
  }
}

class AdRequestsSpecialCase extends SumIgnoreSpecialCase {
  constructor (dimension?: Dimension) { super('$adrequests', 'adrequests', dimension, ['bidder'], (v) => v.bidder === 'adagio') }
}

class TotalAdRequestsSpecialCase extends SumIgnoreSpecialCase {
  constructor (dimension?: Dimension) { super('$totaladrequests', 'adrequests', dimension, ['mvt', 'bidder'], (v) => v.bidder === 'adagio') }
}

const SPECIAL_CASES: Record<string, SpecialCaseConstructor> = {
  $adcalls: AdCallsSpecialCase,
  $adrequests: AdRequestsSpecialCase,
  $totaladrequests: TotalAdRequestsSpecialCase
}
