Source code for mip_dmp.qt5.components.embedding_visualization_widget

# Copyright 2023 The HIP team, University Hospital of Lausanne (CHUV), Switzerland & Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module that defines the class dedicated to the widget that supports the visualization of the initial automated mapping matches via embedding."""

# External imports
import os
import numpy as np
import matplotlib.pyplot as plt
import pkg_resources
from matplotlib.backends.backend_qt5agg import (
    FigureCanvasQTAgg as FigureCanvas,
    NavigationToolbar2QT as NavigationToolbar,
)
from PySide2.QtCore import QCoreApplication
from PySide2.QtWidgets import QVBoxLayout, QWidget, QComboBox

# Internal imports
from mip_dmp.plot.embedding import scatterplot_embeddings
from mip_dmp.process.embedding import generate_embeddings, reduce_embeddings_dimension


# Constants
WINDOW_NAME = "Word Embedding Matches Visualization"
NB_KEPT_MATCHES = 15

[docs]class WordEmbeddingVisualizationWidget(QWidget): """Class for the widget that supports the visualization of the automated column / CDE code matches via embedding.""" def __init__(self, parent=None): """Initialize the widget. If parent is `None`, the widget renders as a separate window.""" super(WordEmbeddingVisualizationWidget, self).__init__(parent) self.adjustWindow() self.widgetLayout = QVBoxLayout() self.setLayout(self.widgetLayout) # Set up the combo box for selecting the dimensionality reduction method self.dimReductionMethodComboBox = QComboBox() self.dimReductionMethodComboBox.addItems(["tsne", "pca"]) self.widgetLayout.addWidget(self.dimReductionMethodComboBox) # Set up the combo box for selecting the word to visualize # its dimensionaly reduced embedding vector in the 3D scatter plot # with the ones of the CDE codes self.wordComboBox = QComboBox() self.widgetLayout.addWidget(self.wordComboBox) # Set up the matplotlib figure and canvas self.canvasLayout = QVBoxLayout() self.figure = plt.figure(figsize=(6, 6)) self.canvas = FigureCanvas(self.figure) self.toolbar = NavigationToolbar(self.canvas, self) self.canvasLayout.addWidget(self.canvas) self.canvasLayout.addWidget(self.toolbar) self.widgetLayout.addLayout(self.canvasLayout, stretch=1) # Initialize the class attributes self.inputDatasetColumns = list() self.targetCDECodes = list() self.inputDatasetColumnEmbeddings = list() self.targetCDECodeEmbeddings = list() self.matchedCdeCodes = dict() self.matchingMethod = None self.embeddings = dict() # Connect signals to slots self.dimReductionMethodComboBox.currentIndexChanged.connect( self.generate_embedding_figure ) self.wordComboBox.currentIndexChanged.connect(self.generate_embedding_figure)
[docs] def adjustWindow(self): """Adjust the window size, Qt Style Sheet, and title. Parameters ---------- mainWindow : QMainWindow The main window of the application. """ # Adjust the window size # self.resize(1280, 720) # Set the window Qt Style Sheet styleSheetFile = pkg_resources.resource_filename( "mip_dmp", os.path.join("qt5", "assets", "stylesheet.qss") ) with open(styleSheetFile, "r") as fh: self.setStyleSheet(fh.read()) # Set the window title self.setWindowTitle( QCoreApplication.translate(f"{WINDOW_NAME}", f"{WINDOW_NAME}", None) )
[docs] def set_word_list(self, wordList): """Set the list of words that can be visualized in the 3D scatter plot. wordList: list List of words to visualize in the 3D scatter plot """ self.wordComboBox.clear() self.wordComboBox.addItems(wordList)
[docs] def set_matching_method(self, matchingMethod): """Set the matching method. matchingMethod: str Matching method. Can be "glove" or "chars2vec" """ self.matchingMethod = matchingMethod
[docs] def generate_embeddings( self, inputDatasetColumns: list, targetCDECodes: list, matchingMethod: str ): """Generate the embeddings of the columns and CDE codes. Set the input dataset columns (`self.inputDatasetColumns`), the target CDE codes (`self.targetCDECodes`), the input dataset column embeddings (`self.inputDatasetColumnEmbeddings`) and the target CDE code embeddings (`self.targetCDECodeEmbeddings`). The embeddings are generated using the specified matching method (`matchingMethod`). The matching method can be "glove" or "chars2vec". inputDatasetColumns: list List of the input dataset columns. targetCDECodes: list List of the target CDE codes. matchingMethod: str Matching method. Can be "glove" or "chars2vec" """ self.set_matching_method(matchingMethod) self.inputDatasetColumns = inputDatasetColumns self.targetCDECodes = targetCDECodes self.inputDatasetColumnEmbeddings = generate_embeddings( inputDatasetColumns, matchingMethod ) self.targetCDECodeEmbeddings = generate_embeddings( targetCDECodes, matchingMethod )
[docs] def set_embeddings( self, inputDatasetColumnEmbeddings: list, inputDatasetColumns: list, targetCDECodeEmbeddings: list, targetCDECodes: list, matchedCdeCodes: dict, matchingMethod: str, ): """Set the input dataset column and target CDE code embeddings. inputDatasetColumnEmbeddings: list List of the input dataset column embeddings. inputDatasetColumns: list List of the input dataset columns. targetCDECodeEmbeddings: list List of the target CDE code embeddings. targetCDECodes: list List of the target CDE codes. matchedCdeCodes: dict Dictionary of the matched CDE codes in the form:: { "input_dataset_column1": { "words": ["cde_code1", "cde_code2", ...], "embeddings": [embedding_vector1, embedding_vector2, ...] "distances": [distance1, distance2, ...] }, "input_dataset_column2": { "words": ["cde_code1", "cde_code2", ...], "embeddings": [embedding_vector1, embedding_vector2, ...] "distances": [distance1, distance2, ...] }, ... } matchingMethod: str Matching method. Can be "glove" or "chars2vec". """ self.set_matching_method(matchingMethod) self.inputDatasetColumnEmbeddings = inputDatasetColumnEmbeddings self.inputDatasetColumns = inputDatasetColumns self.targetCDECodeEmbeddings = targetCDECodeEmbeddings self.targetCDECodes = targetCDECodes self.matchedCdeCodes = matchedCdeCodes # Reduce embeddings dimension to 3 components via t-SNE or PCA for visualization dim_reduction_method = self.dimReductionMethodComboBox.currentText() x, y, z = reduce_embeddings_dimension( self.inputDatasetColumnEmbeddings + self.targetCDECodeEmbeddings, reduce_method=dim_reduction_method, ) # Set the dictionary with the embeddings and their labels, format expected # by the scatterplot function self.embeddings = dict( { "x": x, "y": y, "z": z, "label": self.inputDatasetColumns + self.targetCDECodes, "type": ( ["column"] * len(self.inputDatasetColumns) + ["cde"] * len(self.targetCDECodes) ), } )
[docs] def set_wordcombobox_items(self, wordList): """Set the items of the word combo box. wordList: list List of words to visualize in the combo box of the widget that controls the selection of the word to visualize in the 3D scatter plot. """ self.wordComboBox.clear() self.wordComboBox.addItems(wordList)
[docs] def generate_embedding_figure(self): """Generate 3D scatter plot showing dimensionality-reduced embedding vectors of the words.""" if ( len(self.inputDatasetColumnEmbeddings) > 0 and len(self.targetCDECodeEmbeddings) > 0 ): matchedCdeCodes = self.matchedCdeCodes.copy() # Keep only the NB_KEPT_MATCHES most similar CDE codes for a variable for key in ["words", "distances"]: matchedCdeCodes[self.wordComboBox.currentText()][ key ] = matchedCdeCodes[self.wordComboBox.currentText()][ key ][:NB_KEPT_MATCHES] embeddings = self.embeddings.copy() embeddings = [embedding_vector[:NB_KEPT_MATCHES] for embedding_vector in embeddings] # Generate 3D scatter plot scatterplot_embeddings( self.figure, self.embeddings, self.matchedCdeCodes, self.wordComboBox.currentText(), ) # Draw the figure self.figure.canvas.draw()