Q6 – Mixture Density Network

Contributed by Chin-Wei Huang.

 

Consider a univariate-output Mixture Density Network, with n specifying the number of latent variables: p(y|x) = \sum_{i=1}^n p(c=i|x)p(y|x;\theta_i) where \theta=\{\theta_i\}_{i=1:n} are the set of class conditional parameters.

In the following questions, assume the “prior” probability distributions p(c=i|x) form a multinoulli distribution, parameterized by a softmax function (a one hidden-layer network) mapping from the input, i.e. p(c=i|x;\zeta) = s_i(x;\zeta) = \frac{\exp{(\zeta_i^T{x})} }{{\sum_{i'}{\exp{(\zeta_{i'}^T{x})}}}}.

  1. Suppose Y is continuous, and let p(y;\theta^{(i)}(x)) = \mathcal{N}(y;x^T\beta_i,\sigma^2_i). To do prediction, use the expected conditional as a point estimate of the output. Derive \mathbf{E}[y|x] and \mathbf{Var}[y|x].}
  2. Holding the class conditional parameters (\theta_i) fixed, derive a stochastic (i.e. for one data point) gradient ascent expression for the softmax weight parameters \zeta_i using maximum likelihood principle. (Hint: M-step of the EM algorithm).
  3. Now devise a prediction mapping function h:\mathcal{X}\rightarrow\mathcal{Y} defined as h(x)=\sum_i^n \sigma_i(x)\mu_i(x) , where generally \sigma(\cdot) is a MLP and \mu(\cdot) is a prediction function depending on the input x. Now let \sigma(x) be a softmax regression of n classes and \mu be a set of n linear mapping functions, i.e. h(y|x) = \sum_i^n \{s_i(x;\zeta) (\sum_j^p x_j\beta_{ij}) \}. If we want to minimise the quadratic loss l(y_n,x_n) = (y_n - h(y|x_n))^2 for each data point $n$, what is the gradient descent update expression for parameter \zeta_i if \beta is fixed?
  4. Comment on the difference between the previous two training objectives.
Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s