diff --git a/examples/time_series_generation.ipynb b/examples/time_series_generation.ipynb index 247ff6c..2679732 100644 --- a/examples/time_series_generation.ipynb +++ b/examples/time_series_generation.ipynb @@ -67,7 +67,6 @@ "outputs": [], "source": [ "def plot_data(data, N=500):\n", - "\n", " fig, axs = plt.subplots(2, 2)\n", " time = np.linspace(0, 1, 20)\n", " for i in range(N):\n", @@ -88,7 +87,7 @@ " axs[1, 1].plot(time, data[i, :, 3], linewidth=0.1)\n", " axs[1, 1].set_title(\"Volatility process of asset 2\")\n", " axs[1, 1].set_xlabel(\"Time\")\n", - " \n", + "\n", " plt.tight_layout()" ] }, @@ -139,161 +138,150 @@ "outputs": [], "source": [ "class LearnedSinusoidalPosEmb(eqx.Module):\n", - " \n", " weights: jnp.ndarray\n", - " \n", + "\n", " def __init__(self, dim, *, key):\n", " assert dim % 2 == 0\n", - " half_dim = dim //2\n", + " half_dim = dim // 2\n", " self.weights = jrandom.normal(key=key, shape=(half_dim,))\n", - " \n", + "\n", " def __call__(self, x: Float, *, key=None):\n", " \"\"\"Return [x, sin(w*x), cos(w*x)]\"\"\"\n", " freqs = x * self.weights * 2 * math.pi\n", " fouriered = jnp.concatenate((jnp.sin(freqs), jnp.cos(freqs)), axis=0)\n", " fouriered = jnp.concatenate((x[None], fouriered), axis=0)\n", - " \n", + "\n", " return fouriered\n", - " \n", - " \n", + "\n", + "\n", "class TimeModulatedFeedForward(eqx.Module):\n", " \"\"\"A variant of feed forward neural network where time is introduced\"\"\"\n", - " \n", + "\n", " lin1: list[nn.Linear]\n", " lin2: list[nn.Linear]\n", " to_scale_shift: nn.Sequential\n", - " \n", + "\n", " def __init__(self, dim: int, order: int, d_ff: int, time_dim: int, *, key):\n", - " \n", " self.to_scale_shift = nn.Sequential(\n", - " layers=[\n", - " nn.Lambda(jax.nn.silu),\n", - " nn.Linear(time_dim, d_ff * 2, key=key)\n", - " ]\n", - " )\n", - " \n", + " layers=[nn.Lambda(jax.nn.silu), nn.Linear(time_dim, d_ff * 2, key=key)]\n", + " )\n", + "\n", " self.lin1 = []\n", " self.lin2 = []\n", " for i in range(1, order + 1):\n", - " lin1 = nn.Linear(dim**i, d_ff, key=jrandom.fold_in(key, i*2)) \n", + " lin1 = nn.Linear(dim**i, d_ff, key=jrandom.fold_in(key, i * 2))\n", " lin2 = nn.Linear(d_ff, dim**i, key=jrandom.fold_in(key, i * 2 + 1))\n", " self.lin1.append(lin1)\n", " self.lin2.append(lin2)\n", - " \n", + "\n", " def __call__(self, x: list[Array], t_embed: Float[Array, [\" time_dim\"]]):\n", " scale_shift = self.to_scale_shift(t_embed)\n", " scale, shift = jnp.split(scale_shift, 2, axis=-1)\n", - " \n", + "\n", " shapes = [xx.shape for xx in x]\n", " x = jax.tree_util.tree_map(lambda xx: einops.rearrange(xx, \"... -> (...)\"), x)\n", - " \n", + "\n", " # first layer\n", " x = [f(xx) for f, xx in zip(self.lin1, x)]\n", - " \n", + "\n", " # use activation function\n", " x = jax.tree_util.tree_map(jax.nn.silu, x)\n", - " \n", + "\n", " # time modulation\n", " x = jax.tree_util.tree_map(lambda xx: xx * (scale + 1) + shift, x)\n", - " \n", + "\n", " # second layer\n", " x = [f(xx) for f, xx in zip(self.lin2, x)]\n", " x = jax.tree_util.tree_map(lambda xx, shape: jnp.reshape(xx, shape), x, shapes)\n", - " \n", + "\n", " return x\n", - " \n", - " \n", + "\n", + "\n", "class Block(eqx.Module):\n", - " \n", " attn_block: TensorSelfAttention\n", " attn_norm: TensorLayerNorm\n", - " \n", + "\n", " mlp_block: TimeModulatedFeedForward\n", " mlp_norm: TensorLayerNorm\n", - " \n", - " \n", + "\n", " def __init__(self, dim, order, n_heads, time_dim, *, key):\n", " attn_key, mlp_key = jrandom.split(key)\n", " self.attn_block = TensorSelfAttention(\n", - " order=order, dim=dim,\n", + " order=order,\n", + " dim=dim,\n", " n_heads=n_heads,\n", " key=attn_key,\n", " )\n", " self.attn_norm = TensorLayerNorm(dim, order)\n", - " \n", - " self.mlp_block = TimeModulatedFeedForward(dim, \n", - " order=order, \n", - " d_ff=dim * dim * 4, \n", - " time_dim=time_dim,\n", - " key=mlp_key)\n", + "\n", + " self.mlp_block = TimeModulatedFeedForward(\n", + " dim, order=order, d_ff=dim * dim * 4, time_dim=time_dim, key=mlp_key\n", + " )\n", " self.mlp_norm = TensorLayerNorm(dim, order)\n", - " \n", - " \n", - " def __call__(self, \n", - " x: list[Array],\n", - " time_embed: Float[Array, \" time_dim\"],\n", - " *, \n", - " key=None):\n", - " \n", + "\n", + " def __call__(\n", + " self, x: list[Array], time_embed: Float[Array, \" time_dim\"], *, key=None\n", + " ):\n", " resid = x\n", " x = self.attn_block(self.attn_norm(x), key=key)\n", " x = jax.tree_util.tree_map(lambda x1, x2: x1 + x2, resid, x)\n", - " \n", + "\n", " resid = x\n", " x = jax.vmap(lambda xx: self.mlp_block(xx, time_embed))(self.mlp_norm(x))\n", " x = jax.tree_util.tree_map(lambda x1, x2: x1 + x2, resid, x)\n", - " \n", + "\n", " return x\n", - " \n", - " \n", + "\n", + "\n", "class SigFormer(eqx.Module):\n", - " \n", " blocks: list[Block]\n", - " \n", + "\n", " def __init__(self, dim, order, n_heads, time_dim, n_layers, *, key):\n", " self.blocks = [\n", " Block(dim, order, n_heads, time_dim, key=jrandom.fold_in(key, i))\n", " for i in range(n_layers)\n", - " ] \n", - " \n", - " \n", + " ]\n", + "\n", " def __call__(self, x, time_embed, *, key=None):\n", " for block in self.blocks:\n", " x = block(x, time_embed)\n", " return x\n", - " \n", - " \n", - " \n", "\n", "\n", "class ScoreFunction(eqx.Module):\n", " \"\"\"Score function will be approximated by this class\"\"\"\n", - " \n", + "\n", " conditional: bool = eqx.static_field()\n", " order: int = eqx.static_field()\n", - " \n", + "\n", " project: nn.Linear\n", " conditional_embedding: nn.Sequential\n", " time_mlp: nn.Sequential\n", " transformer: eqx.Module\n", " flatten: TensorFlatten\n", " readout: nn.Linear\n", - " \n", - " \n", - " def __init__(self, \n", - " input_embed_dim: int,\n", - " condition_embed_dim: int,\n", - " input_dim: int = 4,\n", - " depth: int = 4,\n", - " order: int = 2,\n", - " heads: int = 4,\n", - " conditional: bool = True,\n", - " learned_sinusoidal_dim: int = 16,\n", - " *,\n", - " key\n", - " ):\n", - " proj_key, time_mlp_key, readout_key, cond_emb_key, transformer_key = jrandom.split(key, 5)\n", - " \n", + "\n", + " def __init__(\n", + " self,\n", + " input_embed_dim: int,\n", + " condition_embed_dim: int,\n", + " input_dim: int = 4,\n", + " depth: int = 4,\n", + " order: int = 2,\n", + " heads: int = 4,\n", + " conditional: bool = True,\n", + " learned_sinusoidal_dim: int = 16,\n", + " *,\n", + " key,\n", + " ):\n", + " (\n", + " proj_key,\n", + " time_mlp_key,\n", + " readout_key,\n", + " cond_emb_key,\n", + " transformer_key,\n", + " ) = jrandom.split(key, 5)\n", + "\n", " self.conditional = conditional\n", " self.order = order\n", " if conditional:\n", @@ -301,28 +289,36 @@ " else:\n", " embed_dim = input_embed_dim\n", " self.project = nn.Linear(input_dim, input_embed_dim, key=proj_key)\n", - " \n", - " sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim, key=jrandom.fold_in(time_mlp_key, 0))\n", + "\n", + " sinu_pos_emb = LearnedSinusoidalPosEmb(\n", + " learned_sinusoidal_dim, key=jrandom.fold_in(time_mlp_key, 0)\n", + " )\n", " fourier_dim = learned_sinusoidal_dim + 1\n", " time_dim = embed_dim * 4\n", - " \n", + "\n", " self.time_mlp = nn.Sequential(\n", " layers=[\n", " sinu_pos_emb,\n", " nn.Linear(fourier_dim, time_dim, key=jrandom.fold_in(time_mlp_key, 1)),\n", " nn.Lambda(jax.nn.gelu),\n", - " nn.Linear(time_dim, time_dim, key=jrandom.fold_in(time_mlp_key, 2))\n", + " nn.Linear(time_dim, time_dim, key=jrandom.fold_in(time_mlp_key, 2)),\n", " ]\n", " )\n", - " \n", + "\n", " self.conditional_embedding = nn.Sequential(\n", " layers=[\n", - " nn.Linear(input_dim, embed_dim * 2, key=jrandom.fold_in(cond_emb_key, 1)),\n", + " nn.Linear(\n", + " input_dim, embed_dim * 2, key=jrandom.fold_in(cond_emb_key, 1)\n", + " ),\n", " nn.Lambda(jax.nn.relu),\n", - " nn.Linear(embed_dim * 2, condition_embed_dim, key=jrandom.fold_in(cond_emb_key, 2))\n", + " nn.Linear(\n", + " embed_dim * 2,\n", + " condition_embed_dim,\n", + " key=jrandom.fold_in(cond_emb_key, 2),\n", + " ),\n", " ]\n", " )\n", - " \n", + "\n", " self.transformer = SigFormer(\n", " dim=embed_dim,\n", " order=order,\n", @@ -331,35 +327,35 @@ " n_layers=depth,\n", " key=transformer_key,\n", " )\n", - " \n", + "\n", " self.flatten = TensorFlatten()\n", " readout_in_dim = sum(embed_dim ** (i + 1) for i in range(order))\n", " self.readout = nn.Linear(readout_in_dim, input_dim, key=readout_key)\n", - " \n", - " \n", - " \n", - " def __call__(self, \n", - " x: Float[Array, \"seq_len dim\"],\n", - " time: Float[Array, \" dim_t\"], \n", - " condition: Float[Array, \" dim\"]=None,):\n", + "\n", + " def __call__(\n", + " self,\n", + " x: Float[Array, \"seq_len dim\"],\n", + " time: Float[Array, \" dim_t\"],\n", + " condition: Float[Array, \" dim\"] = None,\n", + " ):\n", " x = jax.vmap(self.project)(x)\n", " t = self.time_mlp(time, key=None)\n", - " \n", + "\n", " if condition is not None and self.conditional:\n", " cond = self.conditional_embedding(condition)\n", " cond = cond[None, ...]\n", " cond = jnp.repeat(cond, x.shape[0], axis=0)\n", " x = jnp.concatenate((x, cond), axis=-1)\n", - " \n", + "\n", " # compute signature\n", " x = jnp.pad(x, ((1, 0), (0, 0)), constant_values=0.0)\n", " x = signax.signature(x, depth=self.order, stream=True, flatten=False)\n", - " \n", + "\n", " x = self.transformer(x, t)\n", - " \n", + "\n", " x = self.flatten(x)\n", " x = jax.vmap(self.readout)(x)\n", - " \n", + "\n", " return x" ] }, @@ -383,159 +379,162 @@ " t_max = math.atan(math.exp(-0.5 * logsnr_min))\n", " return -0.2 * jnp.log(jnp.tan(t_min + t * (t_max - t_min)))\n", "\n", + "\n", "def right_pad_dims(x, t):\n", " padded_dim = x.ndim - t.ndim\n", - " if padded_dim <=0:\n", + " if padded_dim <= 0:\n", " return t\n", - " new_shape = t.shape + (1, ) * padded_dim\n", + " new_shape = t.shape + (1,) * padded_dim\n", " return jnp.reshape(t, new_shape)\n", "\n", + "\n", "class GaussianDiffusion(eqx.Module):\n", " num_sample_steps: int = eqx.static_field()\n", " dim: int = eqx.static_field()\n", " min_snr_loss_weight: bool = eqx.static_field()\n", " min_snr_gamma: float = eqx.static_field()\n", - " \n", + "\n", " model: ScoreFunction\n", - " \n", - " def __init__(self,\n", - " model,\n", - " dim: int = 4,\n", - " num_sample_steps=500,\n", - " min_snr_loss_weight=True,\n", - " min_snr_gamma=5.,\n", - " ):\n", + "\n", + " def __init__(\n", + " self,\n", + " model,\n", + " dim: int = 4,\n", + " num_sample_steps=500,\n", + " min_snr_loss_weight=True,\n", + " min_snr_gamma=5.0,\n", + " ):\n", " self.dim = dim\n", " self.num_sample_steps = num_sample_steps\n", " self.min_snr_loss_weight = min_snr_loss_weight\n", " self.min_snr_gamma = min_snr_gamma\n", - " \n", + "\n", " self.model = model\n", - " \n", - " \n", + "\n", " def p_mean_variance(self, x, time, time_next, condition=None):\n", - " \n", " log_snr = logsnr_schedule_cosine(time)\n", " log_snr_next = logsnr_schedule_cosine(time_next)\n", - " c = - jnp.expm1(log_snr - log_snr_next)\n", - " squared_alpha, squared_alpha_next = jax.nn.sigmoid(log_snr), jax.nn.sigmoid(log_snr_next)\n", - " squared_sigma, squared_sigma_next = jax.nn.sigmoid(-log_snr), jax.nn.sigmoid(-log_snr_next)\n", - " \n", - " alpha, sigma, alpha_next = map(jnp.sqrt, (squared_alpha, squared_sigma, squared_alpha_next))\n", - " \n", + " c = -jnp.expm1(log_snr - log_snr_next)\n", + " squared_alpha, squared_alpha_next = jax.nn.sigmoid(log_snr), jax.nn.sigmoid(\n", + " log_snr_next\n", + " )\n", + " squared_sigma, squared_sigma_next = jax.nn.sigmoid(-log_snr), jax.nn.sigmoid(\n", + " -log_snr_next\n", + " )\n", + "\n", + " alpha, sigma, alpha_next = map(\n", + " jnp.sqrt, (squared_alpha, squared_sigma, squared_alpha_next)\n", + " )\n", + "\n", " pred = self.model(x, log_snr, condition)\n", - " \n", + "\n", " x_start = (x - sigma * pred) / alpha\n", - " \n", + "\n", " model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start)\n", "\n", " posterior_variance = squared_sigma_next * c\n", "\n", " return model_mean, posterior_variance\n", - " \n", - " \n", + "\n", " def p_sample(self, x, time, time_next, condition=None, *, key):\n", - " \n", " model_mean, model_variance = self.p_mean_variance(\n", " x=x,\n", " time=time,\n", " time_next=time_next,\n", " condition=condition,\n", " )\n", - " \n", + "\n", " noise = jrandom.normal(key=key, shape=x.shape)\n", - " ret = jnp.where(time_next == 0, \n", - " model_mean,\n", - " model_mean + jnp.sqrt(model_variance) * noise\n", - " )\n", - " \n", + " ret = jnp.where(\n", + " time_next == 0, model_mean, model_mean + jnp.sqrt(model_variance) * noise\n", + " )\n", + "\n", " return ret\n", - " \n", - " \n", - " \n", + "\n", " def p_sample_loop(self, shape, condition=None, *, key):\n", " init_key, path_key = jrandom.split(key)\n", " x = jrandom.normal(key=init_key, shape=shape)\n", " steps = jnp.linspace(1.0, 0.0, self.num_sample_steps + 1)\n", - " \n", + "\n", " def scan_fn(carry, ind):\n", " x_, key_ = carry\n", " time = steps[ind]\n", " time_next = steps[ind + 1]\n", " carry = self.p_sample(\n", - " x=x_,\n", - " time=time,\n", - " time_next=time_next,\n", - " condition=condition,\n", - " key=key_\n", + " x=x_, time=time, time_next=time_next, condition=condition, key=key_\n", " )\n", - " \n", + "\n", " key_ = jrandom.split(key_, 1)[0]\n", " return (carry, key_), None\n", - " \n", - " ret, _ = jax.lax.scan(scan_fn,\n", - " (x, path_key),\n", - " jnp.arange(self.num_sample_steps),\n", - " )\n", + "\n", + " ret, _ = jax.lax.scan(\n", + " scan_fn,\n", + " (x, path_key),\n", + " jnp.arange(self.num_sample_steps),\n", + " )\n", " return ret\n", - " \n", - " def sample(self, \n", - " batch_size: int,\n", - " seq_len: int, \n", - " x0: Float[Array, \"batch_size dim\"]=None,\n", - " *, \n", - " key,):\n", + "\n", + " def sample(\n", + " self,\n", + " batch_size: int,\n", + " seq_len: int,\n", + " x0: Float[Array, \"batch_size dim\"] = None,\n", + " *,\n", + " key,\n", + " ):\n", " shape = (seq_len, self.dim)\n", " batch_key = jrandom.split(key, batch_size)\n", - " ret = jax.vmap(lambda x_start, k: self.p_sample_loop(shape=shape, condition=x_start, key=k))(x0, batch_key)\n", + " ret = jax.vmap(\n", + " lambda x_start, k: self.p_sample_loop(shape=shape, condition=x_start, key=k)\n", + " )(x0, batch_key)\n", " return ret\n", - " \n", - " \n", - " \n", + "\n", " def q_sample(self, x_start, times, noise=None, *, key=None):\n", - " \n", " if noise is None:\n", " noise = jrandom.normal(key=key, shape=x_start.shape)\n", - " \n", + "\n", " log_snr = logsnr_schedule_cosine(times)\n", " log_snr_padded = right_pad_dims(x_start, log_snr)\n", " alpha = jnp.sqrt(jax.nn.sigmoid(log_snr_padded))\n", " sigma = jnp.sqrt(jax.nn.sigmoid(-log_snr_padded))\n", - " \n", + "\n", " x_noise = x_start * alpha + noise * sigma\n", - " \n", + "\n", " return x_noise, log_snr\n", - " \n", - " \n", + "\n", " def p_losses(self, x_start, times, condition=None, noise=None, *, key):\n", - " \n", " if noise is None:\n", " noise = jrandom.normal(key=key, shape=x_start.shape)\n", - " \n", + "\n", " x, log_snr = self.q_sample(x_start, times=times, noise=noise)\n", - " \n", + "\n", " model_out = self.model(x, log_snr, condition=condition)\n", - " \n", + "\n", " target = noise\n", - " \n", + "\n", " loss = jnp.mean((model_out - target) ** 2)\n", - " \n", + "\n", " snr = jnp.exp(log_snr)\n", - " \n", + "\n", " maybe_clip_snr = jax.lax.stop_gradient(snr)\n", - " \n", + "\n", " if self.min_snr_loss_weight:\n", " maybe_clip_snr = jnp.clip(maybe_clip_snr, a_max=self.min_snr_gamma)\n", - " \n", + "\n", " loss_weight = maybe_clip_snr / snr\n", - " \n", + "\n", " return loss * loss_weight\n", - " \n", - " def compute_loss(self, x: Float[Array, \"seq_len dim\"], condition: Float[Array, \" dim\"]=None, *, key):\n", + "\n", + " def compute_loss(\n", + " self,\n", + " x: Float[Array, \"seq_len dim\"],\n", + " condition: Float[Array, \" dim\"] = None,\n", + " *,\n", + " key,\n", + " ):\n", " time_key, loss_key = jrandom.split(key)\n", - " times = jrandom.uniform(key=time_key, shape=(1, )).squeeze()\n", - " return self.p_losses(x, times, condition=condition, key=loss_key)\n", - " " + " times = jrandom.uniform(key=time_key, shape=(1,)).squeeze()\n", + " return self.p_losses(x, times, condition=condition, key=loss_key)" ] }, { @@ -553,59 +552,62 @@ "metadata": {}, "outputs": [], "source": [ - "def acf(x, max_lag, axis=(0,1)):\n", - " \n", + "def acf(x, max_lag, axis=(0, 1)):\n", " acfs = []\n", - " x = x - jnp.mean(x, axis=(0,1))\n", + " x = x - jnp.mean(x, axis=(0, 1))\n", " std = jnp.var(x, axis=(0, 1))\n", " for i in range(max_lag):\n", - " y = x[:, i:] * x[:, :-i] if i > 0 else x ** 2\n", + " y = x[:, i:] * x[:, :-i] if i > 0 else x**2\n", " acf_i = jnp.mean(y, axis=axis) / std\n", " acfs.append(acf_i)\n", - " \n", - " if axis == (0,1):\n", + "\n", + " if axis == (0, 1):\n", " return jnp.stack(acfs)\n", - " \n", + "\n", " return jnp.concatenate(acfs, 1)\n", "\n", + "\n", "def acf_nonstationary(x, symmetric=False):\n", - " \n", " b, t, d = x.shape\n", " correlations = jnp.zeros((t, t, d))\n", - " \n", + "\n", " for i in range(t):\n", " for j in range(i, t):\n", - " correlation = jnp.sum(x[:, i, :] * x[:, j, :], axis=0) \n", - " correlation = correlation / (jnp.linalg.norm(x[:, i, :], axis=0) * jnp.linalg.norm(x[:, j, :], axis=0))\n", - " correlations = correlations.at[i,j,:].set(correlation)\n", + " correlation = jnp.sum(x[:, i, :] * x[:, j, :], axis=0)\n", + " correlation = correlation / (\n", + " jnp.linalg.norm(x[:, i, :], axis=0)\n", + " * jnp.linalg.norm(x[:, j, :], axis=0)\n", + " )\n", + " correlations = correlations.at[i, j, :].set(correlation)\n", " if symmetric:\n", - " correlations = correlations.at[j,i,:].set(correlation)\n", - " \n", + " correlations = correlations.at[j, i, :].set(correlation)\n", + "\n", " return correlations\n", "\n", + "\n", "def cacf(x, lags, axis=(0, 1)):\n", - " \n", " def get_lower_triangle_indices(n):\n", " return [list(x) for x in jnp.tril_indices(n, n)]\n", - " \n", + "\n", " ind = get_lower_triangle_indices(x.shape[2])\n", - " \n", - " x = (x - jnp.mean(x, axis=axis, keepdims=True))/ jnp.std(x, axis=axis, keepdims=True)\n", - " \n", + "\n", + " x = (x - jnp.mean(x, axis=axis, keepdims=True)) / jnp.std(\n", + " x, axis=axis, keepdims=True\n", + " )\n", + "\n", " x_l = x[..., ind[0]]\n", " x_r = x[..., ind[1]]\n", - " \n", + "\n", " cacfs = []\n", " for i in range(lags):\n", - " \n", " y = x_l[:, i:] * x_r[:, :-i] if i > 0 else x_l * x_r\n", - " \n", + "\n", " cacf_i = jnp.mean(y, (1))\n", " cacfs.append(cacf_i)\n", - " \n", + "\n", " ret = jnp.concatenate(cacfs, 1)\n", " ret = jnp.reshape(ret, (ret.shape[0], -1, len(ind[0])))\n", - " \n", + "\n", " return ret\n", "\n", "\n", @@ -614,15 +616,21 @@ " x = jnp.reshape(x, (-1, L * C))\n", " return jnp.cov(x, rowvar=False)\n", "\n", + "\n", "def acf_diff(x, y):\n", " diff = acf_nonstationary(x) - acf_nonstationary(y)\n", " # diff = acf(x, max_lag=x.shape[1]) - acf(y, max_lag=y.shape[1])\n", " return jnp.sqrt(jnp.sum(diff**2, 0))\n", "\n", + "\n", "def cc_diff(x, y):\n", - " diff = jnp.mean(cacf(x, lags=x.shape[1]), 0)[0] - jnp.mean(cacf(y, lags=y.shape[1]), 0)[0]\n", + " diff = (\n", + " jnp.mean(cacf(x, lags=x.shape[1]), 0)[0]\n", + " - jnp.mean(cacf(y, lags=y.shape[1]), 0)[0]\n", + " )\n", " return jnp.sum(jnp.abs(diff), 0)\n", "\n", + "\n", "def cov_diff(x, y):\n", " diff = cov(x) - cov(y)\n", " return jnp.mean(jnp.abs(diff))" @@ -660,7 +668,7 @@ "outputs": [], "source": [ "key = jrandom.PRNGKey(0)\n", - "model_key, train_key, data_key = jrandom.split(key,3)" + "model_key, train_key, data_key = jrandom.split(key, 3)" ] }, { @@ -676,7 +684,7 @@ " depth=4,\n", " order=2,\n", " heads=4,\n", - " key=jrandom.PRNGKey(0)\n", + " key=jrandom.PRNGKey(0),\n", ")\n", "model = GaussianDiffusion(score_function, dim=4)\n", "\n", @@ -706,7 +714,9 @@ "source": [ "@eqx.filter_jit\n", "def make_step(model, x, x0, key, opt_state, opt_update):\n", - " loss, grads = eqx.filter_value_and_grad(lambda m: compute_loss(m, x, x0, key))(model)\n", + " loss, grads = eqx.filter_value_and_grad(lambda m: compute_loss(m, x, x0, key))(\n", + " model\n", + " )\n", " updates, opt_state = opt_update(grads, opt_state)\n", " model = eqx.apply_updates(model, updates)\n", " key = jrandom.split(key, 1)[0]\n", @@ -780,13 +790,14 @@ "batch_size = 200\n", "\n", "for iter in range(200):\n", - " \n", - " indices = jrandom.permutation(key=jrandom.fold_in(data_key, iter), x=jnp.arange(data.shape[0]))\n", + " indices = jrandom.permutation(\n", + " key=jrandom.fold_in(data_key, iter), x=jnp.arange(data.shape[0])\n", + " )\n", " current = 0\n", - " \n", - " total_loss = 0.\n", + "\n", + " total_loss = 0.0\n", " while current < data.shape[0] - batch_size:\n", - " batch_indices = indices[current: current + batch_size]\n", + " batch_indices = indices[current : current + batch_size]\n", " x0_batch = x0[batch_indices]\n", " x_batch = normalized_dx[batch_indices]\n", " model, loss, opt_state, train_key = make_step(\n", @@ -799,24 +810,27 @@ " )\n", " total_loss += loss.item()\n", " current += batch_size\n", - " \n", + "\n", " # evaluation\n", " if (iter + 1) % 10 == 0:\n", " print(f\"Iter {iter} \\t Loss: {total_loss:.4f}\")\n", " # this is how we get x0. Basically, this is from train data. There should be better way to do it\n", " # in fact, my submission is screwed because of not generating x0 correctly\n", " x0_fake = x0[:2000]\n", - " x_fake = sample(model, batch_size=2000, seq_len=19, x0=x0_fake, key=jrandom.PRNGKey(0))\n", + " x_fake = sample(\n", + " model, batch_size=2000, seq_len=19, x0=x0_fake, key=jrandom.PRNGKey(0)\n", + " )\n", " # reverse transform: rescale + cumsum + exp\n", " x_fake = x_fake[0] * std + mean\n", " x_fake = jnp.concatenate([jnp.log(x0_fake[:, None, :]), x_fake], axis=1)\n", " x_fake = jnp.exp(jnp.cumsum(x_fake, 1))\n", - " \n", + "\n", " acf_value = jnp.mean(acf_diff(x_fake, val_data))\n", " cacf_value = jnp.mean(cc_diff(x_fake, val_data))\n", " cov_value = jnp.mean(cov_diff(x_fake, val_data))\n", - " print(f\"ACF: {acf_value.item():.5f} \\t Correlation: {cacf_value.item():.5f} \\t Covariance: {cov_value.item():.5f}\")\n", - " " + " print(\n", + " f\"ACF: {acf_value.item():.5f} \\t Correlation: {cacf_value.item():.5f} \\t Covariance: {cov_value.item():.5f}\"\n", + " )" ] }, {