#include <limits.h>
#include <float.h>
#include "NsSystem.hh"
#include "NsTract.hh"
#define CHECK_RANGE(val, min, max) \
ABORT_UNLESS(Util::isInRange(val, min, max), \
"bad value for '{}': {}", #val, val)
NsTract::NsTract(const string &id,
NsLayer *fromLayer,
NsLayer *toLayer,
const string &type)
: id(id), type(type), fromLayer(fromLayer), toLayer(toLayer),
e3Level(0), lastE3Level(DBL_MAX), lastTimeStep(UINT_MAX)
{
acqLearnRate = props.getDouble(type + '.' + "acqLearnRate");
reactE3Level = props.getDouble(type + '.' + "reactE3Level");
consLearnRate01h = props.getDouble(type + '.' + "consLearnRate01h");
psdDecayRate01h = props.getDouble(type + '.' + "psdDecayRate01h");
cpAmparRemovalRate01h = props.getDouble(type + '.' + "cpAmparRemovalRate01h");
ciAmparInsertionRate01h = props.getDouble(type + '.' + "ciAmparInsertionRate01h");
ciAmparRemovalRate01h = props.getDouble(type + '.' + "ciAmparRemovalRate01h");
baseDepotProb01h = props.getDouble(type + '.' + "baseDepotProb01h");
maxE3DepotProb01h = props.getDouble(type + '.' + "maxE3DepotProb01h");
e3DecayRate01h = props.getDouble(type + '.' + "e3DecayRate01h");
maxPotProb01h = props.getDouble(type + '.' + "maxPotProb01h");
// Sanity check
//
CHECK_RANGE(acqLearnRate, 0.0, 1.0);
CHECK_RANGE(reactE3Level, 0.0, 1.0);
CHECK_RANGE(consLearnRate01h, 0.0, 1.0);
CHECK_RANGE(psdDecayRate01h, 0.0, 1.0);
CHECK_RANGE(cpAmparRemovalRate01h, 0.0, 1.0);
CHECK_RANGE(ciAmparRemovalRate01h, 0.0, 1.0);
CHECK_RANGE(baseDepotProb01h, 0.0, 1.0);
CHECK_RANGE(e3DecayRate01h, 0.0, 1.0);
CHECK_RANGE(maxE3DepotProb01h, 0.0, 1.0);
CHECK_RANGE(maxPotProb01h, 0.0, 1.0);
for (auto fu : fromLayer->units) {
for (auto tu : toLayer->units) {
if (fu != tu) {
connections.push_back(new NsConnection(this, fu, tu));
}
}
}
}
/**
* Given an exponential decay rate for some interval A, calculate the
* equivalent rate for some other interval B.
*
* Exponential decay is: x(t+a) = (1 - rateA) * x(t)
*
* In "increasing exponential decay" the distance to some asymptote S decays
* exponentially with rate (1 - rateA), so
*
* A - x(t+a) = (1 - rateA) * (S - x(t))
*
* For both cases, rateB = 1 - (1 - rateA^(B/A))
*
* "Increasing exponential decay" is also known as "exponential decay
* (increasing form)" or "exponential decay (rising form)".
*/
static double calcExpDecayRate(double rateA, double intervalA, double intervalB)
{
return 1.0 - pow(1.0 - rateA, intervalB/intervalA);
}
/**
* Given a probability of an event happening during a some interval A,
* calculate the equivalent probability for some other interval B.
*
* P(n) = 1 - (1 - P(A))^(B/A)
*
* Note: this exactly the same as calcExpDecayRate, which is not
* surprising, since a constant probability of decay at the particle
* level translates to a constant *rate* of decay at the population
* level.
*/
static double calcProb(double probA, double intervalA, double intervalB)
{
return 1.0 - pow(1.0 - probA, intervalB/intervalA);
}
/**
* Given a constant rate for some interval A, calculate the equivalent rate
* for some other interval B.
*/
static double calcConstantRate(double rateA, double intervalA, double intervalB)
{
return intervalB / intervalA * rateA;
}
/**
* Calculates rates for the current timeStep value
*/
void NsTract::calcRates()
{
consLearnRate = calcExpDecayRate(consLearnRate01h, 1.0, timeStep);
psdDecayRate = calcExpDecayRate(psdDecayRate01h, 1.0, timeStep);
cpAmparRemovalRate = calcExpDecayRate(cpAmparRemovalRate01h, 1.0,
timeStep);
ciAmparInsertionRate = calcConstantRate(ciAmparInsertionRate01h, 1.0,
timeStep);
ciAmparRemovalRate = calcExpDecayRate(ciAmparRemovalRate01h, 1.0,
timeStep);
baseDepotProb = calcProb(baseDepotProb01h, 1.0, timeStep);
e3DecayRate = calcExpDecayRate(e3DecayRate01h, 1.0, timeStep);
maxPotProb = calcProb(maxPotProb01h, 1.0, timeStep);
calcDepotProb();
}
/*
* Calculate the total probability of depotentiation as the combination of
* two independent probabilities: the constitutive depotentiation
* (baseDepotProb) and depotentition due to the E3 enzyme (e3DepotProb)
*
* This function is called whenever e3Level or timeStep changes.
*/
inline void NsTract::calcDepotProb()
{
// Don't waste time if nothing changed
//
if (e3Level != lastE3Level || timeStep != lastTimeStep) {
double e3DepotProb01h = maxE3DepotProb01h * e3Level;
e3DepotProb = calcProb(e3DepotProb01h, 1.0, timeStep);
depotProb = baseDepotProb + e3DepotProb - baseDepotProb * e3DepotProb;
ABORT_IF(depotProb > 1.0, "impossible");
lastE3Level = e3Level;
lastTimeStep = timeStep;
}
}
void NsTract::stimulate(double learnRate, uint numStimCycles,
const char *tag)
{
for (auto c : connections) {
c->stimulate(learnRate, numStimCycles, tag);
}
}
void NsTract::acquire(uint numStimCycles, const char *tag)
{
stimulate(acqLearnRate, numStimCycles, tag);
}
void NsTract::consolidate(uint numStimCycles)
{
stimulate(consLearnRate, numStimCycles, "cons");
}
void NsTract::amparTrafficking()
{
for (auto c : connections) {
c->amparTrafficking(cpAmparRemovalRate,
ciAmparInsertionRate,
ciAmparRemovalRate);
}
}
/**
* Randomly depotentiate some connections
*
*/
void NsTract::depotentiateSome()
{
for (auto c: connections) {
if (c->isPotentiated && (Util::randDouble(0.0, 1.0) < depotProb)) {
c->depotentiate("random");
}
}
}
/**
* Run maintenance processes
*/
void NsTract::maintain()
{
depotentiateSome();
amparTrafficking();
e3Level -= e3DecayRate * e3Level;
// Recalculate depotentiation probability after updating E3 level
//
calcDepotProb();
debugTrace("time: {} tract: {} e3Level: {} depotProb: {}\n",
simTime, id, e3Level, depotProb);
}
/**
* Toggle PSI on or off on all of the tract's connections
* @param state State
*/
void NsTract::togglePsi(bool state)
{
for (auto c: connections) {
c->togglePsi(state);
}
}
/**
* Set E3 level and invoke reactivation processing in all connection that
* are in the Hebbian condition, i.e. from-unit and to-unit are both active
*/
void NsTract::reactivate()
{
// - Activate E3 enzyme. (E3 increases probability of depotentiation)
// TODO: should this be restricted to connections originating from or
// terminating on the units selected in makePattern below? i.e. units
// activated by reactivation.
//
e3Level = reactE3Level;
calcDepotProb();
for (auto c: connections) {
if (c->fromUnit->isActive && c->toUnit->isActive) {
c->reactivate();
}
}
}
/**
* Count number of potentiated connections in the tract
* @return The count
*/
uint NsTract::getNumPotentiated() const
{
uint ret = 0;
for (auto c: connections) {
if (c->isPotentiated) ret++;
}
return ret;
}
/**
* Print header line for the numPotentiate printouts
*/
void NsTract::printNumPotentiatedHdr()
{
infoTrace("time tract id numPotentiated\n");
}
/**
* Print number of potentiated connections in the tract
*/
void NsTract::printNumPotentiated() const
{
infoTrace("{} tract {} {}\n", simTime, id, getNumPotentiated());
}
/**
* Print the state of all of the tract's connections
*/
void NsTract::printState() const
{
printNumPotentiated();
for (auto c: connections) {
c->printState();
}
}
/**
* Generate a string representation of the tract and all its connections
*/
string NsTract::toStr(uint iLvl, const string &iStr) const
{
string ret = fmt::format("{}NsTract[{}]: ",
Util::repeatStr(iStr, iLvl), id);
ret += fmt::format("\n{}acqLearnRate={}",
Util::repeatStr(iStr, iLvl + 1), acqLearnRate);
ret += fmt::format("\n{}consLearnRate={}",
Util::repeatStr(iStr, iLvl + 1), consLearnRate);
for (auto c: connections) {
ret += "\n" + c->toStr(iLvl + 1, iStr);
}
return ret;
}