Coconut is the official PyTorch implementation of the research paper “Training Large Language Models to Reason in a Continuous Latent Space.” The framework introduces a novel method for enhancing large language models (LLMs) with continuous latent reasoning steps, enabling them to generate and refine reasoning chains within a learned latent space rather than relying solely on discrete symbolic reasoning. It supports training across multiple reasoning paradigms—including standard Chain-of-Thought (CoT), no-thought, and hybrid configurations—using configurable training stages and latent representations. The repository is built with Hugging Face Transformers, PyTorch Distributed, and Weights & Biases (wandb) for logging, supporting large-scale experiments on mathematical and logical reasoning datasets such as GSM8K, ProntoQA, and ProsQA.
Features
- Reproducible experiment scripts matching the paper’s benchmark protocols
- Supports distributed multi-GPU training with torchrun and mixed-precision (bf16)
- Dataset preprocessing tools for GSM8K, ProntoQA, and ProsQA
- Integrated wandb logging and checkpoint management across training stages
- Modular YAML-based configuration for multi-stage training and evaluation
- Implements continuous latent reasoning for LLMs beyond discrete CoT prompting