package cs241e.assignments

import ProgramRepresentation._
import cs241e.scanparse.Grammars._

import scala.collection.mutable

/** Implementation of semantic analysis for the Lacs language. */

object Typer {
  /** Representation of a Lacs type, which is either an Int or a function type with parameter types and a return type.
  sealed abstract class Type
  case object IntType extends Type
  case class FunctionType(parameterTypes: Seq[Type], returnType: Type) extends Type

  /** Given a `tree`, finds all descendants of the `tree` whose root node has kind `lhsKind`.
    * Does not search within the found subtrees for any nested occurrences of additional descendants.
    * For example, searching the root of a program tree with `lhsKind = "procedure"` will return the trees all
    * of the top-level procedures, but not any procedures nested within them.
  def collect(tree: Tree, lhsKind: String): Seq[Tree] =
    if(tree.lhs.kind == lhsKind) Seq(tree) else tree.children.flatMap((tree: Tree) => collect(tree, lhsKind))

  /** Given a tree that is either a "type" or contains exactly one "type" nested within it, returns
    * an instance of `Type` representing the corresponding type.
  def parseType(tree: Tree): Type = {
    val types = collect(tree, "type")
    require(types.size == 1)


  /** A variable combined with its declared type. */
  case class TypedVariable(variable: Variable, tpe: Type)

  /** Create a new `Variable` given its `name` and type `tpe`. */
  def makeVariable(name: String, tpe: Type): Variable =
    new Variable(name, isPointer = (tpe != IntType))

  /** A `SymbolTable` maps each name to either a `TypedVariable` or a `ProcedureScope`. */
  type SymbolTable = Map[String, Either[TypedVariable, ProcedureScope]]

  /** Given a tree containing subtrees rooted at "vardef", creates a `TypedVariable` for each such tree. */
  def parseVarDefs(tree: Tree): Seq[TypedVariable] = {
    collect(tree, "vardef").map{ varDef => ??? }

  /** Call `error()` if any `String` occurs in `names` multiple times. */
  def checkDuplicates(names: Seq[String]): Unit = {
    val duplicates = names.diff(names.distinct)
    if(duplicates.nonEmpty) sys.error(s"Duplicate identifiers ${duplicates}")

  /** A `ProcedureScope` holds the semantic information about a particular procedure that is needed to type-check
    * the body of the procedure, including information coming from outer procedure(s) within which this
    * procedure may be nested.
    * @param tree the tree defining the procedure (rooted at a "defdef")
    * @param outer the `ProcedureScope` of the outer procedure that immediately contains this one
  class ProcedureScope(tree: Tree, outer: Option[ProcedureScope] = None) {
    assert(tree.production ==
      "defdef DEF ID LPAREN parmsopt RPAREN COLON type BECOMES LBRACE vardefsopt defdefsopt expras RBRACE")
    val Seq(_, id, _, parmsopt, _, _, retTypeTree, _, _, vardefs, defdefs, expras, _) = tree.children

    /** The name of the procedure. */
    val name: String = ???

    /** The parameters of the procedure. */
    val parms: Seq[TypedVariable] = ???

    /** The variables declared in the procedure. */
    val variables: Seq[TypedVariable] = ???

    /** The declared return type of the procedure. */
    val returnType: Type = ???

    /** The new `Procedure` object that will represent this procedure. */
    val procedure: Procedure = ???

    /** The `ProcedureScope`s of the nested procedures that are immediately nested within this procedure.
      * Note: this `val` will recursively call `new ProcedureScope(...)`.
    val subProcedures: Seq[ProcedureScope] = ???

    /** The names of parameters, variables, and nested procedures that are newly defined within this procedure
      * (as opposed to being inherited from some outer procedure).
    val newNames: Seq[String] = ???

    /** The symbol table to be used when type-checking the body of this procedure. It should contain all
      * symbols (parameters, variables, nested procedures) defined in this procedure, as well as those
      * defined in outer procedures within which this one is nested. Symbols defined in this procedure
      * override (shadow) those of outer procedures.
      * The symbol table is first initialized to null, and filled in later, after this `ProcedureScope`
      * has been constructed, by calling the `createSymbolTable` method. This is necessary
      * because computation of the `symbolTable` depends on the `symbolTable`s of any
      * outer procedures within which this procedure is nested. However, their `symbolTable`s contain
      * the `ProcedureScope`s for the procedures nested within them, which are not created until after
      * the enclosing procedure is created.
    var symbolTable: SymbolTable = null

    /** Create a symbol table for the current procedure, and update the `symbolTable` field with it.
      * See the comments for `val symbolTable` for details on what the `symbolTable` should contain.
      * The `outerSymbolTable` parameter contains the symbol table of the enclosing scope (either an
      * outer procedure within which the current procedure is nested, or, if the current procedure
      * is a top-level procedure, a symbol table containing the names of all of the top-level procedures).
      * This method must also recursively call `createSymbolTable` on all of its `subProcedures` to
      * construct their symbol tables as well.
    def createSymbolTable(outerSymbolTable: SymbolTable): Unit = ???

    /** Returns a sequence containing `this` `ProcedureScope` and the `ProcedureScope`s for all procedures
      * declared inside of this procedure, including those nested recursively within other nested procedures.
      * Scala hint: learn about the `flatMap` method in the Scala library. If you are not familiar with flatMap,
      * one place you can read about it is here:
    def descendantScopes: Seq[ProcedureScope] = ???

  /** Checks that the body of a procedure satisfies the type-checking rules in the Lacs language specification.
    * Returns a `Map` that provides a `Type` for each `Tree` that has a `Type` according to the language
    * specification.

  def typeCheck(scope: ProcedureScope): Map[Tree, Type] = {
    /** The map that will be returned containing the `Type` of each `Tree` that has a `Type`. */
    val treeToType = mutable.Map[Tree, Type]()

    /** Calls `error()` if `tpe1` and `tpe2` are not equal. If they are equal, returns them. */
    def mustEqual(tpe1: Type, tpe2: Type): Type =
      if(tpe1 == tpe2) tpe1 else sys.error(s"Type mismatch: expected $tpe2, got $tpe1")

    /** For a `tree` rooted at a node that has a `Type`, computes the `Type`, adds it to `treeToType`,
      * and returns it.
    def typeOf(tree: Tree): Type = treeToType.getOrElseUpdate(tree, ???)

    /* Check that the type of the expression returned from the procedure matches the declared type of the procedure. */
    mustEqual(scope.returnType, typeOf(scope.expras))
    Map() ++ treeToType