Fine-Tuning Large Vision-Language Models as Decision-Making Agents via Reinforcement Learning

Teaser Image
UC Berkeley, UIUC, New York University
*Project Lead, †Equal Contribution
We propose a Reinforcement Learning (RL) training framework for large Vision Language Models (VLM) or Multimodal Large Language Models (MLLM) using task-specific rewards. At each time step, the VLM takes the current observation and a system prompt as inputs and outputs a formatted utterance containing a chain of thought reasoning and a text action. The action is parsed and mapped to the environment producing a reward. We then apply RL with the task reward to fine-tune the VLM.

Training VLMs with RL

We propose an algorithmic framework that directly fine-tune a large vision-language model (VLM) or Multimodal large language model (MLLM) using Reinforcement Learning (RL), using task rewards.

An overview of our RL training framework for VLM. Our method contains three key components: (1) Designing the input prompt $v_t^\text{in}$, for obtaining a formatted output $v_t^\text{out}$; (2) Post-processing the output text $v_t^\text{out}$ for a legal action $a_t$; (3) Estimating the action probability $\pi_\theta(a_t|o_t,v_t^\text{in})$.

Designing the input prompt $v_t^\text{in}$

We design a $v_t^\text{in}$ in the following format, containing (1) task description; (2) legal action space; and (3) the desired output format $v_t^\text{out}$(containing the CoT reasoning). SEAL-Bench Results
Task-specific CoT prompt input $v_t^\text{in}$, where the blue part represents the CoT reasoning and the red part is the text-based action.

Post-processing the output text $v_t^\text{out}$

We post-process the output text $v_t^\text{out}$ by directly searching for the keywords of "action": "$a$", where $a\in\mathcal{A}$ is the text version of a legal action. If the output text $v_t^\text{out}$ does not contain a legal action, we perform random exploration in the legal action space. \begin{equation} f(v^\text{out}) = \begin{cases} a,& \text{if } \texttt{"}\mathtt{action}\texttt{"}: \texttt{"}a\texttt{"}\in v^\text{out},\\ \mathtt{Unif}(\mathcal{A}),& \text{otherwise.} \end{cases} \end{equation}

Estimating the action probability $\pi_\theta(a_t|o_t,v_t^\text{in})$

A naive way to compute the log action probability $\log \pi_\theta(a_t|o_t,v_t^\text{in})$ is to directly sum up the log probability of each token in $v_t^\text{out}$ via: \begin{equation*} \log\pi_\theta(a_t|o_t,v_t^\text{in})\leftarrow\log \pi_\theta({\color{#6C8EBF} v^\text{tht}_t}|o_t,v^\text{in}_t) + \log \pi_\theta({\color{#B85450} v^\text{act}_t}|o_t,v^\text{in}_t,v^\text{tht}_t). \end{equation*} However, since our method directly parse the action tokens ${\color{#B85450} v^\text{act}_t}$ to obtain the legal action $a_t$, and the CoT tokens ${\color{#6C8EBF} v^\text{tht}_t}$ are generally much longer than the action tokens ${\color{#B85450} v^\text{act}_t}$, we adopt a regularized version that scale down the log probability of the CoT tokens ${\color{#6C8EBF} v^\text{tht}_t}$ by a factor $\lambda\in[0,1]$, which results in: \begin{equation*} \log\pi_\theta(a_t|o_t,v_t^\text{in})\leftarrow\lambda \log \pi_\theta({\color{#6C8EBF} v^\text{tht}_t}|o_t,v^\text{in}_t) + \log \pi_\theta({\color{#B85450} v^\text{act}_t}|o_t,v^\text{in}_t,v^\text{tht}_t). \end{equation*} We found that choosing a moderate value of $\lambda$ (e.g., $\lambda\in[0.2,0.5]$) works well in practice.

Evaluation Tasks

GymCards

Our GymCards environment is designed to evaluate VLMs' arithmetic capabilities using visual perception and language reasoning in deterministic and stochastic tasks. More precisely, tasks in the GymCards domain require the VLM to recognize the numbers (potentially from cards) and utilize the numbers for reasoning.
Visualization of a state of tasks in the card environments. From left to right: Numberline, Blackjack, EZPoints, and Points24.

Goals

NumberLine. Move the current number to the target number by adding or subtracting 1 from the current number.

EZPoints. Generate an equation that equals 12, using the 2 numbers from the cards ('J'. 'Q', and 'K' are treated as '10').

Points24. Generate an equation that equals 24, using the 4 numbers from the cards ('J'. 'Q', and 'K' are treated as '10').

Blackjack. Win the current game by choosing "hit" or "stand".

Examples of Transitions in the GymCards Environment

AlfWorld

We also adopt the ALFWorld environment, to evaluate VLM's decision-making capabilities in tasks that require visual semantic understanding. See an example of the ALFWorld environment below.
An example of transition in the ALFWorld environment.

Experimental Results

Improving VLM agents' decision-making capabilities

We show that our method can enable a backbone LLaVA-1.6-7b model to outperform commercial models (GPT4-V and Gemini), and supervised fine-tuned LLaVA-7b.
The average episode success rate of our method compared against other methods. Our method achieves the best overall performance in both GymCards and ALFWorld (without expert data) environments.
The average episode success rate of our method compared against other methods. The curves of Points24 are not provided because no method achieves a reasonable performance.

The importance of CoT reasoning

We demonstrate that the CoT reasoning is a crucial component of our method -- the performance of our method significantly deteriorates without the CoT reasoning.
The average episode success rate of our method with and without the CoT reasoning.
The average episode success rate of our method with and without the CoT reasoning. The curves of Points24 are not provided because no method achieves a reasonable performance.

BibTeX

@misc{zhai2024fine,
  title={Fine-Tuning Large Vision-Language Models as Decision-Making Agents via Reinforcement Learning},
  author={Zhai, Yuexiang and Bai, Hao and Lin, Zipeng and Pan, Jiayi and Tong, Shengbang and Zhou, Yifei and Suhr, Alane and Xie, Saining and LeCun, Yann and Ma, Yi and Levine, Sergey}
  year={2024},
  eprint={2405.10292},
  archivePrefix={arXiv},
  primaryClass={cs.AI}
}

Acknowledgment

Website template is taken from here. We thank the Hyperbolic Labs for the generous compute support.