GPO: learning from critical steps to improve LLM reasoning
A novel fine-tuning strategy designed to improve LLM multi-step reasoning capabilities by focusing on pivotal moments.
Large language models (LLMs) are increasingly used in various domains, showing impressive potential on different tasks. Recently, reasoning LLMs have been proposed to improve the reasoning or thinking capabilities of LLMs to solve complex problems. Despite the promising results of reasoning LLMs, enhancing the multi-step reasoning capabilities of LLMs still remains a significant challenge. While existing optimization methods have advanced the LLM reasoning capabilities, they often treat reasoning trajectories as a whole, without considering the underlying critical steps within the trajectory. In this paper, we introduce Guided Pivotal Optimization (GPO), a novel fine-tuning strategy that dives into the reasoning process to enable more effective improvements. GPO first identifies the "critical step" within a reasoning trajectory — a point that the model must carefully proceed to succeed at the problem. We locate the critical step by estimating the advantage function. GPO then resets the policy to the critical step, samples the new rollout and prioritizes the learning process on those rollouts. This focus allows the model to learn more effectively from pivotal moments within the reasoning process to improve the reasoning performance. We demonstrate that GPO is a general strategy that can be integrated with various optimization methods to improve reasoning performance. Besides theoretical analysis, our experiments across challenging reasoning benchmarks show that GPO can consistently and significantly enhance the performance of existing optimization methods, showcasing its effectiveness and generalizability in improving LLM reasoning by concentrating on pivotal moments within the generation process.
Latest publications
Leveraging parameter space symmetries for reasoning skill transfer in LLMs
Utilizing an alignment-first strategy to transfer advanced reasoning skills to a non-reasoning model.
NeurIPSViCrit: a verifiable reinforcement learning proxy task for visual perception in VLMs
An RL proxy task that trains VLMs to localize synthetic hallucinations injected into human-written captions.
NeurIPSInfluence functions for efficient data selection in reasoning
A proposal to define reasoning data quality using influence functions.
NeurIPS