Cortical model with reinforcement learning drives realistic virtual arm (Dura-Bernal et al 2015)

 Download zip file   Auto-launch 
Help downloading and running models
Accession:183014
We developed a 3-layer sensorimotor cortical network of consisting of 704 spiking model-neurons, including excitatory, fast-spiking and low-threshold spiking interneurons. Neurons were interconnected with AMPA/NMDA, and GABAA synapses. We trained our model using spike-timing-dependent reinforcement learning to control a virtual musculoskeletal human arm, with realistic anatomical and biomechanical properties, to reach a target. Virtual arm position was used to simultaneously control a robot arm via a network interface.
References:
1 . Dura-Bernal S, Zhou X, Neymotin SA, Przekwas A, Francis JT, Lytton WW (2015) Cortical Spiking Network Interfaced with Virtual Musculoskeletal Arm and Robotic Arm. Front Neurorobot 9:13 [PubMed]
2 . Dura-Bernal S, Li K, Neymotin SA, Francis JT, Principe JC, Lytton WW (2016) Restoring Behavior via Inverse Neurocontroller in a Lesioned Cortical Spiking Model Driving a Virtual Arm. Front Neurosci 10:28 [PubMed]
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism:
Cell Type(s): Neocortex M1 L5B pyramidal pyramidal tract GLU cell; Neocortex M1 L2/6 pyramidal intratelencephalic GLU cell; Neocortex M1 interneuron basket PV GABA cell; Neocortex fast spiking (FS) interneuron; Neostriatum fast spiking interneuron; Neocortex spiking regular (RS) neuron; Neocortex spiking low threshold (LTS) neuron;
Channel(s):
Gap Junctions:
Receptor(s): GabaA; AMPA; NMDA;
Gene(s):
Transmitter(s): Gaba; Glutamate;
Simulation Environment: NEURON; Python (web link to model);
Model Concept(s): Synaptic Plasticity; Learning; Reinforcement Learning; STDP; Reward-modulated STDP; Sensory processing; Motor control; Touch;
Implementer(s): Neymotin, Sam [Samuel.Neymotin at nki.rfmh.org]; Dura, Salvador [ salvadordura at gmail.com];
Search NeuronDB for information about:  Neocortex M1 L2/6 pyramidal intratelencephalic GLU cell; Neocortex M1 L5B pyramidal pyramidal tract GLU cell; Neocortex M1 interneuron basket PV GABA cell; GabaA; AMPA; NMDA; Gaba; Glutamate;
#include "MuscleStatusEventHandler.h"

#include <Multibody/MultibodyDyna.h> 

#include <sys/socket.h>
#include <netinet/in.h>
#include <stdio.h>
#include <fcntl.h>      /* To change socket to nonblocking mode */
#include <arpa/inet.h>  /* For inet_pton() */

FILE_STATIC_CALLHACK(MuscleStatusEventHandler);

namespace mf_mbd
{
	const int numMuscles = 18;
	double muscleLengths[numMuscles] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
	bool verboseSend2 = 0;
	
	InitFactoryStaticMembersMacro(MuscleStatusEventHandler, PeriodicEventHandler);

	//class MultibodySystem;
	MuscleStatusEventHandler::MuscleStatusEventHandler()
		:PeriodicEventHandler()
	{
	}

	MuscleStatusEventHandler::MuscleStatusEventHandler(MultibodySystem& system)
		:PeriodicEventHandler()
	{
		setMultibodySystem(system);
	}

	MuscleStatusEventHandler::MuscleStatusEventHandler(MultibodySystem& system, Real interval)
		:PeriodicEventHandler(interval)
	{
		setMultibodySystem(system);
	}

	MuscleStatusEventHandler::~MuscleStatusEventHandler()
	{
		if(_out.is_open()) _out.close();
	}

	bool MuscleStatusEventHandler::handle(Real currTime)
	{
		//Real currTime = getMultibodySystem()->getMultibodyDyna()->getCurrTime();
		if(!needHandle(currTime)) 
			return false;
		PeriodicEventHandler::_setHandledTime(currTime);

		std::vector<double> data;
		data.reserve(_muscles.size() * _varNames.size());
		//data.reserve(_muscles.size() * _varNames.size() + 1);		
		//data.push_back(currTime); //put the time to the first column;

		for(int m = 0; m < _muscles.size(); ++m) {
			LOAMuscle::Ref muscle = _muscles[m];

			for(int i = 0; i < _varNames.size(); ++i) {
				std::string name = _varNames[i];
				if(name == "excitation") {
					data.push_back(muscle->getExcitation());
				}
				else if(name == "activation") {
					data.push_back(muscle->getActivation());
				}
				else if(name == "activationDeriv") {
					data.push_back(muscle->getActivationDeriv());
				}
				else if(name == "force") { //total or tendon force
					data.push_back(muscle->getForce());
				}
				else if(name == "stress") {
					data.push_back(muscle->getStress());
				}
				else if(name == "speed") {
					data.push_back(muscle->getSpeed());
				}
				else if(name == "activeFiberForce") { 
					data.push_back(muscle->getActiveFiberForce());
				}
				else if(name == "passiveFiberForce") { 
					data.push_back(muscle->getPassiveFiberForce());
				}
				else if(name == "normalizedFiberLength") { 
					data.push_back(muscle->getNormalizedFiberLength());
				}
				else if(name == "RelativeMaxContraction") {
					data.push_back(muscle->getForce()/muscle->getMaxIsometricForce());
				}			
				else if(name == "fiberLengthDeriv") {
					data.push_back(muscle->getFiberLengthDeriv());
				}
				else if(name == "isometricFiberForce") {
					data.push_back(muscle->computeFiberIsometricForce(muscle->getActivation(),muscle->getFiberLength()));
				}			
				else if(name == "capacity") {
					data.push_back(muscle->getCapacity());
				}			
				//else if(name = "momentArms") {
				//	const std::map<ArticulatedJoint::Ref,VecN>& getMomentArms() {return _momArms;} 
				//	data.push_back(muscle->getMomentArms());
				//}
				else{
					CHK_ERR(false, "Can not find muscle variable with name " + name);
					continue;
				}
			}
			
			// store all muscle normalized lengths in a vector that can be sent via udp
			// stored in the following order: DELT1  DELT2 DELT3 Infraspinatus Latissimus_dorsi_1 Latissimus_dorsi_2 Latissimus_dorsi_3 Teres_minor PECM1 PECM2 PECM3 Coracobrachialis TRIlong TRIlat TRImed BIClong BICshort BRA 
			//muscleLengths[m] = muscle->getNormalizedFiberLength();
			//muscleLengths[m] = muscle->getOptimalFiberLength(); // 
			muscleLengths[m] = muscle->getFiberLength();
		}

		// send packets
		for(int m = 0; m < numMuscles; ++m) {
			std::cout << muscleLengths[m] << "  ";
		}
		std::cout << std::endl;
		
		if (verboseSend2) {
			printf("\nSent muscle lengths to stdout\n");
		}
		return true;
	}

	void MuscleStatusEventHandler::initBeforeRun()
	{
		EventHandler::initBeforeRun();
	}

	void MuscleStatusEventHandler::readFromXML(DOMNode* node)
	{
		XMLDOM::DOMElement* tmpNode = NULL;
		
		/*
		tmpNode = XMLDOM::getFirstChildElementByTagName(node,"PortSend2");
		CHK_ERR(tmpNode, "Can not find socket port");
		portSend2 =  XMLDOM::getValueAsType<int>(tmpNode);
		*/
		
		tmpNode = XMLDOM::getFirstChildElementByTagName(node,"Interval");
		double interval = XMLDOM::getValueAsType<double>(tmpNode);
		this->setEventInterval(interval);
	
		tmpNode = XMLDOM::getFirstChildElementByTagName(node,"LOAMuscleForceSubsystem");
		CHK_ERR(tmpNode, "Can not find LOAMuscleForceSubsystem node");
		std::string sysName = XMLDOM::getAttribute(tmpNode, "name");

		ForceSubsystem* fsys = getMultibodySystem()->getForceSubsystem(sysName);
		LOAMuscleForceSubsystem* mfsys = dynamic_cast<LOAMuscleForceSubsystem*>(fsys);
		std::string err_msg = "The ForceSubsystem " + sysName + " must be a LOAMuscleForceSubsystem";
		CHK_ERR(mfsys, err_msg);

		_msclSys = mfsys;

		bool needAll = false;
		tmpNode = XMLDOM::getFirstChildElementByTagName(node,"MuscleNames");
		CHK_ERR(tmpNode, "Can not find MuscleNames node");
		std::string all = XMLDOM::getAttribute(tmpNode,"all");
		if(!all.empty()) {
			boost::to_lower(all);
			if(all == "true") {
				needAll = true;
			}
		}

		if(!needAll) { //read all muscle names
			std::string str = XMLDOM::getTextAsStdStringAndTrim(tmpNode);
			mf_utils::Tokenizer tokens(str);
			for(int i = 0; i < tokens.size(); ++i) {
				std::string name = tokens[i];
				//muscle* msl = _msclSys->getMuscle(name);
				LOAMuscle* msl = dynamic_cast<LOAMuscle*>(_msclSys->getForceMatt(name));
				if(!msl) CHK_ERR(tmpNode, "Can not find muscle with name " + name);
				_muscles.push_back(msl);
			}
		}
		else { //all the muscles
			//_muscles = _msclSys->getMuscles();
			_muscles.resize(_msclSys->getNumForceMatts());
			for(int i = 0; i < _muscles.size(); ++i) _muscles[i] = _msclSys->getForceMattTrueType(i);
		}

		needAll = false;
		tmpNode = XMLDOM::getFirstChildElementByTagName(node,"MuscleVars");
		CHK_ERR(tmpNode, "Can not find MuscleVars node");
		all = XMLDOM::getAttribute(tmpNode,"all");
		if(!all.empty()) {
			boost::to_lower(all);
			if(all == "true") {
				needAll = true;
			}
		}

		if(!needAll) { //read all muscle names
			std::string str = XMLDOM::getTextAsStdStringAndTrim(tmpNode);
			mf_utils::Tokenizer tokens(str);
			for(int i = 0; i < tokens.size(); ++i) {
				std::string name = tokens[i];
				if(name == "excitation") {
				}
				else if(name == "activation") {
				}
				else if(name == "activationDeriv") {
				}
				else if(name == "force") { //total or tendon force
				}
				else if(name == "stress") {
				}
				else if(name == "speed") {
				}
				else if(name == "activeFiberForce") { 
				}
				else if(name == "passiveFiberForce") { 
				}
				else if(name == "normalizedFiberLength") { 
				}
				else if(name == "RelativeMaxContraction") { 
				}			
				else if(name == "fiberLengthDeriv") { 
				}
				else if(name == "isometricFiberForce") { 
				}
				else if(name == "capacity")
				{
				}
				//else if(name = "momentArms") { 
				//}
				else{
					continue;
				}

				_varNames.push_back(name);
			}
		}
		else { //all the vars
			_varNames.push_back("excitation");
			_varNames.push_back("activation");
			_varNames.push_back("activationDeriv");
			_varNames.push_back("force");
			_varNames.push_back("stress");
			_varNames.push_back("speed");
			_varNames.push_back("activeFiberForce");
			_varNames.push_back("passiveFiberForce");
			_varNames.push_back("normalizedFiberLength");
			_varNames.push_back("RelativeMaxContraction");
			_varNames.push_back("fiberLengthDeriv");
			_varNames.push_back("isometricFiberForce");
			_varNames.push_back("capacity");
		}

		tmpNode = XMLDOM::getFirstChildElementByTagName(node,"PntOutput");
		if(tmpNode) {
			_outFileName = XMLDOM::getAttribute(tmpNode,"name");
		}

		if(_outFileName.empty()) {
			_outFileName = _msclSys->getName() + "_status.pnt";
		}

		_out.open(_outFileName.c_str());
		if(!_out.is_open()) CL_ERR("Can not open file " + _outFileName);

	}

} //end namespace 

Loading data, please wait...