Mixture Density Networks

So far; RNNs that Model Categorical Data

::: columns

:::: column

::::column { width=95% } { width=80% } ::::

:::


Expressive Data is Often Continuous

::: columns

:::: column { width=60% } { width=60% } ::::

::::column { width=60% }

::::

:::


So are Bio-Signals

::: columns

:::: column { width=60% } ::::

::::column { width=60% } { width=60% } ::::

:::

Image Credit: Wikimedia


Categorical vs. Continuous Models

::: columns

:::: column ::::

::::column ::::

:::


Normal (Gaussian) Distribution

::: columns

:::: column

::::

::::column { width=100% }

::::

:::


Problem: Normal distribution might not fit data

::: columns

:::: column What if the data is complicated?

::::column { width=100% } ::::

:::


Mixture of Normals

::: columns

:::: column Three groups of parameters:

::::

::::column { width=100% } ::::

:::


This solves our problem:

::: columns

:::: column Returning to our modelling problem, let’s plot the PDF of a evenly-weighted mixture of the two sample normal models.

We set:

In this case, I knew the right parameters, but normally you would have to estimate, or learn, these somehow… ::::

::::column { width=100% } ::::

:::


Mixture Density Networks


Mixture Density Networks


Simple Example in Keras

{ width=30% }


Feedforward MSE Network

{ width=30% } { width=40% }

Here’s a simple two-hidden-layer network (286 parameters), trained to produce the above result.

~~~~~{.python .numberLines} model = Sequential() model.add(Dense(15, batch_input_shape=(None, 1), activation=’tanh’)) model.add(Dense(15, activation=’tanh’)) model.add(Dense(1, activation=’linear’)) model.compile(loss=’mse’, optimizer=’rmsprop’) model.fit(x=x_data, y=y_data, batch_size=128, epochs=200, validation_split=0.15)



---


## MDN Architecture:

![](/assets/mdn/mdn-network.png)

- Loss function for MDN is negative log of likelihood function $\mathcal{L}$.
- $\mathcal{L}$ measures likelihood of $t$ being drawn from a mixture parametrised by $\mu$, $\sigma$, and $\pi$ which are generated by the network inputs $x$:

$$\mathcal{L} = \sum_{i=1}^K\pi_i(\mathbf{x})\mathcal{N}\bigl(\mu_i(\mathbf{x}), \sigma_i^2(\mathbf{x}); \mathbf{t} \bigr)$$


---


### Feedforward MDN Solution

![](/assets/mdn/arcsine-feedforward-mdn-predictions.png){ width=30% }
![](/assets/mdn/arcsine-feedforward-mdn-loss.png){ width=40% }

And, here's a simple two-hidden-layer MDN (510 parameters), that achieves the above result! Much better!

~~~~~{.python .numberLines}
N_MIXES = 5

model = Sequential()
model.add(Dense(15, batch_input_shape=(None, 1), activation='relu'))
model.add(Dense(15, activation='relu'))
model.add(mdn.MDN(1, N_MIXES)) # here's the MDN layer!
model.compile(loss=mdn.get_mixture_loss_func(1,N_MIXES), optimizer='rmsprop')
model.summary()

Getting inside the MDN layer

Here’s the same network wihtout using the MDN layer abstraction (this is with Keras’ functional API):

~~~~~{.python .numberLines} def elu_plus_one_plus_epsilon(x): “"”ELU activation with a very small addition to help prevent NaN in loss.””” return (K.elu(x) + 1 + 1e-8)

N_HIDDEN = 15 N_MIXES = 5

inputs = Input(shape=(1,), name=’inputs’) hidden1 = Dense(N_HIDDEN, activation=’relu’, name=’hidden1’)(inputs) hidden2 = Dense(N_HIDDEN, activation=’relu’, name=’hidden2’)(hidden1)

mdn_mus = Dense(N_MIXES, name=’mdn_mus’)(hidden2) mdn_sigmas = Dense(N_MIXES, activation=elu_plus_one_plus_epsilon, name=’mdn_sigmas’)(hidden2) mdn_pi = Dense(N_MIXES, name=’mdn_pi’)(hidden2)

mdn_out = Concatenate(name=’mdn_outputs’)([mdn_mus, mdn_sigmas, mdn_pi])

model = Model(inputs=inputs, outputs=mdn_out) model.summary()



---


## Loss Function: The Tricky Bit.

Loss function for the MDN should be the negative log likelihood:

~~~~~{.python .numberLines}
def mdn_loss(y_true, y_pred):
    # Split the inputs into paramaters
    out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[N_MIXES, N_MIXES, N_MIXES],
                                         axis=-1, name='mdn_coef_split')
    mus = tf.split(out_mu, num_or_size_splits=N_MIXES, axis=1)
    sigs = tf.split(out_sigma, num_or_size_splits=N_MIXES, axis=1)
    # Construct the mixture models
    cat = tfd.Categorical(logits=out_pi)
    coll = [tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale
            in zip(mus, sigs)]
    mixture = tfd.Mixture(cat=cat, components=coll)
    # Calculate the loss function
    loss = mixture.log_prob(y_true)
    loss = tf.negative(loss)
    loss = tf.reduce_mean(loss)
    return loss

model.compile(loss=mdn_loss, optimizer='rmsprop')

Let’s go through bit by bit…


Loss Function: Part 1:

First we have to extract the mixture paramaters.

~~~~~{.python .numberLines}

Split the inputs into paramaters

out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[N_MIXES, N_MIXES, N_MIXES], axis=-1, name=’mdn_coef_split’) mus = tf.split(out_mu, num_or_size_splits=N_MIXES, axis=1) sigs = tf.split(out_sigma, num_or_size_splits=N_MIXES, axis=1)


- Split up the parameters $\boldsymbol\mu$, $\boldsymbol\sigma$, and $\boldsymbol\pi$, remember that there are N_MIXES $= K$ of each of these.
- $\boldsymbol\mu$ and $\boldsymbol\sigma$ have to be split _again_ so that we can iterate over them (you can't iterate over an axis of a tensor...)


---


## Loss Function: Part 2:

Now we have to construct the mixture model's PDF. 

~~~~~{.python .numberLines}
# Construct the mixture models
cat = tfd.Categorical(logits=out_pi) 
coll = [tfd.Normal(loc=loc, scale=scale) for loc, scale
        in zip(mus, sigs)]
mixture = tfd.Mixture(cat=cat, components=coll)

Loss Function: Part 3:

Finally, we calculate the loss:

~~~~~{.python .numberLines} loss = mixture.log_prob(y_true) loss = tf.negative(loss) loss = tf.reduce_mean(loss) ~~~~~


Some more details….

{ width=40% }


MDN-RNNs

MDNs can be handy at the end of an RNN! Imagine a robot calculating moves forward through space, it might have to choose from a number of valid positions, each of which could be modelled by a 2D Normal model.


MDN-RNN Architecture

Can be as simple as putting an MDN layer after recurrent layers!


Use Cases: Handwriting Generation

{ width=40% } { width=40% }


Use Cases: SketchRNN

{ width=40% } { width=40% }


Use Cases: RoboJam

::: columns

:::: column

::::

:::: column

:::


Use Cases: World Models

::: columns

:::: column

:::: column { width=80% } { width=80% } ::::

:::


References

  1. Christopher M. Bishop. 1994. Mixture Density Networks. Technical Report NCRG/94/004. Neural Computing Research Group, Aston University.
  2. Axel Brando. 2017. Mixture Density Networks (MDN) for distribution and uncertainty estimation. Master’s thesis. Universitat Politècnica de Catalunya.
  3. A. Graves. 2013. Generating Sequences With Recurrent Neural Networks. ArXiv e-prints (Aug. 2013). ArXiv:1308.0850
  4. David Ha and Douglas Eck. 2017. A Neural Representation of Sketch Drawings. ArXiv e-prints (April 2017). ArXiv:1704.03477
  5. Charles P. Martin and Jim Torresen. 2018. RoboJam: A Musical Mixture Density Network for Collaborative Touchscreen Interaction. In Evolutionary and Biologically Inspired Music, Sound, Art and Design: EvoMUSART ’18, A. Liapis et al. (Ed.). Lecture Notes in Computer Science, Vol. 10783. Springer International Publishing. DOI:10.1007/9778-3-319-77583-8_11
  6. D. Ha and J. Schmidhuber. 2018. Recurrent World Models Facilitate Policy Evolution. ArXiv e-prints (Sept. 2018). ArXiv:1809.01999