import { createToken, CstChildrenDictionary, CstNode, CstParser, IToken, Lexer as ChevrotainLexer } from "chevrotain";
import { Decimal } from "./decimal";
import {
  CalculatorError,
  DivisionByZeroError,
  InfinityError,
  InvalidInputError,
  NaNError,
  NegativeBaseError,
} from "./errors";

// Lexer

const Whitespace = createToken({ name: "Whitespace", pattern: /\s+/, group: ChevrotainLexer.SKIPPED });

const AdditiveOperator = createToken({ name: "AdditiveOperator", pattern: ChevrotainLexer.NA });
const Plus = createToken({ name: "Plus", pattern: /\+/, categories: [AdditiveOperator] });
const Minus = createToken({ name: "Minus", pattern: /[-−]/, categories: [AdditiveOperator] });

const MultiplicativeOperator = createToken({ name: "MultiplicativeOperator", pattern: ChevrotainLexer.NA });
const Multiply = createToken({ name: "Multiply", pattern: /[*×]/, categories: [MultiplicativeOperator] });
const Divide = createToken({ name: "Divide", pattern: /[/÷]/, categories: [MultiplicativeOperator] });

const Power = createToken({ name: "Power", pattern: /\^|\*\*|××/ });
const Percent = createToken({ name: "Percent", pattern: "%" });

const NumberLiteral = createToken({ name: "NumberLiteral", pattern: /(\d+\.\d*|\.?\d+)…?([eE][-+]?\d+)?/ });
const LeftParenthesis = createToken({ name: "LeftParenthesis", pattern: "(" });
const RightParenthesis = createToken({ name: "RightParenthesis", pattern: ")" });

const tokens = [
  Whitespace,
  Power, // before Multiply to gobble up "**"
  Plus,
  Minus,
  Multiply,
  Divide,
  Percent,
  NumberLiteral,
  LeftParenthesis,
  RightParenthesis,

  AdditiveOperator,
  MultiplicativeOperator,
];

export const ExpressionLexer = new ChevrotainLexer(tokens);

export function lex(input: string): IToken[] {
  let lexResult = ExpressionLexer.tokenize(input);
  if (lexResult.errors.length > 0) {
    throw CalculatorError.fromLexingError(lexResult.errors[0], input);
  }
  return lexResult.tokens;
}

// For the just-in-time result, ignore trailing operators, as in `1+2+`, or even
// `1+2*-(`.
export function fixupTrailingOperators(tokens: IToken[]): IToken[] {
  tokens = tokens.slice();
  // Pop any number of unary operators and left parentheses
  while (["Plus", "Minus", "LeftParenthesis"].includes(tokens[tokens.length - 1]?.tokenType.name)) {
    tokens.pop();
  }
  // Pop a single binary operator. (Plus and Minus were already popped above.)
  if (["Multiply", "Divide", "Power"].includes(tokens[tokens.length - 1]?.tokenType.name)) {
    tokens.pop();
  }
  return tokens;
}

// Add missing trailing closing parentheses
export function fixupTrailingParentheses(tokens: IToken[]): IToken[] {
  tokens = tokens.slice();

  let balance = 0;
  for (let token of tokens) {
    if (token.tokenType.name === "LeftParenthesis") {
      balance++;
    } else if (token.tokenType.name === "RightParenthesis") {
      balance--;
    }
  }

  // Add synthetic right parenthesis tokens to balance the parentheses. (This is
  // a bit cumbersome, but we need to operate on tokens rather than strings
  // because this function is applied after fixupTrailingOperators.)
  for (let i = 0; i < balance; i++) {
    let { startOffset, endOffset, startLine, endLine, startColumn, endColumn } = tokens[tokens.length - 1];
    tokens.push({
      image: ")",
      tokenType: RightParenthesis,
      tokenTypeIdx: RightParenthesis.tokenTypeIdx!,
      // Reuse the last token's location, rather than adding 1. That's where
      // we'd want to report any errors.
      startOffset,
      endOffset,
      startLine,
      endLine,
      startColumn,
      endColumn,
    });
  }
  return tokens;
}

// Parser

// Operator precedence, from lowest to highest
// + -
// * /
// **
// + - (unary)
// % (unary)

export class ExpressionParser extends CstParser {
  constructor() {
    super(tokens);
    this.performSelfAnalysis();
  }

  public expression = this.RULE("expression", () => {
    this.SUBRULE(this.additionExpression);
  });

  private additionExpression = this.RULE("additionExpression", () => {
    this.SUBRULE(this.multiplicationExpression, { LABEL: "lhs" });
    this.MANY(() => {
      this.CONSUME(AdditiveOperator, { LABEL: "operator" });
      this.SUBRULE2(this.multiplicationExpression, { LABEL: "rhs" });
    });
  });

  private multiplicationExpression = this.RULE("multiplicationExpression", () => {
    this.SUBRULE(this.powerExpression, { LABEL: "lhs" });
    this.MANY(() => {
      this.CONSUME(MultiplicativeOperator, { LABEL: "operator" });
      this.SUBRULE2(this.powerExpression, { LABEL: "rhs" });
    });
  });

  private powerExpression = this.RULE("powerExpression", () => {
    this.SUBRULE(this.negationExpression, { LABEL: "lhs" });
    this.MANY(() => {
      this.CONSUME(Power, { LABEL: "operator" });
      this.SUBRULE2(this.negationExpression, { LABEL: "rhs" });
    });
  });

  private negationExpression = this.RULE("negationExpression", () => {
    this.OR([
      {
        ALT: () => {
          this.OR2([
            { ALT: () => this.CONSUME(Plus, { LABEL: "operator" }) },
            { ALT: () => this.CONSUME(Minus, { LABEL: "operator" }) },
          ]);
          this.SUBRULE(this.negationExpression, { LABEL: "operand" });
        },
      },
      { ALT: () => this.SUBRULE(this.percentExpression, { LABEL: "operand" }) },
    ]);
  });

  private percentExpression = this.RULE("percentExpression", () => {
    this.SUBRULE(this.primaryExpression, { LABEL: "operand" });
    this.OPTION(() => {
      this.CONSUME(Percent, { LABEL: "operator" });
    });
  });

  private primaryExpression = this.RULE("primaryExpression", () => {
    this.OR([{ ALT: () => this.CONSUME(NumberLiteral) }, { ALT: () => this.SUBRULE(this.parenthesisExpression) }]);
  });

  private parenthesisExpression = this.RULE("parenthesisExpression", () => {
    this.CONSUME(LeftParenthesis);
    this.SUBRULE(this.expression);
    this.CONSUME(RightParenthesis);
  });
}

const expressionParser = new ExpressionParser();

// Helper function for repeated operations (e.g. `10+2===` yields `16`).
export function topLevelBinaryOperation(cst: CstNode): CstNode | null {
  switch (cst.name) {
    case "expression": {
      return topLevelBinaryOperation(cst.children.additionExpression[0] as CstNode);
    }
    case "additionExpression": {
      if (cst.children.operator !== undefined) {
        return cst;
      } else {
        return topLevelBinaryOperation(cst.children.lhs[0] as CstNode);
      }
    }
    case "multiplicationExpression": {
      if (cst.children.operator !== undefined) {
        return cst;
      } else {
        return topLevelBinaryOperation(cst.children.lhs[0] as CstNode);
      }
    }
    case "powerExpression": {
      if (cst.children.operator !== undefined && cst.children.operator.length === 1) {
        return cst;
      } else {
        // To repeat stacked exponentiation (like `2^3^4`, operator.length ===
        // 2), we'd need to account for right-associativity, so it's not
        // implemented for now.
        return null;
      }
    }
    default: {
      return null;
    }
  }
}

export function parse(input: string, tokens?: IToken[]): CstNode {
  if (tokens === undefined) {
    tokens = lex(input);
  }
  expressionParser.input = tokens;
  const cst = expressionParser.expression();
  if (expressionParser.errors.length > 0) {
    throw CalculatorError.fromParsingError(expressionParser.errors[0], input);
  }
  return cst;
}

// Evaluator

// RenderedResult is used during evaluation for restoring a rounded rendering of
// the previous result to full precision. For example, if we press `1/3=`, the
// calculator might display "0.3333…" (that is, rounded to well below
// Decimal.precision and ending in an ellipsis). Then, if we type `*3=`, the
// evaluator is asked to "0.3333…*3". To do this, when it encounters the number
// literal "0.3333…" (stored in `renderedResult.rendered`), it substitutes the
// full-precision Decimal object (stored in `renderedResult.decimal`). The
// result of this full-precision operation then rounds to "1" rather than
// "0.9999", giving the illusion of lossless computation.
export type RenderedResult = {
  decimal: Decimal;
  rendered: string;
};
// A Calculation is a RenderedResult, plus the CST that produced it.
export type Calculation = RenderedResult & {
  cst: CstNode | null;
};

const BaseCstVisitor = expressionParser.getBaseCstVisitorConstructor<any, Decimal>();

const PercentMarker = Symbol("PercentMarker");

function isPercent(n: Decimal): boolean {
  return (n as any)[PercentMarker] === true;
}

function markPercent(n: Decimal) {
  (n as any)[PercentMarker] = true;
}

class ExpressionEvaluator extends BaseCstVisitor {
  constructor(public previousResult: RenderedResult | null) {
    super();
    this.validateVisitor();
  }

  public expression(ctx: CstChildrenDictionary): Decimal {
    return this.visit(ctx.additionExpression as CstNode[]);
  }

  public additionExpression(ctx: CstChildrenDictionary): Decimal {
    let result = this.visit(ctx.lhs as CstNode[]);
    if (ctx.operator === undefined) {
      return result;
    }
    for (let i = 0; i < ctx.operator.length; i++) {
      const operator = ctx.operator[i] as IToken;
      let rhs = this.visit(ctx.rhs[i] as CstNode);
      if (isPercent(rhs)) {
        rhs = result.mul(rhs);
      }
      if (operator.tokenType.name === "Plus") {
        result = result.plus(rhs);
      } else if (operator.tokenType.name === "Minus") {
        result = result.minus(rhs);
      } else {
        throw new Error(`Unreachable: operator ${operator.image}`);
      }
    }
    return result;
  }

  public multiplicationExpression(ctx: CstChildrenDictionary): Decimal {
    let result = this.visit(ctx.lhs as CstNode[]);
    if (ctx.operator === undefined) {
      return result;
    }
    for (let i = 0; i < ctx.operator.length; i++) {
      const operator = ctx.operator[i] as IToken;
      const rhs = this.visit(ctx.rhs[i] as CstNode);
      if (operator.tokenType.name === "Multiply") {
        result = result.times(rhs);
      } else if (operator.tokenType.name === "Divide") {
        if (rhs.isZero()) {
          throw new DivisionByZeroError();
        }
        result = result.dividedBy(rhs);
      } else {
        throw new Error(`Unreachable: operator ${operator.image}`);
      }
    }
    return result;
  }

  public powerExpression(ctx: CstChildrenDictionary): Decimal {
    // The grammar is left-associative to avoid backtracking, but we evaluate
    // right-associative.
    let operands = (ctx.lhs as CstNode[]).concat((ctx.rhs ?? []) as CstNode[]).map((node) => this.visit(node));
    if (operands.length === 0) {
      throw new Error("Unreachable: empty power expression");
    }
    let result = operands[operands.length - 1];
    for (let i = operands.length - 2; i >= 0; i--) {
      let base = operands[i];
      let exponent = result;
      if (base.isZero() && exponent.isNegative()) {
        throw new DivisionByZeroError();
      } else if (base.isNegative() && !exponent.isInteger()) {
        throw new NegativeBaseError();
      }
      result = base.pow(exponent);
    }
    return result;
  }

  public negationExpression(ctx: CstChildrenDictionary): Decimal {
    let result = this.visit(ctx.operand as CstNode[]);
    if (ctx.operator !== undefined) {
      const operator = ctx.operator[0] as IToken;
      if (operator.tokenType.name === "Plus") {
        // do nothing
      } else if (operator.tokenType.name === "Minus") {
        const wasPercent = isPercent(result);
        result = result.negated();
        if (wasPercent) {
          markPercent(result);
        }
      } else {
        throw new Error(`Unreachable: operator ${operator.image}`);
      }
    }
    return result;
  }

  public percentExpression(ctx: CstChildrenDictionary): Decimal {
    let result = this.visit(ctx.operand as CstNode[]);
    if (ctx.operator !== undefined) {
      const operator = ctx.operator[0] as IToken;
      if (operator.tokenType.name === "Percent") {
        result = result.dividedBy(100);
        markPercent(result);
      } else {
        throw new Error(`Unreachable: operator ${operator.image}`);
      }
    }
    return result;
  }

  public primaryExpression(ctx: CstChildrenDictionary): Decimal {
    if (ctx.NumberLiteral !== undefined) {
      const numberLiteral = ctx.NumberLiteral[0] as IToken;
      return this.parseNumberLiteral(numberLiteral);
    } else if (ctx.parenthesisExpression !== undefined) {
      return this.visit(ctx.parenthesisExpression as CstNode[]);
    } else {
      throw new Error("Unreachable: unexpected primary expression");
    }
  }

  public parseNumberLiteral(numberLiteral: IToken): Decimal {
    let literal = numberLiteral.image;
    if (literal.includes("…")) {
      if (this.previousResult === null || literal !== this.previousResult.rendered) {
        throw new InvalidInputError({
          start: numberLiteral.startOffset,
          end: (numberLiteral.endOffset ?? numberLiteral.startOffset) + 1, // IToken.endOffset is inclusive
        });
      }
      return this.previousResult.decimal;
    } else {
      return new Decimal(literal);
    }
  }

  public parenthesisExpression(ctx: CstChildrenDictionary): Decimal {
    return this.visit(ctx.expression as CstNode[]);
  }
}

export function evaluate(cst: CstNode, previousResult: RenderedResult | null): Decimal {
  let result = new ExpressionEvaluator(previousResult).visit(cst);
  if (result.isNaN()) {
    throw new NaNError();
  } else if (result.eq(new Decimal("Infinity")) || result.eq(new Decimal("-Infinity"))) {
    throw new InfinityError();
  }
  return result;
}
