Sicheng Liu bio photo

Email

LinkedIn

Github

Week 12 - Multi-Task Learning

Multi-Task Learning

Multi-task Learning is an approach to learn multiple related problems at the same time, by exploiting the relatedness between problems.

Problem Setup

Given $m$ learning tasks ${\tau_i}^m_{i=1}$, multi-task learning aims to help improve the learning for a model for $\tau_i$ by using gthe knowledge contained in all or some of the $m$ tasks.

For task $\tau_i$, we have training set $\mathcal{D}i={x_j^i, y_j^i}{j=1}^{n_i}$.

Our taks is to learn hypothesis for ${\tau_i}^m_{i=1}$.

What to share?

  • Parameters
  • Features

MLT Model

Consider we have:

Linear hypothesis function: $h(x)=w^Tx$.

Different tasks: ${\tau_i}^m_{i=1}$

The hypothesis for $i^{th}$ task: $w^i$

The empirical risk can be: \(\underset{W=[w^1,...W^m]}{\operatorname{\min}}\frac{1}{m} \sum^m_{i=1} \frac{1}{n_i} \sum^{n_i}_{j=1}(l(x_j^i, y_j^i, w^i))\) but there is no relativeness been considered.

Parameter-based MTL Models

If we assume the multiple tasks are related by their parameter that \(w^i = w_0+\Delta w^i\) where $w_0$ is the shared features and $\Delta w^i$ is the task specific features.

Therefore, we can have the new objective function. \(\underset{w_0,\Delta W=[\Delta w^1,...\Delta W^m]}{\operatorname{\min}}\frac{1}{m} \sum^m_{i=1} \frac{1}{n_i} \sum^{n_i}_{j=1}(l(x_j^i, y_j^i, w_0+\Delta w^i))\)

In order to enforcing the model to have stronger relatedness, we want to enlarge $w_0$ and reduce the $\Delta w^i$. We can add a regularizer in the objective function. \(\underset{w_0,\Delta W=[\Delta w^1,...\Delta W^m]}{\operatorname{\min}}\frac{1}{m} \sum^m_{i=1} \frac{1}{n_i} \sum^{n_i}_{j=1}(l(x_j^i, y_j^i, w_0+\Delta w^i))+\lambda ||\Delta W||^2_F\)

Feature-based MTL Models

We want to project features to a new feature space so that those features are more related. \(\mathcal{D}_i=\{P^Tx_j^i, y_j^i\}_{j=1}^{n_i}\) where $P$ is a projection matrix.

Then the new objective function can be \(\underset{W,P}{\operatorname{\min}}\frac{1}{m} \sum^m_{i=1} \frac{1}{n_i} \sum^{n_i}_{j=1}(l(P^T x_j^i, y_j^i,w^i))+\lambda rank(W) \\ s.t. PP^T=I\)