from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import pymysql
import json
from schemas import DiagnoseRequest
from tb_knowledge import run_rule_engine
from ml_model import train_model_from_db, answers_to_features, predict_ml_model

app = FastAPI()

app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

def get_db():
    return pymysql.connect(host="103.216.116.241", user="root", password="Tung_2002@", database="copd_tb_db",port=13306)

def threshold_logic(answers):
    lao_count = sum(1 for a in answers if a.questionnaire_id <= 7 and a.answer)
    copd_count = sum(1 for a in answers if a.questionnaire_id >= 8 and a.answer)

    result = []
    if lao_count >= 2:
        result.append("lao")
    if copd_count >= 3:
        result.append("copd")
    return result


@app.get("/questions")
def get_questions():
    conn = get_db()
    with conn.cursor() as cur:
        cur.execute("SELECT id, question FROM questions")
        result = cur.fetchall()
    conn.close()
    return [{"id": row[0], "question": row[1]} for row in result]

@app.post("/diagnose")
def diagnose(req: DiagnoseRequest):
    rule_based = run_rule_engine(req.answers)
    threshold_based = threshold_logic(req.answers)
    features = answers_to_features(req.answers)

    # Lưu kết quả ban đầu
    conn = get_db()
    with conn.cursor() as cur:
        cur.execute("INSERT INTO results (name, age, gender, answers, diagnosis) VALUES (%s, %s, %s, %s, %s)",
                    (req.user.name, req.user.age, req.user.gender, json.dumps(features), "_pending_"))
        conn.commit()

        cur.execute("SELECT answers, diagnosis FROM results WHERE diagnosis != '_pending_'")
        result = cur.fetchall()
    conn.close()

    result_dicts = [{"answers": row[0], "diagnosis": row[1]} for row in result]

    ml_result = []
    if result_dicts:
        try:
            train_model_from_db(result_dicts)
            prediction = predict_ml_model(features)
            ml_result = [s.strip() for s in prediction.split(",") if s.strip() and s.strip() != "Không đủ dữ liệu ML"]
        except:
            ml_result = []

    # Tổng hợp chẩn đoán cuối cùng
    final = set(rule_based + threshold_based + ml_result)
    if not final:
        status = "Không đủ dấu hiệu rõ ràng để nghi ngờ Lao hoặc COPD"
        suggestion = "Bạn nên theo dõi thêm và thực hiện lại khảo sát sau một thời gian nếu có triệu chứng mới."
        final_result = "Không đủ dữ liệu ML"
    else:
        status = "Phân tích hoàn tất"
        suggestion = "Bạn nên thực hiện xét nghiệm chẩn đoán chuyên sâu với cơ sở y tế."
        final_result = ", ".join(sorted(final))

    # Cập nhật kết quả
    conn = get_db()
    with conn.cursor() as cur:
        cur.execute("UPDATE results SET diagnosis = %s WHERE diagnosis = '_pending_' AND name = %s AND age = %s",
                    (final_result, req.user.name, req.user.age))
        conn.commit()
    conn.close()

    return {
        "rule_based": sorted(set(rule_based)),
        "logic_threshold": sorted(set(threshold_based)),
        "ml_prediction": sorted(set(ml_result)),
        "final_diagnosis": sorted(final),
        "status": status,
        "suggestion": suggestion
    }
