maml_step#
- ivy.maml_step(batch, inner_cost_fn, outer_cost_fn, variables, inner_grad_steps, inner_learning_rate, /, *, inner_optimization_step=<function gradient_descent_update>, inner_batch_fn=None, outer_batch_fn=None, average_across_steps=False, batched=True, inner_v=None, keep_inner_v=True, outer_v=None, keep_outer_v=True, return_inner_v=False, num_tasks=None, stop_gradients=True)[source]#
Perform step of vanilla second order MAML.
- Parameters:
batch (
Container) – The input batchinner_cost_fn (
Callable) – callable for the inner loop cost function, receiving sub-batch, inner vars and outer varsouter_cost_fn (
Callable) – callable for the outer loop cost function, receiving task-specific sub-batch, inner vars and outer vars. If None, the cost from the inner loop will also be optimized in the outer loop.variables (
Container) – Variables to be optimized during the meta stepinner_grad_steps (
int) – Number of gradient steps to perform during the inner loop.inner_learning_rate (
float) – The learning rate of the inner loop.inner_optimization_step (
Callable, default:<function gradient_descent_update at 0x7f8923862cb0>) – The function used for the inner loop optimization. Default is ivy.gradient_descent_update.inner_batch_fn (
Optional[Callable], default:None) – Function to apply to the task sub-batch, before passing to the inner_cost_fn. Default isNone.outer_batch_fn (
Optional[Callable], default:None) – Function to apply to the task sub-batch, before passing to the outer_cost_fn. Default isNone.average_across_steps (
bool, default:False) – Whether to average the inner loop steps for the outer loop update. Default isFalse.batched (
bool, default:True) – Whether to batch along the time dimension, and run the meta steps in batch. Default isTrue.inner_v (
Optional[Container], default:None) – Nested variable keys to be optimized during the inner loop, with same keys and boolean values. (Default value = None)keep_inner_v (
bool, default:True) – If True, the key chains in inner_v will be kept, otherwise they will be removed. Default isTrue.outer_v (
Optional[Container], default:None) – Nested variable keys to be optimized during the inner loop, with same keys and boolean values. (Default value = None)keep_outer_v (
bool, default:True) – If True, the key chains in inner_v will be kept, otherwise they will be removed. Default isTrue.return_inner_v (
Union[str,bool], default:False) – Either ‘first’, ‘all’, or False. ‘first’ means the variables for the first task inner loop will also be returned. variables for all tasks will be returned with ‘all’. Default isFalse.num_tasks (
Optional[int], default:None) – Number of unique tasks to inner-loop optimize for the meta step. Determined from batch by default.stop_gradients (
bool, default:True) – Whether to stop the gradients of the cost. Default isTrue.
- Return type:
- Returns:
ret – The cost and the gradients with respect to the outer loop variables.
Examples
With
ivy.Containerinput:>>> import ivy >>> from ivy.functional.ivy.gradients import _variable
>>> ivy.set_backend("torch")
>>> def inner_cost_fn(sub_batch, v): ... return sub_batch.mean().x / v.mean().latent >>> def outer_cost_fn(sub_batch,v): ... return sub_batch.mean().x / v.mean().latent
>>> num_tasks = 2 >>> batch = ivy.Container({"x": ivy.arange(1, num_tasks + 1, dtype="float32")}) >>> variables = ivy.Container({ ... "latent": _variable(ivy.repeat(ivy.array([[1.0]]), num_tasks, axis=0)) ... })
>>> cost = ivy.maml_step(batch, inner_cost_fn, outer_cost_fn, variables, 5, 0.01) >>> print(cost) (ivy.array(1.40069818), { latent: ivy.array([-1.13723135]) }, ())