-
Notifications
You must be signed in to change notification settings - Fork 7
/
rlfit.m
82 lines (69 loc) · 2.25 KB
/
rlfit.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
function [beta, LL, Q] = rlfit(Qfun, choice, outcome, lb, ub, niter, ...
ispresentx)
% fits a reinforcement learning model to a multi-option choice paradigm
% inputs:
%
% Qfun is a handle to a function that accepts a vector of parameters, a
% vector of choice indices, and a vector of outcomes, and returns
% the action values, Q
%
% choice is a vector, one entry per trial, the index of the chosen option
%
% outcome is a set of outcomes for each trial
%
% lb and ub are vectors of upper and lower bounds on parameters
%
% niter (optional) is the number of random restarts to use in fitting
%
% ispresentx (optional) is a set-options x number of trials matrix,
% indicating the options present on each trial
% (1 if present, 0 if not present)
%
% outputs:
%
% beta is the vector of fitted model parameters; first entry is the softmax
% inverse temperature, followed by parameters of the model
%
% LL is the log likelihood of the data (choice, outcome) given beta
%
% Q is a trials x options matrix of action values
if (~exist('ispresentx', 'var')) || isempty(ispresentx)
ispresentx = true;
end
if ~exist('niter', 'var')
niter = 10;
end
if ~exist('lb', 'var')
lb = [];
end
if ~exist('ub', 'var')
ub = [];
end
% rescale outcomes to offer better fit convergence
outmean = mean(outcome(:));
outstd = std(outcome(:));
z = bsxfun(@minus, outcome, outmean)/outstd;
% first, define a log likelihood function that takes as its input a vector
% of parameters, the first of which is the inverse temperature of the
% softmax
LLfun = @(x, choice, z) LL_softmax(x(1) * Qfun(x(2:end), choice, z) .* ...
ispresentx, choice);
% then define a function to be minimized (the total negative log
% likelihood)
fitfun = @(beta)(-1)*sum(LLfun(beta, choice, z));
% now combine upper and lower bounds on softmax temp with upper and lower
% bounds on other parameters
lb = [1e-5, lb]; %lower bounds
ub = [10, ub]; %upper bounds
% optmize to fit model
w = warning ('off','all');
options = optimset('Display', 'off');
[beta,fval]=multmin(fitfun, lb, ub, niter, options);
warning(w);
% return log likelihood
LL=-fval;
% get action values
Q = Qfun(beta(2:end), choice, z);
% undo scaling
Q = Q*outstd + outmean; % rescale appropriately
beta(1) = beta(1)/outstd;