Zero-Shot Reinforcement Learning via Function Encoders

The University of Texas at Austin
ICML 2024
The function Encoder work flow. First, the space of perturbing functions,
      e.g. the space of reward functions or the space of adversary policies, is encoded into representations
      via the learned basis functions. Then, these representations are passed into the RL algorithm.

Abstract

Although reinforcement learning (RL) can solve many challenging sequential decision making problems, achieving zero-shot transfer across related tasks remains a challenge. The difficulty lies in finding a good representation for the current task so that the agent understands how it relates to previously seen tasks. To achieve zeroshot transfer, we introduce the function encoder, a representation learning algorithm which represents a function as a weighted combination of learned, non-linear basis functions. By using a function encoder to represent the reward function or the transition function, the agent has information on how the current task relates to previously seen tasks via a coherent vector representation. Thus, the agent is able to achieve transfer between related tasks at run time with no additional training. We demonstrate state-of-theart data efficiency, asymptotic performance, and training stability in three RL fields by augmenting basic RL algorithms with a function encoder task representation.

The Function Encoder

The goal of this paper is to represent functions from an arbitrary function space, so that this representation can be used for downstream tasks such as reinforcement learning. Naturally, functions are well-represented by their coefficients with respect to a given basis. However, many practical function spaces are high-dimensional, and so not amenable to classic basis functions such as Fourier series. Therefore, we aim to find basis functions for arbitrary function spaces from data. We introduce the function encoder, a algorithm which learns basis functions from data using a neural network.

A video of learned basis functions converging to span the space of quadratic functions.

This video demonstrates the function encoder algorithm applied to the space of quadratic functions. Initially, the basis functions are random, and the randomly sampled quadratic functions (bottom) are poorly approximated. However, as training progresses, the basis functions converge and this approximation becomes more and more accurate. Furthermore, we can use the same basis functions to extrapolate to out-of-distribution functions not seen during training. Due to the nature of basis functions, these out-of-distribution quadratics are still well represented as they lie within the span of the basis functions. By construction, a function's representation, i.e. its coefficients with respect to the basis functions, is fully informative and linear. This property makes function encoder representations great for downstream tasks such as reinforcement learning.

Hidden-Parameter Dynamics Predictions

A MuJoCo Half Cheetah with Variable limb sizes.

To demonstrate the efficacy of this approach, we first show that the function encoder can learn complex function spaces. We consider a modified version of the MuJoCo Half Cheetah environment where the lengths of the limbs and the control authority are varied ceach episode. These hidden-parameters affect the system dynamics. The goal is to predict the dynamics given a small online dataset, but without direct knowledge of the hidden parameters.

A learning curve showing that the function encoder achieves better performance than a transformer.

We compare the function encoder against a transformer, which can incorporate the online dataset as input to the encoder side of the transformer. The function encoder outperforms the transformer in terms of data efficiency and asymptotic performance. We additionally compare against an oracle, which has access to the hidden parameters as an additional input. Interestingly, a variant of the function encoder even outperforms the oracle, which suggests the function encoder's inductive bias is great for this type of transfer.

A heat map shows that the function encoder's dynamics representation is smooth with respect to a change in hidden parameters

Each axis of this plot shows a change in hidden parameter. The colors represent the cosine similarity between the dynamics representations of two different hidden parameter values. The figure shows that the function encoder's representation is smooth with respect to a change in hidden parameters. This suggests this representation is easy to work with for downstream tasks.

Multi-Task Reinforcement Learning

Ms. Pacman going to different goal locations.

In multi-task RL, the reward function varies every episode. We consider a modified version of Ms. Pacman where the goal location changes every episode. However, the agent does not know the goal location, and only has a dataset of state,action,reward pairs. Thus, this setting is more general than goal-reaching tasks alone. We proceed by first learning basis functions to span the space of reward functions. Then, for each reward function, we compute its representation according to the basis. Lastly, we pass this representation into the RL algorithm, in this case DQN. Due to the linearity of the representation space, we are also able to make architecture choices to take advantage of the inductive bias, i.e. by using an architecture similar to successor features.

A learning curve showing that the function encoder achieves better performance than baselines.

We compare the function encoder against multi-task RL baselines. The function encoder outperforms the other approaches in terms of asymptotic performance. Interestingly, transformers also perform well in this setting.

A heat map shows that the function encoder's reward representation is smooth with respect to a change in goal location

We compare the reward function representations of each goal location, relative to the goal in the top-left marked with a star. The colors represent the cosine similarity between the reward representations of two different goal locations. The figure shows that the function encoder's representation is smooth with respect to a change in goal location, whereas the same is not always true for representation learning algorithms. The transformer also learns a smooth representation space in this example, which likely explains why it performs so well.

Multi-Agent Reinforcement Learning

Two-player tag.

In this setting, the environment is a two-player, zero-sum game. Each episode, an adversary is randomly sampled, and the goal is for the ego agent to exploit the adversary. The environment is tag, and the ego agent is the tagger. Each episode, we presume access to a dataset of state, adversary action pairs. The ego agent should use this data to identify the adversary's policy, and then should exploit that policy if possible. In the function encoder case, this data is used to compute the adversary's representation with respect to learned basis functions, and this representation is fed into the RL algorithm.

A learning curve showing that the function encoder achieves better performance than baselines.

We compare the function encoder against multi-agent baselines, and its achieves the best asymptotic performance. Interestingly, PPO alone is unstable in this environment. This is a result of the unstationary nature of multi-agent RL. In contrast, PPO + FE is stable and achieves good performance. In addition, the transformer performs poorly in this environment. This is due to the inherent data inefficiency of transformers, where data is much more restricted in multi-agent settings then it is in multi-task settings. Lastly, we observe that a one-hot encoding of the adversary is not sufficient to achieve good performance.

BibTeX

@inproceedings{FunctionEncoder,
  author       = {Tyler Ingebrand and
                  Amy Zhang and
                  Ufuk Topcu},
  title        = {Zero-Shot Reinforcement Learning via Function Encoders},
  booktitle    = {{ICML}},
  year         = {2024}
}