"""
ASM3: dependency graph
============================================================================
"""
# Copyright (C) 2024 Juan Pablo Carbajal
# Copyright (C) 2024 Mariane Yvonne Schneider
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
# Author: Juan Pablo Carbajal
import matplotlib.pyplot as plt
import numpy as np
from sympy import *
from pint import UnitRegistry
try:
import dsafeatures
except ModuleNotFoundError:
import sys
import os
sys.path.insert(0, os.path.abspath("../.."))
import dsafeatures.odemodel as m
import dsafeatures.symtools as st
from dsafeatures.printing import *
init_printing()
ureg = UnitRegistry() # units registry
# %%
# Load the model
# --------------
asm = m.ODEModel.from_csv(name="ASM3")
# States
X = asm.states
# Process rates
r_sym, r = asm.rates, asm.rates_expr
# Matrix
M = asm.coef_matrix
# State derivatives
dotX = asm.derivative(1)
# %%
# Graph
# -------
g = asm.graph
# nodes properties for visualization
vis_attr = {"isolated": dict(color="grey", size=1600),
"rate": dict(color="gold", size=600),
"measurement": dict(color="pink", size=1600),
"state": dict(color="red", size=1600),
"input": dict(color="violet", size=1600),
"other": dict(color="black", size=800),
}
# set initial positions
pos = st.nx.multipartite_layout(g, "layer")
p_ = np.asarray(list(pos.values()))
cnt = np.asarray([p_.max(axis=0)[0], p_.mean(axis=0)[1]])
# organize core
core = [n for n, d in g.nodes(data=True) if d["layer"] in ["rate", "state"]]
pos_core = st.nx.spring_layout(g.subgraph(core).to_undirected(), k=0.4, iterations=5000, center=cnt, seed=1234321)
pos.update(pos_core)
# orbit measurements and inputs
iso = [n for n, d in g.nodes(data=True) if d["layer"] in ["isolated", "other"]]
pos_meas = st.nx.spring_layout(g.to_undirected(), pos=pos, fixed=core+iso, k=0.5, iterations=1000)
pos.update(pos_meas)
# move isolated and other
mx = min(c[0] for c in pos.values())
for n in iso:
pos[n][0] = mx
#N2, _ = asm.state_by_name("S_N2")
#pos[N2] -= pos[N2] * 0.4
#NE, _ = asm.state_by_name("X_UE")
#pos[NE] -= pos[NE] * 0.2
#Xh, _ = asm.state_by_name("X_OHO")
#pos[Xh] += [0.2, -0.15]
r_n = [n for n, d in g.nodes(data=True) if d["layer"]=="rate"]
pos[r_n[1]][1] += 0.25
st.nx.draw(g, pos=pos,
node_color=[vis_attr[d["layer"]]["color"] for _, d in g.nodes(data=True)],
node_size=[vis_attr[d["layer"]]["size"] for _, d in g.nodes(data=True)],
with_labels=True,
labels={x: latex(x, mode="inline") for x in g.nodes},
arrowsize=20,
)
# %%
# Model components
# ------------------
# Symbols and process rates name
# *******************************
def to_latex(x):
if "_" in x:
p = x.split("_", 1)
x = fr"{p[0]}_\text{{{p[1]}}}"
return fr"${x}$"
for c in ("states", "processrates"):
info_ = asm.component[c].rename(columns={"description": "Description"})
info_["This work"] = info_.sympy.apply(lambda x: latex(x, mode='inline'))
info_["Other name"] = info_.name.apply(to_latex)
print(info_[["This work", "Other name", "Description"]].to_latex(index=False))
# %%
# Description of active states
info_ = asm.component["states"].set_index("sympy")
info_ = [f"\item[{latex(x, mode='inline')}] {info_.loc[x].description}" for x in X if x in core]
info_ = [r"\begin{itemize}"] + info_ + [r"\end{itemize}"]
print("\n".join(info_))
# %%
# Parameters
# ************
info_ = asm.component["parameters"][["sympy", "value", "name", "unit", "description"]].copy()
info_.rename(columns={"description": "Description",
"value": "Value",
"unit": "Units"}, inplace=True)
info_["This work"] = info_.sympy.apply(lambda x: latex(x, mode='inline'))
info_["Other name"] = info_.name.apply(to_latex)
info_["Value"] = info_.Value.apply(lambda x: latex(x, mode='inline'))
def units_to_latex(x):
try:
y = ureg.parse_units(x)
if y.dimensionless:
# handle unit kinds, WIP in pint https://github.com/hgrecco/pint/pull/1967
return x
y = f"{y:Lx}"
return y.replace("[","").replace("]","")
except AssertionError:
return "$-$"
info_["Units"] = info_.Units.apply(units_to_latex)
info_["Description"] = info_.Description.str.replace(r"\(([\\\w]+)[_]([\\\w]+)\)", r"($\1_\\text{\2}$)", regex=True)
info_.sort_values(by="This work", inplace=True)
print(info_[["This work", "Other name", "Value", "Units", "Description"]].to_latex(index=False))
# %%
# States vector
# **************
print(latex(X))
# %%
# Matrix
# *******
M_entries = [Symbol(fr"m_{{{i}\,{j.name}}}") for i in X for j in r_sym]
Ms = Matrix(np.asarray(M_entries).reshape(*M.shape))
for i in range(M.shape[0]):
for j in range(M.shape[1]):
if isinstance(M[i, j], Number):
Ms[i, j] = M[i, j]
print(latex(Ms))
# %%
info_ = [fr"{latex(ms)} &\coloneqq {latex(m)} \\" for ms, m in zip(Ms, M) if not isinstance(m, Number)]
info_ = [r"\begin{align}"]+info_ + [r"\end{align}"]
print("\n".join(info_))
# %%
# Rates vector
# ************
print(latex(r_sym))
# %%
info_ = [fr"{latex(r_)} &\coloneqq {latex(f_)} \\" for r_, f_ in zip(r_sym, r)]
info_ = [r"\begin{align}"]+info_ + [r"\end{align}"]
print("\n".join(info_))
# %%
# ODE
# ****
def symb_to_dot(x):
if "_" in x.name:
p = x.name.split("_", 1)
x = fr"\dot{{{p[0]}}}_{p[1]}"
else:
x = fr"\dot{{{x.name}}}"
return x
dotX_sym = Matrix([Symbol(symb_to_dot(x)) for x in X])
dotX_expr = Ms * Matrix([Symbol(x_.name) for x_ in r_sym])
sub_order = ["state", "measurement", "isolated", "other"]
for layer in sub_order:
info_ = []
for x_, dx_, f_ in zip(X, dotX_sym, dotX_expr):
if g.nodes[x_]["layer"] == layer:
info_.append(fr"{latex(dx_)} &\coloneqq {latex(f_)} \\")
if info_:
info_ = [r"\begin{align}"]+ info_ + [fr"\label{{eq:asm1-ode-{layer}}}\end{{align}}"]
print("\n".join(info_))
# info_ = []
# order = [0] * len(X)
# sub_order = ["state", "measurement", "isolated", "other"]
# for idx, x_dx_f_ in enumerate(zip(X, dotX_sym, dotX_expr)):
# x_, dx_, f_ = x_dx_f_
# order[idx] = sub_order.index(g.nodes[x_]["layer"])
# info_.append(fr"{latex(dx_)} &\coloneqq {latex(f_)} \\")
# info_ = sorted(info_, key=lambda x_: order[info_.index(x_)])
# info_ = [r"\begin{align}"]+ info_ + [r"\end{align}"]
# print("\n".join(info_))
# %%
plt.show()