Source code for pm4py.visualization.decisiontree.util.dt_to_string

'''
    PM4Py – A Process Mining Library for Python
Copyright (C) 2024 Process Intelligence Solutions UG (haftungsbeschränkt)

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see this software project's root or 
visit <https://www.gnu.org/licenses/>.

Website: https://processintelligence.solutions
Contact: info@processintelligence.solutions
'''
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
from typing import Dict, Tuple, Set, List


[docs] def apply( clf: DecisionTreeClassifier, columns: List[str] ) -> Tuple[Dict[str, str], Dict[str, Set[str]]]: """ Translates a decision tree object into a dictionary associating a set of conditions for each target class Parameters ---------------- clf Decision tree classifier columns Columns Returns ---------------- dict_classes Dictionary associating a set of conditions for each target class """ tree_string = export_text(clf).split("\n") levels = {} target_classes = {} variables = {} i = 0 while i < len(tree_string): if "---" in tree_string[i]: level = len(tree_string[i].split("|")) - 2 this_part = tree_string[i].split("--- ")[1] this_part_idx_space = this_part.index(" ") this_part_0 = this_part[:this_part_idx_space] this_part_1 = this_part[this_part_idx_space + 1:] if "class" in this_part: all_levels = ( "(" + " && ".join([levels[i] for i in range(level)]) + ")" ) target_class = this_part.split(": ")[-1] if target_class not in target_classes: target_classes[target_class] = [] target_classes[target_class].append(all_levels) if target_class not in variables: variables[target_class] = set() for j in range(level): variables[target_class].add(levels[j].split(" ")[0]) else: levels[level] = ( columns[int(this_part_0.split("_")[1])] + " " + this_part_1 ) i = i + 1 for cl in target_classes: target_classes[cl] = " || ".join(target_classes[cl]) return target_classes, variables