clc;clearvars; format long; if exist('History', 'dir') rmdir('History','s') end mkdir('History') addpath('casadi-windows-matlabR2016a-v3.5.5') import casadi.* addpath('rawlings-group-octave-mpctools\') mpc = import_mpctools(); rng(0) %% %###################################################################### % Sampling and estimation of moments %###################################################################### global modelID; modelID = 2 [xInit, v, k, maxStateValues, tStep, tEnd] = modelLoader() timeGrid = 0:tStep:tEnd; NsimGrid = length(timeGrid); Nt = round(1.0/tStep); Nsim = NsimGrid - Nt; writematrix(timeGrid, "Sim" + string(modelID) + "/timeGrid.csv") VSize = size(v); numSpecies = VSize(1); numReacs = VSize(2); % calculate the true CME solution and its moments [CMEExp, CMEVar, CMESke, CMERho] = calcCMEMoments(xInit, v, k, maxStateValues, timeGrid); CMEExp1D = squeeze(CMEExp); CMEVar1D = squeeze(CMEVar); writematrix(CMEExp, "Sim" + string(modelID) + "/CMEExp.csv") writematrix(CMEVar, "Sim" + string(modelID) + "/CMEVar.csv") writematrix(CMESke, "Sim" + string(modelID) + "/CMESke.csv") writematrix(CMERho, "Sim" + string(modelID) + "/CMERho.csv") % calculate the MoM solution with Truncation Closure [tTruncated,yTruncated] = ode15s(@(t,x) odeTruncated(t,x,k),[0 tEnd],[xInit-1 0]); writematrix(tTruncated, "Sim" + string(modelID) + "/tTruncated.csv") writematrix(yTruncated, "Sim" + string(modelID) + "/yTruncated.csv") %% Check how fast SSA is for different numbers of Samples and how precise it is % calculate the moment estimation from different numbers of samples sampleNumbers = [10, 100, 1000]; for i=1:length(sampleNumbers) numSamples = sampleNumbers(i) tic createHistorySample(numSamples, tEnd, v, k, xInit); toc createAndStoreHistorySample(numSamples, tEnd, v, k, xInit); [expEstimation, varEstimation, skeEstimation, rhoEstimation] = estimateMomentsOverTimeGrid(numSamples, numSpecies, timeGrid); writematrix(expEstimation, "Sim" + string(modelID) + "/" + string(numSamples) +"SampleExp.csv") writematrix(varEstimation, "Sim" + string(modelID) + "/" + string(numSamples) +"SampleVar.csv") writematrix(skeEstimation, "Sim" + string(modelID) + "/" + string(numSamples) +"SampleSke.csv") writematrix(rhoEstimation, "Sim" + string(modelID) + "/" + string(numSamples) +"SampleRho.csv") % reshape the variance estimations to be matrices and not tensors varSize = size(varEstimation); varEstimation = reshape(varEstimation, [1,varSize(3)]); ExpErrorSSA = (CMEExp1D - expEstimation).^2; VarErrorSSA = (transpose(CMEVar1D) - varEstimation).^2; size(VarErrorSSA) size(ExpErrorSSA) writematrix(ExpErrorSSA, "Sim" + string(modelID) + "/" + string(numSamples) +"SampleExpError.csv") writematrix(VarErrorSSA, "Sim" + string(modelID) + "/" + string(numSamples) +"SampleVarError.csv") end % load the trajectories that will be followed by the controller expEstimation10 = readmatrix("Sim" + string(modelID) + "/10SampleExp.csv"); varEstimation10 = readmatrix("Sim" + string(modelID) + "/10SampleVar.csv"); % load the trajectories that only rely on samples expEstimation100 = readmatrix("Sim" + string(modelID) + "/100SampleExp.csv"); varEstimation100 = readmatrix("Sim" + string(modelID) + "/100SampleVar.csv"); % transpose the trajectories needed for control expEstimation10 = transpose(expEstimation10); varEstimation10 = transpose(varEstimation10); %% %###################################################################### % Virtual reference consistent with the dynamics %###################################################################### % standard change-of-rate factor is 100.0 rateOfChangePenalites = [1.0, 10.0, 1000.0, 100.0]; for penaltyIndex = 1:length(rateOfChangePenalites) Nx = 2; % state dimension / number of considered moments Nu = 1; % scalar input (the first moment we neglected) % Simulator function. model = mpc.getCasadiIntegrator(@(x,u)ode(x,u,k), tStep, [Nx, Nu], {'x', 'u'}); % Get nonlinear model and also linear approximation. fmpc = mpc.getCasadiFunc(@(x,u)ode(x,u,k), [Nx, Nu], {'x', 'u'},... 'rk4', true(), 'Delta', tStep, 'M', 2,... 'funcname', 'fnonlin'); xsp = [expEstimation10';varEstimation10']; usp = zeros(Nu, Nsim+Nt-1); Q = diag([1000.0,30.0]); R = diag([1.0]); P = diag([0.0,0.0]); % For terminal penalty. S = rateOfChangePenalites(penaltyIndex)*R; % Rate-of-change penalty. % Define stage cost. lcasadi = mpc.getCasadiFunc(@stagecost, {Nx, Nu, Nu, Nx, Nu, [Nx, Nx], [Nu, Nu], [Nu, Nu]}, ... {'x', 'u', 'uprev', 'xsp', 'usp', 'Q', 'R', 'S'}, 'funcname', 'l'); Vf = mpc.getCasadiFunc(@termcost, {Nx, [Nx, Nx]}, {'x', 'P'}, {'Vf'}); % Vf = []; symmetricHoeffdingBound = sqrt(-1*(maxStateValues-1)^2/(2*10)*log(0.05/2.0)); % Build bounds, parameters, and N. lbx = [max(0, (expEstimation10 - symmetricHoeffdingBound))'; zeros(1,NsimGrid)]; ubx = [min(maxStateValues-1, (expEstimation10 + symmetricHoeffdingBound))'; inf*ones(1,NsimGrid)]; lbu = -inf*ones(1,NsimGrid-1); ubu = inf*ones(1,NsimGrid-1); lb = struct('x',lbx, 'u', lbu); ub = struct('x',ubx, 'u', ubu); % Build solver and optimize. x0 = [xInit-1;0]; N = struct('x', Nx, 'u', Nu, 't', Nt); par = struct('xsp', xsp(:,1:Nt+1), 'usp', usp(:,1:Nt),... 'Q', Q, 'R', R, 'P', P, 'S', S, 'uprev', zeros(Nu,1)); controller = mpc.nmpc('f', fmpc, 'l', lcasadi, 'N', N, 'x0', x0, ... 'Vf', Vf, 'par', par, 'verbosity', 1); tic % Start loop. x = NaN(Nx,Nsim+1); x(:,1) = x0; u = NaN(Nu,Nsim); for i = 1:Nsim fprintf('(%3d) ', i); % update time dependent parameters controller.fixvar('x', 1, x(:,i)); controller.par.xsp = xsp(:,i:i+Nt); controller.par.usp = usp(:,i:i+Nt-1); controller.lb = struct('x',lb.x(:,i:i+Nt), 'u',lb.u(:,i:i+Nt-1)); controller.ub = struct('x',ub.x(:,i:i+Nt), 'u',ub.u(:,i:i+Nt-1)); % Apply control law. controller.solve(); fprintf('Controller: %s, ', controller.status); if ~isequal(controller.status, 'Solve_Succeeded') fprintf('\n'); warning('controller failed at time %d', i); break end u(:,i) = controller.var.u(:,1); controller.saveguess(); controller.par.uprev = u(:,i); % Save previous u. % Evolve plant. x(:,i+1) = full(model(x(:,i), u(:,i))); % x(:,i+1) = controller.var.x(:,2); fprintf('\n'); end u = [u,controller.var.u(:,2:end)]; for i = 1:Nt-1 x(:,Nsim+i+1) = full(model(x(:,Nsim+i), u(:,Nsim+i))); end toc % writematrix((0:Nsim)*tStep, "Sim" + string(modelID) + "/tControlled.csv") writematrix(timeGrid, "Sim" + string(modelID) + "/tControlled.csv") writematrix(x, "Sim" + string(modelID) + "/yControlled"+ string(cast(rateOfChangePenalites(penaltyIndex), "int32")) +".csv") writematrix(u, "Sim" + string(modelID) + "/uControlled"+ string(cast(rateOfChangePenalites(penaltyIndex), "int32")) +".csv") %% Error calculations MoMControlResSize = size(x); numControlSteps = MoMControlResSize(2); ExpErrorControl = (CMEExp1D(1:numControlSteps) - x(1,:)).^2; VarErrorControl = (CMEVar1D(1:numControlSteps) - transpose(x(2,:))).^2; writematrix(ExpErrorControl, "Sim" + string(modelID) + "/expError"+ string(cast(rateOfChangePenalites(penaltyIndex), "int32")) +"Control.csv") writematrix(VarErrorControl, "Sim" + string(modelID) + "/varError"+ string(cast(rateOfChangePenalites(penaltyIndex), "int32")) +"Control.csv") end %###################################################################### % Functions % state x is a column array % state change array v is of shape numSpecies x numReacs %###################################################################### %% Load all parameters for a given model function [xInit, v, k, maxStateValues, tStep, tEnd] = modelLoader() global modelID % constant growth, quadratic decrease if (modelID == 0) xInit = 26; k = [1.25,0.25]; v = [1,-2]; tEnd = 10.0; tStep = 0.1; maxStateValues = [31]; % 2D Michaelis-Menten (not fully supported) elseif (modelID == 1) xInit = [25; 10]; k = [0.1, 1.0, 0.1]; v = [-1, 1, 0; -1, 1, 1]; tEnd = 10.0; tStep = 0.1; maxStateValues = xInit; % Reversible dimerization elseif (modelID == 2) xInit = 41; k = [0.05,0.025]; v = [2,-2]; tEnd = 10.0; tStep = 0.1; maxStateValues = [51]; end end %% define different propensities depending on the choice of model % this function takes the number of molecules! as input x % propensities for modelID 1 have to be adjusted depending on the maximal % number of Enzymes in the system. % propensities for modelID 2 have to be adjusted depending on the maximal % number of molecules in the system. function propensities = calculatePropensities(x, k) global modelID; if (modelID == 0) propensities = [k(1); k(2)*x(1)*(x(1)-1)]; elseif (modelID == 1) propensities = [k(1)*x(1)*x(2); k(2)*(10-x(2)); k(3)*(10-x(2))]; elseif (modelID == 2) propensities = [k(1)*(50-x(1))/2.0; k(2)*x(1)*(x(1)-1)]; end end %% Perform 1 Gillespie step function [tNew, xNew] = gillespieStep(t,x,v,k) % extract the number of reactions from v VSize = size(v); numReacs = VSize(2); % get propensities propensities = calculatePropensities(x,k); % draw the index of the next reaction proportionally to the propensities nextReaction = randsample(linspace(1, numReacs, numReacs),1 ,true ,propensities); % draw the exponentially distributed time increment timeIncrement = log(1.0/rand)/sum(propensities); % update state and time and return them tNew = t+timeIncrement; xNew = x+v(:,nextReaction); end %% Perform Gillespie steps until the final time is reached function [tHistory, xHistory] = gillespieHistory(tEnd, v, k, xInit) % extract the number of chemical species from v VSize = size(v); numSpecies = VSize(1); % initialize the index that counts the number of performed Gill. steps index = 1; % initialize the history of states and times and set their first value xHistory = zeros(numSpecies, 1000); tHistory = zeros(1, 1000); % the number of initial molecules is obtained by state-1 xHistory(:, 1) = xInit-1; tHistory(1, 1) = 0.0; % loop over the desired simulation time while(tHistory(1, index) < tEnd) % increase the reaction counter index = index + 1; % if the history arrays are filled, extend them to twice their size if(index == length(tHistory)) lenght(tHistory) xHistory = [xHistory, zeros(numSpecies, length(tHistory))]; tHistory = [tHistory, zeros(1, length(tHistory))]; end % perform 1 gillespie step and store the new state and time [tNew, xNew] = gillespieStep(tHistory(1, index-1), xHistory(:, index-1), v, k); xHistory(:, index) = xNew; tHistory(1, index) = tNew; end % return only the filled entries of the history arrays xHistory = xHistory(:, 1:index); tHistory = tHistory(1, 1:index); end %% Create a fixed number of Gillespie trajectories and store them as csv function [] = createAndStoreHistorySample(numSamples, tEnd, v, k, xInit) % loop over the number of wanted samples for i = 1:numSamples % create one Gillespie trajectory [tHistory, xHistory] = gillespieHistory(tEnd, v, k, xInit); % create a string object for index i iString = string(i); % store the calculated trajectories and name them by index i writematrix(tHistory, strcat('History/tHistory',iString,'.csv')); writematrix(xHistory, strcat('History/xHistory',iString,'.csv')); end end %% Create a fixed number of Gillespie trajectories and do not store them for performance tests function [] = createHistorySample(numSamples, tEnd, v, k, xInit) % loop over the number of wanted samples for i = 1:numSamples % create one Gillespie trajectory [tHistory, xHistory] = gillespieHistory(tEnd, v, k, xInit); end end %% Bisection function to search for the state at the indicated time point function [stateEvaluation] = evaluate1HistoryAt1Time(tHistory, xHistory, time) % initiate the level index indicating the current hierarchy level level = 1; % extract the overall number of time points numTimePoints = length(tHistory); % initialize the current index as the central time point index = round(numTimePoints/2.0); % loop until tHistory(index \le time \le tHistory(index+1) while(or(tHistory(1, index) > time, tHistory(1, index+1) < time)) % increase the bisection level by 1 level = level+1; % if the current index specifies a too large time ... if (tHistory(1, index) > time) % ... decrease the index index = index - max(1,round(numTimePoints/(2^level))); % else if the current index specifies a too small time ... elseif (tHistory(1, index+1) < time) % increase the index index = index + max(1,round(numTimePoints/(2^level))); end end % if the index is correct, return the state stateEvaluation = xHistory(:, index); end %% Estimate moments of all chemical species over a time grid from loaded Gillespie trajectories function [SSAExp, SSAVar, SSASke, SSARho] = estimateMomentsOverTimeGrid(numSamples, numSpecies, timeGrid) % extract the number of time points numTimePoints = length(timeGrid); % allocate a 3D array to store all Gillespie sample information sampleStateEvaluation = zeros(numSpecies, numSamples, numTimePoints); % loop over all considered Gillespie Trajectories for i = 1:numSamples % create a string for the sample index i iString = string(i); % read the two trajectory files for states and times tHistory = readmatrix(strcat('History/tHistory',iString,'.csv')); xHistory = readmatrix(strcat('History/xHistory',iString,'.csv')); % loop over all time points for j = 1:numTimePoints % extract a sample- and time-specific state of all species sampleStateEvaluation(:,i,j) = evaluate1HistoryAt1Time(tHistory, xHistory, timeGrid(j)); end end % allocate storage for the estimated expected value and variance SSAExp = zeros(numSpecies, numTimePoints); SSAVar = zeros(numSpecies, numSpecies, numTimePoints); SSASke = zeros(1, numTimePoints); SSARho = zeros(1, numTimePoints); % loop over all time points for k = 1:numTimePoints % estimate both moments for all species and the specified time SSAExp(:,k) = sum(sampleStateEvaluation(:,:,k),2)/numSamples; SSADiffSingleTime = sampleStateEvaluation(:,:,k) - SSAExp(:,k); for i=1:numSpecies for j=i:numSpecies SSAVar(i,j,k) = sum(SSADiffSingleTime(i,:).*SSADiffSingleTime(j,:)); end end if (numSpecies == 1) SSASke(1,k) = sum(SSADiffSingleTime(1,:).^3); SSARho(1,k) = sum(abs(SSADiffSingleTime(1,:)).^3); end % varEstimation(:,k) = sum((sampleStateEvaluation(:,:,k)-expEstimation(:,k)).^2,2)/(numSamples-1); end SSAVar = SSAVar/(numSamples-1); SSASke = SSASke*numSamples/(numSamples-1)/(numSamples-2); SSARho = SSARho*numSamples/(numSamples-1)/(numSamples-2); end %% function cost = stagecost(x, u, uprev, xsp, usp, Q, R, S) dx = x - xsp; du = u - usp; Deltau = u - uprev; cost = dx'*Q*dx + du'*R*du + Deltau'*S*Deltau; end function dxdt = ode(x, u, k) % Nonlinear ode. global modelID if (modelID == 0) dxdt = [k(1), 2*k(2), -2*k(2), 0 , -2*k(2); k(1), -4*k(2), 4*k(2), -8*k(2), 8*k(2)] ... * [1;x(1,:);x(1,:).^2;x(1,:).*x(2,:);x(2,:)] ... + [0;-4*k(2)] ... *u; elseif (modelID == 2) dxdt = [k(1)*50, 2*k(2)-k(1), -2*k(2), 0 , -2*k(2); k(1)*2*50, -2*k(1)-4*k(2), 4*k(2), -8*k(2), 8*k(2)-2*k(1)] ... * [1;x(1,:);x(1,:).^2;x(1,:).*x(2,:);x(2,:)] ... + [0;-4*k(2)] ... *u; end end function dxdt = odeTruncated(t,x,k) % Nonlinear ode. global modelID if (modelID == 0) dxdt = [k(1), 2*k(2), -2*k(2), 0 , -2*k(2); k(1), -4*k(2), 4*k(2), -8*k(2), 8*k(2)] ... * [1;x(1,:);x(1,:).^2;x(1,:).*x(2,:);x(2,:)]; elseif (modelID == 2) dxdt = [k(1)*50, 2*k(2)-k(1), -2*k(2), 0 , -2*k(2); k(1)*2*50, -2*k(1)-4*k(2), 4*k(2), -8*k(2), 8*k(2)-2*k(1)] ... * [1;x(1,:);x(1,:).^2;x(1,:).*x(2,:);x(2,:)]; end end function Vf = termcost(x, P) % Quadratic terminal penalty. Vf = x'*P*x; end function index = state2index(state, maxStateValues) index = 1; stateSize = size(state); numSpecies = stateSize(1); for d = 1:numSpecies index = index + (state(d)-1)*prod(maxStateValues(1:d-1)); end end function state = index2state(index, maxStateValues) tempIndex = index-1; maxStateValuesSize = size(maxStateValues); numDims = maxStateValuesSize(1); state = zeros(numDims,1); for d=numDims:-1:1 state(d) = floorDiv(tempIndex, prod(maxStateValues(1:d-1))) + 1; tempIndex = mod(tempIndex, prod(maxStateValues(1:d-1))); end end % Due to indexing problems, the number of molecules is derived from the % state by subtracting 1 function CMEMatrix = constructCMEMatrix(v, k, maxStateValues) numStates = prod(maxStateValues); CMEMatrix = zeros(numStates, numStates); VSize = size(v); numSpecies = VSize(1); numReacs = VSize(2); for index = 1:numStates state = index2state(index, maxStateValues); % calculate propensities for the number of molecules (state-1) props = calculatePropensities(state-1, k); for reac=1:numReacs destinationState = state + v(:,reac); upperBoundCheck = min(destinationState <= maxStateValues); lowerBoundCheck = min(destinationState >= ones(numSpecies,1)); if (all(upperBoundCheck) && all(lowerBoundCheck)) destinationIndex = state2index(destinationState, maxStateValues); CMEMatrix(destinationIndex, index) = CMEMatrix(destinationIndex, index) + props(reac); CMEMatrix(index,index) = CMEMatrix(index,index) - props(reac); end end end end function allStates = getStateList(maxStateValues) maxStateValuesSize = size(maxStateValues); numSpecies = maxStateValuesSize(1); numStates = prod(maxStateValues); allStates = zeros(numSpecies, numStates); for i=1:numStates allStates(:,i) = index2state(i, maxStateValues); end end % only works for 1D so far! function [CMEExp, CMEVar, CMESke, CMERho] = calcCMEMoments(initState, v, k, maxStateValues, timeGrid) timePointsSize = size(timeGrid); numTimePoints = timePointsSize(2); VSize = size(v); numSpecies = VSize(1); numStates = prod(maxStateValues); CMEExp = zeros(numSpecies, numTimePoints); CMEVar = zeros(numSpecies, numSpecies, numTimePoints); CMESke = zeros(1, numTimePoints); CMERho = zeros(1, numTimePoints); CMEMatrix = constructCMEMatrix(v, k, maxStateValues); initIndex = state2index(initState, maxStateValues); pInit = zeros(prod(maxStateValues),1); pInit(initIndex) = 1.0; allStates = getStateList(maxStateValues); for timeIndex = 1:numTimePoints pEnd = expm(CMEMatrix*timeGrid(timeIndex))*pInit; %disp(sum(pEnd)) CMEExp(:,timeIndex) = (allStates-1)*pEnd; for index = 1:numStates state = index2state(index, maxStateValues); for i = 1:numSpecies for j = i:numSpecies CMEVar(i, j, timeIndex) = CMEVar(i, j, timeIndex) + pEnd(index)*(state(i)-1-CMEExp(i, timeIndex))*(state(j)-1-CMEExp(j, timeIndex)); end end end for index = 1:numStates state = index2state(index, maxStateValues); CMESke(1,timeIndex) = CMESke(1,timeIndex) + pEnd(index)*(state(1)-1-CMEExp(1, timeIndex)).^3; CMERho(1,timeIndex) = CMERho(1,timeIndex) + pEnd(index)*(abs(state(1)-1-CMEExp(1, timeIndex))).^3; end end end