"""
Fossati AI Bot - Motor SMC (Smart Money Concepts)
Detecta: Order Blocks, BOS, CHoCH, FVG, Liquidez, Estrutura
"""

import pandas as pd
import numpy as np
from dataclasses import dataclass
from typing import Optional, List, Tuple
from enum import Enum
import logging

log = logging.getLogger(__name__)


class Direction(Enum):
    BULLISH = "bullish"
    BEARISH = "bearish"
    NEUTRAL = "neutral"


@dataclass
class OrderBlock:
    direction: Direction
    high: float
    low: float
    open: float
    close: float
    index: int
    timestamp: pd.Timestamp
    mitigated: bool = False
    strength: float = 0.0      # 0-1, força do OB
    is_breaker: bool = False    # OB que virou breaker


@dataclass
class FairValueGap:
    direction: Direction
    top: float
    bottom: float
    index: int
    timestamp: pd.Timestamp
    filled: bool = False


@dataclass
class LiquidityLevel:
    level_type: str    # "EQH", "EQL", "BSL", "SSL"
    price: float
    count: int         # Quantas vezes tocou
    index: int
    swept: bool = False


@dataclass
class MarketStructure:
    trend: Direction
    last_bos: Optional[dict] = None
    last_choch: Optional[dict] = None
    swing_highs: List[float] = None
    swing_lows: List[float] = None
    current_phase: str = "unknown"    # "expansion", "retracement", "consolidation"


@dataclass
class SMCAnalysis:
    symbol: str
    timeframe: str
    timestamp: pd.Timestamp
    structure: MarketStructure
    bullish_obs: List[OrderBlock]
    bearish_obs: List[OrderBlock]
    bullish_fvg: List[FairValueGap]
    bearish_fvg: List[FairValueGap]
    liquidity_levels: List[LiquidityLevel]
    current_price: float
    nearest_bullish_ob: Optional[OrderBlock] = None
    nearest_bearish_ob: Optional[OrderBlock] = None
    in_discount: bool = False    # Abaixo de 50% do range
    in_premium: bool = False     # Acima de 50% do range
    premium_discount_level: float = 0.0


class SMCAnalyzer:
    def __init__(self, config):
        self.cfg = config

    def analyze(self, df: pd.DataFrame, symbol: str, timeframe: str) -> SMCAnalysis:
        if len(df) < 50:
            log.warning(f"{symbol} {timeframe}: dados insuficientes ({len(df)} candles)")
            return None

        df = df.copy().reset_index(drop=True)
        current_price = float(df["close"].iloc[-1])

        structure    = self._detect_structure(df)
        bull_obs     = self._detect_order_blocks(df, Direction.BULLISH)
        bear_obs     = self._detect_order_blocks(df, Direction.BEARISH)
        bull_fvg     = self._detect_fvg(df, Direction.BULLISH)
        bear_fvg     = self._detect_fvg(df, Direction.BEARISH)
        liquidity    = self._detect_liquidity(df)
        pd_level     = self._get_premium_discount(df)

        # OB mais próximo do preço atual
        nearest_bull = self._nearest_ob(bull_obs, current_price, Direction.BULLISH)
        nearest_bear = self._nearest_ob(bear_obs, current_price, Direction.BEARISH)

        in_discount = current_price < pd_level
        in_premium  = current_price > pd_level

        return SMCAnalysis(
            symbol=symbol,
            timeframe=timeframe,
            timestamp=df["timestamp"].iloc[-1] if "timestamp" in df.columns else pd.Timestamp.now(),
            structure=structure,
            bullish_obs=bull_obs,
            bearish_obs=bear_obs,
            bullish_fvg=bull_fvg,
            bearish_fvg=bear_fvg,
            liquidity_levels=liquidity,
            current_price=current_price,
            nearest_bullish_ob=nearest_bull,
            nearest_bearish_ob=nearest_bear,
            in_discount=in_discount,
            in_premium=in_premium,
            premium_discount_level=pd_level,
        )

    def _detect_structure(self, df: pd.DataFrame) -> MarketStructure:
        highs = df["high"].values
        lows  = df["low"].values
        closes= df["close"].values
        n = len(df)
        lb = min(self.cfg.SWING_LOOKBACK, n // 4)

        swing_highs = []
        swing_lows  = []

        for i in range(lb, n - lb):
            if highs[i] == max(highs[i-lb:i+lb+1]):
                swing_highs.append((i, highs[i]))
            if lows[i] == min(lows[i-lb:i+lb+1]):
                swing_lows.append((i, lows[i]))

        trend = Direction.NEUTRAL
        last_bos = None
        last_choch = None

        if len(swing_highs) >= 2 and len(swing_lows) >= 2:
            sh_prices = [h[1] for h in swing_highs[-3:]]
            sl_prices = [l[1] for l in swing_lows[-3:]]

            hh = sh_prices[-1] > sh_prices[-2] if len(sh_prices) >= 2 else False
            hl = sl_prices[-1] > sl_prices[-2] if len(sl_prices) >= 2 else False
            lh = sh_prices[-1] < sh_prices[-2] if len(sh_prices) >= 2 else False
            ll = sl_prices[-1] < sl_prices[-2] if len(sl_prices) >= 2 else False

            if hh and hl:
                trend = Direction.BULLISH
            elif lh and ll:
                trend = Direction.BEARISH

            # Detectar BOS
            current = closes[-1]
            if swing_highs and current > swing_highs[-1][1] * (1 + self.cfg.BOS_MIN_BREAK):
                last_bos = {"direction": "bullish", "level": swing_highs[-1][1], "index": swing_highs[-1][0]}
            elif swing_lows and current < swing_lows[-1][1] * (1 - self.cfg.BOS_MIN_BREAK):
                last_bos = {"direction": "bearish", "level": swing_lows[-1][1], "index": swing_lows[-1][0]}

            # CHoCH: mudança de tendência
            if trend == Direction.BEARISH and last_bos and last_bos["direction"] == "bullish":
                last_choch = last_bos
                last_choch["type"] = "choch_bullish"
            elif trend == Direction.BULLISH and last_bos and last_bos["direction"] == "bearish":
                last_choch = last_bos
                last_choch["type"] = "choch_bearish"

        # Fase de mercado
        if len(closes) >= 20:
            atr = self._atr(df, 14)
            price_range = max(highs[-20:]) - min(lows[-20:])
            avg_atr = atr[-10:].mean() if len(atr) >= 10 else 0
            if price_range > avg_atr * 15:
                phase = "expansion"
            elif price_range < avg_atr * 5:
                phase = "consolidation"
            else:
                phase = "retracement"
        else:
            phase = "unknown"

        return MarketStructure(
            trend=trend,
            last_bos=last_bos,
            last_choch=last_choch,
            swing_highs=[h[1] for h in swing_highs],
            swing_lows=[l[1] for l in swing_lows],
            current_phase=phase,
        )

    def _detect_order_blocks(self, df: pd.DataFrame, direction: Direction) -> List[OrderBlock]:
        obs = []
        n = len(df)
        lb = min(self.cfg.OB_LOOKBACK, n - 1)

        for i in range(1, lb):
            idx = n - 1 - i
            if idx < 2:
                continue

            candle = df.iloc[idx]
            body   = abs(candle["close"] - candle["open"])
            rng    = candle["high"] - candle["low"]
            if rng == 0:
                continue
            body_ratio = body / rng

            if body_ratio < self.cfg.OB_MIN_BODY_RATIO:
                continue

            # Bullish OB: candle de baixa antes de movimento de alta forte
            if direction == Direction.BULLISH:
                if candle["close"] < candle["open"]:
                    # Verifica movimento de alta após
                    future = df.iloc[idx+1:idx+6] if idx+6 <= n else df.iloc[idx+1:]
                    if len(future) == 0:
                        continue
                    move_up = (future["high"].max() - candle["high"]) / candle["high"]
                    if move_up > 0.005:
                        mitigated = df.iloc[idx+1:]["low"].min() < candle["low"]
                        strength  = min(move_up * 20, 1.0)
                        obs.append(OrderBlock(
                            direction=Direction.BULLISH,
                            high=float(candle["high"]),
                            low=float(candle["low"]),
                            open=float(candle["open"]),
                            close=float(candle["close"]),
                            index=idx,
                            timestamp=candle.get("timestamp", pd.Timestamp.now()),
                            mitigated=mitigated,
                            strength=strength,
                        ))

            # Bearish OB: candle de alta antes de movimento de baixa forte
            elif direction == Direction.BEARISH:
                if candle["close"] > candle["open"]:
                    future = df.iloc[idx+1:idx+6] if idx+6 <= n else df.iloc[idx+1:]
                    if len(future) == 0:
                        continue
                    move_dn = (candle["low"] - future["low"].min()) / candle["low"]
                    if move_dn > 0.005:
                        mitigated = df.iloc[idx+1:]["high"].max() > candle["high"]
                        strength  = min(move_dn * 20, 1.0)
                        obs.append(OrderBlock(
                            direction=Direction.BEARISH,
                            high=float(candle["high"]),
                            low=float(candle["low"]),
                            open=float(candle["open"]),
                            close=float(candle["close"]),
                            index=idx,
                            timestamp=candle.get("timestamp", pd.Timestamp.now()),
                            mitigated=mitigated,
                            strength=strength,
                        ))

        # Ordena por força decrescente, retorna apenas não-mitigados
        obs = [ob for ob in obs if not ob.mitigated]
        obs.sort(key=lambda x: x.strength, reverse=True)
        return obs[:5]

    def _detect_fvg(self, df: pd.DataFrame, direction: Direction) -> List[FairValueGap]:
        fvgs = []
        n = len(df)
        for i in range(1, n - 1):
            c1, c2, c3 = df.iloc[i-1], df.iloc[i], df.iloc[i+1]
            if direction == Direction.BULLISH:
                if c1["high"] < c3["low"]:
                    filled = df.iloc[i+2:]["low"].min() <= c3["low"] if i+2 < n else False
                    fvgs.append(FairValueGap(
                        direction=Direction.BULLISH,
                        top=float(c3["low"]),
                        bottom=float(c1["high"]),
                        index=i,
                        timestamp=c2.get("timestamp", pd.Timestamp.now()),
                        filled=filled,
                    ))
            else:
                if c1["low"] > c3["high"]:
                    filled = df.iloc[i+2:]["high"].max() >= c3["high"] if i+2 < n else False
                    fvgs.append(FairValueGap(
                        direction=Direction.BEARISH,
                        top=float(c1["low"]),
                        bottom=float(c3["high"]),
                        index=i,
                        timestamp=c2.get("timestamp", pd.Timestamp.now()),
                        filled=filled,
                    ))
        fvgs = [f for f in fvgs if not f.filled]
        return fvgs[-5:]

    def _detect_liquidity(self, df: pd.DataFrame) -> List[LiquidityLevel]:
        levels = []
        tol    = self.cfg.EQUAL_LEVEL_TOLERANCE
        highs  = df["high"].values
        lows   = df["low"].values

        # Equal Highs (BSL - Buy Side Liquidity)
        for i in range(len(highs)):
            for j in range(i+2, min(i+30, len(highs))):
                if abs(highs[i] - highs[j]) / highs[i] < tol:
                    swept = any(highs[j+1:] > highs[i] * (1 + tol))
                    levels.append(LiquidityLevel("EQH", float(highs[i]), 2, i, swept=swept))
                    break

        # Equal Lows (SSL - Sell Side Liquidity)
        for i in range(len(lows)):
            for j in range(i+2, min(i+30, len(lows))):
                if abs(lows[i] - lows[j]) / lows[i] < tol:
                    swept = any(lows[j+1:] < lows[i] * (1 - tol))
                    levels.append(LiquidityLevel("EQL", float(lows[i]), 2, i, swept=swept))
                    break

        unique, seen = [], set()
        for lv in levels:
            key = round(lv.price, 4)
            if key not in seen and not lv.swept:
                seen.add(key)
                unique.append(lv)
        return unique[-10:]

    def _nearest_ob(self, obs: List[OrderBlock], price: float, direction: Direction) -> Optional[OrderBlock]:
        if not obs:
            return None
        if direction == Direction.BULLISH:
            candidates = [ob for ob in obs if ob.high < price]
            if not candidates:
                return None
            return max(candidates, key=lambda x: x.high)
        else:
            candidates = [ob for ob in obs if ob.low > price]
            if not candidates:
                return None
            return min(candidates, key=lambda x: x.low)

    def _get_premium_discount(self, df: pd.DataFrame) -> float:
        n = min(100, len(df))
        high = df["high"].iloc[-n:].max()
        low  = df["low"].iloc[-n:].min()
        return (high + low) / 2

    def _atr(self, df: pd.DataFrame, period: int = 14) -> np.ndarray:
        h = df["high"].values
        l = df["low"].values
        c = df["close"].values
        tr = np.maximum(h - l, np.maximum(abs(h - np.roll(c, 1)), abs(l - np.roll(c, 1))))
        tr[0] = h[0] - l[0]
        atr = pd.Series(tr).rolling(period).mean().values
        return atr

    def score_setup(self, analysis: SMCAnalysis, signal_direction: Direction,
                    volume_ratio: float = 1.0, rsi: float = 50.0,
                    htf_trend: Direction = Direction.NEUTRAL) -> Tuple[int, List[str]]:
        """Checklist Fossati: 12 confluências, retorna score e lista de confluências ativas."""
        score = 0
        confluences = []

        # 1. Tendência HTF alinhada
        if htf_trend == signal_direction:
            score += 1
            confluences.append("✅ HTF trend alinhado")
        else:
            confluences.append("❌ HTF trend não alinhado")

        # 2. BOS ou CHoCH confirmado
        if analysis.structure.last_bos:
            bos_dir = analysis.structure.last_bos["direction"]
            if (signal_direction == Direction.BULLISH and bos_dir == "bullish") or \
               (signal_direction == Direction.BEARISH and bos_dir == "bearish"):
                score += 1
                confluences.append("✅ BOS confirmado")
            else:
                confluences.append("❌ BOS contrário")
        elif analysis.structure.last_choch:
            score += 1
            confluences.append("✅ CHoCH confirmado")
        else:
            confluences.append("❌ Sem BOS/CHoCH")

        # 3. Order Block de qualidade na entrada
        ob = analysis.nearest_bullish_ob if signal_direction == Direction.BULLISH else analysis.nearest_bearish_ob
        if ob and ob.strength > 0.5:
            score += 1
            confluences.append(f"✅ OB forte (força: {ob.strength:.1%})")
        else:
            confluences.append("❌ OB fraco ou ausente")

        # 4. FVG presente
        fvgs = analysis.bullish_fvg if signal_direction == Direction.BULLISH else analysis.bearish_fvg
        if fvgs:
            score += 1
            confluences.append("✅ FVG presente")
        else:
            confluences.append("❌ Sem FVG")

        # 5. Zona de liquidez acima (LONG) ou abaixo (SHORT)
        liq_levels = analysis.liquidity_levels
        price = analysis.current_price
        if signal_direction == Direction.BULLISH:
            bsl = [l for l in liq_levels if l.level_type == "EQH" and l.price > price]
            if bsl:
                score += 1
                confluences.append(f"✅ BSL acima em {bsl[0].price:.4f}")
            else:
                confluences.append("❌ Sem BSL acima")
        else:
            ssl = [l for l in liq_levels if l.level_type == "EQL" and l.price < price]
            if ssl:
                score += 1
                confluences.append(f"✅ SSL abaixo em {ssl[0].price:.4f}")
            else:
                confluences.append("❌ Sem SSL abaixo")

        # 6. Estrutura de mercado favorável
        trend = analysis.structure.trend
        if trend == signal_direction:
            score += 1
            confluences.append("✅ Estrutura favorável")
        else:
            confluences.append("❌ Estrutura contrária")

        # 7. Volume acima da média
        if volume_ratio > 1.3:
            score += 1
            confluences.append(f"✅ Volume {volume_ratio:.1f}x acima da média")
        elif volume_ratio > 1.0:
            score += 0
            confluences.append(f"⚠️ Volume levemente acima ({volume_ratio:.1f}x)")
        else:
            confluences.append(f"❌ Volume abaixo da média ({volume_ratio:.1f}x)")

        # 8. RSI em zona favorável (não extremo)
        if signal_direction == Direction.BULLISH and 25 < rsi < 55:
            score += 1
            confluences.append(f"✅ RSI em zona de compra ({rsi:.1f})")
        elif signal_direction == Direction.BEARISH and 45 < rsi < 75:
            score += 1
            confluences.append(f"✅ RSI em zona de venda ({rsi:.1f})")
        else:
            confluences.append(f"❌ RSI fora de zona ({rsi:.1f})")

        # 9. Posição no range Premium/Discount
        if signal_direction == Direction.BULLISH and analysis.in_discount:
            score += 1
            confluences.append("✅ Preço em desconto (zona de compra)")
        elif signal_direction == Direction.BEARISH and analysis.in_premium:
            score += 1
            confluences.append("✅ Preço em prêmio (zona de venda)")
        else:
            confluences.append("❌ Preço em posição desfavorável no range")

        # 10. Fase de mercado adequada
        phase = analysis.structure.current_phase
        if phase in ["expansion", "retracement"]:
            score += 1
            confluences.append(f"✅ Fase de mercado: {phase}")
        else:
            confluences.append(f"❌ Mercado em consolidação")

        # 11. R:R mínimo de 2:1 (calculado externamente, assumido se chegou até aqui)
        if ob:
            if signal_direction == Direction.BULLISH:
                sl = ob.low
                tp_distance = (price - sl) * 2
                tp = price + tp_distance
                rr = tp_distance / (price - sl) if price > sl else 0
            else:
                sl = ob.high
                tp_distance = (sl - price) * 2
                tp = price - tp_distance
                rr = tp_distance / (sl - price) if sl > price else 0
            if rr >= 2.0:
                score += 1
                confluences.append(f"✅ R:R de {rr:.1f}:1")
            else:
                confluences.append(f"❌ R:R insuficiente ({rr:.1f}:1)")
        else:
            confluences.append("❌ R:R não calculado (sem OB)")

        # 12. Candle de confirmação (Price Action)
        # Assumido como válido se todos os outros fatores estão presentes
        if score >= 9:
            score += 1
            confluences.append("✅ Price Action confirma setup")
        else:
            confluences.append("⚠️ Aguardando confirmação de PA")

        return min(score, 12), confluences
