\[ \newcommand{\figleft}{{\em (Left)}} \newcommand{\figcenter}{{\em (Center)}} \newcommand{\figright}{{\em (Right)}} \newcommand{\figtop}{{\em (Top)}} \newcommand{\figbottom}{{\em (Bottom)}} \newcommand{\captiona}{{\em (a)}} \newcommand{\captionb}{{\em (b)}} \newcommand{\captionc}{{\em (c)}} \newcommand{\captiond}{{\em (d)}} \newcommand{\newterm}[1]{{\bf #1}} \def\figref#1{figure~\ref{#1}} \def\Figref#1{Figure~\ref{#1}} \def\twofigref#1#2{figures \ref{#1} and \ref{#2}} \def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}} \def\secref#1{section~\ref{#1}} \def\Secref#1{Section~\ref{#1}} \def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}} \def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}} \def\eqref#1{equation~\ref{#1}} \def\Eqref#1{Equation~\ref{#1}} \def\plaineqref#1{\ref{#1}} \def\chapref#1{chapter~\ref{#1}} \def\Chapref#1{Chapter~\ref{#1}} \def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}} \def\algref#1{algorithm~\ref{#1}} \def\Algref#1{Algorithm~\ref{#1}} \def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}} \def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}} \def\partref#1{part~\ref{#1}} \def\Partref#1{Part~\ref{#1}} \def\twopartref#1#2{parts \ref{#1} and \ref{#2}} \def\ceil#1{\lceil #1 \rceil} \def\floor#1{\lfloor #1 \rfloor} \def\1{\bm{1}} \newcommand{\train}{\mathcal{D}} \newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}} \newcommand{\test}{\mathcal{D_{\mathrm{test}}}} \def\eps{{\epsilon}} \def\reta{{\textnormal{$\eta$}}} \def\ra{{\textnormal{a}}} \def\rb{{\textnormal{b}}} \def\rc{{\textnormal{c}}} \def\rd{{\textnormal{d}}} \def\re{{\textnormal{e}}} \def\rf{{\textnormal{f}}} \def\rg{{\textnormal{g}}} \def\rh{{\textnormal{h}}} \def\ri{{\textnormal{i}}} \def\rj{{\textnormal{j}}} \def\rk{{\textnormal{k}}} \def\rl{{\textnormal{l}}} \def\rn{{\textnormal{n}}} \def\ro{{\textnormal{o}}} \def\rp{{\textnormal{p}}} \def\rq{{\textnormal{q}}} \def\rr{{\textnormal{r}}} \def\rs{{\textnormal{s}}} \def\rt{{\textnormal{t}}} \def\ru{{\textnormal{u}}} \def\rv{{\textnormal{v}}} \def\rw{{\textnormal{w}}} \def\rx{{\textnormal{x}}} \def\ry{{\textnormal{y}}} \def\rz{{\textnormal{z}}} \def\rvepsilon{{\mathbf{\epsilon}}} \def\rvtheta{{\mathbf{\theta}}} \def\rva{{\mathbf{a}}} \def\rvb{{\mathbf{b}}} \def\rvc{{\mathbf{c}}} \def\rvd{{\mathbf{d}}} \def\rve{{\mathbf{e}}} \def\rvf{{\mathbf{f}}} \def\rvg{{\mathbf{g}}} \def\rvh{{\mathbf{h}}} \def\rvu{{\mathbf{i}}} \def\rvj{{\mathbf{j}}} \def\rvk{{\mathbf{k}}} \def\rvl{{\mathbf{l}}} \def\rvm{{\mathbf{m}}} \def\rvn{{\mathbf{n}}} \def\rvo{{\mathbf{o}}} \def\rvp{{\mathbf{p}}} \def\rvq{{\mathbf{q}}} \def\rvr{{\mathbf{r}}} \def\rvs{{\mathbf{s}}} \def\rvt{{\mathbf{t}}} \def\rvu{{\mathbf{u}}} \def\rvv{{\mathbf{v}}} \def\rvw{{\mathbf{w}}} \def\rvx{{\mathbf{x}}} \def\rvy{{\mathbf{y}}} \def\rvz{{\mathbf{z}}} \def\erva{{\textnormal{a}}} \def\ervb{{\textnormal{b}}} \def\ervc{{\textnormal{c}}} \def\ervd{{\textnormal{d}}} \def\erve{{\textnormal{e}}} \def\ervf{{\textnormal{f}}} \def\ervg{{\textnormal{g}}} \def\ervh{{\textnormal{h}}} \def\ervi{{\textnormal{i}}} \def\ervj{{\textnormal{j}}} \def\ervk{{\textnormal{k}}} \def\ervl{{\textnormal{l}}} \def\ervm{{\textnormal{m}}} \def\ervn{{\textnormal{n}}} \def\ervo{{\textnormal{o}}} \def\ervp{{\textnormal{p}}} \def\ervq{{\textnormal{q}}} \def\ervr{{\textnormal{r}}} \def\ervs{{\textnormal{s}}} \def\ervt{{\textnormal{t}}} \def\ervu{{\textnormal{u}}} \def\ervv{{\textnormal{v}}} \def\ervw{{\textnormal{w}}} \def\ervx{{\textnormal{x}}} \def\ervy{{\textnormal{y}}} \def\ervz{{\textnormal{z}}} \def\rmA{{\mathbf{A}}} \def\rmB{{\mathbf{B}}} \def\rmC{{\mathbf{C}}} \def\rmD{{\mathbf{D}}} \def\rmE{{\mathbf{E}}} \def\rmF{{\mathbf{F}}} \def\rmG{{\mathbf{G}}} \def\rmH{{\mathbf{H}}} \def\rmI{{\mathbf{I}}} \def\rmJ{{\mathbf{J}}} \def\rmK{{\mathbf{K}}} \def\rmL{{\mathbf{L}}} \def\rmM{{\mathbf{M}}} \def\rmN{{\mathbf{N}}} \def\rmO{{\mathbf{O}}} \def\rmP{{\mathbf{P}}} \def\rmQ{{\mathbf{Q}}} \def\rmR{{\mathbf{R}}} \def\rmS{{\mathbf{S}}} \def\rmT{{\mathbf{T}}} \def\rmU{{\mathbf{U}}} \def\rmV{{\mathbf{V}}} \def\rmW{{\mathbf{W}}} \def\rmX{{\mathbf{X}}} \def\rmY{{\mathbf{Y}}} \def\rmZ{{\mathbf{Z}}} \def\ermA{{\textnormal{A}}} \def\ermB{{\textnormal{B}}} \def\ermC{{\textnormal{C}}} \def\ermD{{\textnormal{D}}} \def\ermE{{\textnormal{E}}} \def\ermF{{\textnormal{F}}} \def\ermG{{\textnormal{G}}} \def\ermH{{\textnormal{H}}} \def\ermI{{\textnormal{I}}} \def\ermJ{{\textnormal{J}}} \def\ermK{{\textnormal{K}}} \def\ermL{{\textnormal{L}}} \def\ermM{{\textnormal{M}}} \def\ermN{{\textnormal{N}}} \def\ermO{{\textnormal{O}}} \def\ermP{{\textnormal{P}}} \def\ermQ{{\textnormal{Q}}} \def\ermR{{\textnormal{R}}} \def\ermS{{\textnormal{S}}} \def\ermT{{\textnormal{T}}} \def\ermU{{\textnormal{U}}} \def\ermV{{\textnormal{V}}} \def\ermW{{\textnormal{W}}} \def\ermX{{\textnormal{X}}} \def\ermY{{\textnormal{Y}}} \def\ermZ{{\textnormal{Z}}} \def\vzero{{\bm{0}}} \def\vone{{\bm{1}}} \def\vmu{{\bm{\mu}}} \def\vtheta{{\bm{\theta}}} \def\va{{\bm{a}}} \def\vb{{\bm{b}}} \def\vc{{\bm{c}}} \def\vd{{\bm{d}}} \def\ve{{\bm{e}}} \def\vf{{\bm{f}}} \def\vg{{\bm{g}}} \def\vh{{\bm{h}}} \def\vi{{\bm{i}}} \def\vj{{\bm{j}}} \def\vk{{\bm{k}}} \def\vl{{\bm{l}}} \def\vm{{\bm{m}}} \def\vn{{\bm{n}}} \def\vo{{\bm{o}}} \def\vp{{\bm{p}}} \def\vq{{\bm{q}}} \def\vr{{\bm{r}}} \def\vs{{\bm{s}}} \def\vt{{\bm{t}}} \def\vu{{\bm{u}}} \def\vv{{\bm{v}}} \def\vw{{\bm{w}}} \def\vx{{\bm{x}}} \def\vy{{\bm{y}}} \def\vz{{\bm{z}}} \def\evalpha{{\alpha}} \def\evbeta{{\beta}} \def\evepsilon{{\epsilon}} \def\evlambda{{\lambda}} \def\evomega{{\omega}} \def\evmu{{\mu}} \def\evpsi{{\psi}} \def\evsigma{{\sigma}} \def\evtheta{{\theta}} \def\eva{{a}} \def\evb{{b}} \def\evc{{c}} \def\evd{{d}} \def\eve{{e}} \def\evf{{f}} \def\evg{{g}} \def\evh{{h}} \def\evi{{i}} \def\evj{{j}} \def\evk{{k}} \def\evl{{l}} \def\evm{{m}} \def\evn{{n}} \def\evo{{o}} \def\evp{{p}} \def\evq{{q}} \def\evr{{r}} \def\evs{{s}} \def\evt{{t}} \def\evu{{u}} \def\evv{{v}} \def\evw{{w}} \def\evx{{x}} \def\evy{{y}} \def\evz{{z}} \def\mA{{\bm{A}}} \def\mB{{\bm{B}}} \def\mC{{\bm{C}}} \def\mD{{\bm{D}}} \def\mE{{\bm{E}}} \def\mF{{\bm{F}}} \def\mG{{\bm{G}}} \def\mH{{\bm{H}}} \def\mI{{\bm{I}}} \def\mJ{{\bm{J}}} \def\mK{{\bm{K}}} \def\mL{{\bm{L}}} \def\mM{{\bm{M}}} \def\mN{{\bm{N}}} \def\mO{{\bm{O}}} \def\mP{{\bm{P}}} \def\mQ{{\bm{Q}}} \def\mR{{\bm{R}}} \def\mS{{\bm{S}}} \def\mT{{\bm{T}}} \def\mU{{\bm{U}}} \def\mV{{\bm{V}}} \def\mW{{\bm{W}}} \def\mX{{\bm{X}}} \def\mY{{\bm{Y}}} \def\mZ{{\bm{Z}}} \def\mBeta{{\bm{\beta}}} \def\mPhi{{\bm{\Phi}}} \def\mLambda{{\bm{\Lambda}}} \def\mSigma{{\bm{\Sigma}}} \newcommand{\tens}[1]{\bm{\mathsfit{#1}}} \def\tA{{\tens{A}}} \def\tB{{\tens{B}}} \def\tC{{\tens{C}}} \def\tD{{\tens{D}}} \def\tE{{\tens{E}}} \def\tF{{\tens{F}}} \def\tG{{\tens{G}}} \def\tH{{\tens{H}}} \def\tI{{\tens{I}}} \def\tJ{{\tens{J}}} \def\tK{{\tens{K}}} \def\tL{{\tens{L}}} \def\tM{{\tens{M}}} \def\tN{{\tens{N}}} \def\tO{{\tens{O}}} \def\tP{{\tens{P}}} \def\tQ{{\tens{Q}}} \def\tR{{\tens{R}}} \def\tS{{\tens{S}}} \def\tT{{\tens{T}}} \def\tU{{\tens{U}}} \def\tV{{\tens{V}}} \def\tW{{\tens{W}}} \def\tX{{\tens{X}}} \def\tY{{\tens{Y}}} \def\tZ{{\tens{Z}}} \def\gA{{\mathcal{A}}} \def\gB{{\mathcal{B}}} \def\gC{{\mathcal{C}}} \def\gD{{\mathcal{D}}} \def\gE{{\mathcal{E}}} \def\gF{{\mathcal{F}}} \def\gG{{\mathcal{G}}} \def\gH{{\mathcal{H}}} \def\gI{{\mathcal{I}}} \def\gJ{{\mathcal{J}}} \def\gK{{\mathcal{K}}} \def\gL{{\mathcal{L}}} \def\gM{{\mathcal{M}}} \def\gN{{\mathcal{N}}} \def\gO{{\mathcal{O}}} \def\gP{{\mathcal{P}}} \def\gQ{{\mathcal{Q}}} \def\gR{{\mathcal{R}}} \def\gS{{\mathcal{S}}} \def\gT{{\mathcal{T}}} \def\gU{{\mathcal{U}}} \def\gV{{\mathcal{V}}} \def\gW{{\mathcal{W}}} \def\gX{{\mathcal{X}}} \def\gY{{\mathcal{Y}}} \def\gZ{{\mathcal{Z}}} \def\sA{{\mathbb{A}}} \def\sB{{\mathbb{B}}} \def\sC{{\mathbb{C}}} \def\sD{{\mathbb{D}}} \def\sF{{\mathbb{F}}} \def\sG{{\mathbb{G}}} \def\sH{{\mathbb{H}}} \def\sI{{\mathbb{I}}} \def\sJ{{\mathbb{J}}} \def\sK{{\mathbb{K}}} \def\sL{{\mathbb{L}}} \def\sM{{\mathbb{M}}} \def\sN{{\mathbb{N}}} \def\sO{{\mathbb{O}}} \def\sP{{\mathbb{P}}} \def\sQ{{\mathbb{Q}}} \def\sR{{\mathbb{R}}} \def\sS{{\mathbb{S}}} \def\sT{{\mathbb{T}}} \def\sU{{\mathbb{U}}} \def\sV{{\mathbb{V}}} \def\sW{{\mathbb{W}}} \def\sX{{\mathbb{X}}} \def\sY{{\mathbb{Y}}} \def\sZ{{\mathbb{Z}}} \def\emLambda{{\Lambda}} \def\emA{{A}} \def\emB{{B}} \def\emC{{C}} \def\emD{{D}} \def\emE{{E}} \def\emF{{F}} \def\emG{{G}} \def\emH{{H}} \def\emI{{I}} \def\emJ{{J}} \def\emK{{K}} \def\emL{{L}} \def\emM{{M}} \def\emN{{N}} \def\emO{{O}} \def\emP{{P}} \def\emQ{{Q}} \def\emR{{R}} \def\emS{{S}} \def\emT{{T}} \def\emU{{U}} \def\emV{{V}} \def\emW{{W}} \def\emX{{X}} \def\emY{{Y}} \def\emZ{{Z}} \def\emSigma{{\Sigma}} \newcommand{\etens}[1]{\mathsfit{#1}} \def\etLambda{{\etens{\Lambda}}} \def\etA{{\etens{A}}} \def\etB{{\etens{B}}} \def\etC{{\etens{C}}} \def\etD{{\etens{D}}} \def\etE{{\etens{E}}} \def\etF{{\etens{F}}} \def\etG{{\etens{G}}} \def\etH{{\etens{H}}} \def\etI{{\etens{I}}} \def\etJ{{\etens{J}}} \def\etK{{\etens{K}}} \def\etL{{\etens{L}}} \def\etM{{\etens{M}}} \def\etN{{\etens{N}}} \def\etO{{\etens{O}}} \def\etP{{\etens{P}}} \def\etQ{{\etens{Q}}} \def\etR{{\etens{R}}} \def\etS{{\etens{S}}} \def\etT{{\etens{T}}} \def\etU{{\etens{U}}} \def\etV{{\etens{V}}} \def\etW{{\etens{W}}} \def\etX{{\etens{X}}} \def\etY{{\etens{Y}}} \def\etZ{{\etens{Z}}} \newcommand{\pdata}{p_{\rm{data}}} \newcommand{\ptrain}{\hat{p}_{\rm{data}}} \newcommand{\Ptrain}{\hat{P}_{\rm{data}}} \newcommand{\pmodel}{p_{\rm{model}}} \newcommand{\Pmodel}{P_{\rm{model}}} \newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}} \newcommand{\pencode}{p_{\rm{encoder}}} \newcommand{\pdecode}{p_{\rm{decoder}}} \newcommand{\precons}{p_{\rm{reconstruct}}} \newcommand{\E}{\mathbb{E}} \newcommand{\Ls}{\mathcal{L}} \newcommand{\R}{\mathbb{R}} \newcommand{\emp}{\tilde{p}} \newcommand{\lr}{\alpha} \newcommand{\reg}{\lambda} \newcommand{\rect}{\mathrm{rectifier}} \newcommand{\softmax}{\mathrm{softmax}} \newcommand{\sigmoid}{\sigma} \newcommand{\softplus}{\zeta} \newcommand{\KL}{D_{\mathrm{KL}}} \newcommand{\Var}{\mathrm{Var}} \newcommand{\standarderror}{\mathrm{SE}} \newcommand{\Cov}{\mathrm{Cov}} \newcommand{\normlzero}{L^0} \newcommand{\normlone}{L^1} \newcommand{\normltwo}{L^2} \newcommand{\normlp}{L^p} \newcommand{\normmax}{L^\infty} \newcommand{\parents}{Pa} % See usage in notation.tex. Chosen to match Daphne's book. \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\argmin}{arg\,min} \DeclareMathOperator{\sign}{sign} \DeclareMathOperator{\Tr}{Tr} \let\ab\allowbreak \]

Can a Deep Learning Model be a Sure Bet for Tabular Prediction?

2024-12-11
deep-learning (11) machine-learning (7) model-architecture (6) tabular-data (16)

Summary

Deep learning on tabular data faces three challenges:

  1. rotational variance -- the order of columns should not matter.
  2. large data demand -- DNNs have a larger hypothesis space, and require more training data compared to shallow algorithms.
  3. over-smooth solution -- DNNs tend to produce overly smooth solutions. i.e. when faced with irregular decision boundaries, the learning algorithms suffer (as pointed out by Grinsztajn et. al

Approach

Overview

Overview of aproach
3 main components: Semi-Permeable Attention, Interpolation-based data-augmentation(Feat-mix, HID-mix), and Attentive FFNs.

Semi-Permiable Attention

The authors propose to add a mask to the attention score matrix such that the less important features do not influence more important features, but the more important features can influence the less important features.

$$ z' = \text{softmax}(\cfrac{(z W_q) (z W_k)^T \underline{\oplus M}}{\sqrt{d}}) (z W_v) $$

where \(\oplus\) denotes element-wise addition and \(\oplus M\) is the proposed change to vanilla MHSA. \(M \in \mathbb{R}^{f \times f}\) is a fixed mask, where

$$ M[i,j] = \begin{cases} -\infty & I(\bf{f}_i) \gt I(\bf{f}_j) \\ 0 & I(\bf{f}_i) \leq I(\bf{f}_j) \end{cases} $$

where \(I(\bf{f}_i)\) is the importance of the \(i\)-th feature. In other words, this terms means that less informative features may use information from more informative features (case 0), but the opposite is blocked.

Interpolation-based data-augmentation

Illustration of HID and FEAT mix

Illustration of HID and FEAT mix
HID-mix operates on the embedding level, while FEAT-mix operates on the feature level.

HID-mix

Given two samples \(z_1^{(0)}, z_2^{(0)} \in \mathbb{R}^{f \times d}\) and their labels \(y_1, y_2\), a new sample can be formed by mixing the embedding dimensions of \(z_1^{(0)}\) and \(z_2^{(0)}\):

$$ \begin{gather} z_m^{(0)} = S_H \odot z_1^{(0)} + (\mathbb{1}-S_H) \odot z_2^{(0)}\\ y_m = \lambda_H y_1 + (1-\lambda_H) y_2 \end{gather} $$

where \(S_H \in \{0,1\}^{f \times d}\) is a stack of binary masks \(s_h\), \(S_H = [s_h, s_h,..., s_h ]^T\), where \(\sum s_h = \floor{\lambda_H \cdot d}\) for each row vector \(s_h\), and \(\mathbb{1}\) is a \(f \times d\) matrix of \(1\)s. In other words, \(S_H\) masks out \(\floor{\lambda_H \cdot d}\) entries of each row.

Intuition
Since each embedding element is projected from a scalar feature value, we can consider each embedding dimension as a distinct "profile" version of input data. Thus, Hid-Mix regularizes the classifier to behave like a bagging predictor.

FEAT-mix

Instead of mixing the embedding, FEAT-mix mixes the features given two samples \(x_1, x_2 \in \mathbb{R}^{f}\) and their labels \(y_1, y_2\), a new sample can be formed by mixing the features of \(x_1\) and \(x_2\):

$$ \begin{gather} x_m = s_F \odot x_1 + (\mathbb{1}_F-s_F) \odot x_2 \\ y_m = \Lambda y_1 + (1-\Lambda) y_2 \end{gather} $$

where \(s_F \in \{0,1\}^{f}\) is a binary mask vector where \(\sum s_F = \floor{\lambda_F \cdot f}\), \(\mathbb{1}_F\) is a \(f\) dimensional vector of \(1\)s, and \(\Lambda\) is a scalar.

If we set \(\Lambda = \lambda_F\), this equivalent to cutmix1.

To differentiate, the authors introduce the usage of feature importance in the label weighting as follows:

$$ \Lambda = \cfrac{\sum_{s^{(i)_F}}I(\bf{f}_i)}{\sum_{i=1}^{\bf{f}}I(\bf{f}_i)} $$

where \(s_F^{(i)}\) is the \(i\)-th element of \(s_F\), and \(I(\bf{f}_i)\) is the importance of the \(i\)-th feature. Similarly to the SPA module, the mutual information is what the authors appear to use.

Intuition
Since each feature may have different contribution to the label, weighing the two labels by how much "usefulness" each sample contributed allows uninformative features to be filtered.

Attentive FFNs

Finally, the authors propose to replace the 2-layer FFN module at the end of the transformer block with a 2-layer Gated Linear Unit (GLU) module instead, like the following:

$$ z' = \text{tanh}(\text{Linear}_1(z)) \odot \text{Linear}_2(z) $$

where \(\odot\) is element-wise multiplication and the first term acts as the gate.

In addition, the authors replace the linear embedding layer with similar GLU setup as well, which used to be \(z_i = \bf{f}_i W_{i,1} + b_{i,1}\), into \(z_i = \text{tanh}(\bf{f}_i W_{i,1} + b_{i,1}) \odot \bf{f}_i W_{i,2} + b_{i,2}\).

However, why they do this is not very clearly motivated.

Findings

Resources

  1. Yun, Sangdoo, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, and Youngjoon Yoo. "Cutmix: Regularization strategy to train strong classifiers with localizable features." In Proceedings of the IEEE/CVF international conference on computer vision, pp. 6023-6032. 2019.

← Back to all readings