use base::{
    fnv::{FnvMap, FnvSet},
    symbol::{Symbol, SymbolRef},
};
use crate::core::{
    is_primitive,
    optimize::{walk_expr, SameLifetime, Visitor},
    Allocator, CExpr, Expr, Named,
};
pub type Cost = u32;
#[derive(Clone, Default, Debug)]
pub struct Data {
    pub cost: Cost,
    pub uses: u32,
}
#[derive(Default, Debug)]
pub struct Costs(FnvMap<Symbol, Data>);
impl Costs {
    pub fn cost(&self, s: &SymbolRef) -> Cost {
        self.0.get(s).map_or_else(Cost::max_value, |x| x.cost)
    }
    pub fn uses(&self, s: &SymbolRef) -> u32 {
        self.0.get(s).map_or(0, |x| x.uses)
    }
    pub fn data(&self, s: &SymbolRef) -> Data {
        self.0.get(s).cloned().unwrap_or_default()
    }
    pub fn insert(&mut self, s: Symbol, data: Data) {
        self.0.insert(s, data);
    }
}
struct AnalyzeCost<'a, 'b> {
    cyclic_bindings: &'b FnvSet<&'a SymbolRef>,
    costs: &'b mut Costs,
    bind_stack: FnvSet<&'a SymbolRef>,
    current: Vec<Cost>,
}
impl<'a> AnalyzeCost<'a, '_> {
    fn add_cost(&mut self, cost: Cost) {
        let current_cost = self.current.last_mut().unwrap();
        *current_cost = (*current_cost).saturating_add(cost);
    }
    fn cost_of(&mut self, expr: CExpr<'a>) -> Cost {
        self.current.push(0);
        self.visit_expr(expr);
        self.current.pop().unwrap()
    }
    fn push_bind(&mut self, name: &'a Symbol, expr: CExpr<'a>) -> Cost {
        self.bind_stack.insert(&**name);
        let cost = self
            .cost_of(expr)
            .saturating_add(if self.cyclic_bindings.contains(&**name) {
                Cost::max_value()
            } else {
                0
            });
        self.costs.0.entry(name.clone()).or_default().cost = cost;
        self.bind_stack.remove(&**name);
        cost
    }
}
impl<'a> Visitor<'a, 'a> for AnalyzeCost<'a, '_> {
    type Producer = SameLifetime<'a>;
    fn visit_expr(&mut self, expr: &'a Expr<'a>) -> Option<&'a Expr<'a>> {
        match *expr {
            Expr::Let(ref bind, body) => {
                match &bind.expr {
                    Named::Recursive(closures) => {
                        for closure in closures {
                            let cost = self.push_bind(&closure.name.name, &closure.expr);
                            self.add_cost(cost);
                        }
                    }
                    Named::Expr(expr) => {
                        let cost = self.push_bind(&bind.name.name, expr);
                        {
                            let data = self.costs.0.get_mut(&bind.name.name).unwrap();
                            data.cost = data.cost.saturating_add(1000);
                        }
                        self.add_cost(cost);
                    }
                }
                self.visit_expr(body);
            }
            Expr::Match(body, alts) => {
                self.add_cost(5);
                self.visit_expr(body);
                let alt_cost = alts
                    .iter()
                    .map(|alt| self.cost_of(alt.expr))
                    .max()
                    .unwrap_or(0);
                self.add_cost(alt_cost);
            }
            Expr::Call(Expr::Ident(id, ..), ..) if is_primitive(&id.name) => {
                self.add_cost(2);
                walk_expr(self, expr);
            }
            Expr::Call(..) => {
                self.add_cost(5);
                walk_expr(self, expr);
            }
            Expr::Ident(ref id, _) if self.bind_stack.contains(&*id.name) => {
                self.costs
                    .0
                    .entry(id.name.clone())
                    .or_insert_with(|| Data {
                        uses: 0,
                        cost: Cost::max_value(),
                    })
                    .uses += 1;
                self.add_cost(Cost::max_value());
            }
            Expr::Ident(ref id, _) => {
                self.costs
                    .0
                    .entry(id.name.clone())
                    .or_insert_with(|| Data {
                        uses: 0,
                        cost: Cost::max_value(),
                    })
                    .uses += 1;
                self.add_cost(1);
            }
            Expr::Const(..) => self.add_cost(1),
            Expr::Data(..) => {
                self.add_cost(5);
                walk_expr(self, expr);
            }
            Expr::Cast(..) => {
                walk_expr(self, expr);
            }
        }
        None
    }
    fn detach_allocator(&self) -> Option<&'a Allocator<'a>> {
        None
    }
}
pub(crate) fn analyze_costs<'a>(cyclic_bindings: &FnvSet<&'a SymbolRef>, expr: CExpr<'a>) -> Costs {
    let mut costs = Costs::default();
    let mut visitor = AnalyzeCost {
        cyclic_bindings,
        costs: &mut costs,
        current: vec![0],
        bind_stack: FnvSet::default(),
    };
    visitor.visit_expr(expr);
    costs
}