Source code for mip_dmp.qt5.components.embedding_visualization_widget
# Copyright 2023 The HIP team, University Hospital of Lausanne (CHUV), Switzerland & Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module that defines the class dedicated to the widget that supports the visualization of the initial automated mapping matches via embedding."""
# External imports
import os
import numpy as np
import matplotlib.pyplot as plt
import pkg_resources
from matplotlib.backends.backend_qt5agg import (
    FigureCanvasQTAgg as FigureCanvas,
    NavigationToolbar2QT as NavigationToolbar,
)
from PySide2.QtCore import QCoreApplication
from PySide2.QtWidgets import QVBoxLayout, QWidget, QComboBox
# Internal imports
from mip_dmp.plot.embedding import scatterplot_embeddings
from mip_dmp.process.embedding import generate_embeddings, reduce_embeddings_dimension
# Constants
WINDOW_NAME = "Word Embedding Matches Visualization"
NB_KEPT_MATCHES = 15
[docs]class WordEmbeddingVisualizationWidget(QWidget):
    """Class for the widget that supports the visualization of the automated column / CDE code matches via embedding."""
    def __init__(self, parent=None):
        """Initialize the widget. If parent is `None`, the widget renders as a separate window."""
        super(WordEmbeddingVisualizationWidget, self).__init__(parent)
        self.adjustWindow()
        self.widgetLayout = QVBoxLayout()
        self.setLayout(self.widgetLayout)
        # Set up the combo box for selecting the dimensionality reduction method
        self.dimReductionMethodComboBox = QComboBox()
        self.dimReductionMethodComboBox.addItems(["tsne", "pca"])
        self.widgetLayout.addWidget(self.dimReductionMethodComboBox)
        # Set up the combo box for selecting the word to visualize
        # its dimensionaly reduced embedding vector in the 3D scatter plot
        # with the ones of the CDE codes
        self.wordComboBox = QComboBox()
        self.widgetLayout.addWidget(self.wordComboBox)
        # Set up the matplotlib figure and canvas
        self.canvasLayout = QVBoxLayout()
        self.figure = plt.figure(figsize=(6, 6))
        self.canvas = FigureCanvas(self.figure)
        self.toolbar = NavigationToolbar(self.canvas, self)
        self.canvasLayout.addWidget(self.canvas)
        self.canvasLayout.addWidget(self.toolbar)
        self.widgetLayout.addLayout(self.canvasLayout, stretch=1)
        # Initialize the class attributes
        self.inputDatasetColumns = list()
        self.targetCDECodes = list()
        self.inputDatasetColumnEmbeddings = list()
        self.targetCDECodeEmbeddings = list()
        self.matchedCdeCodes = dict()
        self.matchingMethod = None
        self.embeddings = dict()
        # Connect signals to slots
        self.dimReductionMethodComboBox.currentIndexChanged.connect(
            self.generate_embedding_figure
        )
        self.wordComboBox.currentIndexChanged.connect(self.generate_embedding_figure)
[docs]    def adjustWindow(self):
        """Adjust the window size, Qt Style Sheet, and title.
        Parameters
        ----------
        mainWindow : QMainWindow
            The main window of the application.
        """
        # Adjust the window size
        # self.resize(1280, 720)
        # Set the window Qt Style Sheet
        styleSheetFile = pkg_resources.resource_filename(
            "mip_dmp", os.path.join("qt5", "assets", "stylesheet.qss")
        )
        with open(styleSheetFile, "r") as fh:
            self.setStyleSheet(fh.read())
        # Set the window title
        self.setWindowTitle(
            QCoreApplication.translate(f"{WINDOW_NAME}", f"{WINDOW_NAME}", None)
        )
[docs]    def set_word_list(self, wordList):
        """Set the list of words that can be visualized in the 3D scatter plot.
        wordList: list
            List of words to visualize in the 3D scatter plot
        """
        self.wordComboBox.clear()
        self.wordComboBox.addItems(wordList)
[docs]    def set_matching_method(self, matchingMethod):
        """Set the matching method.
        matchingMethod: str
            Matching method. Can be "glove" or "chars2vec"
        """
        self.matchingMethod = matchingMethod
[docs]    def generate_embeddings(
        self, inputDatasetColumns: list, targetCDECodes: list, matchingMethod: str
    ):
        """Generate the embeddings of the columns and CDE codes.
        Set the input dataset columns (`self.inputDatasetColumns`), the target CDE codes (`self.targetCDECodes`),
        the input dataset column embeddings (`self.inputDatasetColumnEmbeddings`) and the target CDE code embeddings
        (`self.targetCDECodeEmbeddings`).
        The embeddings are generated using the specified matching method (`matchingMethod`).
        The matching method can be "glove" or "chars2vec".
        inputDatasetColumns: list
            List of the input dataset columns.
        targetCDECodes: list
            List of the target CDE codes.
        matchingMethod: str
            Matching method. Can be "glove" or "chars2vec"
        """
        self.set_matching_method(matchingMethod)
        self.inputDatasetColumns = inputDatasetColumns
        self.targetCDECodes = targetCDECodes
        self.inputDatasetColumnEmbeddings = generate_embeddings(
            inputDatasetColumns, matchingMethod
        )
        self.targetCDECodeEmbeddings = generate_embeddings(
            targetCDECodes, matchingMethod
        )
[docs]    def set_embeddings(
        self,
        inputDatasetColumnEmbeddings: list,
        inputDatasetColumns: list,
        targetCDECodeEmbeddings: list,
        targetCDECodes: list,
        matchedCdeCodes: dict,
        matchingMethod: str,
    ):
        """Set the input dataset column and target CDE code embeddings.
        inputDatasetColumnEmbeddings: list
            List of the input dataset column embeddings.
        inputDatasetColumns: list
            List of the input dataset columns.
        targetCDECodeEmbeddings: list
            List of the target CDE code embeddings.
        targetCDECodes: list
            List of the target CDE codes.
        matchedCdeCodes: dict
            Dictionary of the matched CDE codes in the form::
                {
                    "input_dataset_column1": {
                        "words": ["cde_code1", "cde_code2", ...],
                        "embeddings": [embedding_vector1, embedding_vector2, ...]
                        "distances": [distance1, distance2, ...]
                    },
                    "input_dataset_column2": {
                        "words": ["cde_code1", "cde_code2", ...],
                        "embeddings": [embedding_vector1, embedding_vector2, ...]
                        "distances": [distance1, distance2, ...]
                    },
                    ...
                }
        matchingMethod: str
            Matching method. Can be "glove" or "chars2vec".
        """
        self.set_matching_method(matchingMethod)
        self.inputDatasetColumnEmbeddings = inputDatasetColumnEmbeddings
        self.inputDatasetColumns = inputDatasetColumns
        self.targetCDECodeEmbeddings = targetCDECodeEmbeddings
        self.targetCDECodes = targetCDECodes
        self.matchedCdeCodes = matchedCdeCodes
        # Reduce embeddings dimension to 3 components via t-SNE or PCA for visualization
        dim_reduction_method = self.dimReductionMethodComboBox.currentText()
        x, y, z = reduce_embeddings_dimension(
            self.inputDatasetColumnEmbeddings + self.targetCDECodeEmbeddings,
            reduce_method=dim_reduction_method,
        )
        # Set the dictionary with the embeddings and their labels, format expected
        # by the scatterplot function
        self.embeddings = dict(
            {
                "x": x,
                "y": y,
                "z": z,
                "label": self.inputDatasetColumns + self.targetCDECodes,
                "type": (
                    ["column"] * len(self.inputDatasetColumns)
                    + ["cde"] * len(self.targetCDECodes)
                ),
            }
        )
[docs]    def set_wordcombobox_items(self, wordList):
        """Set the items of the word combo box.
        wordList: list
            List of words to visualize in the combo box of the widget
            that controls the selection of the word to visualize in the
            3D scatter plot.
        """
        self.wordComboBox.clear()
        self.wordComboBox.addItems(wordList)
[docs]    def generate_embedding_figure(self):
        """Generate 3D scatter plot showing dimensionality-reduced embedding vectors of the words."""
        if (
            len(self.inputDatasetColumnEmbeddings) > 0
            and len(self.targetCDECodeEmbeddings) > 0
        ):
            matchedCdeCodes = self.matchedCdeCodes.copy()
            # Keep only the NB_KEPT_MATCHES most similar CDE codes for a variable
            for key in ["words", "distances"]:
                matchedCdeCodes[self.wordComboBox.currentText()][
                    key
                ] = matchedCdeCodes[self.wordComboBox.currentText()][
                    key
                ][:NB_KEPT_MATCHES]
            embeddings = self.embeddings.copy()
            embeddings = [embedding_vector[:NB_KEPT_MATCHES] for embedding_vector in embeddings]
            # Generate 3D scatter plot
            scatterplot_embeddings(
                self.figure,
                self.embeddings,
                self.matchedCdeCodes,
                self.wordComboBox.currentText(),
            )
            # Draw the figure
            self.figure.canvas.draw()