#include "../include/specie.h"

void species::set_params(string m_name, vector<string> m_traitsId, vector< int > m_ageClasses,
               int m_nLoci, vector<int> m_nAlleles, vector<int> m_ploidy, vector<int> m_chromoSplit,
               vector<trait> m_traits,
               vector<vector <int>> m_inherit, vector<double> m_crossRate,
               vector<double> m_mutationRate,  vector<list<fluxMat>> m_fluxMat, vector<double> m_selectMatrix,
               int m_cap, double m_growth, double m_selfing, int m_maxAge, string m_assortMating)
{
    name = m_name;
    traitsId = m_traitsId;
    ageClasses = m_ageClasses;
    nLoci = m_nLoci;
    nAlleles = m_nAlleles;
    ploidy = m_ploidy;
    //cout << "GP" << gammaParam << endl;
    chromoSplit = m_chromoSplit;

    traits = m_traits;
    //inherit = m_inherit;
    crossRate = m_crossRate;
    mutationRate = m_mutationRate;
    initVar = -1;
    flux = m_fluxMat;
    //fluxTimes =
    vector<string> propagOrder = {"seed", "pollen"};
    for (int s=0;s<propagOrder.size();s++) // female then male
    {
        fluxMat matTimeOne;
        int pf_pos = 0;
        //vector<int> timesForPropagule;
        for (auto pflux = flux[s].begin(); pflux != flux[s].end(); ++pflux)
        {
            //cout << "TIMES" << pflux->get_time() << endl;
            int t = pflux->get_time();
            if (t == 1)
            {
                matTimeOne = *pflux;
            }
            pf_pos++;
            //timesForPropagule.push_back(t);
        }
        //fluxTime.push_back(timesForPropagule);
        currentTimeFlux.push_back(matTimeOne);
    }



    selectMatrix = m_selectMatrix;
    cap = m_cap;
    growth = m_growth;
    selfing =  m_selfing;
    maxAge = m_maxAge;
    assortMating = m_assortMating;
    assortSuffix =  "-phenoWithE";

    assortThres = 0.0;

    int sumLocPos = 0;
    for(int i = 0; i < nLoci; i++)
    {
        locIndex.push_back(sumLocPos);
        for(int j = sumLocPos; j < sumLocPos+ploidy[i]; j++)
            revLocIndex.push_back(i);
        sumLocPos += ploidy[i];
    }

    //Here we set a mask to select mother an father traits.
    //Once crossings are over, there is two possibilities of gametes (two halves)
    //if(Random() < 0.5)
    mask_male = valarray<bool>((bool)0, sumLocPos);
    mask_female = valarray<bool>((bool)0, sumLocPos);


    for(int loc = 0; loc < nLoci; loc++) //For each locus
    {
        for(int j = 0; j < ploidy[loc]; j++) //For each copy of the locus
        {
            if (loc < m_inherit[0].size())
            {
                if(j < m_inherit[0][loc])
                {
                    mask_female[locIndex[loc] + j] = 1;
                }
                else
                {
                    mask_male[locIndex[loc] + j] = 1;
                }
            }
            else
            {
                cout << "BUG:loc >= m_inherit[0].size()" << endl;
            }
        }
    }
    //sumLocPos of both valarry should be 1 on all spots

    gSize = sumLocPos; // Also equal to revLocIndex.size()
    //Link trait and species (Could find a better design)
}

species::species(string m_name, vector<string> m_traitsId, vector< int > m_ageClasses,
               int m_nLoci, vector<int> m_nAlleles, vector<int> m_ploidy, vector<int> m_chromoSplit,
               vector<trait> m_traits,
               vector<vector <int>> m_inherit, vector<double> m_crossRate,
               vector<double> m_mutationRate,  vector<list<fluxMat>> m_fluxMat, vector<double> m_selectMatrix,
               int m_cap, double m_growth, double m_selfing, int m_maxAge, string m_assortMating)
{

    set_params(m_name, m_traitsId, m_ageClasses, m_nLoci, m_nAlleles, m_ploidy, m_chromoSplit,
               m_traits, m_inherit, m_crossRate, m_mutationRate,  m_fluxMat, m_selectMatrix,
               m_cap, m_growth, m_selfing, m_maxAge, m_assortMating);



}

species::species(string m_name, vector<string> m_traitsId, vector< int > m_ageClasses,
               int m_nLoci, vector<int> m_nAlleles, vector<int> m_ploidy, vector<int> m_chromoSplit,
               vector<trait> m_traits,
               vector<vector <int>> m_inherit, vector<double> m_crossRate,
               vector<double> m_mutationRate,  vector<list<fluxMat>> m_fluxMat, vector<double> m_selectMatrix,
               int m_cap, double m_growth, double m_selfing, int m_maxAge, string m_assortMating,
               double m_rho)
{

    // Check rho value, can be removed
    if(m_rho <= 0.0) {cout << "Rho must be strictly positive" << endl; exit(0);}

    set_params(m_name, m_traitsId, m_ageClasses, m_nLoci, m_nAlleles, m_ploidy, m_chromoSplit,
               m_traits, m_inherit, m_crossRate, m_mutationRate,  m_fluxMat, m_selectMatrix,
               m_cap, m_growth, m_selfing, m_maxAge, m_assortMating);

    rho = m_rho;

}


species::~species()
{
    //dtor
    for (auto i: index) //i: index works?
    {
      delete i.second;
    }
}



trait * species::get_trait(unsigned int index)
{
    if(index >= traits.size()) return NULL;
    return &(traits[index]);
}


vector<trait *> species::get_traits()
{
    vector<trait *> rv;
    rv.reserve(traits.size());
    for(unsigned int i = 0; i < traits.size(); i++)
        rv.push_back(&(traits[i]));

    return rv;
}


void species::adapt_fluxMat(int t)
{
    for (int p=0;p<currentTimeFlux.size();p++)
    {
        for (auto pflux = flux[p].begin(); pflux != flux[p].end(); ++pflux)
        {
            //cout << "TIMES" << pflux->get_time() << endl;
            if ((pflux->get_time() == t) && (t != 1))
            {
                cout << t << ": FLUX CHANGED FOR " << p << endl;
                currentTimeFlux[p] = *pflux;

                // test new flux
                /*vector<double> CTF = currentTimeFlux[p].get_donnors(23);
                cout << CTF[22] << " " << CTF[23] << " " << CTF[24] << endl;*/
            }

        }
    }
}


/* Basically an alias for flux_matrix::get_donnors
 */
vector<double> species::get_flux_mat(int cellNumber, int propagule)
{
    //       (propagule) => list (time) => fluxMatrix
    //return flux[propagule].front().get_donnors(cellNumber);
    return currentTimeFlux[propagule].get_donnors(cellNumber);
}


bool species::is_assort_strategy(string askedStrategy)
{
    string strategy = "expand_tolerance"; // TODO EXTERNALIZE STRAT
    return (askedStrategy == strategy);

}


void species::clear_unmatchable()
{
       unmatchable.clear();
}

void species::add_unmatchable(indId u)
{
       unmatchable.insert(u);
}

set<indId> species::get_unmatchable()
{
       return unmatchable;
}


/* Basically alias for flux_matrix::draw
 */
int species::select_weighted_pop(int index, int propagule, vector<double> &weights)
{
    // flux_matrix->propagule->first_frame->draw_function for pop = index
    //cout << "SWP" << propagule << " " << index << endl;
    //cout << "SWP" << weights.size() << endl;

    //return flux[propagule].front().draw(index, weights);
    return currentTimeFlux[propagule].draw(index, weights);
}

void species::checkChromoStruct(bool isDisplay = false)
{
    int beginD = 5;
    string beginS(5, ' ');
    vector<trait *> traits = get_traits();
    vector<string> chromoStruct;

    bool isLtypeOverlay = false;

    chromoStruct.push_back(beginS);//beginD+nLoci*2, ' '
    chromoStruct.push_back(beginS);
    for (int lo=0; lo < nLoci; lo++)
    {
        if (lo%5 == 0 and lo < 100)
        {
            string tN = to_string(lo);
            string tS = string(2-tN.length(), ' ');
            chromoStruct[0] += tN + tS;
        } else {
            chromoStruct[0] += "  ";
        }
        chromoStruct[1] += ". ";
    }

    for (int t=0; t < traits.size(); t++)
    {
        string nT = traits[t]->get_name();
        string beginT(chromoStruct[1].length() - nT.length(), ' ');
        chromoStruct.push_back(nT+beginT);//beginD+nLoci*2, ' '

        // how to find traits loci ?
        int nlt = traits[t]->get_nLtype();
        //ltype typesIndex
        //vector<char> nltypechar = {'a', 'b', 'c', '?'};
        for (int n=0;n < nlt;n++)
        {
            vector<unsigned int> loci = traits[t]->get_ltype(n)->get_locList();
            for (int nl=0;nl < loci.size();nl++)
            {
                if (chromoStruct[t+2][beginD+loci[nl]*2] == ' ')
                {
                    chromoStruct[t+2][beginD+loci[nl]*2] = traits[t]->get_ltype(n)->get_id().at(0);
                } else {
                    chromoStruct[t+2][beginD+loci[nl]*2] = '*';
                    isLtypeOverlay = true;
                }
            }
        }
    }
    for (int c=0; c < chromoSplit.size(); c++) {
        for (int l=1; l < chromoStruct.size(); l++) {
            int pos = beginD+2*chromoSplit[c]-1;
            //cout << "pos" << pos << endl;
            if (pos < chromoStruct[l].length()) {
                chromoStruct[l][pos] = '|';
            }
        }
    }

    if (isDisplay)
    {
        for (int l=0; l < chromoStruct.size(); l++) {
            cout << chromoStruct[l] << endl;
        }
        cout << "___" << endl;
    }

    if (isLtypeOverlay)
    {
        cout << "Overlay of ltypes not possible. Redefine the loci list with one locus in only one ltype." << endl;
        exit(0);
    }
}


//OPTIMIZE HERE ! + revamp
valarray<int> species::spawn_gamete(individual * pInd, int sex)
{
// We might have a problem : if the same sex is always first in the genome order then isn't there a bias in favor of transmissions?
// To address this, we must ensure that vectSwitch is randomised when the ploidy increases
    int vmax;
    unsigned int plo;
    double coRate;
    vector<int> vectSwitch = {0};

    if (pInd == NULL) { cout << "WARNING pInd == NULL" << endl; exit(0); }

    valarray<int> gamete = pInd->get_genome_va();

    //Part I : Reorganise the genome according to crossing over / chromosomes
    //Make a crossing over function ?
    for(unsigned int loc = 0; loc < locIndex.size(); loc++)
    {
        plo = ploidy[loc];

        //STEP 1 : Adjust size and content of vectSwitch to ploidy:
        //vectSwitch expand
        while(vectSwitch.size() < plo)
        {
                vectSwitch.push_back(vectSwitch.size());
        }
        //vectSwitch shrink
        while(vectSwitch.size() > plo)
        {
            //Get max
            vmax = 0;
            for(unsigned int i = 0; i < vectSwitch.size(); i++)
                if(vectSwitch[i] > vectSwitch[vmax]) vmax = i;
            /* Option one, shift every values
            for(int i = vmax; i < vectSwitch.size() - 1 ; i++)
                vectSwitch[i] = vectSwitch[i+1];
            */
            /* Option 2, permute with the last */
            vectSwitch[vmax] = vectSwitch.back();

            //Wether we should go for option one or option two is subject to reflexion,
            //But in absence of opinion I choose the faster one, option 2

            vectSwitch.pop_back();
        }

        //STEP 2 : Apply new cross over to vectSwitch
        coRate = crossRate[loc];
        //Get all pairs of two elements from vectSwitch (each possibility of crossing over)
        for(unsigned int i = 0; i < vectSwitch.size(); i++)
        {
            for(unsigned int j = i + 1; j < vectSwitch.size(); j++ )
            {
                if(Random() < coRate)
                {
                    swap(vectSwitch[i], vectSwitch[j]);
                }
            }
        }

        //STEP 3 : Apply permutations to the locus
        for(unsigned int i = 0; i < vectSwitch.size(); i++)
        {
            swap(gamete[ locIndex[loc] + i ], gamete[ locIndex[loc] + vectSwitch[i] ]);
        }
    }


    //Part II : Erase every part that is not transmitted (0 is fine) => could also be done elswhere
    /*
    for(unsigned int i = 0; i < locIndex.size(); i++) //For each locus
    {
        for(int j = 0; j < ploidy[i]; j++) //For each copy of the locus
        {
            //We only transmit the n first copies in the genome order
            if(j >= inherit[sex][i]) gamete[i+j] = 0;
        }
    }
    */
    if(sex == 0) gamete[!mask_female] = 0;
    if(sex == 1) gamete[!mask_male] = 0;



    //The resulting gamete has 0 for each non transmitted location
    //This should allow us to just sum gametes
    return gamete;
}


bool species::index_exist(string key)
{
    unordered_map<string, group_index*>::const_iterator got = index.find(key);
    if ( got == index.end() ) return 0;
    else return 1;
}


group_index * species::make_gIndex(landscape &myMap, string key, int m_nClasses)
{
    //If already exist : delete first
    if ( index_exist(key) == 1)
    {
        delete index[key];
    }
    index[key] = new group_index(myMap, this, key);
    return index[key];
}



group_index * species::get_index(string key)
{
    if(index_exist(key) == 0)
    {
        cout << "WARNING, index key not found: " << key << endl;
        return NULL;
    }
    else return index[key];
}





void species::calc_demo(landscape &myMap)
{
    if( demo.size() != myMap.max_size() )
        demo = vector<double>(myMap.max_size(), 0.0);
    for(unsigned int i = 0; i < myMap.max_size(); i++)
    {
        demo[i] = myMap.get_cell(i)->get_pop(this)->get_nInd();
    }
    //cout << "DEMOG" << demo.size() << endl;
}



vector<string> species::get_origins(bool getInternal, int cellRef, int largMap)
{
  bool merge = true;
  vector<string> migrations;

  int larg = largMap; 
  int altDest = cellRef/larg;
  //vector<map<int, vector<float>>> meanTraitByAlt;
  map<int, vector<double>> meanTraitByAlt; // CAUTION: MERGE ONLY WORKS FOR FIRST TRAIT


  if (allMigrations.size() > 0)
  {
    string title;
    if (merge)
    {
      title = "Coming into latitude "+to_string(altDest)+"\nOrigin latitude\tmean(T1) [var(T1)]\tT1(ind)";
    }
    else
    {
      title = "origin\tdestination";
      if (allMigrations[0].size()-2 == traits.size())
      {
        for (int t=0;t<traits.size();t++)
        {
          title += "\t"+traits[t].get_name();
        }
      }
      else
      {
        cout << "[BUG] trait number (" << traits.size() << ") different from origin values (" << allMigrations[0].size()-2 << ")";
      }
    }
    migrations.push_back(title);
  }
  else
  {
	//cout << "[WARNING] No migration found" << endl;
  }
  for (int i=0;i<allMigrations.size();i++)
  {
    if (allMigrations[i][0] != allMigrations[i][1] || getInternal)
    {
      string line;
      if (merge)
      {
        if ((int)(allMigrations[i][1]/larg) == altDest)
        {
          int currentOriginLat = allMigrations[i][0]/larg;
          map<int, vector<double>>::iterator itLat = meanTraitByAlt.find(currentOriginLat);
          if (itLat  == meanTraitByAlt.end() )
          {
            vector<double> newAlt;
            newAlt.push_back(allMigrations[i][2]);
            //meanTraitByAlt[meanTraitByAlt] = newAlt;
            meanTraitByAlt.insert(pair<int, vector<double>>(currentOriginLat, newAlt));
          }
          else
          {
            itLat->second.push_back(allMigrations[i][2]);
          }
        }
      }
      else
      {
        line = to_string((int)allMigrations[i][0]) + "\t" + to_string((int)allMigrations[i][1]);
        for (int t=2;t<allMigrations[i].size();t++)
        {
            line += "\t" + to_string(allMigrations[i][t]);
        }
        migrations.push_back(line);
      }
    }
  }
  if (merge)
  {
    map<int, vector<double>>::iterator it = meanTraitByAlt.begin();
    while (it != meanTraitByAlt.end())
    {
      string line = "";
      line += to_string(it->first);
      vector<double> phenos = it->second;
      double meanPheno = accumulate(phenos.begin(), phenos.end(), 0.0) / phenos.size();
      double varPheno = calc_var(phenos);
      line += "\t" + to_string(meanPheno) + " [" + to_string(varPheno) + "]";
      for (int i =0;i<phenos.size();i++)
      {
          line += "\t" +to_string(phenos[i]);
      }
      migrations.push_back(line);
      it++;
    }
  }
  return migrations;
}

void species::purgeMigrations()
{
  allMigrations.clear();
}

void species::addMigration(population* parentPop, unsigned int dest, indId parent)
{

  vector<double> thisMig;
  thisMig.push_back(parentPop->get_abs_pos());
  thisMig.push_back(dest);
  for(trait * currTrait : get_traits())
  {
     double currPheno = (double)(parentPop->get_index(currTrait->get_name() + "-pheno")->get_val(parent));
     thisMig.push_back(currPheno);
  }
  allMigrations.push_back(thisMig);
}

//Placeholder
void species::get_selTraits(vector<trait*>& selTraits)
{
    selTraits.clear();
    selTraits.reserve(selTraits.size());
    for(auto&& t : traits) selTraits.push_back(&t);
}


// Debug function
bool species::check_integrity(individual * pInd)
{
    unsigned int loc;
    for(int pos = 0; pos < gSize; ++pos)
    {
        loc = revLocIndex[pos];
        if(pInd->gen_val(pos) >= nAlleles[loc])
        {
            cout << "Locus " << loc << " found " << pInd->gen_val(pos) << " on  max of " << nAlleles[loc] << endl;
            cout << "Species = " << name << endl;
            return false;
        }

    }
    return true;
}




vector<double> species::get_all_values(string indexedValue, landscape &myMap)
{
    // Step 1 : gather all values for this trait
    vector<double> values;
    double totDemo = accumulate(begin(demo),end(demo), 0.0);
    values.reserve(totDemo);

    // Iterate on map*
    vector<int> populatedPops = myMap.getPopulatedPops(this);
    for(unsigned int i = 0; i < populatedPops.size(); i++)
    {
        //Get pop pointer
        population * pPop = myMap.get_cell(populatedPops[i])->get_pop(this);
        //Get inex pointers
        individual_index * pIndex = pPop->get_index_safe(indexedValue);
        vector<double> temp = pIndex->get_values();
        values.insert(values.end(), temp.begin(), temp.end());
    }
    return values;
}

void species::update_assort_threshold(landscape &myMap)
{
    // Get all values of the trait
    if (initVar == -1)
    {
        vector<double> values = get_all_values(assortMating+"-pheno", myMap);
        initVar = calc_var(values);
    }

    // Calc variance of trait values across all populations
    double varTrait = initVar;
    // Calc standard dev of threshold
    double sdassortThres = sqrt((varTrait / (rho*rho)) - varTrait);
    // Draw treshold (abs value, it is a distance)
    assortThres = fabs(gasdev() * sdassortThres);
    cout << "New assort threshold = " << assortThres << " from dist of sd " << sdassortThres << endl;
    return;
}

void species::update_SelInt_matrix(landscape &myMap)
{
    int nbTraits = traits.size();
    vector<vector<double>> tmpSIM(myMap.max_size());
    //cout << "MAP S " <<myMap.max_size() << endl;
    //cout << "SM S " <<selectMatrix.size() << endl;
    for(unsigned int i = 0; i < myMap.max_size(); i++)
    {
        vector<double> tmpSIM_i(selectMatrix.size(), 0);
        for (unsigned int s = 0; s < selectMatrix.size(); s++) // get common values
        {
            tmpSIM_i[s] = selectMatrix[s];
        }

        for (int t=0; t < traits.size(); t++)
        {
            tmpSIM_i[nbTraits*t+t] = traits[t].get_SelInt(i); // change temporal values
        }
        tmpSIM[i] = tmpSIM_i;
    }
    selectMatrixByPatch = tmpSIM;
    /*cout << "SIM " << selectMatrixByPatch.size() << endl;
    for(unsigned int i = 0; i < myMap.max_size(); i++)
    {
        for (unsigned int s = 0; s < selectMatrix.size(); s++) // get common values
        {
            cout << selectMatrixByPatch[i][s] << " ";
        }
        cout << endl;
    }*/
}
