Distilling neural networks into wavelet models using interpretations

Fig 1. A wavelet adapting to new data.

Recent deep neural networks (DNNs) often predict extremely well, but sacrifice interpretability and computational efficiency. Interpretability is crucial in many disciplines, such as science and medicine, where models must be carefully vetted or where interpretation is the goal itself. Moreover, interpretable models are concise and often yield computational efficiency.

In our recent paper, we propose adaptive wavelet distillation (AWD), a method which distills information from a trained DNN into a wavelet transform. Surprisingly, we find that the resulting transform improves state-of-the-art predictive performance, despite being extremely concise, interpretable, and computationally efficient! (The trick is making wavelets adaptive and using interpretations, more on that later). In close collaboration with domain experts, we show how AWD addresses challenges in two real-world settings: cosmological parameter inference🌌 and molecular-partner prediction🦠.

Background and motivation

What’s a wavelet? Wavelets are an extremely powerful signal-processing tool for transforming structured signals, such as images and time-series data. Wavelets have many useful properties, including fast computation, multi-scale structure, an orthonormal basis, and interpretation in both spatial and frequency domains (if you’re interested, see this whole book on wavelets). This makes wavelets a great candidate for an interpretable model in many signal-processing applications. Traditionally, wavelets are hand-designed to satisfy humaan-specified criteria. Here, we allow wavelets to be adaptive: they change their form based on the given input data and trained DNN.

TRIM In order to adapt the wavelet model based on a given DNN, we need a way to convey information from the DNN to the wavelet parameters. This can be done using Transformation Importance (TRIM). Given a transformation $Psi$ (here the wavelet transform) and an input $x$, TRIM obtains feature importances for a given DNN $f$ in the transformed domain using a simple reparameterization trick: the DNN is pre-prended with the transform $Psi$, coefficients $Psi x$ are extracted, and then the inverse transform is

This article is purposely trimmed, please visit the source to read the full article.

The post Distilling neural networks into wavelet models using interpretations appeared first on The Berkeley Artificial Intelligence Research Blog.

This post was originally published on this site