Skip to content

model_explorer

ModelExplorer

This is ModelExplorer docstring.

Source code in src/antakia/gui/tabs/model_explorer.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class ModelExplorer:
    """
    This is ModelExplorer docstring.
    """

    def __init__(self, X: pd.DataFrame):
        self.build_widget()
        self.model: MLModel | None = None
        self.region: ModelRegion | None = None
        self.X = X
        pass

    def build_widget(self):
        self.feature_importance_tab = v.TabItem(  # Tab 1) feature importances # 43
            class_="mt-2",
            children=[]
        )
        self.pdp_feature_select = v.Select()
        self.pdp_figure = v.Container()
        self.widget = v.Tabs(
            v_model=0,  # default active tab
            children=[
                         v.Tab(children=["Feature Importance"]),
                         v.Tab(children=["Partial Dependency"]),
                     ]
                     +
                     [
                         self.feature_importance_tab,
                         v.TabItem(  # Tab 2) Partial dependence
                             children=[
                                 v.Col(
                                     children=[
                                         self.pdp_feature_select,
                                         self.pdp_figure
                                     ]
                                 )
                             ]
                         ),  # End of v.TabItem #2
                     ]
        )
        self.pdp_feature_select.on_event('change', self.display_pdp)

    def update_selected_model(self, model: MLModel, region: ModelRegion):
        self.model = model
        self.region = region
        self.update_feature_importances()
        self.update_pdp_tab()

    def update_feature_importances(self):
        if self.model is not None:
            feature_importances = self.model.feature_importances_.sort_values(ascending=True)
            fig = Bar(x=feature_importances, y=feature_importances.index, orientation='h')
            self.figure_fi = FigureWidget(data=[fig])
            self.figure_fi.update_layout(
                autosize=True,
                margin={
                    't': 0,
                    'b': 0,
                    'l': 0,
                    'r': 0
                },
            )
            self.figure_fi._config = self.figure_fi._config | {"displaylogo": False}

            self.feature_importance_tab.children = [self.figure_fi]
        else:
            self.feature_importance_tab.children = []

    def update_pdp_tab(self):
        if self.pdp_feature_select.v_model not in self.X.columns and self.model is not None:
            features = list(self.model.feature_importances_.sort_values(ascending=False).index)
            self.pdp_feature_select.items = features
            self.pdp_feature_select.v_model = features[0]
        self.display_pdp()

    def display_pdp(self, *args):
        if self.model is not None:
            selected_feature = self.pdp_feature_select.v_model
            if self.X[self.region.mask][selected_feature].nunique() > 1:
                predict_func = self.model.__class__.predict
                figure = PDPIsolate(
                    df=self.X.copy(), feature=selected_feature, feature_name=selected_feature,
                    model=self.model, model_features=self.X.columns, pred_func=predict_func,
                    n_classes=0  # regression
                ).plot()[0]
                self.figure_pdp = FigureWidget(figure)
                self.figure_pdp.update_layout(
                    autosize=True, width=None, height=None,
                    margin={
                        't': 0,
                        'b': 0,
                        'l': 0,
                        'r': 0
                    },
                )
                self.figure_pdp._config = self.figure_pdp._config | {"displaylogo": False}

                self.pdp_figure.children = [self.figure_pdp]
            else:
                self.pdp_figure.children = ['only one feature value, no pdp to display']
        else:
            self.pdp_figure.children = []

    def reset(self):
        self.model = None
        self.region = None
        self.update_feature_importances()
        self.update_pdp_tab()