IsoDD sigma_line is a vector instead of a scalar

Created by: Lookatator

Hello,

we may have found a quick bug in the implementation of IsoDD.

It is due to some tricky thing in the formula: the sigma_line should be a scalar, and not a vector of the size of the arrays.

On the figure below, both sets of points have

  • sigma_iso=0.01
  • sigma_line=0.1

And:

  • the orange one is how it's meant to look like
  • and the blue one is what (we think) our implementation currently does.

Iso-DD-bug-illustration

Here is the code to reproduce this figure:

Plotting Script


from typing import Optional, Tuple

import jax
import jax.numpy as jnp
from chex import ArrayTree
from matplotlib import pyplot as plt
from typing_extensions import TypeAlias

Genotype: TypeAlias = ArrayTree
RNGKey: TypeAlias = jax.random.KeyArray


def isoline_variation(
  x1: Genotype,
  x2: Genotype,
  random_key: jax.random.KeyArray,
  iso_sigma: float,
  line_sigma: float,
  minval: Optional[float] = None,
  maxval: Optional[float] = None,
) -> Tuple[Genotype, jax.random.KeyArray]:

    def _variation_fn(
      x1: jnp.ndarray,
      x2: jnp.ndarray,
      random_key: jax.random.KeyArray
    ) -> jnp.ndarray:
        subkey1, subkey2 = jax.random.split(random_key)
        iso_noise = jax.random.normal(subkey1,
                                      shape=x1.shape) * iso_sigma
        line_noise = jax.random.normal(subkey2,
                                       shape=x2.shape) * line_sigma
        x = (x1 + iso_noise) + line_noise * (x2 - x1)

        # Back in bounds if necessary (floating point issues)
        if (minval is not None) or (maxval is not None):
            x = jnp.clip(x,
                         minval,
                         maxval)
        return x

    # create a tree with random keys
    nb_leaves = len(jax.tree_leaves(x1))
    random_key, subkey = jax.random.split(random_key)
    subkeys = jax.random.split(subkey,
                               num=nb_leaves)
    keys_tree = jax.tree_unflatten(jax.tree_structure(x1),
                                   subkeys)

    # apply isolinedd to each branch of the tree
    x = jax.tree_map(lambda y1, y2, key: _variation_fn(y1, y2, key),
                     x1,
                     x2,
                     keys_tree)

    return x, random_key


def isoline_variation_fixed(
  x1: Genotype,
  x2: Genotype,
  random_key: jax.random.KeyArray,
  iso_sigma: float,
  line_sigma: float,
  minval: Optional[float] = None,
  maxval: Optional[float] = None,
) -> Tuple[Genotype, jax.random.KeyArray]:

    # Computing line_noise
    random_key, key_line_noise = jax.random.split(random_key)
    batch_size = jax.tree_leaves(x1)[0].shape[0]
    line_noise = jax.random.normal(key_line_noise,
                                   shape=(batch_size,)) * line_sigma

    def _variation_fn(
      x1: jnp.ndarray,
      x2: jnp.ndarray,
      key_iso_noise: RNGKey
    ) -> jnp.ndarray:

        iso_noise = jax.random.normal(key_iso_noise,
                                      shape=x1.shape) * iso_sigma
        x = (x1 + iso_noise) + jax.vmap(jnp.multiply)((x2 - x1),
                                                      line_noise)

        # Back in bounds if necessary (floating point issues)
        if (minval is not None) or (maxval is not None):
            x = jnp.clip(x,
                         minval,
                         maxval)
        return x

    # create a tree with random keys
    nb_leaves = len(jax.tree_leaves(x1))
    random_key, subkey = jax.random.split(random_key)
    subkeys = jax.random.split(subkey,
                               num=nb_leaves)
    keys_tree = jax.tree_unflatten(jax.tree_structure(x1),
                                   subkeys)

    # apply isolinedd to each branch of the tree
    x = jax.tree_map(lambda y1, y2, key: _variation_fn(y1, y2, key),
                     x1,
                     x2,
                     keys_tree)

    return x, random_key


def main():
    genotype_1 = {"Params_0": jnp.zeros(shape=(1000, 2)),
                  "Params_1": jnp.zeros(shape=(1000, 5, 2))}
    genotype_2 = {"Params_0": jnp.zeros(shape=(1000, 2)) + 2,
                  "Params_1": jnp.zeros(shape=(1000, 5, 2))}

    random_key = jax.random.PRNGKey(0)
    random_key, _key = jax.random.split(random_key)

    iso_sigma = 0.01
    line_sigma = 0.1

    new_genotype_old = jax.jit(isoline_variation)(genotype_1,
                                           genotype_2,
                                           random_key,
                                           iso_sigma,
                                           line_sigma)

    new_genotype_fixed = jax.jit(isoline_variation_fixed)(genotype_1,
                                           genotype_2,
                                           random_key,
                                           iso_sigma,
                                           line_sigma)

    fig, ax = plt.subplots()

    ax.scatter(new_genotype_old[0]["Params_0"][:, 0],
               new_genotype_old[0]["Params_0"][:, 1],
               label="iso-dd bug")

    ax.scatter(new_genotype_fixed[0]["Params_0"][:, 0],
               new_genotype_fixed[0]["Params_0"][:, 1],
               label="iso-dd")

    ax.scatter(0, 0, label="parent_1")
    ax.scatter(2, 2, label="parent_2")

    ax.legend()
    plt.show()


if __name__ == '__main__':
    main()
To upload designs, you'll need to enable LFS and have an admin enable hashed storage. More information