
only_train_once
OTOv1-v3, NeurIPS, ICLR, TMLR, DNN Training, Compression, Structured Pruning, Erasing Operators, CNN, Diffusion, LLM
Stars: 261

Only Train Once (OTO) is an automatic, architecture-agnostic DNN training and compression framework that allows users to train a general DNN from scratch or a pretrained checkpoint to achieve high performance and slimmer architecture simultaneously in a one-shot manner without fine-tuning. The framework includes features for automatic structured pruning and erasing operators, as well as hybrid structured sparse optimizers for efficient model compression. OTO provides tools for pruning zero-invariant group partitioning, constructing pruned models, and visualizing pruning and erasing dependency graphs. It supports the HESSO optimizer and offers a sanity check for compliance testing on various DNNs. The repository also includes publications, installation instructions, quick start guides, and a roadmap for future enhancements and collaborations.
README:
This repository is the (deprecated) Pytorch implementation of Only-Train-Once (OTO). OTO is an $\color{LimeGreen}{\textbf{automatic}}$, $\color{LightCoral}{\textbf{architecture}}$ $\color{LightCoral}{\textbf{agnostic}}$ DNN $\color{Orange}{\textbf{training}}$ and $\color{Violet}{\textbf{compression}}$ (via $\color{CornflowerBlue}{\textbf{structure pruning}}$ and $\color{DarkGoldenRod}{\textbf{erasing}}$ operators) framework. By OTO, users could train a general DNN either from scratch or a pretrained checkpoint to achieve both high performance and slimmer architecture simultaneously in the one-shot manner (without fine-tuning).
Please find our series of works and bibtexs for kind citations.
- OTOv3: Automatic Architecture-Agnostic Neural Network Training and Compression from Structured Pruning to Erasing Operators preprint.
- LoRAShear: Efficient Large Language Model Structured Pruning and Knowledge Recovery preprint.
- An Adaptive Half-Space Projection Method for Stochastic Optimization Problems with Group Sparse Regularization in TMLR 2023.
- OTOv2: Automatic, Generic, User-Friendly in ICLR 2023.
- Only Train Once (OTO): A One-Shot Neural Network Training And Pruning Framework in NeurIPS 2021.
In addition, we recommend our following efficient ML works.
- DREAM: Diffusion Rectification and Estimation-Adaptive Models, efficient diffusion training, in CVPR 2024.
- DISTILLM: Towards Streamlined Distillation for Large Language Models, LLM distillation, in ICML 2024.
We recommend to run the framework under pytorch>=2.0
. Use pip
or git clone
to install.
pip install only_train_once
or
git clone https://github.com/tianyic/only_train_once.git
We provide an example of OTO framework usage. More explained details can be found in tutorials.
import torch
from sanity_check.backends import densenet121
from only_train_once import OTO
# Create OTO instance
model = densenet121()
dummy_input = torch.zeros(1, 3, 32, 32)
oto = OTO(model=model.cuda(), dummy_input=dummy_input.cuda())
# Create HESSO optimizer
optimizer = oto.hesso(variant='sgd', lr=0.1, target_group_sparsity=0.7)
# Train the DNN as normal via HESSO
model.train()
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(max_epoch):
f_avg_val = 0.0
for X, y in trainloader:
X, y = X.cuda(), y.cuda()
y_pred = model.forward(X)
f = criterion(y_pred, y)
optimizer.zero_grad()
f.backward()
optimizer.step()
# A compressed densenet will be generated.
oto.construct_subnet(out_dir='./')
-
Pruning Zero-Invariant Group Partition. OTO at first automatically figures out the dependancy inside the target DNN to build a pruning dependency graph. Then OTO partitions DNN's trainable variables into so-called Pruning Zero-Invariant Groups (PZIGs). PZIG describes a class of pruning minimally removal structure of DNN, or can be largely interpreted as the minimal group of variables that must be pruned together.
-
Hybrid Structured Sparse Optimizer. A structured sparsity optimization problem is formulated. A hybrid structured sparse optimizer, including HESSO, DHSPG, LSHPG, is then employed to find out which PZIGs are redundant, and which PZIGs are important for the model prediction. The selected hybrid optimizer explores group sparsity more reliably and typically achieves higher generalization performance than other sparse optimizers.
-
Construct pruned model. The structures corresponding to redundant PZIGs (being zero) are removed to form the pruned model. Due to the property of PZIGs, the pruned model returns the exact same output as the full model. Therefore, no further fine-tuning is required.
The sanity check
provides the tests for pruning mode in OTO onto various DNNs from CNN to LLM. The pass of sanity check indicates the compliance of OTO onto target DNN.
python sanity_check/sanity_check.py
Note that some tests require additional dependency. Comment off unnecessary tests. We highly recommend to proceed a sanity check over a new customized DNN for testing compliance.
The visual_examples
provides the visualization of pruning dependency graphs and erasing dependency graphs. Visualization serves as a frequently used tool for employing OTO onto new unseen DNNs if meets errors.
-
Add more explanations into the current repository.
-
Release a technical report regarding the HESSO optimizer which is not discussed yet in our papers.
-
Release refactorized DHSPG and LHSPG.
-
Release the full pipeline of LoRAShear (upon business administration).
-
Provide more tutorials to cover the experiments in the pruning mode. Main experiments in OTOv2 can be found at otov2_branch.
-
Release official erasing mode after the review process of OTOv3.
-
Provide documentations of the OTO API.
We would greatly appreciate the contributions in any form, such as bug fixes, new features and new tutorials, from our open-source community.
We are humble to provide benefits for the AI community. We look forward to working with the community together to make DNN's training and compression to be more automatic and convinient.
We are open and happy for collabrations. Feel free to reach out [email protected] if have any interesting idea.
The previous OTOv2 repo has been moved into legacy_branch for academic replication.
If you find the repo useful, please kindly star this repository and cite our papers:
For OTOv3 preprint
@article{chen2023otov3,
title={OTOv3: Automatic Architecture-Agnostic Neural Network Training and Compression from Structured Pruning to Erasing Operators},
author={Chen, Tianyi and Ding, Tianyu and Zhu, Zhihui and Chen, Zeyu and Wu, HsiangTao and Zharkov, Ilya and Liang, Luming},
journal={arXiv preprint arXiv:2312.09411},
year={2023}
}
For LoRAShear preprint
@article{chen2023lorashear,
title={LoRAShear: Efficient Large Language Model Structured Pruning and Knowledge Recovery},
author={Chen, Tianyi and Ding, Tianyu and Yadav, Badal and Zharkov, Ilya and Liang, Luming},
journal={arXiv preprint arXiv:2310.18356},
year={2023}
}
For AdaHSPG+ publication in TMLR (theoretical optimization paper)
@article{dai2023adahspg,
title={An adaptive half-space projection method for stochastic optimization problems with group sparse regularization},
author={Dai, Yutong and Chen, Tianyi and Wang, Guanyi and Robinson, Daniel P},
journal={Transactions on machine learning research},
year={2023}
}
For OTOv2 publication in ICLR 2023
@inproceedings{chen2023otov2,
title={OTOv2: Automatic, Generic, User-Friendly},
author={Chen, Tianyi and Liang, Luming and Tianyu, DING and Zhu, Zhihui and Zharkov, Ilya},
booktitle={International Conference on Learning Representations},
year={2023}
}
For OTOv1 publication in NeurIPS 2021
@inproceedings{chen2021otov1,
title={Only Train Once: A One-Shot Neural Network Training And Pruning Framework},
author={Chen, Tianyi and Ji, Bo and Tianyu, DING and Fang, Biyi and Wang, Guanyi and Zhu, Zhihui and Liang, Luming and Shi, Yixin and Yi, Sheng and Tu, Xiao},
booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
year={2021}
}
For Tasks:
Click tags to check more tools for each tasksFor Jobs:
Alternative AI tools for only_train_once
Similar Open Source Tools

only_train_once
Only Train Once (OTO) is an automatic, architecture-agnostic DNN training and compression framework that allows users to train a general DNN from scratch or a pretrained checkpoint to achieve high performance and slimmer architecture simultaneously in a one-shot manner without fine-tuning. The framework includes features for automatic structured pruning and erasing operators, as well as hybrid structured sparse optimizers for efficient model compression. OTO provides tools for pruning zero-invariant group partitioning, constructing pruned models, and visualizing pruning and erasing dependency graphs. It supports the HESSO optimizer and offers a sanity check for compliance testing on various DNNs. The repository also includes publications, installation instructions, quick start guides, and a roadmap for future enhancements and collaborations.

pytorch-forecasting
PyTorch Forecasting is a PyTorch-based package for time series forecasting with state-of-the-art network architectures. It offers a high-level API for training networks on pandas data frames and utilizes PyTorch Lightning for scalable training on GPUs and CPUs. The package aims to simplify time series forecasting with neural networks by providing a flexible API for professionals and default settings for beginners. It includes a timeseries dataset class, base model class, multiple neural network architectures, multi-horizon timeseries metrics, and hyperparameter tuning with optuna. PyTorch Forecasting is built on pytorch-lightning for easy training on various hardware configurations.

llumnix
Llumnix is a cross-instance request scheduling layer built on top of LLM inference engines such as vLLM, providing optimized multi-instance serving performance with low latency, reduced time-to-first-token (TTFT) and queuing delays, reduced time-between-tokens (TBT) and preemption stalls, and high throughput. It achieves this through dynamic, fine-grained, KV-cache-aware scheduling, continuous rescheduling across instances, KV cache migration mechanism, and seamless integration with existing multi-instance deployment platforms. Llumnix is easy to use, fault-tolerant, elastic, and extensible to more inference engines and scheduling policies.

MInference
MInference is a tool designed to accelerate pre-filling for long-context Language Models (LLMs) by leveraging dynamic sparse attention. It achieves up to a 10x speedup for pre-filling on an A100 while maintaining accuracy. The tool supports various decoding LLMs, including LLaMA-style models and Phi models, and provides custom kernels for attention computation. MInference is useful for researchers and developers working with large-scale language models who aim to improve efficiency without compromising accuracy.

Genesis
Genesis is a physics platform designed for general purpose Robotics/Embodied AI/Physical AI applications. It includes a universal physics engine, a lightweight, ultra-fast, pythonic, and user-friendly robotics simulation platform, a powerful and fast photo-realistic rendering system, and a generative data engine that transforms user-prompted natural language description into various modalities of data. It aims to lower the barrier to using physics simulations, unify state-of-the-art physics solvers, and minimize human effort in collecting and generating data for robotics and other domains.

dash-infer
DashInfer is a C++ runtime tool designed to deliver production-level implementations highly optimized for various hardware architectures, including x86 and ARMv9. It supports Continuous Batching and NUMA-Aware capabilities for CPU, and can fully utilize modern server-grade CPUs to host large language models (LLMs) up to 14B in size. With lightweight architecture, high precision, support for mainstream open-source LLMs, post-training quantization, optimized computation kernels, NUMA-aware design, and multi-language API interfaces, DashInfer provides a versatile solution for efficient inference tasks. It supports x86 CPUs with AVX2 instruction set and ARMv9 CPUs with SVE instruction set, along with various data types like FP32, BF16, and InstantQuant. DashInfer also offers single-NUMA and multi-NUMA architectures for model inference, with detailed performance tests and inference accuracy evaluations available. The tool is supported on mainstream Linux server operating systems and provides documentation and examples for easy integration and usage.

oat
Oat is a simple and efficient framework for running online LLM alignment algorithms. It implements a distributed Actor-Learner-Oracle architecture, with components optimized using state-of-the-art tools. Oat simplifies the experimental pipeline of LLM alignment by serving an Oracle online for preference data labeling and model evaluation. It provides a variety of oracles for simulating feedback and supports verifiable rewards. Oat's modular structure allows for easy inheritance and modification of classes, enabling rapid prototyping and experimentation with new algorithms. The framework implements cutting-edge online algorithms like PPO for math reasoning and various online exploration algorithms.

pytorch-forecasting
PyTorch Forecasting is a PyTorch-based package designed for state-of-the-art timeseries forecasting using deep learning architectures. It offers a high-level API and leverages PyTorch Lightning for efficient training on GPU or CPU with automatic logging. The package aims to simplify timeseries forecasting tasks by providing a flexible API for professionals and user-friendly defaults for beginners. It includes features such as a timeseries dataset class for handling data transformations, missing values, and subsampling, various neural network architectures optimized for real-world deployment, multi-horizon timeseries metrics, and hyperparameter tuning with optuna. Built on pytorch-lightning, it supports training on CPUs, single GPUs, and multiple GPUs out-of-the-box.

Macaw-LLM
Macaw-LLM is a pioneering multi-modal language modeling tool that seamlessly integrates image, audio, video, and text data. It builds upon CLIP, Whisper, and LLaMA models to process and analyze multi-modal information effectively. The tool boasts features like simple and fast alignment, one-stage instruction fine-tuning, and a new multi-modal instruction dataset. It enables users to align multi-modal features efficiently, encode instructions, and generate responses across different data types.

KAG
KAG is a logical reasoning and Q&A framework based on the OpenSPG engine and large language models. It is used to build logical reasoning and Q&A solutions for vertical domain knowledge bases. KAG supports logical reasoning, multi-hop fact Q&A, and integrates knowledge and chunk mutual indexing structure, conceptual semantic reasoning, schema-constrained knowledge construction, and logical form-guided hybrid reasoning and retrieval. The framework includes kg-builder for knowledge representation and kg-solver for logical symbol-guided hybrid solving and reasoning engine. KAG aims to enhance LLM service framework in professional domains by integrating logical and factual characteristics of KGs.

aimo-progress-prize
This repository contains the training and inference code needed to replicate the winning solution to the AI Mathematical Olympiad - Progress Prize 1. It consists of fine-tuning DeepSeekMath-Base 7B, high-quality training datasets, a self-consistency decoding algorithm, and carefully chosen validation sets. The training methodology involves Chain of Thought (CoT) and Tool Integrated Reasoning (TIR) training stages. Two datasets, NuminaMath-CoT and NuminaMath-TIR, were used to fine-tune the models. The models were trained using open-source libraries like TRL, PyTorch, vLLM, and DeepSpeed. Post-training quantization to 8-bit precision was done to improve performance on Kaggle's T4 GPUs. The project structure includes scripts for training, quantization, and inference, along with necessary installation instructions and hardware/software specifications.

langgraphjs
LangGraph.js is a library for building stateful, multi-actor applications with LLMs, offering benefits such as cycles, controllability, and persistence. It allows defining flows involving cycles, providing fine-grained control over application flow and state. Inspired by Pregel and Apache Beam, it includes features like loops, persistence, human-in-the-loop workflows, and streaming support. LangGraph integrates seamlessly with LangChain.js and LangSmith but can be used independently.

InsPLAD
InsPLAD is a dataset and benchmark for power line asset inspection in UAV images. It contains 10,607 high-resolution UAV color images of seventeen unique power line assets with six defects. The dataset is used for object detection, defect classification, and anomaly detection tasks in computer vision. InsPLAD offers challenges like multi-scale objects, intra-class variation, cluttered background, and varied lighting conditions, aiming to improve state-of-the-art methods in the field.

chembench
ChemBench is a project aimed at expanding chemistry benchmark tasks in a BIG-bench compatible way, providing a pipeline to benchmark frontier and open models. It enables benchmarking across a wide range of API-based models and employs an LLM-based extractor as a fallback mechanism. Users can evaluate models on specific chemistry topics and run comprehensive evaluations across all topics in the benchmark suite. The tool facilitates seamless benchmarking for any model supported by LiteLLM and allows running non-API hosted models.

k2
K2 (GeoLLaMA) is a large language model for geoscience, trained on geoscience literature and fine-tuned with knowledge-intensive instruction data. It outperforms baseline models on objective and subjective tasks. The repository provides K2 weights, core data of GeoSignal, GeoBench benchmark, and code for further pretraining and instruction tuning. The model is available on Hugging Face for use. The project aims to create larger and more powerful geoscience language models in the future.

FuseAI
FuseAI is a repository that focuses on knowledge fusion of large language models. It includes FuseChat, a state-of-the-art 7B LLM on MT-Bench, and FuseLLM, which surpasses Llama-2-7B by fusing three open-source foundation LLMs. The repository provides tech reports, releases, and datasets for FuseChat and FuseLLM, showcasing their performance and advancements in the field of chat models and large language models.
For similar tasks

only_train_once
Only Train Once (OTO) is an automatic, architecture-agnostic DNN training and compression framework that allows users to train a general DNN from scratch or a pretrained checkpoint to achieve high performance and slimmer architecture simultaneously in a one-shot manner without fine-tuning. The framework includes features for automatic structured pruning and erasing operators, as well as hybrid structured sparse optimizers for efficient model compression. OTO provides tools for pruning zero-invariant group partitioning, constructing pruned models, and visualizing pruning and erasing dependency graphs. It supports the HESSO optimizer and offers a sanity check for compliance testing on various DNNs. The repository also includes publications, installation instructions, quick start guides, and a roadmap for future enhancements and collaborations.

ChaKt-KMP
ChaKt is a multiplatform app built using Kotlin and Compose Multiplatform to demonstrate the use of Generative AI SDK for Kotlin Multiplatform to generate content using Google's Generative AI models. It features a simple chat based user interface and experience to interact with AI. The app supports mobile, desktop, and web platforms, and is built with Kotlin Multiplatform, Kotlin Coroutines, Compose Multiplatform, Generative AI SDK, Calf - File picker, and BuildKonfig. Users can contribute to the project by following the guidelines in CONTRIBUTING.md. The app is licensed under the MIT License.

crawl4ai
Crawl4AI is a powerful and free web crawling service that extracts valuable data from websites and provides LLM-friendly output formats. It supports crawling multiple URLs simultaneously, replaces media tags with ALT, and is completely free to use and open-source. Users can integrate Crawl4AI into Python projects as a library or run it as a standalone local server. The tool allows users to crawl and extract data from specified URLs using different providers and models, with options to include raw HTML content, force fresh crawls, and extract meaningful text blocks. Configuration settings can be adjusted in the `crawler/config.py` file to customize providers, API keys, chunk processing, and word thresholds. Contributions to Crawl4AI are welcome from the open-source community to enhance its value for AI enthusiasts and developers.

sandbox
Sandbox is an open-source cloud-based code editing environment with custom AI code autocompletion and real-time collaboration. It consists of a frontend built with Next.js, TailwindCSS, Shadcn UI, Clerk, Monaco, and Liveblocks, and a backend with Express, Socket.io, Cloudflare Workers, D1 database, R2 storage, Workers AI, and Drizzle ORM. The backend includes microservices for database, storage, and AI functionalities. Users can run the project locally by setting up environment variables and deploying the containers. Contributions are welcome following the commit convention and structure provided in the repository.

void
Void is an open-source Cursor alternative, providing a full source code for users to build and develop. It is a fork of the vscode repository, offering a waitlist for the official release. Users can contribute by checking the Project board and following the guidelines in CONTRIBUTING.md. Support is available through Discord or email.

aphrodite-engine
Aphrodite is an inference engine optimized for serving HuggingFace-compatible models at scale. It leverages vLLM's Paged Attention technology to deliver high-performance model inference for multiple concurrent users. The engine supports continuous batching, efficient key/value management, optimized CUDA kernels, quantization support, distributed inference, and modern samplers. It can be easily installed and launched, with Docker support for deployment. Aphrodite requires Linux or Windows OS, Python 3.8 to 3.12, and CUDA >= 11. It is designed to utilize 90% of GPU VRAM but offers options to limit memory usage. Contributors are welcome to enhance the engine.

cua
Cua is a tool for creating and running high-performance macOS and Linux virtual machines on Apple Silicon, with built-in support for AI agents. It provides libraries like Lume for running VMs with near-native performance, Computer for interacting with sandboxes, and Agent for running agentic workflows. Users can refer to the documentation for onboarding, explore demos showcasing AI-Gradio and GitHub issue fixing, and utilize accessory libraries like Core, PyLume, Computer Server, and SOM. Contributions are welcome, and the tool is open-sourced under the MIT License.

momentum-core
Momentum is an open-source behavioral auditor for backend code that helps developers generate powerful insights into their codebase. It analyzes code behavior, tests it at every git push, and ensures readiness for production. Momentum understands backend code, visualizes dependencies, identifies behaviors, generates test code, runs code in the local environment, and provides debugging solutions. It aims to improve code quality, streamline testing processes, and enhance developer productivity.
For similar jobs

sweep
Sweep is an AI junior developer that turns bugs and feature requests into code changes. It automatically handles developer experience improvements like adding type hints and improving test coverage.

teams-ai
The Teams AI Library is a software development kit (SDK) that helps developers create bots that can interact with Teams and Microsoft 365 applications. It is built on top of the Bot Framework SDK and simplifies the process of developing bots that interact with Teams' artificial intelligence capabilities. The SDK is available for JavaScript/TypeScript, .NET, and Python.

ai-guide
This guide is dedicated to Large Language Models (LLMs) that you can run on your home computer. It assumes your PC is a lower-end, non-gaming setup.

classifai
Supercharge WordPress Content Workflows and Engagement with Artificial Intelligence. Tap into leading cloud-based services like OpenAI, Microsoft Azure AI, Google Gemini and IBM Watson to augment your WordPress-powered websites. Publish content faster while improving SEO performance and increasing audience engagement. ClassifAI integrates Artificial Intelligence and Machine Learning technologies to lighten your workload and eliminate tedious tasks, giving you more time to create original content that matters.

chatbot-ui
Chatbot UI is an open-source AI chat app that allows users to create and deploy their own AI chatbots. It is easy to use and can be customized to fit any need. Chatbot UI is perfect for businesses, developers, and anyone who wants to create a chatbot.

BricksLLM
BricksLLM is a cloud native AI gateway written in Go. Currently, it provides native support for OpenAI, Anthropic, Azure OpenAI and vLLM. BricksLLM aims to provide enterprise level infrastructure that can power any LLM production use cases. Here are some use cases for BricksLLM: * Set LLM usage limits for users on different pricing tiers * Track LLM usage on a per user and per organization basis * Block or redact requests containing PIIs * Improve LLM reliability with failovers, retries and caching * Distribute API keys with rate limits and cost limits for internal development/production use cases * Distribute API keys with rate limits and cost limits for students

uAgents
uAgents is a Python library developed by Fetch.ai that allows for the creation of autonomous AI agents. These agents can perform various tasks on a schedule or take action on various events. uAgents are easy to create and manage, and they are connected to a fast-growing network of other uAgents. They are also secure, with cryptographically secured messages and wallets.

griptape
Griptape is a modular Python framework for building AI-powered applications that securely connect to your enterprise data and APIs. It offers developers the ability to maintain control and flexibility at every step. Griptape's core components include Structures (Agents, Pipelines, and Workflows), Tasks, Tools, Memory (Conversation Memory, Task Memory, and Meta Memory), Drivers (Prompt and Embedding Drivers, Vector Store Drivers, Image Generation Drivers, Image Query Drivers, SQL Drivers, Web Scraper Drivers, and Conversation Memory Drivers), Engines (Query Engines, Extraction Engines, Summary Engines, Image Generation Engines, and Image Query Engines), and additional components (Rulesets, Loaders, Artifacts, Chunkers, and Tokenizers). Griptape enables developers to create AI-powered applications with ease and efficiency.