diktya.distributions¶
-
class
TruncNormal
(a, b, mean, std)[source]¶ Bases:
diktya.distributions.Distribution
Normal distribution truncated between [a;b].
-
length
¶
-
-
class
Uniform
(low, high)[source]¶ Bases:
diktya.distributions.Distribution
-
length
¶
-
-
class
DistributionCollection
(distributions)[source]¶ Bases:
diktya.distributions.Distribution
,diktya.distributions.Normalization
A collection of multiple distributions:
Parameters: distributions – A list of tuples where the tuples have the form * (name, distribution) * (name, distribution, nb_elems) * (name, distribution, nb_elems, normalization) distribution must be a subclass of Distribution
. The nb_elems specify how many elements are drawn from the distribution, if omitted it will be set to 1. The normalization specifies how it is noramlised. It can be omitted and will then be set to dist.default_normalization().Example:
dist = DistributionCollection([ ("x_rotation", Normal(0, 1)), ("y_rotation", Uniform(-np.pi, np.pi), 1, SinCosAngleNorm()), ("center", Normal(0, 2), 2), ]) # Sample 10 vectors from the collection arr = dist.sample(10) # The array is a structured numpy array. The keys are the one from # constructure distributions dictionary. arr["x_rotation"][0] # Normalizes the arr samples, according to the normalisation normed = dist.normalize(arr) # the normalization/denormalization should be invariant assert np.allclose(dist.denormalize(normed), arr)