A course in ML/AI for creative expression
Charles Martin - The Australian National University
Image Credit: Wikimedia
The “Standard” probability distribution
Has two parameters:
Probability Density Function:
\[\mathcal{N}(x \mid \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2} } e^{ -\frac{(x-\mu)^2}{2\sigma^2} }\]What if the data is complicated?
It’s easy to “fit” a normal model to any data.
Just calculate \(\mu\) and \(\sigma\)
(might not fit the data well)
Three groups of parameters:
Probability Density Function:
\[p(x) = \sum_{i=1}^K \pi_i\mathcal{N}(x \mid \mu, \sigma^2)\]Returning to our modelling problem, let’s plot the PDF of a evenly-weighted mixture of the two sample normal models.
We set:
(bold used to indicate the vector of parameters for each component)
In this case, I knew the right parameters, but normally you would have to estimate, or learn, these somehow…
Difficult data is not hard to find! Think about modelling an inverse sine (arcsine) function.
Loss function for MDN is negative log of likelihood function \(\mathcal{L}\).
\(\mathcal{L}\) measures likelihood of \(t\) being drawn from a mixture parametrised by \(\mu\), \(\sigma\), and \(\pi\) which are generated by the network inputs \(x\): \[\mathcal{L} = \sum_{i=1}^K\pi_i(\mathbf{x})\mathcal{N}\bigl(\mu_i(\mathbf{x}), \sigma_i^2(\mathbf{x}); \mathbf{t} \bigr)\]Two-hidden-layer MDN (510 parameters)---code snippet:
N_MIXES = 5
model = Sequential()
model.add(Dense(15, batch_input_shape=(None, 1), activation='relu'))
model.add(Dense(15, activation='relu'))
model.add(mdn.MDN(1, N_MIXES)) # here's the MDN layer!
model.compile(loss=mdn.get_mixture_loss_func(1,N_MIXES), optimizer='rmsprop')
model.summary()
Two-hidden-layer MDN (510 parameters)---works much better!
def elu_plus_one_plus_epsilon(x):
return (K.elu(x) + 1 + 1e-8)
N_HIDDEN = 15; N_MIXES = 5
inputs = Input(shape=(1,), name='inputs')
hidden1 = Dense(N_HIDDEN, activation='relu', name='hidden1')(inputs)
hidden2 = Dense(N_HIDDEN, activation='relu', name='hidden2')(hidden1)
mdn_mus = Dense(N_MIXES, name='mdn_mus')(hidden2)
mdn_sigmas = Dense(N_MIXES, activation=elu_plus_one_plus_epsilon, name='mdn_sigmas')(hidden2)
mdn_pi = Dense(N_MIXES, name='mdn_pi')(hidden2)
mdn_out = Concatenate(name='mdn_outputs')([mdn_mus, mdn_sigmas, mdn_pi])
model = Model(inputs=inputs, outputs=mdn_out)
Loss function for the MDN should be the negative log likelihood:
def mdn_loss(y_true, y_pred):
# Split the inputs into paramaters
out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[N_MIXES, N_MIXES, N_MIXES],
axis=-1, name='mdn_coef_split')
mus = tf.split(out_mu, num_or_size_splits=N_MIXES, axis=1)
sigs = tf.split(out_sigma, num_or_size_splits=N_MIXES, axis=1)
# Construct the mixture models
cat = tfd.Categorical(logits=out_pi)
coll = [tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale
in zip(mus, sigs)]
mixture = tfd.Mixture(cat=cat, components=coll)
# Calculate the loss function
loss = mixture.log_prob(y_true)
loss = tf.negative(loss)
loss = tf.reduce_mean(loss)
return loss
model.compile(loss=mdn_loss, optimizer='rmsprop')
Let’s go through bit by bit…
First we have to extract the mixture paramaters.
# Split the inputs into paramaters
out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[N_MIXES, N_MIXES, N_MIXES],
axis=-1, name='mdn_coef_split')
mus = tf.split(out_mu, num_or_size_splits=N_MIXES, axis=1)
sigs = tf.split(out_sigma, num_or_size_splits=N_MIXES, axis=1)
Now we have to construct the mixture model’s PDF.
# Construct the mixture models
cat = tfd.Categorical(logits=out_pi)
coll = [tfd.Normal(loc=loc, scale=scale) for loc, scale
in zip(mus, sigs)]
mixture = tfd.Mixture(cat=cat, components=coll)
Mixture
abstraction provided in tensorflow-probability.distributions
.tfd.Normal
.Finally, we calculate the loss:
loss = mixture.log_prob(y_true)
loss = tf.negative(loss)
loss = tf.reduce_mean(loss)
mixture.log_prob(y_true)
means “the log-likelihood of sampling y_true
from the distribution called mixture
.”MDNs can be handy at the end of an RNN! Imagine a robot calculating moves forward through space, it might have to choose from a number of valid positions, each of which could be modelled by a 2D Normal model.
Can be as simple as putting an MDN layer after recurrent layers!