Skip to content

Influence Functions with (Eigenvalue-corrected) Kronecker-Factored Approximate Curvature

License

Notifications You must be signed in to change notification settings

pomonam/kronfluence

Repository files navigation

Kronfluence

License License Doi CI Linting Ruff


Kronfluence is a PyTorch package designed to compute influence functions using Kronecker-factored Approximate Curvature (KFAC) or Eigenvalue-corrected KFAC (EKFAC). For detailed description of the methodology, see the paper, Studying Large Language Model Generalization with Influence Functions.


Installation

Important

Requirements:

  • Python: Version 3.9 or later
  • PyTorch: Version 2.1 or later

To install the latest stable version, use the following pip command:

pip install kronfluence

Alternatively, you can install directly from source:

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e .

Getting Started

Kronfluence supports influence computations on nn.Linear and nn.Conv2d modules. See the Technical Documentation page for a comprehensive guide.

TL;DR You need to prepare a trained model and datasets, and pass them into the Analyzer class.

import torch
import torchvision
from torch import nn

from kronfluence.analyzer import Analyzer, prepare_model

# Define the model and load the trained model weights.
model = torch.nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 10, bias=True),
)
model.load_state_dict(torch.load("model_path.pth"))

# Load the dataset.
train_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)
eval_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)

# Define the task. See the Technical Documentation page for details.
task = MnistTask()

# Prepare the model for influence computation.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)

# Fit all EKFAC factors for the given model.
analyzer.fit_all_factors(factors_name="my_factors", dataset=train_dataset)

# Compute all pairwise influence scores with the computed factors.
analyzer.compute_pairwise_scores(
    scores_name="my_scores",
    factors_name="my_factors",
    query_dataset=eval_dataset,
    train_dataset=train_dataset,
    per_device_query_batch_size=1024,
)

# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="my_scores")

Kronfluence supports various PyTorch features; the following table summarizes the supported features:

The examples folder contains several examples demonstrating how to use Kronfluence.

LogIX

While Kronfluence supports influence function computations on large-scale models like Meta-Llama-3-8B-Instruct, for those interested in running influence analysis on even larger models or with a large number of query data points, our project LogIX may be worth exploring. It integrates with frameworks like HuggingFace Trainer and PyTorch Lightning and is also compatible with many PyTorch features (DDP & FSDP & DeepSpeed).

Contributing

Contributions are welcome! To get started, please review our Code of Conduct. For bug fixes, please submit a pull request. If you would like to propose new features or extensions, we kindly request that you open an issue first to discuss your ideas.

Setting Up Development Environment

To contribute to Kronfluence, you will need to set up a development environment on your machine. This setup includes installing all the dependencies required for linting and testing.

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e ."[dev]"

Style Testing

To maintain code quality and consistency, we run ruff and linting tests on pull requests. Before submitting a pull request, please ensure that your code adheres to our formatting and linting guidelines. The following commands will modify your code. It is recommended to create a Git commit before running them to easily revert any unintended changes.

Sort import orderings using isort:

isort kronfluence

Format code using ruff:

ruff format kronfluence

To view all pylint complaints, run the following command:

pylint kronfluence

Please address any reported issues before submitting your PR.

Acknowledgements

Omkar Dige contributed to the profiling, DDP, and FSDP utilities, and Adil Asif provided valuable insights and suggestions on structuring the DDP and FSDP implementations. I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.

License

This software is released under the Apache 2.0 License, as detailed in the LICENSE file.