Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More efficient sampling from KroneckerMultiTaskGP #2460

Closed
wants to merge 11 commits into from

Conversation

slishak-PX
Copy link
Contributor

@slishak-PX slishak-PX commented Aug 5, 2024

Motivation

See #2310 (comment)

import torch
from botorch.models import KroneckerMultiTaskGP

n_inputs = 10
n_tasks = 4
n_train = 2048
n_test = 1
device = torch.device("cuda:0")

train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device)

test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device)

gp = KroneckerMultiTaskGP(train_x, train_y)

posterior = gp.posterior(test_x)
posterior.rsample(torch.Size([256, 1]))

The final line requires allocation of 128GB of GPU memory, because of the call to torch.cholesky_solve with B shaped (256, 1, 8192, 1) and L shaped (8192, 8192).

By moving the largest batch dimension to the final position, we should achieve a more efficient operation.

Also fix docstring for MultitaskGPPosterior.

Have you read the Contributing Guidelines on pull requests?

Yes

Test Plan

Passes unit tests (specifically test_multitask.py).

Benchmarking results:
image
image

Related PRs

N/A

@facebook-github-bot
Copy link
Contributor

Hi @slishak-PX!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Copy link

codecov bot commented Aug 10, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.98%. Comparing base (e29e30a) to head (6417409).
Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2460   +/-   ##
=======================================
  Coverage   99.98%   99.98%           
=======================================
  Files         193      193           
  Lines       17062    17072   +10     
=======================================
+ Hits        17059    17069   +10     
  Misses          3        3           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this up. Would love to see some benchmarks for this. Ultimately, this is something that ideally should be handled upstream in gpytorch, could you add a comment to that extent?

botorch/posteriors/multitask.py Outdated Show resolved Hide resolved
botorch/posteriors/multitask.py Outdated Show resolved Hide resolved
botorch/posteriors/multitask.py Outdated Show resolved Hide resolved
@Balandat
Copy link
Contributor

@slishak-PX have you had a chance to run some benchmarks on this?

@slishak-PX
Copy link
Contributor Author

@Balandat sorry, I will prioritise this as soon as we have the CLA signed, hopefully in the next week or two

@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Sep 29, 2024
@slishak-PX
Copy link
Contributor Author

Added benchmarking results. Uses the following code:

benchmark.py (run on the code in this PR, and BoTorch 0.12.0):

import pickle as pkl

import botorch
import torch
from tqdm import tqdm
from botorch.models import KroneckerMultiTaskGP

device = torch.device("cuda:0")


def get_data(n_inputs=10, n_tasks=4, n_train=128, n_test=1, seed=50):
    torch.manual_seed(seed)
    train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device)
    train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device)
    test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device)

    return train_x, train_y, test_x


def instantiate_and_sample(train_x, train_y, test_x, n_samples=1):
    with torch.no_grad():
        gp = KroneckerMultiTaskGP(train_x, train_y)
        posterior = gp.posterior(test_x)
        posterior.rsample(torch.Size([n_samples]))


def profile(func, *args, **kwargs):
    torch.cuda.reset_peak_memory_stats(device=device)
    m0 = torch.cuda.max_memory_allocated(device=device)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    func(*args, **kwargs)
    end.record()

    torch.cuda.synchronize()
    time = start.elapsed_time(end)

    m1 = torch.cuda.max_memory_allocated(device=device)
    torch.cuda.empty_cache()

    memory = (m1 - m0) / 1024**3

    return memory, time


if __name__ == "__main__":
    fname = botorch.__version__ + "_results.pkl"
    print(fname)

    n_tasks_list = [2, 4, 8]
    n_train_list = [32, 128, 512]
    n_test_list = [1, 8, 64]
    n_samples_list = [1, 4, 16, 64, 256]

    results = []
    for n_tasks in tqdm(n_tasks_list, desc="n_tasks"):
        for n_train in tqdm(n_train_list, leave=False, desc="n_train"):
            for n_test in tqdm(n_test_list, leave=False, desc="n_test"):
                train_x, train_y, test_x = get_data(
                    n_tasks=n_tasks, n_train=n_train, n_test=n_test
                )
                for n_samples in tqdm(n_samples_list, leave=False, desc="n_sample"):
                    memory = []
                    time = []
                    for i in range(10):
                        try:
                            m, t = profile(
                                instantiate_and_sample,
                                train_x,
                                train_y,
                                test_x,
                                n_samples,
                            )
                        except:
                            print("Failed!")
                            print(
                                {
                                    "n_tasks": n_tasks,
                                    "n_train": n_train,
                                    "n_test": n_test,
                                    "n_samples": n_samples,
                                }
                            )
                            raise

                        if i > 0:
                            memory.append(m)
                            time.append(t)

                    results.append(
                        {
                            "n_tasks": n_tasks,
                            "n_train": n_train,
                            "n_test": n_test,
                            "n_samples": n_samples,
                            "memory": memory,
                            "time": time,
                        }
                    )

    with open(fname, "wb") as f:
        pkl.dump(results, f)

Analysis notebook:

import pickle as pkl

import numpy as np
import pandas as pd
import plotly.express as px


with open("Unknown_results.pkl", "rb") as f:
    results = pkl.load(f)
df_new = pd.DataFrame(results)
df_new["version"] = "PR 2460"

with open("0.12.0_results.pkl", "rb") as f:
    results = pkl.load(f)
df_old = pd.DataFrame(results)
df_old["version"] = "BoTorch 0.12.0"

df = pd.concat([df_new, df_old])
df["memory_mean"] = df["memory"].apply(np.mean)
df["memory_std"] = df["memory"].apply(np.std)
df["time_mean"] = df["time"].apply(np.mean)
df["time_std"] = df["time"].apply(np.std)

px.line(
    df,
    x="n_samples",
    y="memory_mean",
    error_y="memory_std",
    facet_col="n_tasks",
    facet_row="n_test",
    color="n_train",
    line_dash="version",
    log_x=True,
    log_y=True,
    width=800,
    height=800,
)

px.line(
    df,
    x="n_samples",
    y="time_mean",
    error_y="time_std",
    facet_col="n_tasks",
    facet_row="n_test",
    color="n_train",
    line_dash="version",
    log_x=True,
    log_y=True,
    width=800,
    height=800,
)

@slishak-PX slishak-PX marked this pull request as ready for review September 30, 2024 17:30
@facebook-github-bot
Copy link
Contributor

@Balandat has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, thanks a lot for contributing this change and the comprehensive benchmarks. Seems like there is no downside to using this from a perf perspective and the logic is also sufficiently straightforward so I'm not worried about tech debt.

The only ask I have before merging this in (besides fixing the flake8 lint) is to write a short unittest.

cc @jandylin, @SebastianAment re sampling from Kronecker structured GPs and interesting matrix solve efficiency gains...

botorch/posteriors/multitask.py Outdated Show resolved Hide resolved
@slishak-PX
Copy link
Contributor Author

You were absolutely right to ask for a unit test - the implementation was not entirely correct, although I think in the context of the use in MultitaskGPPosterior the error was inconsequential. Should all be corrected now. I've also re-run the benchmarks and there is no change.

@facebook-github-bot
Copy link
Contributor

@Balandat has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@Balandat merged this pull request in 8924d1b.

@Balandat
Copy link
Contributor

Balandat commented Oct 1, 2024

Many thanks for the contribution, @slishak-PX !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants