use base::{
    fnv::FnvMap,
    symbol::{Symbol, SymbolRef},
};
use crate::core::{
    optimize::{walk_expr, DifferentLifetime, Visitor},
    Allocator, CExpr, Expr, Named, Pattern,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
enum Pureness {
    None,
    Load,
    Call,
}
impl Pureness {
    fn merge(&mut self, pureness: Pureness) {
        *self = (*self).min(pureness);
    }
}
#[derive(Clone, Default, Debug)]
pub struct PurityMap(FnvMap<Symbol, Pureness>);
impl PurityMap {
    pub fn pure_call(&self, k: &SymbolRef) -> bool {
        self.0.get(k).map_or(false, |p| *p == Pureness::Call)
    }
    pub fn pure_load(&self, k: &SymbolRef) -> bool {
        self.0.contains_key(k)
    }
}
pub fn purity<'a>(expr: CExpr<'a>) -> PurityMap {
    let mut pure_symbols = PurityMap(FnvMap::default());
    let mut visitor = Pure {
        is_pure: Pureness::Call,
        pure_symbols: &mut pure_symbols,
    };
    visitor.visit_expr(expr);
    pure_symbols
}
struct Pure<'b> {
    is_pure: Pureness,
    pure_symbols: &'b mut PurityMap,
}
impl Pure<'_> {
    fn is_pure(&mut self, symbol: &Symbol, expr: CExpr) -> Pureness {
        let mut visitor = Pure {
            is_pure: Pureness::Call,
            pure_symbols: self.pure_symbols,
        };
        visitor.visit_expr(expr);
        let is_pure = visitor.is_pure;
        if is_pure != Pureness::None {
            self.pure_symbols.0.insert(symbol.clone(), is_pure);
        }
        is_pure
    }
    fn mark_pure(&mut self, pat: &Pattern) {
        match pat {
            Pattern::Ident(id) => {
                self.pure_symbols.0.insert(id.name.clone(), Pureness::Load);
            }
            Pattern::Record { fields, .. } => {
                for field in fields {
                    self.pure_symbols.0.insert(
                        field.1.as_ref().unwrap_or(&field.0.name).clone(),
                        Pureness::Load,
                    );
                }
            }
            Pattern::Constructor(_, params) => {
                for param in params {
                    self.pure_symbols
                        .0
                        .insert(param.name.clone(), Pureness::Load);
                }
            }
            Pattern::Literal(_) => (),
        }
    }
}
impl<'l, 'expr> Visitor<'l, 'expr> for Pure<'_> {
    type Producer = DifferentLifetime<'l, 'expr>;
    fn visit_expr(&mut self, expr: CExpr<'expr>) -> Option<CExpr<'l>> {
        match *expr {
            Expr::Call(ref f, _) => match f {
                Expr::Ident(ref id, ..) => {
                    if self.pure_symbols.pure_call(&*id.name) || id.name.is_primitive() {
                        walk_expr(self, expr);
                    } else {
                        self.is_pure = Pureness::None;
                    }
                }
                _ => {
                    self.is_pure = Pureness::None;
                }
            },
            Expr::Ident(ref id, ..) => {
                if !self.pure_symbols.pure_load(&id.name)
                    && !id.name.is_primitive()
                    && !id.name.is_global()
                {
                    self.is_pure.merge(Pureness::Load);
                }
            }
            Expr::Let(ref bind, expr) => {
                match bind.expr {
                    Named::Recursive(ref closures) => {
                        for closure in closures {
                            for arg in &closure.args {
                                self.pure_symbols.0.insert(arg.name.clone(), Pureness::Load);
                            }
                            self.is_pure(&closure.name.name, closure.expr);
                        }
                    }
                    Named::Expr(expr) => {
                        let is_pure = self.is_pure(&bind.name.name, expr);
                        self.is_pure.merge(is_pure);
                    }
                }
                self.visit_expr(expr);
            }
            Expr::Match(scrutinee, alts) => {
                let is_pure = self.is_pure;
                self.is_pure = Pureness::Call;
                self.visit_expr(scrutinee);
                let scrutinee_is_pure = self.is_pure;
                self.is_pure.merge(is_pure);
                if scrutinee_is_pure != Pureness::None {
                    for alt in alts {
                        self.mark_pure(&alt.pattern);
                    }
                }
                for alt in alts {
                    self.visit_expr(alt.expr);
                }
            }
            _ => {
                walk_expr(self, expr);
            }
        }
        None
    }
    fn detach_allocator(&self) -> Option<&'l Allocator<'l>> {
        None
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;
    use base::symbol::Symbols;
    use crate::core::interpreter::tests::parse_expr;
    #[test]
    fn pure_global() {
        let mut symbols = Symbols::new();
        let allocator = Arc::new(Allocator::new());
        let expr = parse_expr(&mut symbols, &allocator, "let x = global in x");
        assert_eq!(
            purity(expr).0,
            collect![(symbols.simple_symbol("x"), Pureness::Load)]
        );
    }
}