# coding:utf-8
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2021-2026 Peng-Hui Guo <m@guo.ph>

"""
Integer Recourse (Gurobi)
=======================================================
"""

# %%
# Import necessary packages.

import json
import os
import sys
import time

from benderslib import CallbackBase, BendersContext, CST, IntegerLShaped, LShapedOCGen, SubProblems
from benderslib.solvers import Gurobi

from gurobipy import LinExpr

try:
    sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
except NameError:
    sys.path.insert(0, os.path.abspath("."))

from _utils import SMPSReader, first_stage_model, second_stage_model, deterministic_equivalent_model, collect_data, \
    draw, limit_memory


# %%
# Define a callback for the in-out stabilization.
#
# .. seealso::
#
#     - Fischetti, M., Ljubić, I., & Sinnl, M. (2016). Benders decomposition without separability:
#       A computational study for capacitated facility location problems. European Journal of
#       Operational Research, 253(3), 557–569. https://doi.org/10.1016/j.ejor.2016.03.002
#     - Fischetti, M., Ljubić, I., & Sinnl, M. (2017). Redesigning benders decomposition for
#       large-scale facility location. Management Science, 63(7),
#       2146–2162. https://doi.org/10.1287/mnsc.2016.2461
#
# .. warning::
#
#    This callback is only compatible with single-cut Benders methods.

class InOut(CallbackBase):

    def __init__(self, lambda_, alpha, delta, n, m, integer=False):
        self.lambda_ = lambda_
        self.alpha = alpha
        self.delta = delta

        self.n = n
        self.m = m

        self.integer = integer

        self.master_linear = None
        self.sub_linear = None
        self.cut_generator = None
        self.core = None
        self.lb_not_improved_iter_num = 0

        self.classical_cut_time = 0

    def on_sub_build(self, context: BendersContext):
        # Initialize required attributes for the callback.

        if self.sub_linear is None:
            sub_models = [p.model.relax() for p in context.sub_problem]
            self.sub_linear = SubProblems([Gurobi(s) for s in sub_models])
            self.sub_linear.complicating_vars = context.master_problem.complicating_vars

            # Set theta_lb based on linear relaxation of the subproblems
            theta_lb = 0
            for sub in self.sub_linear:
                sub.model.optimize()
                theta_lb += sub.model.ObjVal

            theta_lb = theta_lb / len(sub_models)
            # Add a small margin to avoid numerical issues
            theta_lb -= abs(theta_lb) * 0.2
            context.benders.params.theta_lb = theta_lb
            est = context.master_problem.estimators[0]
            var = context.master_problem.model.getVarByName(est)
            var.setAttr("LB", theta_lb)
            print(f"InOut callback: set <params.theta_lb> to < {theta_lb:.2f} >.")

        if self.master_linear is None:
            if self.integer:
                self.master_linear = context.master_problem.model.copy()
            else:
                self.master_linear = context.master_problem.model.relax()

        if self.core is None:
            self.master_linear.optimize()
            self.core = {v: self.master_linear.getVarByName(v).X for v in context.master_problem.complicating_vars}

        if self.cut_generator is None:
            self.cut_generator = LShapedOCGen(context.master_problem, self.sub_linear, context.benders.params)

        self._add_stabilization_cuts(context, m=self.m)

    def on_opti_cut_generated(self, context: BendersContext):
        if context.benders.result.gap < 0.005:
            return

        t_ = time.perf_counter()

        self.sub_linear.fix_vars(context.current_comp_vals)
        self.sub_linear.prl_solve()
        cut = self.cut_generator.generate()[0]
        context.current_opti_cuts.append(cut)

        self.classical_cut_time += time.perf_counter() - t_

    def on_benders_end(self, context: BendersContext):
        print(f"InOut callback: time for classical cuts is < {self.classical_cut_time:.2f} > seconds.")

    def _add_stabilization_cuts(self, context: BendersContext, m):
        start_time = time.perf_counter()
        cuts = []
        constrs = []
        current_obj = -float('Inf')

        for i in range(m):
            lambda_ = self.lambda_
            delta_ = self.delta
            if self.lb_not_improved_iter_num >= self.n:
                self.lb_not_improved_iter_num = 0
                lambda_ = 1
                delta_ = 0
            if self.lb_not_improved_iter_num >= self.n * 2:
                break

            # Update points
            point = dict()
            for var_name in self.core:
                x = self.master_linear.getVarByName(var_name).X
                self.core[var_name] = self.alpha * x + (1 - self.alpha) * self.core[var_name]
                point[var_name] = lambda_ * x + (1 - lambda_) * self.core[var_name] + delta_

            # Generate cuts
            self.sub_linear.fix_vars(point)
            self.sub_linear.prl_solve()
            cut = self.cut_generator.generate()[0]

            # Add cuts
            vars = [self.master_linear.getVarByName(var_name) for var_name in cut.vars]
            expr = LinExpr(cut.coefs, vars)
            # LShapedOCGen returns only >= cuts.
            assert cut.sense == CST.GE
            cons = self.master_linear.addConstr(expr >= cut.rhs)
            constrs.append(cons)
            cuts.append(cut)

            # Check lower bound improvement
            self.master_linear.optimize()
            if self.master_linear.ObjVal > current_obj + 1e-6:
                current_obj = self.master_linear.ObjVal
                self.lb_not_improved_iter_num = 0
            else:
                self.lb_not_improved_iter_num += 1

        # Add cut without positive slack to the master problem
        cut_added_num = 0
        for cons, cut in zip(constrs, cuts):
            if cons.Slack <= 0 and not cut in context.master_problem.optimality_cuts:
                context.master_problem.add_cut(cut)
                cut_added_num += 1
            else:
                self.master_linear.remove(cons)

        end_time = time.perf_counter()
        print(f"InOut callback: generated <{cut_added_num}> cuts in < {(end_time - start_time):.2f} > seconds.")


# %%
# Solve the instances using different methods and save the results.

@limit_memory(limit_gb=14.5)
def solve(smps_files, instance_name, time_limit, solve_methods):
    SMPS = SMPSReader(*smps_files)
    SMPS.parse()
    ins_file = f"./_ins/{instance_name}.json"
    SMPS.to_json(ins_file)
    with open(ins_file, 'r') as f:
        data = json.load(f)

    # Solve using deterministic equivalent
    if "de" in solve_methods:
        model = deterministic_equivalent_model(data)
        model.setParam('TimeLimit', time_limit)
        model.optimize()
        model.write(f"./_sol/{instance_name}_de.json")

        # model.write("_temp.mps")
        # from pyscipopt import Model
        # scip = Model()
        # scip.readProblem(filename="_temp.mps")
        # scip.optimize()
        # scip.writeStatisticsJson(f"./_sol/{instance_name}_de.json")

    # Solve using Benders decomposition
    if "bd" in solve_methods:
        master_model, complicating_vars = first_stage_model(data)
        sub_models, probs = second_stage_model(data)
        BD = IntegerLShaped.from_models(
            master_model=master_model,
            master_solver=Gurobi,
            sub_model=sub_models,
            sub_solver=Gurobi,
            complicating_vars=complicating_vars,
            prob=probs,
        )

        if "sslp" in instance_name:
            # The SSLP instances have negative objective values,
            # so we need to set a negative value to `theta_lb`.
            # A stronger `theta_lb` is provided in the callback.
            BD.params.theta_lb = -1e3
            BD.register(InOut(lambda_=0.2, alpha=0.3, delta=0.2, n=5, m=100, integer=True))

        if "smkp" in instance_name:
            BD.register(InOut(lambda_=0.2, alpha=0.3, delta=0.2, n=5, m=100))

        BD.params.parallel_sub = True
        BD.params.use_bnc = True
        BD.params.time_limit = time_limit
        BD.solve()
        BD.save(f"./_sol/{instance_name}_bd.json")


# %%
# .. rubric:: Set 2 Instances
#
# - *SSLP*, *SMKP*: https://www2.isye.gatech.edu/~sahmed/siplib/
#
# We selected these instances based on these criteria:
#
# - They are two-stage stochastic programming instances.
# - They have binary complicating variables and (mixed) integer recourse, such that combinatorial cuts can be used.
#
# Other stochastic programming instance collections:
#
# - SPS Resources: https://www.stoprog.org/resources
# - splib: https://github.com/vitaut-archive/splib
# - SIPLIB: https://www2.isye.gatech.edu/~sahmed/siplib/ (the source of the instances used in this example)
# - POSTS: https://users.iems.northwestern.edu/~jrbirge/html/dholmes/post.html
# - List of Optimization Problem Libraries: https://github.com/ekhoda/optimization_problem_libraries
# - MSPLib-Library: https://github.com/bonnkleiford/MSPLib-Library
# - RANDOMRHS 2013: https://users.wpi.edu/~atrapp/randomrhs_2013.htm


def run(solve_methods=None, draw_result=False, dry_run=True):
    if solve_methods is None:
        solve_methods = ["de", "bd"]

    _dir = './set2'

    smps_files = {
        # Source: https://www2.isye.gatech.edu/~sahmed/siplib/

        "sslp_1": (_dir + "/sslp/sslp_15_45_5.cor",
                   _dir + "/sslp/sslp_15_45_5.tim",
                   _dir + "/sslp/sslp_15_45_5.sto"),
        "sslp_2": (_dir + "/sslp/sslp_15_45_10.cor",
                   _dir + "/sslp/sslp_15_45_10.tim",
                   _dir + "/sslp/sslp_15_45_10.sto"),
        "sslp_3": (_dir + "/sslp/sslp_15_45_15.cor",
                   _dir + "/sslp/sslp_15_45_15.tim",
                   _dir + "/sslp/sslp_15_45_15.sto"),
        "sslp_4": (_dir + "/sslp/sslp_5_25_50.cor",
                   _dir + "/sslp/sslp_5_25_50.tim",
                   _dir + "/sslp/sslp_5_25_50.sto"),
        "sslp_5": (_dir + "/sslp/sslp_5_25_100.cor",
                   _dir + "/sslp/sslp_5_25_100.tim",
                   _dir + "/sslp/sslp_5_25_100.sto"),
        "sslp_6": (_dir + "/sslp/sslp_10_50_100.cor",
                   _dir + "/sslp/sslp_10_50_100.tim",
                   _dir + "/sslp/sslp_10_50_100.sto"),
        "sslp_7": (_dir + "/sslp/sslp_10_50_500.cor",
                   _dir + "/sslp/sslp_10_50_500.tim",
                   _dir + "/sslp/sslp_10_50_500.sto"),
        "sslp_8": (_dir + "/sslp/sslp_10_50_1000.cor",
                   _dir + "/sslp/sslp_10_50_1000.tim",
                   _dir + "/sslp/sslp_10_50_1000.sto"),
        "sslp_9": (_dir + "/sslp/sslp_10_50_2000.cor",
                   _dir + "/sslp/sslp_10_50_2000.tim",
                   _dir + "/sslp/sslp_10_50_2000.sto"),
    }

    smps_files.update({
        # Source: https://www2.isye.gatech.edu/~sahmed/siplib/

        f"smkp_{i}": (_dir + f"/smkp/smkp_{i}.cor",
                      _dir + f"/smkp/smkp_{i}.tim",
                      _dir + f"/smkp/smkp_{i}.sto")
        for i in range(1, 31)
    })

    ins_names = []
    de_files = []
    bd_files = []
    ins_classes = []
    s_nums = []

    for ins_name, smps_files_ in smps_files.items():
        ins_name = ins_name
        de_file = f"./_sol/{ins_name}_de.json"
        bd_file = f"./_sol/{ins_name}_bd.json"

        ins_names.append(ins_name)
        de_files.append(de_file)
        bd_files.append(bd_file)
        ins_classes.append(ins_name.split("_")[0])
        s_nums.append(int(smps_files_[0].split("_")[-1].split(".")[0]))

        if not dry_run:
            solve(smps_files_, ins_name, time_limit=3600, solve_methods=solve_methods)

    data_points = collect_data(
        ins_names=ins_names,
        de_files=de_files,
        bd_files=bd_files,
        ins_classes=ins_classes,
        sample_nums=s_nums,
    )

    if draw_result:
        draw(data_points)

    return data_points


if __name__ == "__main__":
    ...
    # run(draw_result=True)

# %%
#
# .. seealso::
#
#     * Tutorial of the Logic-based Benders Decomposition: :doc:`../../tutorials/ilshaped`
#     * This example uses the following class: :class:`~benderslib.IntegerLShaped`
#
# .. tags:: benders: integer l-shaped, solver: gurobi, stochastic, branch-and-check, callback, enhancement
