Introduction
The advent of ChatGPT and Large Language Models has already affected education. With mixed results and a spectrum of ethical acceptability, students can use chat-tuned LLMs to plan, as a starting point for research, to edit and suggest stylistic or grammatical improvements, or even as a ghostwriter to write assignments.
The well-known non-profit Khan Academy offers its own personalized tutor, Khanmigo, developed in partnership with OpenAI to guide learners using an inductive approach. But despite impressive capabilities in many domains, even the largest and most advanced LLMs exhibit surprising failures, especially in math. If LLMs are prone to glaring mistakes befitting learning students themselves, how can they be expected to act as a trustworthy teaching tool by Khan Academy?
One method that vastly improves the ability of LLMs to solve grade school-level math problems is chain-of-thought reasoning and prompting. Remember when your teachers docked points when you didn’t show your work? By instructing a fine-tuned LLM to break problems down and write out the steps, they often fare much better at solving them.
In the next few sections, we’ll discuss and distinguish chain-of-thought (CoT) and similar techniques and demonstrate the method with a few sample problems using the HuggingFace library.
Chain-of-Thought & Showing Your Work
Even as LLMs and their pre-training datasets grew to the point where state-of-the-art models with hundreds of billions of parameters, trained on multi-terabytes of data, they continue to struggle with basic math problem prompts.
Previous work in Nye et al.’s 2021 “Show Your Work” paper, encouraged models to use a distinct “scratchpad” by fine-tuning models with supervised scratchpad target outputs and providing few-shot examples in the prompt. Publishing in the 2022 NeurIPs conference proceedings, Wei et al.’s Chain of Thought paper was built on the scratchpad concept using few-shot prompt examples alone, with no gradient updates. So ultimately, Wei’s chain-of-thought strategy is a matter of prompt engineering.
Here’s a word problem in the Chain-of-Thought paper that gave LLMs difficulty:
“The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?”
In the paper, this prompt yielded an incorrect value of 27 apples in the few-shot scenario, with no CoT. Adding examples that explicitly describe necessary steps yields the correct value of 9 apples. Models are often perfectly capable of getting the right answer if presented with each individual component of a multistep problem and with a CoT prompt, the model has no trouble and provides the correct answers:
Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
A: Originally, Leah had 32 chocolates. Her sister had 42. So, in total, they had 32 + 42 = 74. After eating 35, they had 74 – 35 = 39. The answer is 39.
Using CoT prompts, Wei and colleagues found remarkable improvements across a range of different task benchmarks, most notably, in PaLM 540B achieving an improvement on the GSM8K math word problem from 18% solve rate to 57% solve rate. In addition, the authors found substantial improvements using chain-of-thought on the SVAMP, ASDiv , AQuA , and MAWPS datasets, all of which involve a variety of math word problem tasks.
Exploring the effect of chain-of-thought on PaLM, LaMDA, and GPT-3 model families, the authors found that CoT improvements correlate strongly with the size of the model. This result, in keeping with previous work, forms the basis of the authors’ strong assertion that chain-of-thought reasoning is an “emergent property of model scale that allows sufficiently large language models to perform reasoning tasks that otherwise have flat scaling curves.”
Hey LLM, Let’s Think Step-by-Step
A different paper written by Kojima et al. found that this parameter dependence extends to the zero-shot regime as well. Kojima and colleagues detailed that the simple prompt addendum “Let’s think step-by-step” (LTSBT) elicits the same type of multistep explanatory solutions as earlier CoT and scratchpad work! However, improvements were concentrated on larger models. Kojima et al. also broke their problem presentation into a reasoning prompt (“Let’s think step-by-step”) and another for extracting the answer from the output of the first (using some variation of “the answer is”).
To get a feel for how chain-of-thought and related prompting techniques can affect LLM problem-solving, we created a mini-experiment demo using adapted free practice problems on Khan Academy and coordinating different prompting methods using the HuggingFace transformers library and 3 fine-tuned checkpoints based on the 7 billion parameter variant of Google’s Gemma.
As baselines, we included vanilla zero-shot and few-shot prompts, as well as a sabotaged zero-shot scenario to encourage short answers i.e: “the answer is:”. We also included chain-of-thought, few-shot, and zero-shot scenarios, as well as an augmented LTSBS version of each.
You can find the practice problems used for evaluation, the prompt example variations (few-shot, chain-of-thought, etc.), and code for investigating the different prompt formulas in the GitHub repo.
With 7 questions averaged across the 3 variants of Gemma 7B named above, the highest average solve rate was about 81.0% using CoT plus LTSBS. CoT alone was the second most successful prompt strategy, on average, with a solve rate of about 76.2%. Apart from the sabotaged prompts, the unmodified Few-shot prompting yielded a 48% solve rate which is worse than the unmodified Zero-shot at a 71% solve rate.
If you want to try it yourself, you’ll only need a couple of dependencies running on Python 3.8:
virtualenv hf --python=python3.8
source hf/bin/activate
pip install pip3 install torch --index-url <https://download.pytorch.org/whl/cu118>
pip install transformers accelerate
# to convert slow tokenizers to fast ones
pip install sentencepiece
git clone <https://github.com/riveSunder/chain_of_thought.git>
cd chain_of_thought
Conclusions and Future Outlook
Chain-of-thought and similar prompting techniques have been rapidly adopted over the last few years, and model families like Google’s Gemini and the related but more open Gemma models owe a significant portion of their capabilities to chain-of-thought prompting styles.
Recent works by Feng et al. (2023) and by Merril and Sabharwal (2024) have attempted to fill in the gaps. Feng and colleagues used circuit complexity theory to assert that for some problems, transformers are intrinsically incapable of solving them with a direct, immediate answer, or at least unless those models grow.
Current thinking is that encouraging models to explicitly go through each step increases the computation they apply to a given problem by acting as a recurrent hidden state or memory. Techniques like CoT also allow transformers to overcome limitations, especially in their intrinsic ability to simulate computational models or execute multistep algorithms.
The newest method of editing the prompt to yield better results is by adding context to the prompt via RAG or retrieval augmented generation. There can be amazing models created by implementing RAG and CoT to create problem-solving and context-driven AI.
Chain of Thought prompting is such a fascinating approach! It’s incredible how breaking down reasoning can improve an LLM’s performance and accuracy. This could lead to even more advanced AI applications. Excited to see where this technique goes next!