Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continual learning with KAN #227

Open
lukmanulhakeem97 opened this issue May 22, 2024 · 8 comments
Open

Continual learning with KAN #227

lukmanulhakeem97 opened this issue May 22, 2024 · 8 comments

Comments

@lukmanulhakeem97
Copy link

Regardless of computaion, do KAN at present able perform continual learning on 2D input data (or its flattened 1d vector features)? if not what is the main challenge KAN have with it?

@ASCIIJK
Copy link

ASCIIJK commented May 23, 2024

Regardless of computaion, do KAN at present able perform continual learning on 2D input data (or its flattened 1d vector features)? if not what is the main challenge KAN have with it?

I have made the experiments. It seems that KAN is hard to learn 2D input data without forgetting. Specifically, we use a mixed 2D Gaussian distribution with 5 peaks to construct a CL tasks, which shows as bellow:
Ground_task5
And the model learns each peak with 50,000 data points. For exemple, the data points of first task is showed as bellow:
Pred_task0
Then, we get the results after 5 tasks:
Pred_task4
This forgetting issue occurrs in each task, such as task 1:
Pred_task1
PS: We use the model: "model = KAN(width=[2, 16, 1], grid=5, k=6, noise_scale=0.1, bias_trainable=False, sp_trainable=False, sb_trainable=False)". And we have made sure that the loss is down to zero at each task. So you can find a perfect peak as the same as training data. We think that KAN maybe hard to learn the high-dimensional data without forgetting?

@fangkuoyu
Copy link

fangkuoyu commented May 27, 2024

Regardless of computaion, do KAN at present able perform continual learning on 2D input data (or its flattened 1d vector features)? if not what is the main challenge KAN have with it?

I have tried MNIST for continual learning, but I haven't obtained any positive results yet.

There are three stages in the process:

Stage_1: Train MNIST characters (0,1,2,3) from the train set, and test MNIST characters (0,1,2,3) from the test set; (establishing a baseline)

Stage_2: Train MNIST characters (4,5,6) and test MNIST characters (0,1,2,3,4,5,6); (with the hope that the model will memorize the results of Stage_1)

Stage_3: Train MNIST characters (7,8,9) and test MNIST characters (0,1,2,3,4,5,6,7,8,9); (with the hope that the model will memorize the results of Stage_1 and Stage_2)

I have tried two approaches:

Method_1: The original image (28x28) is resized to (7x7) and then flattened to (49). The KAN size is (49, 10, 10) with grid =3 and k=3. The training process is the same as Tutorial Example 7. Ref

Method_2: The original image (28x28) is mapped to (64) by nn.Linear(28*28,64) under PyTorch. The KAN size is (64, 16, 10) with grid=3 and k=3. The training process is the same as PyTorch Training. Ref

Roughly speaking both methods can achieve train accuracy > 90% in all three stages, but test accuracy degrades as (Stage 1 > 90%, Stage 2 ~ 40%, Stage 3 ~ 20%). I have also tried to change the grid size up to 100, but no significant improvement on test accuracy.

I am wondering if any width/grid/k setting under a memory/computation budget could reach a better accuracy of continual learning on MNIST.

Bytheway, some implementations of conv KAN as layer-drop-in-replacement don't provide the setting of 'bias_trainable=False, sp_trainable=False, sb_trainable=False' which limits the study of continual learning, e.g., on CIFAR-10.

@ASCIIJK
Copy link

ASCIIJK commented May 27, 2024

Regardless of computaion, do KAN at present able perform continual learning on 2D input data (or its flattened 1d vector features)? if not what is the main challenge KAN have with it?

I have tried MNIST for continual learning, but I haven't obtained any positive results yet.

There are three stages in the process:

Stage_1: Train MNIST characters (0,1,2,3) from the train set, and test MNIST characters (0,1,2,3) from the test set; (establishing a baseline)

Stage_2: Train MNIST characters (4,5,6) and test MNIST characters (0,1,2,3,4,5,6); (with the hope that the model will memorize the results of Stage_1)

Stage_3: Train MNIST characters (7,8,9) and test MNIST characters (0,1,2,3,4,5,6,7,8,9); (with the hope that the model will memorize the results of Stage_1 and Stage_2)

I have tried two approaches:

Method_1: The original image (28x28) is resized to (7x7) and then flattened to (49). The KAN size is (49, 10, 10) with grid =3 and k=3. The training process is the same as Tutorial Example 7. Ref

Method_2: The original image (28x28) is mapped to (64) by nn.Linear(28*28,64) under PyTorch. The KAN size is (64, 16, 10) with grid=3 and k=3. The training process is the same as PyTorch Training. Ref

Roughly speaking both methods can achieve train accuracy > 90% in all three stages, but test accuracy degrades as (Stage 1 > 90%, Stage 2 ~ 40%, Stage 3 ~ 20%). I have also tried to change the grid size up to 100, but no significant improvement on test accuracy.

I am wondering if any width/grid/k setting under a memory/computation budget could reach a better accuracy of continual learning on MNIST.

Bytheway, some implementations of conv KAN as layer-drop-in-replacement don't provide the setting of 'bias_trainable=False, sp_trainable=False, sb_trainable=False' which limits the study of continual learning, e.g., on CIFAR-10.

Yes, I get the same results as yours. KAN seems to achieve continue learning only on some simple tasks, such as 1-D data fitting and 2-D scatter classification. And there are many limitations such as grid size and the number of layers. I find that it achieve continue learning on 2-D scatter classification with large grid size (at least 50?) and no intermediate hidden layers. But if you add the intermediate hidden layers or use the smaller grid size, the model forgets very fast in subsequent tasks. Maybe I need to try more combination of hyper-parameters.

@rafaelcp
Copy link

rafaelcp commented May 28, 2024

As I hypothesized in the efficient-kan repo, it seems KAN cannot do continual learning in more than 1 dimension if the output depends on more than 1 of them, as it cannot isolate ranges in groups of values the same way it does over single values. I did this experiment to show it:
image
Leftmost image: dataset composed of 10000 pixels (100x100). The output depends on X and Y, jointly.
Other images: model prediction after training on each of the 5 rows, starting from the bottom one. Each row is composed by 2000 pixels (20x100). Notice how it generalizes the blob to the entire columns after each task, but erases it on the next task.

However, this is a [2,1] KAN without any hidden layers, and it turned out it couldn't learn it even in batch mode. So I tried a [2,5,1] KAN, which learned it in batch mode to a reasonable degree. Unfortunately, no success with continual learning:
image

I'm using SGD with all biases turned off (Adam can mess things up in continual learning due to running statistics and momentum). Also, I'm using FastKAN.

@ASCIIJK
Copy link

ASCIIJK commented May 28, 2024

As I hypothesized in the efficient-kan repo, it seems KAN cannot do continual learning in more than 1 dimension if the output depends on more than 1 of them, as it cannot isolate ranges in groups of values the same way it does over single values. I did this experiment to show it: image Leftmost image: dataset composed of 10000 pixels (100x100). The output depends on X and Y, jointly. Other images: model prediction after training on each of the 5 rows, starting from the bottom one. Each row is composed by 2000 pixels (20x100). Notice how it generalizes the blob to the entire columns after each task, but erases it on the next task.

However, this is a [2,1] KAN without any hidden layers, and it turned out it couldn't learn it even in batch mode. So I tried a [2,5,1] KAN, which learned it in batch mode to a reasonable degree. Unfortunately, no success with continual learning: image

I'm using SGD with all biases turned off (Adam can mess things up in continual learning due to running statistics and momentum). Also, I'm using FastKAN.

I have also found this issue. KAN with hidden layer cannot achieve continue learning. And I have made the experiments on 2-D scatter classification. It constructs 25 Gaussian distributions with different means. And model (KAN(width=[2, 25], grid=50)) learns 5 kinds of 2D Gaussian distributions at each task. The results shows that KAN with more than 50 grids performs well in continue learning. I reskon that KAN with more grids reduce the importance of each activative function to avoid the key function from rewriting. And I add just one hidden layer into the model (KAN(width=[2, 25, 25], grid=50)). It performs catastrophic forgetting. Therefore, this robustness in one-layer KAN is treated as achieving continue learning. Actually, this coincidence may be very fragile in multilayer KAN. In another view, KAN with more grids has much more parameters. It seems to using a large model to fit the small dataset, which a large number of parameters are redundant, thereby maintaining a easy decision boundary for old tasks. In the end, it seems that most of efficient KAN library could not achieve continue learning, even on 1D data fitting task.

@fangkuoyu
Copy link

@ASCIIJK @rafaelcp The paper of KAN describes continual learning as follows:

KANs have local plasticity and can avoid catastrophic forgetting by leveraging the locality of splines. The idea is simple: since spline bases are local, a sample will only affect a few nearby spline coefficients, leaving far-away coefficients intact (which is desirable since faraway regions may have already stored information that we want to preserve).

I think that the above statements depend on the distribution of input data. In the case of modeling peaks in 1D, the distribution of peaks is sparse so that the locality will stay true. But, in the case of modeling peaks in 2D or MNIST in 2D, the distribution of target features is dense in the space so modeling on new data will affect modeling on old data.

The paper of KAN also says that:

Here we simply present our preliminary results on an extremely simple example, to demonstrate how one could possibly leverage locality in KANs (thanks to spline parametrizations) to reduce catastrophic forgetting. However, it remains unclear whether our method can generalize to more realistic setups, especially in high-dimensional cases where it is unclear how to define “locality”.

Based on our experiments, I think that KAN for continual learning holds on special domains, but not for general purposes.

@rafaelcp
Copy link

The 2D Gaussians domain is also sparse, so not a matter of sparse x dense. It is a matter of dependency between variables (which, unfortunately, is the norm).

@KindXiaoming
Copy link
Owner

Hi, just want to draw your attention to this paper which seems quite relevant:
Distal Interference: Exploring the Limits of Model-Based Continual Learning

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants