From bbaf981195f703ab6a1e2fb8c53c2d900c3e07e4 Mon Sep 17 00:00:00 2001 From: jylim2016 Date: Sat, 5 Nov 2016 02:24:07 +0800 Subject: [PATCH] warm_start for SGD --- fastFM/ffm.pyx | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/fastFM/ffm.pyx b/fastFM/ffm.pyx index 56a29e8..e405c21 100644 --- a/fastFM/ffm.pyx +++ b/fastFM/ffm.pyx @@ -128,11 +128,20 @@ def ffm_sgd_fit(fm, X, double[:] y): pt_param = PyCapsule_GetPointer(param, "FFMParam") # allocate the coefs - cdef double w_0 = 0 - cdef np.ndarray[np.float64_t, ndim=1, mode='c'] w =\ - np.zeros(n_features, dtype=np.float64) - cdef np.ndarray[np.float64_t, ndim=2, mode='c'] V =\ - np.zeros((fm.rank, n_features), dtype=np.float64) + # allocate the coefs + cdef double w_0 + cdef np.ndarray[np.float64_t, ndim=1, mode='c'] w + cdef np.ndarray[np.float64_t, ndim=2, mode='c'] V + + if fm.warm_start: + w_0 = 0 if fm.ignore_w_0 else fm.w0_ + w = np.zeros(n_features, dtype=np.float64) if fm.ignore_w else fm.w_ + V = np.zeros((fm.rank, n_features), dtype=np.float64)\ + if fm.rank == 0 else fm.V_ + else: + w_0 = 0 + w = np.zeros(n_features, dtype=np.float64) + V = np.zeros((fm.rank, n_features), dtype=np.float64) cffm.ffm_sgd_fit(&w_0, w.data, V.data, pt_X, &y[0], pt_param) @@ -147,11 +156,20 @@ def ffm_fit_sgd_bpr(fm, X, np.ndarray[np.float64_t, ndim=2, mode='c'] pairs): pt_param = PyCapsule_GetPointer(param, "FFMParam") #allocate the coefs - cdef double w_0 = 0 - cdef np.ndarray[np.float64_t, ndim=1, mode='c'] w =\ - np.zeros(n_features, dtype=np.float64) - cdef np.ndarray[np.float64_t, ndim=2, mode='c'] V =\ - np.zeros((fm.rank, n_features), dtype=np.float64) + # allocate the coefs + cdef double w_0 + cdef np.ndarray[np.float64_t, ndim=1, mode='c'] w + cdef np.ndarray[np.float64_t, ndim=2, mode='c'] V + + if fm.warm_start: + w_0 = 0 if fm.ignore_w_0 else fm.w0_ + w = np.zeros(n_features, dtype=np.float64) if fm.ignore_w else fm.w_ + V = np.zeros((fm.rank, n_features), dtype=np.float64)\ + if fm.rank == 0 else fm.V_ + else: + w_0 = 0 + w = np.zeros(n_features, dtype=np.float64) + V = np.zeros((fm.rank, n_features), dtype=np.float64) cffm.ffm_sgd_bpr_fit(&w_0, w.data, V.data, pt_X, pairs.data, pairs.shape[0], pt_param)