Aussie AI Blog

Multi-Step Reasoning Inference Optimization

  • Dec 21st, 2024
  • by David Spuler, Ph.D.

What is Reasoning Inference Optimization?

Reasoning inference optimization is the application of the over 500 inference optimization techniques to the advanced multi-step reasoning models, which are based on Chain-of-Thought and other test time compute methods. These reasoning models, sometimes called Large Reasoner Models (LRMs), use multiple repeated steps of LLM inference queries to converge on a better answer. This is the new trend, as exemplified by OpenAI's "o1" and "o3" models, whereby LLMs can be smarter if you let them think a little longer.

Which is also slower!

Hence, the idea for a large new subarea of AI research: reasoning inference optimization. The goal is to make these multi-step reasoning algorithms faster, whilst retaining their powerful reasoning ability.

The commercial value of these speed optimizations is apparent and already used in commercial platforms. OpenAI announced that their "o1" model used 60% fewer tokens than their "o1-preview" model. No word yet on how the "o3" model compares in terms of this metric.

There's not a lot of research papers on this efficiency area yet. The vast majority of papers on Chain-of-Thought and other multi-step reasoning algorithms have been about improving their reasoning capabilities, without really being concerned with improving their speed. But this is changing, and a few papers have started appearing about efficiency of Chain-of-Thought. Thus, I expect many more papers soon!

Faster Reasoning Models

How can we run these reasoning algorithms faster? It's a well-known problem that these models run slower. The "o1" model is much slower to respond than other OpenAI one-shot inference models.

The basic options to run a reasoner model faster include:

  • All the existing LLM inference optimization methods.
  • High-level changes to the reasoning algorithm
  • New reasoning-specific low-level inference optimizations.

Let's examine each of these major areas in turn.

Existing LLM Inference Optimization Techniques

The first point about running multi-step inferenece models faster is that each step involves inference, and all of the usual LLM inference optimization techniques apply. Running each step faster will run the whole thing faster. Hence, any of the major techniques can be used:

There are many others, and in fact we've cataloged research papers on over 500 distinct inference optimization methods. All of these methods can be used at every inference sub-query that's used in an advanced reasoner model.

Note that some methods are "lossy" in that they reduce the accuracy of the queries as a trade-off to the speed improvement. Some other types of optimizations are "lossless" and produce the exact same results, but faster. In the above list, better GPUs and speculative decoding are lossless, whereas small models, quantization, and pruning are lossy.

High-Level Chain-of-Thought Optimizations

The Chain-of-Thought algorithms are relatively new, and the best place to optimize their speed is at the top. As already noted, OpenAI stated that they had already reduced the token counts from "o1-preview" to "o1" models by 60%. Although the details were not provided, presumably this was done by tweaking the algorithms at the high-level.

The total cost of a reasoning algorithm is closely related to the total number of tokens processed. Token reduction in a Chain-of-Thought algorithm could be improved by:

  • Fewer steps!
  • Abandoning unproductive paths and steps earlier
  • Parallelization of steps and paths

Some of the simpler and pragmatic ways to reduce the total token counts would be:

  • Shorter meta-prompt instructions
  • Packing algorithms for each step

We've seen a lot of these types of optimizations in RAG architectures. Interim steps of Chain-of-Thought effectively involve a "context" part of the prompts, which is similar to having contexts from document chunks in a RAG application.

Low-Level Reasoning Inference Optimizations

An interesting question is whether any of the existing LLM inference optimization methods, as used for a single inference query, can be even further optimized in multi-step inference. If we consider the multiple steps of inference in the "chain" of reasoning, there are several considerations for inference. Some of the main strategies that seem to offer promise include:

See also CoT-specific optimization research:

Low-Level Token Reductions

Token reduction strategies, as discussed at the high-level above, can also be considered at the lower-level of the inference engine. Examples of relevant techniques in this area include:

It's not clear that these methods are particularly applicable to multi-step reasoning. The goal of these algorithms is to perfect the answer, not to shrink it down to fewer words, so the two approaches may be at odds. On the other hand, the very first paper that I've noted about reasoning model speed optimization (Kang et al, 2024) is about token pruning and context compression for reasoning models!

Caching and Computation Reuse

Caching and computation reuse ideas come to mind when we realize that the Chain-of-Thought algorithm iterates over things that tend to be similar. There is prior data and computations at every step:

  • Tokens and text in each prior output (i.e., decoded text)
  • Abandoned paths of tokens and text
  • KV cache data from prior computations
  • Logit values for prior tokens

Some of the first thoughts on this area of "optimizing the optimizations" include:

Grammatical Error Correction

An interesting point to note is that many steps of Chain-of-Thought are about revising answers. This is very similar to the task of Grammatical Error Correction (GEC), also known less formally as "editing" a document. The use case of revising and improving an input text has several specific characteristics:

  • Output text is "similar" to the input text.
  • Sub-sections of the input text will appear verbatim in the output text.
  • Token length of input and output texts are similar.

Chain-of-Thought effectively does a revision often, perhaps not at each step, but often enough that these optimization techniques may be relevant. For these reasons, much of the research on speeding up GEC algorithms may be relevant to speeding up Chain-of-Thought, such as:

Reduced Lossiness

Another way to reverse this question overall is whether any of the lossy inference optimization techniques can retain better accuracy in a multi-step algorithm. In other words, rather than trying to further speed up the optimizations, can we retain their existing speedup, but address their accuracy limitations. Many of the big LLM inference optimizations are lossy, such as quantization and pruning, but maybe they can be improved using information from prior inference steps.

Some of the inference optimization techniques that are dismissed as too inaccurate in single-shot inference might potentially become useful if they were less lossy. The question is thus whether having a long chain of extra text and computation data from prior steps (i.e., from each of their inference queries) can in some way allow lossy optimizations to be improved. I don't think I've seen any research papers on this topic!

References

  1. OpenAI, Dec 2024, OpenAI o1 and new tools for developers, https://openai.com/index/o1-and-new-tools-for-developers/ ("Lower latency: o1 uses on average 60% fewer reasoning tokens than o1-preview for a given request.")
  2. Yu Kang, Xianghui Sun, Liangyu Chen, Wei Zou, 16 Dec 2024, C3oT: Generating Shorter Chain-of-Thought without Compromising Effectiveness, https://arxiv.org/abs/2412.11664 (Token pruning and prompt compression for Chain-of-Thought.)

Reasoning and CoT Token Efficiency Topics

Blog articles on reasoning efficiency:

More research information on general efficiency optimization techniques for reasoning models:

Efficiency optimizations to Chain-of-Thought include:

More AI Research Topics

Read more about: