"""
ASM1: equations for the 0-level set of teh 2nd time derivative of SO
============================================================================

"""
# Copyright (C) 2025 Juan Pablo Carbajal
# Copyright (C) 2025 Mariane Yvonne Schneider
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

# Author: Juan Pablo Carbajal <ajuanpi@gmail.com>

import matplotlib.pyplot as plt
from sympy import *


try:
    import dsafeatures
except ModuleNotFoundError:
    import sys
    import os

    sys.path.insert(0, os.path.abspath("../.."))

from dsafeatures.printing import *
from dsafeatures.odemodel import ODEModel

init_printing()

plt.style.use('figures.mplstyle')

# %%
# Select model
# -------------
model = ODEModel.from_csv(name='ASM1')

X = model.states
r_sym, r = model.rates, model.rates_expr
M = model.coef_matrix
O2, idxO2 = model.state_by_name("S_O2")

# Aeration
# Define aeration
O2_max = Symbol(r'\text{SO}_\text{sat}', real=True, positive=True)
max_O2_inflow = Symbol(r'\dot{q}_{\text{SO}\text{max}}', real=True, positive=True)
aer = Function(r'\dot{q}_\text{SO}', real=True, postivive=True)(O2, O2_max, max_O2_inflow)

# Extended system
r_sym_ext = r_sym.col_join(Matrix([aer]))
one_O2 = eye(len(X))[:, [idxO2]]
M = M.row_join(one_O2)

# First derivative
dotX = M @ r_sym_ext
# 2nd derivative os SO
Jr = r_sym_ext.jacobian(X)
ddotO2 = (M[idxO2, :] @ (Jr @ dotX))[0]

# %%
# We just use numbered variables for states and paramters
x = symbols(fr'x1:{len(X)+1}', real=True, positive=True)
state_subs = dict(zip(X, x))
p = symbols(fr'p1:{len(model.parameters())+1}', real=True, positive=True)
param_subs = dict(zip(model.parameters(), p))

# %%
# We use static symbols for the rates and their derivatives
f = symbols(fr'r1:{len(r_sym_ext)+1}', real=True)
rate_subs = dict(zip(r_sym_ext, f))
partials = [d for d in Jr.flat() if not d.is_zero]
df = [Symbol(fr'\,{{}}_{{{latex(state_subs[d.args[1][0]])}}}{latex(rate_subs[d.args[0]])}', real=True) for d in partials]
dr_subs = dict(zip(partials, df))
# %%
# replace these on the 2nd derivative
ddO2_s = ddotO2.subs(dr_subs).subs(rate_subs).subs(param_subs)

# collect variables until nothing changes
ddO2_vars = [s for s in ddO2_s.free_symbols if s not in param_subs.values()]
ddO2_s = ddO2_s.expand()
while (expr_ := ddO2_s.collect(ddO2_vars)) != ddO2_s:
    ddO2_s = expr_

# %%
# Further, we replace constant expressions with new parameter names.
# To do this we need to detect all expressions that contain only parameters
def has_only(expr, syms):
    """Check if expressions has only syms."""
    return not bool(expr.free_symbols - set(syms))

def subs_expr_withall(expr, syms, x):
    if has_only(expr, syms):
        return x


pnew = numbered_symbols(prefix='k', start=1)
pbundle_subs = {}
for a in preorder_traversal(ddO2_s):
    if a.is_Atom or not has_only(a, p):
        continue
    elif a not in pbundle_subs.keys():
        p_ = next(pnew)
        pbundle_subs[a] = p_
# %%
# We now replace the new symbols.
# However, this might leave us with repeated new constants.
# Hence, we find the remaining constants and renumber them
ddO2_s = ddO2_s.subs(pbundle_subs).collect(pbundle_subs.values())
remain_ = ddO2_s.free_symbols - set(ddO2_vars)
subs_ = {r: s for r, s in zip(remain_, numbered_symbols(prefix='C', start=1))}
ddO2_s = ddO2_s.subs(subs_)
pbundle_subs = {k:subs_[v] for k,v in pbundle_subs.items() if v in subs_.keys()}

# %%
# Final expression
ddO2_s.expand().collect(ddO2_vars).collect(pbundle_subs.values())

# %%
plt.show()