IsoDD sigma_line is a vector instead of a scalar
Created by: Lookatator
Hello,
we may have found a quick bug in the implementation of IsoDD.
It is due to some tricky thing in the formula: the sigma_line should be a scalar, and not a vector of the size of the arrays.
On the figure below, both sets of points have
- sigma_iso=0.01
- sigma_line=0.1
And:
- the orange one is how it's meant to look like
- and the blue one is what (we think) our implementation currently does.
Here is the code to reproduce this figure:
Plotting Script
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from chex import ArrayTree
from matplotlib import pyplot as plt
from typing_extensions import TypeAlias
Genotype: TypeAlias = ArrayTree
RNGKey: TypeAlias = jax.random.KeyArray
def isoline_variation(
x1: Genotype,
x2: Genotype,
random_key: jax.random.KeyArray,
iso_sigma: float,
line_sigma: float,
minval: Optional[float] = None,
maxval: Optional[float] = None,
) -> Tuple[Genotype, jax.random.KeyArray]:
def _variation_fn(
x1: jnp.ndarray,
x2: jnp.ndarray,
random_key: jax.random.KeyArray
) -> jnp.ndarray:
subkey1, subkey2 = jax.random.split(random_key)
iso_noise = jax.random.normal(subkey1,
shape=x1.shape) * iso_sigma
line_noise = jax.random.normal(subkey2,
shape=x2.shape) * line_sigma
x = (x1 + iso_noise) + line_noise * (x2 - x1)
# Back in bounds if necessary (floating point issues)
if (minval is not None) or (maxval is not None):
x = jnp.clip(x,
minval,
maxval)
return x
# create a tree with random keys
nb_leaves = len(jax.tree_leaves(x1))
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey,
num=nb_leaves)
keys_tree = jax.tree_unflatten(jax.tree_structure(x1),
subkeys)
# apply isolinedd to each branch of the tree
x = jax.tree_map(lambda y1, y2, key: _variation_fn(y1, y2, key),
x1,
x2,
keys_tree)
return x, random_key
def isoline_variation_fixed(
x1: Genotype,
x2: Genotype,
random_key: jax.random.KeyArray,
iso_sigma: float,
line_sigma: float,
minval: Optional[float] = None,
maxval: Optional[float] = None,
) -> Tuple[Genotype, jax.random.KeyArray]:
# Computing line_noise
random_key, key_line_noise = jax.random.split(random_key)
batch_size = jax.tree_leaves(x1)[0].shape[0]
line_noise = jax.random.normal(key_line_noise,
shape=(batch_size,)) * line_sigma
def _variation_fn(
x1: jnp.ndarray,
x2: jnp.ndarray,
key_iso_noise: RNGKey
) -> jnp.ndarray:
iso_noise = jax.random.normal(key_iso_noise,
shape=x1.shape) * iso_sigma
x = (x1 + iso_noise) + jax.vmap(jnp.multiply)((x2 - x1),
line_noise)
# Back in bounds if necessary (floating point issues)
if (minval is not None) or (maxval is not None):
x = jnp.clip(x,
minval,
maxval)
return x
# create a tree with random keys
nb_leaves = len(jax.tree_leaves(x1))
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey,
num=nb_leaves)
keys_tree = jax.tree_unflatten(jax.tree_structure(x1),
subkeys)
# apply isolinedd to each branch of the tree
x = jax.tree_map(lambda y1, y2, key: _variation_fn(y1, y2, key),
x1,
x2,
keys_tree)
return x, random_key
def main():
genotype_1 = {"Params_0": jnp.zeros(shape=(1000, 2)),
"Params_1": jnp.zeros(shape=(1000, 5, 2))}
genotype_2 = {"Params_0": jnp.zeros(shape=(1000, 2)) + 2,
"Params_1": jnp.zeros(shape=(1000, 5, 2))}
random_key = jax.random.PRNGKey(0)
random_key, _key = jax.random.split(random_key)
iso_sigma = 0.01
line_sigma = 0.1
new_genotype_old = jax.jit(isoline_variation)(genotype_1,
genotype_2,
random_key,
iso_sigma,
line_sigma)
new_genotype_fixed = jax.jit(isoline_variation_fixed)(genotype_1,
genotype_2,
random_key,
iso_sigma,
line_sigma)
fig, ax = plt.subplots()
ax.scatter(new_genotype_old[0]["Params_0"][:, 0],
new_genotype_old[0]["Params_0"][:, 1],
label="iso-dd bug")
ax.scatter(new_genotype_fixed[0]["Params_0"][:, 0],
new_genotype_fixed[0]["Params_0"][:, 1],
label="iso-dd")
ax.scatter(0, 0, label="parent_1")
ax.scatter(2, 2, label="parent_2")
ax.legend()
plt.show()
if __name__ == '__main__':
main()