diktya.gan

class GAN(generator: keras.engine.training.Model, discriminator: keras.engine.training.Model)[source]

Bases: diktya.models.AbstractModel

Generative Adversarial Networks (GAN) are a unsupervised learning framework. It consists of a generator and a discriminator network. The generator recieves a noise vector as input and produces some fake data. The discriminator is trained to distinguish between fake data from the generator and real data. The generator is optimized to fool the discriminator. Please refere to Goodwellow et. al for a detail introduction into GANs.

Parameters:
  • generator (Model) – model of the generator. Must have one output and one input must be named z.
  • discriminator (Model) – model of the discriminator. Must have exaclty one input named data. For every sample, the output must be a scalar between 0 and 1.
z = Input(shape=(20,), name='z')
data = Input(shape=(1, 32, 32), name='real')

n = 64
fake = sequential([
    Dense(2*16*n, activation='relu'),
    Reshape(2*n, 4, 4),
])(z)

realness = sequential([
    Convolution2D(n, 3, 3, border='same'),
    LeakyRelu(0.3),
    Flatten(),
    Dense(1),
])

generator = Model(z, fake)
generator.compile(Adam(lr=0.0002, beta_1=0.5), 'binary_crossentropy')

discriminator = Model(data, realness)
discriminator.compile(Adam(lr=0.0002, beta_1=0.5), 'binary_crossentropy')
gan = GAN(generator, discriminator)

gan.fit_generator(...)
input_names
uses_learning_phase
train_on_batch(inputs)[source]

Runs a single weight update on a single batch of data. Updates both generator and discriminator.

Parameters:
  • inputs (optional) –

    Inputs for both the discriminator and the geneator. It can either be a numpy array, a list or dict.

    • numpy array: real
    • list: [real], [real, z]
    • dict: {'real': real}, {'real': real, 'z': z},
      {'real': real, 'z': z, 'additional_input', x}
  • generator_inputs (optional dict) – This inputs will only be passed to the generator.
  • discriminator_inputs (optional dict) – This inputs will only be passed to the discriminator.
Returns:

A list of metrics. You can get the names of the metrics with metrics_names().

fit_generator(generator, nb_batches_per_epoch, nb_epoch, batch_size=128, verbose=1, train_on_batch='train_on_batch', callbacks=[])[source]

Fits the generator and discriminator on data generated by a Python generator. The generator is not run in parallel as in keras.

Parameters:
  • generator – the output of the generator must satisfy the train_on_batch method.
  • nb_batches_per_epoch (int) – run that many batches per epoch
  • nb_epoch (int) – run that many epochs
  • batch_size (int) – size of one batch
  • verbose – verbosity mode
  • callbacks – list of callbacks.
generate(inputs=None, nb_samples=None)[source]

Use the generator to generate data.

Parameters:
  • inputs – Dictionary of name to input arrays to the generator. Can include the random noise z or some conditional varialbes.
  • nb_samples – Specifies how many samples will be generated, if z is not in the inputs dictionary.
Returns:

A numpy array with the generated data.

random_z(batch_size=32)[source]

Samples z from uniform distribution between -1 and 1. The returned array is of shape (batch_size, ) + self.z_shape[1:]

random_z_point()[source]

Returns one random point in the z space.

interpolate(x, y, nb_steps=100)[source]

Interpolates linear between two points in the z-space.

Parameters:
  • x – point in the z-space
  • y – point in the z-space
  • nb_steps – interpolate that many points
Returns:

The generated data from the interpolated points. The data corresponding to x and y are on the first and last position of the returned array.

neighborhood(z_point=None, std=0.25, n=128)[source]

samples the neighborhood of a z_point by adding gaussian noise to it. You can control the standard derivation of the noise with std.