// @(#)root/roostats:$Id$
// Authors: Kevin Belasco        17/06/2009
// Authors: Kyle Cranmer         17/06/2009
/*************************************************************************
 * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers.               *
 * All rights reserved.                                                  *
 *                                                                       *
 * For the licensing terms see $ROOTSYS/LICENSE.                         *
 * For the list of contributors see $ROOTSYS/README/CREDITS.             *
 *************************************************************************/

/** \class RooStats::MCMCInterval
    \ingroup Roostats

   MCMCInterval is a concrete implementation of the RooStats::ConfInterval
   interface.  It takes as input Markov Chain of data points in the parameter
   space generated by Monte Carlo using the Metropolis algorithm.  From the Markov
   Chain, the confidence interval can be determined in two ways:

#### Using a Kernel-Estimated PDF: (not the default method)

   A RooNDKeysPdf is constructed from the data set using adaptive kernel width.
   With this RooNDKeysPdf F, we then integrate over the most likely domain in the
   parameter space (tallest points in the posterior RooNDKeysPdf) until the target
   confidence level is reached within an acceptable neighborhood as defined by
   SetEpsilon(). More specifically: we calculate the following for different
   cutoff values C until we reach the target confidence level: \f$\int_{ F >= C } F
   d{normset} \f$.
   Important note: this is not the default method because of a bug in constructing
   the RooNDKeysPdf from a weighted data set.  Configure to use this method by
   calling SetUseKeys(true), and the data set will be interpreted without weights.


#### Using a binned data set: (the default method)

   This is the binned analog of the continuous integrative method that uses the
   kernel-estimated PDF.  The points in the Markov Chain are put into a binned
   data set and the interval is then calculated by adding the heights of the bins
   in decreasing order until the desired level of confidence has been reached.
   Note that this means the actual confidence level is >= the confidence level
   prescribed by the client (unless the user calls SetHistStrict(false)).  This
   method is the default but may not remain as such in future releases, so you may
   wish to explicitly configure to use this method by calling SetUseKeys(false)


   These are not the only ways for the confidence interval to be determined, and
   other possibilities are being considered being added, especially for the
   1-dimensional case.


   One can ask an MCMCInterval for the lower and upper limits on a specific
   parameter of interest in the interval.  Note that this works better for some
   distributions (ones with exactly one local maximum) than others, and sometimes
   has little value.
*/


#include "Rtypes.h"

#include "TMath.h"

#include "RooStats/MCMCInterval.h"
#include "RooStats/MarkovChain.h"
#include "RooStats/Heaviside.h"
#include "RooDataHist.h"
#include "RooNDKeysPdf.h"
#include "RooProduct.h"
#include "RooStats/RooStatsUtils.h"
#include "RooRealVar.h"
#include "RooArgList.h"
#include "TH1.h"
#include "TH1F.h"
#include "TH2F.h"
#include "TH3F.h"
#include "RooMsgService.h"
#include "RooGlobalFunc.h"
#include "TObject.h"
#include "THnSparse.h"
#include "RooNumber.h"

#include <cstdlib>
#include <string>
#include <algorithm>


using namespace RooFit;
using namespace RooStats;
using std::endl;

////////////////////////////////////////////////////////////////////////////////

MCMCInterval::MCMCInterval(const char *name)
   : ConfInterval(name), fTFLower(-RooNumber::infinity()), fTFUpper(RooNumber::infinity())
{
   fVector.clear();
}

////////////////////////////////////////////////////////////////////////////////

MCMCInterval::MCMCInterval(const char *name, const RooArgSet &parameters, MarkovChain &chain)
   : ConfInterval(name), fChain(&chain), fTFLower(-RooNumber::infinity()), fTFUpper(RooNumber::infinity())
{
   fVector.clear();
   SetParameters(parameters);
}

////////////////////////////////////////////////////////////////////////////////

MCMCInterval::~MCMCInterval() = default;

struct CompareDataHistBins {
   CompareDataHistBins(RooDataHist* hist) : fDataHist(hist) {}
   bool operator() (Int_t bin1 , Int_t bin2) {
      fDataHist->get(bin1);
      double n1 = fDataHist->weight();
      fDataHist->get(bin2);
      double n2 = fDataHist->weight();
      return (n1 < n2);
   }
   RooDataHist* fDataHist;
};

struct CompareSparseHistBins {
   CompareSparseHistBins(THnSparse* hist) : fSparseHist(hist) {}
   bool operator() (Long_t bin1, Long_t bin2) {
      double n1 = fSparseHist->GetBinContent(bin1);
      double n2 = fSparseHist->GetBinContent(bin2);
      return (n1 < n2);
   }
   THnSparse* fSparseHist;
};

struct CompareVectorIndices {
   CompareVectorIndices(MarkovChain* chain, RooRealVar* param) :
      fChain(chain), fParam(param) {}
   bool operator() (Int_t i, Int_t j) {
      double xi = fChain->Get(i)->getRealValue(fParam->GetName());
      double xj = fChain->Get(j)->getRealValue(fParam->GetName());
      return (xi < xj);
   }
   MarkovChain* fChain;
   RooRealVar* fParam;
};

////////////////////////////////////////////////////////////////////////////////
/// kbelasco: for this method, consider running DetermineInterval() if
/// fKeysPdf==nullptr, fSparseHist==nullptr, fDataHist==nullptr, or fVector.empty()
/// rather than just returning false.  Though this should not be an issue
/// because nobody should be able to get an MCMCInterval that has their interval
/// or posterior representation nullptr/empty since they should only get this
/// through the MCMCCalculator

bool MCMCInterval::IsInInterval(const RooArgSet& point) const
{
   if (fIntervalType == kShortest) {
      if (fUseKeys) {
         if (fKeysPdf == nullptr)
            return false;

         // evaluate keyspdf at point and return whether >= cutoff
         RooStats::SetParameters(&point, const_cast<RooArgSet *>(&fParameters));
         return fKeysPdf->getVal(&fParameters) >= fKeysCutoff;
      } else {
         if (fUseSparseHist) {
            if (fSparseHist == nullptr)
               return false;

            // evaluate sparse hist at bin where point lies and return
            // whether >= cutoff
            RooStats::SetParameters(&point,
                                    const_cast<RooArgSet*>(&fParameters));
            Long_t bin;
            // kbelasco: consider making x static
            std::vector<double> x(fDimension);
            for (Int_t i = 0; i < fDimension; i++)
               x[i] = fAxes[i]->getVal();
            bin = fSparseHist->GetBin(x.data(), false);
            double weight = fSparseHist->GetBinContent((Long64_t)bin);
            return (weight >= (double)fHistCutoff);
         } else {
            if (fDataHist == nullptr)
               return false;

            // evaluate data hist at bin where point lies and return whether
            // >= cutoff
            Int_t bin;
            bin = fDataHist->getIndex(point);
            fDataHist->get(bin);
            return (fDataHist->weight() >= (double)fHistCutoff);
         }
      }
   } else if (fIntervalType == kTailFraction) {
      if (fVector.empty())
         return false;

      // return whether value of point is within the range
      double x = point.getRealValue(fAxes[0]->GetName());
      if (fTFLower <= x && x <= fTFUpper)
         return true;

      return false;
   }

   coutE(InputArguments) << "Error in MCMCInterval::IsInInterval: "
      << "Interval type not set.  Returning false." << std::endl;
   return false;
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::SetConfidenceLevel(double cl)
{
   fConfidenceLevel = cl;
   DetermineInterval();
}

// kbelasco: update this or just take it out
// kbelasco: consider keeping this around but changing the implementation
// to set the number of bins for each RooRealVar and then recreating the
// histograms
//void MCMCInterval::SetNumBins(Int_t numBins)
//{
//   if (numBins > 0) {
//      fPreferredNumBins = numBins;
//      for (Int_t d = 0; d < fDimension; d++)
//         fNumBins[d] = numBins;
//   }
//   else {
//      coutE(Eval) << "* Error in MCMCInterval::SetNumBins: " <<
//                     "Negative number of bins given: " << numBins << std::endl;
//      return;
//   }
//
//   // If the histogram already exists, recreate it with the new bin numbers
//   if (fHist != nullptr)
//      CreateHist();
//}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::SetAxes(RooArgList& axes)
{
   Int_t size = axes.size();
   if (size != fDimension) {
      coutE(InputArguments) << "* Error in MCMCInterval::SetAxes: " <<
                               "number of variables in axes (" << size <<
                               ") doesn't match number of parameters ("
                               << fDimension << ")" << std::endl;
      return;
   }
   for (Int_t i = 0; i < size; i++)
      fAxes[i] = static_cast<RooRealVar*>(axes.at(i));
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::CreateKeysPdf()
{
   // kbelasco: check here for memory leak.  does RooNDKeysPdf use
   // the RooArgList passed to it or does it make a clone?
   // also check for memory leak from chain, does RooNDKeysPdf clone that?
   if (fAxes.empty() || fParameters.empty()) {
      coutE(InputArguments) << "Error in MCMCInterval::CreateKeysPdf: "
         << "parameters have not been set." << std::endl;
      return;
   }

   if (fNumBurnInSteps >= fChain->Size()) {
      coutE(InputArguments) <<
         "MCMCInterval::CreateKeysPdf: creation of Keys PDF failed: " <<
         "Number of burn-in steps (num steps to ignore) >= number of steps " <<
         "in Markov chain." << std::endl;
      fKeysPdf.reset();
      fCutoffVar.reset();
      fHeaviside.reset();
      fProduct.reset();
      return;
   }
   
   std::unique_ptr<RooAbsData> chain{fChain->GetAsConstDataSet()->reduce(SelectVars(fParameters), EventRange(fNumBurnInSteps, fChain->Size()))};

   RooArgList paramsList;
   for (Int_t i = 0; i < fDimension; i++)
      paramsList.add(*fAxes[i]);

   fKeysPdf = std::make_unique<RooNDKeysPdf>("keysPDF", "Keys PDF", paramsList, static_cast<RooDataSet&>(*chain), "a");
   fCutoffVar = std::make_unique<RooRealVar>("cutoff", "cutoff", 0);
   fHeaviside = std::make_unique<Heaviside>("heaviside", "Heaviside", *fKeysPdf, *fCutoffVar);
   fProduct = std::make_unique<RooProduct>("product", "Keys PDF & Heaviside Product",
                                        RooArgSet(*fKeysPdf, *fHeaviside));
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::CreateHist()
{
   if (fAxes.empty() || fChain == nullptr) {
      coutE(Eval) << "* Error in MCMCInterval::CreateHist(): " <<
                     "Crucial data member was nullptr." << std::endl;
      coutE(Eval) << "Make sure to fully construct/initialize." << std::endl;
      return;
   }
   fHist.reset();

   if (fNumBurnInSteps >= fChain->Size()) {
      coutE(InputArguments) <<
         "MCMCInterval::CreateHist: creation of histogram failed: " <<
         "Number of burn-in steps (num steps to ignore) >= number of steps " <<
         "in Markov chain." << std::endl;
      return;
   }

   if (fDimension == 1) {
      fHist = std::make_unique<TH1F>("posterior", "MCMC Posterior Histogram",
            fAxes[0]->numBins(), fAxes[0]->getMin(), fAxes[0]->getMax());

   } else if (fDimension == 2) {
      fHist = std::make_unique<TH2F>("posterior", "MCMC Posterior Histogram",
            fAxes[0]->numBins(), fAxes[0]->getMin(), fAxes[0]->getMax(),
            fAxes[1]->numBins(), fAxes[1]->getMin(), fAxes[1]->getMax());

   } else if (fDimension == 3) {
      fHist = std::make_unique<TH3F>("posterior", "MCMC Posterior Histogram",
            fAxes[0]->numBins(), fAxes[0]->getMin(), fAxes[0]->getMax(),
            fAxes[1]->numBins(), fAxes[1]->getMin(), fAxes[1]->getMax(),
            fAxes[2]->numBins(), fAxes[2]->getMin(), fAxes[2]->getMax());

   } else {
      coutE(Eval) << "* Error in MCMCInterval::CreateHist() : " <<
                     "TH1* couldn't handle dimension: " << fDimension << std::endl;
      return;
   }

   // Fill histogram
   Int_t size = fChain->Size();
   const RooArgSet* entry;
   for (Int_t i = fNumBurnInSteps; i < size; i++) {
      entry = fChain->Get(i);
      if (fDimension == 1) {
         (static_cast<TH1F&>(*fHist)).Fill(entry->getRealValue(fAxes[0]->GetName()),
                              fChain->Weight());
      } else if (fDimension == 2) {
         (static_cast<TH2F&>(*fHist)).Fill(entry->getRealValue(fAxes[0]->GetName()),
                              entry->getRealValue(fAxes[1]->GetName()),
                              fChain->Weight());
      } else {
         (static_cast<TH3F &>(*fHist))
            .Fill(entry->getRealValue(fAxes[0]->GetName()), entry->getRealValue(fAxes[1]->GetName()),
                   entry->getRealValue(fAxes[2]->GetName()), fChain->Weight());
      }
   }

   if (fDimension >= 1)
      fHist->GetXaxis()->SetTitle(fAxes[0]->GetName());
   if (fDimension >= 2)
      fHist->GetYaxis()->SetTitle(fAxes[1]->GetName());
   if (fDimension >= 3)
      fHist->GetZaxis()->SetTitle(fAxes[2]->GetName());
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::CreateSparseHist()
{
   if (fAxes.empty() || fChain == nullptr) {
      coutE(InputArguments) << "* Error in MCMCInterval::CreateSparseHist(): "
                            << "Crucial data member was nullptr." << std::endl;
      coutE(InputArguments) << "Make sure to fully construct/initialize."
                            << std::endl;
      return;
   }
   std::vector<double> min(fDimension);
   std::vector<double> max(fDimension);
   std::vector<Int_t> bins(fDimension);
   for (Int_t i = 0; i < fDimension; i++) {
      min[i] = fAxes[i]->getMin();
      max[i] = fAxes[i]->getMax();
      bins[i] = fAxes[i]->numBins();
   }
   fSparseHist = std::make_unique<THnSparseF>("posterior", "MCMC Posterior Histogram",
         fDimension, bins.data(), min.data(), max.data());

   // kbelasco: it appears we need to call Sumw2() just to get the
   // histogram to keep a running total of the weight so that Getsumw doesn't
   // just return 0
   fSparseHist->Sumw2();

   if (fNumBurnInSteps >= fChain->Size()) {
      coutE(InputArguments) <<
         "MCMCInterval::CreateSparseHist: creation of histogram failed: " <<
         "Number of burn-in steps (num steps to ignore) >= number of steps " <<
         "in Markov chain." << std::endl;
   }

   // Fill histogram
   Int_t size = fChain->Size();
   const RooArgSet* entry;
   std::vector<double> x(fDimension);
   for (Int_t i = fNumBurnInSteps; i < size; i++) {
      entry = fChain->Get(i);
      for (Int_t ii = 0; ii < fDimension; ii++)
         x[ii] = entry->getRealValue(fAxes[ii]->GetName());
      fSparseHist->Fill(x.data(), fChain->Weight());
   }
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::CreateDataHist()
{
   if (fParameters.empty() || fChain == nullptr) {
      coutE(Eval) << "* Error in MCMCInterval::CreateDataHist(): " <<
                     "Crucial data member was nullptr or empty." << std::endl;
      coutE(Eval) << "Make sure to fully construct/initialize." << std::endl;
      return;
   }

   if (fNumBurnInSteps >= fChain->Size()) {
      coutE(InputArguments) <<
         "MCMCInterval::CreateDataHist: creation of histogram failed: " <<
         "Number of burn-in steps (num steps to ignore) >= number of steps " <<
         "in Markov chain." << std::endl;
      fDataHist = nullptr;
      return;
   }

   std::unique_ptr<RooAbsData> data{fChain->GetAsConstDataSet()->reduce(SelectVars(fParameters), EventRange(fNumBurnInSteps, fChain->Size()))};
   fDataHist = std::unique_ptr<RooDataHist>{static_cast<RooDataSet &>(*data).binnedClone()};
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::CreateVector(RooRealVar* param)
{
   fVector.clear();
   fVecWeight = 0;

   if (fChain == nullptr) {
      coutE(InputArguments) << "* Error in MCMCInterval::CreateVector(): " <<
                     "Crucial data member (Markov chain) was nullptr." << std::endl;
      coutE(InputArguments) << "Make sure to fully construct/initialize."
                            << std::endl;
      return;
   }

   if (fNumBurnInSteps >= fChain->Size()) {
      coutE(InputArguments) <<
         "MCMCInterval::CreateVector: creation of vector failed: " <<
         "Number of burn-in steps (num steps to ignore) >= number of steps " <<
         "in Markov chain." << std::endl;
   }

   // Fill vector
   Int_t size = fChain->Size() - fNumBurnInSteps;
   fVector.resize(size);
   Int_t i;
   Int_t chainIndex;
   for (i = 0; i < size; i++) {
      chainIndex = i + fNumBurnInSteps;
      fVector[i] = chainIndex;
      fVecWeight += fChain->Weight(chainIndex);
   }

   stable_sort(fVector.begin(), fVector.end(),
               CompareVectorIndices(fChain.get(), param));
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::SetParameters(const RooArgSet& parameters)
{
   fParameters.removeAll();
   fParameters.add(parameters);
   fDimension = fParameters.size();
   fAxes.resize(fDimension);
   Int_t n = 0;
   for (auto *obj : fParameters) {
      if (dynamic_cast<RooRealVar *>(obj) != nullptr) {
         fAxes[n] = static_cast<RooRealVar*>(obj);
      } else {
         coutE(Eval) << "* Error in MCMCInterval::SetParameters: " << obj->GetName() << " not a RooRealVar*"
                     << std::endl;
      }
      n++;
   }
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::DetermineInterval()
{
   switch (fIntervalType) {
      case kShortest:
         DetermineShortestInterval();
         break;
      case kTailFraction:
         DetermineTailFractionInterval();
         break;
      default:
         coutE(InputArguments) << "MCMCInterval::DetermineInterval(): " <<
            "Error: Interval type not set" << std::endl;
         break;
   }
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::DetermineShortestInterval()
{
   if (fUseKeys) {
         DetermineByKeys();
   } else {
         DetermineByHist();
   }
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::DetermineTailFractionInterval()
{
   if (fLeftSideTF < 0 || fLeftSideTF > 1) {
      coutE(InputArguments) << "MCMCInterval::DetermineTailFractionInterval: "
         << "Fraction must be in the range [0, 1].  "
         << fLeftSideTF << "is not allowed." << std::endl;
      return;
   }

   if (fDimension != 1) {
      coutE(InputArguments) << "MCMCInterval::DetermineTailFractionInterval(): "
         << "Error: Can only find a tail-fraction interval for 1-D intervals"
         << std::endl;
      return;
   }

   if (fAxes.empty()) {
      coutE(InputArguments) << "MCMCInterval::DetermineTailFractionInterval(): "
                            << "Crucial data member was nullptr." << std::endl;
      coutE(InputArguments) << "Make sure to fully construct/initialize."
                            << std::endl;
      return;
   }

   // kbelasco: fill in code here to find interval
   //
   // also make changes so that calling GetPosterior...() returns nullptr
   // when fIntervalType == kTailFraction, since there really
   // is no posterior for this type of interval determination
   if (fVector.empty())
      CreateVector(fAxes[0]);

   if (fVector.empty() || fVecWeight == 0) {
      // if size is still 0, then creation failed.
      // if fVecWeight == 0, then there are no entries (indicates the same
      // error as fVector.empty() because that only happens when
      // fNumBurnInSteps >= fChain->Size())
      // either way, reset and return
      fVector.clear();
      fTFLower = -1.0 * RooNumber::infinity();
      fTFUpper = RooNumber::infinity();
      fTFConfLevel = 0.0;
      fVecWeight = 0;
      return;
   }

   RooRealVar* param = fAxes[0];

   double c = fConfidenceLevel;
   double leftTailCutoff  = fVecWeight * (1 - c) * fLeftSideTF;
   double rightTailCutoff = fVecWeight * (1 - c) * (1 - fLeftSideTF);
   double leftTailSum  = 0;
   double rightTailSum = 0;

   // kbelasco: consider changing these values to +infinity and -infinity
   double ll = param->getMin();
   double ul = param->getMax();

   double x;
   double w;

   // save a lot of GetName() calls if compiler does not already optimize this
   const char* name = param->GetName();

   // find lower limit
   Int_t i;
   for (i = 0; i < (Int_t)fVector.size(); i++) {
      x = fChain->Get(fVector[i])->getRealValue(name);
      w = fChain->Weight();

      if (std::abs(leftTailSum + w - leftTailCutoff) <
          std::abs(leftTailSum - leftTailCutoff)) {
         // moving the lower limit to x would bring us closer to the desired
         // left tail size
         ll = x;
         leftTailSum += w;
      } else
         break;
   }

   // find upper limit
   for (i = (Int_t)fVector.size() - 1; i >= 0; i--) {
      x = fChain->Get(fVector[i])->getRealValue(name);
      w = fChain->Weight();

      if (std::abs(rightTailSum + w - rightTailCutoff) <
          std::abs(rightTailSum - rightTailCutoff)) {
         // moving the lower limit to x would bring us closer to the desired
         // left tail size
         ul = x;
         rightTailSum += w;
      } else
         break;
   }

   fTFLower = ll;
   fTFUpper = ul;
   fTFConfLevel = 1 - (leftTailSum + rightTailSum) / fVecWeight;
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::DetermineByKeys()
{
   if (fKeysPdf == nullptr)
      CreateKeysPdf();

   if (fKeysPdf == nullptr) {
      // if fKeysPdf is still nullptr, then it means CreateKeysPdf failed
      // so clear all the data members this function would normally determine
      // and return
      fFull = 0.0;
      fKeysCutoff = -1;
      fKeysConfLevel = 0.0;
      return;
   }

   // now we have a keys pdf of the posterior

   double cutoff = 0.0;
   fCutoffVar->setVal(cutoff);
   double full = std::unique_ptr<RooAbsReal>{fProduct->createIntegral(fParameters, NormSet(fParameters))}->getVal(fParameters);
   fFull = full;
   if (full < 0.98) {
      coutW(Eval) << "Warning: Integral of Keys PDF came out to " << full
         << " instead of expected value 1.  Will continue using this "
         << "factor to normalize further integrals of this PDF." << std::endl;
   }

   // kbelasco: Is there a better way to set the search range?
   // from 0 to max value of Keys
   // kbelasco: how to get max value?
   //double max = product.maxVal(product.getMaxVal(fParameters));

   double volume = 1.0;
   for (auto *var : static_range_cast<RooRealVar*>(fParameters))
      volume *= (var->getMax() - var->getMin());

   double topCutoff = full / volume;
   double bottomCutoff = topCutoff;
   double confLevel = CalcConfLevel(topCutoff, full);
   if (AcceptableConfLevel(confLevel)) {
      fKeysConfLevel = confLevel;
      fKeysCutoff = topCutoff;
      return;
   }
   bool changed = false;
   // find high end of range
   while (confLevel > fConfidenceLevel) {
      topCutoff *= 2.0;
      confLevel = CalcConfLevel(topCutoff, full);
      if (AcceptableConfLevel(confLevel)) {
         fKeysConfLevel = confLevel;
         fKeysCutoff = topCutoff;
         return;
      }
      changed = true;
   }
   if (changed) {
      bottomCutoff = topCutoff / 2.0;
   } else {
      changed = false;
      bottomCutoff /= 2.0;
      confLevel = CalcConfLevel(bottomCutoff, full);
      if (AcceptableConfLevel(confLevel)) {
         fKeysConfLevel = confLevel;
         fKeysCutoff = bottomCutoff;
         return;
      }
      while (confLevel < fConfidenceLevel) {
         bottomCutoff /= 2.0;
         confLevel = CalcConfLevel(bottomCutoff, full);
         if (AcceptableConfLevel(confLevel)) {
            fKeysConfLevel = confLevel;
            fKeysCutoff = bottomCutoff;
            return;
         }
         changed = true;
      }
      if (changed) {
         topCutoff = bottomCutoff * 2.0;
      }
   }

   coutI(Eval) << "range set: [" << bottomCutoff << ", " << topCutoff << "]"
               << std::endl;

   cutoff = (topCutoff + bottomCutoff) / 2.0;
   confLevel = CalcConfLevel(cutoff, full);

   // need to use WithinDeltaFraction() because sometimes the integrating the
   // posterior in this binary search seems to not have enough granularity to
   // find an acceptable conf level (small no. of strange cases).
   // WithinDeltaFraction causes the search to terminate when
   // topCutoff is essentially equal to bottomCutoff (compared to the magnitude
   // of their mean).
   while (!AcceptableConfLevel(confLevel) &&
          !WithinDeltaFraction(topCutoff, bottomCutoff)) {
      if (confLevel > fConfidenceLevel) {
         bottomCutoff = cutoff;
      } else if (confLevel < fConfidenceLevel) {
         topCutoff = cutoff;
      }
      cutoff = (topCutoff + bottomCutoff) / 2.0;
      coutI(Eval) << "cutoff range: [" << bottomCutoff << ", "
                  << topCutoff << "]" << std::endl;
      confLevel = CalcConfLevel(cutoff, full);
   }

   fKeysCutoff = cutoff;
   fKeysConfLevel = confLevel;
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::DetermineByHist()
{
   if (fUseSparseHist) {
      DetermineBySparseHist();
   } else {
      DetermineByDataHist();
   }
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::DetermineBySparseHist()
{
   Long_t numBins;
   if (fSparseHist == nullptr)
      CreateSparseHist();

   if (fSparseHist == nullptr) {
      // if fSparseHist is still nullptr, then CreateSparseHist failed
      fHistCutoff = -1;
      fHistConfLevel = 0.0;
      return;
   }

   numBins = (Long_t)fSparseHist->GetNbins();

   std::vector<Long_t> bins(numBins);
   for (Int_t ibin = 0; ibin < numBins; ibin++)
      bins[ibin] = (Long_t)ibin;
   std::stable_sort(bins.begin(), bins.end(), CompareSparseHistBins(fSparseHist.get()));

   double nEntries = fSparseHist->GetSumw();
   double sum = 0;
   double content;
   Int_t i;
   // see above note on indexing to understand numBins - 3
   for (i = numBins - 1; i >= 0; i--) {
      content = fSparseHist->GetBinContent(bins[i]);
      if ((sum + content) / nEntries >= fConfidenceLevel) {
         fHistCutoff = content;
         if (fIsHistStrict) {
            sum += content;
            i--;
            break;
         } else {
            i++;
            break;
         }
      }
      sum += content;
   }

   if (fIsHistStrict) {
      // keep going to find the sum
      for ( ; i >= 0; i--) {
         content = fSparseHist->GetBinContent(bins[i]);
         if (content == fHistCutoff) {
            sum += content;
         } else {
            break; // content must be < fHistCutoff
         }
      }
   } else {
      // backtrack to find the cutoff and sum
      for ( ; i < numBins; i++) {
         content = fSparseHist->GetBinContent(bins[i]);
         if (content > fHistCutoff) {
            fHistCutoff = content;
            break;
         } else // content == fHistCutoff
            sum -= content;
         if (i == numBins - 1) {
            // still haven't set fHistCutoff correctly yet, and we have no bins
            // left, so set fHistCutoff to something higher than the tallest bin
            fHistCutoff = content + 1.0;
         }
      }
   }

   fHistConfLevel = sum / nEntries;
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::DetermineByDataHist()
{
   Int_t numBins;
   if (fDataHist == nullptr)
      CreateDataHist();
   if (fDataHist == nullptr) {
      // if fDataHist is still nullptr, then CreateDataHist failed
      fHistCutoff = -1;
      fHistConfLevel = 0.0;
      return;
   }

   numBins = fDataHist->numEntries();

   std::vector<Int_t> bins(numBins);
   for (Int_t ibin = 0; ibin < numBins; ibin++)
      bins[ibin] = ibin;
   std::stable_sort(bins.begin(), bins.end(), CompareDataHistBins(fDataHist.get()));

   double nEntries = fDataHist->sum(false);
   double sum = 0;
   double content;
   Int_t i;
   for (i = numBins - 1; i >= 0; i--) {
      fDataHist->get(bins[i]);
      content = fDataHist->weight();
      if ((sum + content) / nEntries >= fConfidenceLevel) {
         fHistCutoff = content;
         if (fIsHistStrict) {
            sum += content;
            i--;
            break;
         } else {
            i++;
            break;
         }
      }
      sum += content;
   }

   if (fIsHistStrict) {
      // keep going to find the sum
      for ( ; i >= 0; i--) {
         fDataHist->get(bins[i]);
         content = fDataHist->weight();
         if (content == fHistCutoff) {
            sum += content;
         } else {
            break; // content must be < fHistCutoff
         }
      }
   } else {
      // backtrack to find the cutoff and sum
      for ( ; i < numBins; i++) {
         fDataHist->get(bins[i]);
         content = fDataHist->weight();
         if (content > fHistCutoff) {
            fHistCutoff = content;
            break;
         } else // content == fHistCutoff
            sum -= content;
         if (i == numBins - 1) {
            // still haven't set fHistCutoff correctly yet, and we have no bins
            // left, so set fHistCutoff to something higher than the tallest bin
            fHistCutoff = content + 1.0;
         }
      }
   }

   fHistConfLevel = sum / nEntries;
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::GetActualConfidenceLevel()
{
   if (fIntervalType == kShortest) {
      if (fUseKeys) {
         return fKeysConfLevel;
      } else {
         return fHistConfLevel;
      }
   } else if (fIntervalType == kTailFraction) {
      return fTFConfLevel;
   } else {
      coutE(InputArguments) << "MCMCInterval::GetActualConfidenceLevel: "
         << "not implemented for this type of interval.  Returning 0." << std::endl;
      return 0;
   }
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::LowerLimit(RooRealVar& param)
{
   switch (fIntervalType) {
      case kShortest:
         return LowerLimitShortest(param);
      case kTailFraction:
         return LowerLimitTailFraction(param);
      default:
         coutE(InputArguments) << "MCMCInterval::LowerLimit(): " <<
            "Error: Interval type not set" << std::endl;
         return RooNumber::infinity();
   }
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::UpperLimit(RooRealVar& param)
{
   switch (fIntervalType) {
      case kShortest:
         return UpperLimitShortest(param);
      case kTailFraction:
         return UpperLimitTailFraction(param);
      default:
         coutE(InputArguments) << "MCMCInterval::UpperLimit(): " <<
            "Error: Interval type not set" << std::endl;
         return RooNumber::infinity();
   }
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::LowerLimitTailFraction(RooRealVar& /*param*/)
{
   if (fTFLower == -1.0 * RooNumber::infinity())
      DetermineTailFractionInterval();

   return fTFLower;
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::UpperLimitTailFraction(RooRealVar& /*param*/)
{
   if (fTFUpper == RooNumber::infinity())
      DetermineTailFractionInterval();

   return fTFUpper;
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::LowerLimitShortest(RooRealVar& param)
{
   if (fUseKeys) {
      return LowerLimitByKeys(param);
   } else {
      return LowerLimitByHist(param);
   }
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::UpperLimitShortest(RooRealVar& param)
{
   if (fUseKeys) {
      return UpperLimitByKeys(param);
   } else {
      return UpperLimitByHist(param);
   }
}

////////////////////////////////////////////////////////////////////////////////
/// Determine the lower limit for param on this interval
/// using the binned data set

double MCMCInterval::LowerLimitByHist(RooRealVar& param)
{
   if (fUseSparseHist) {
      return LowerLimitBySparseHist(param);
   } else {
      return LowerLimitByDataHist(param);
   }
}

////////////////////////////////////////////////////////////////////////////////
/// Determine the upper limit for param on this interval
/// using the binned data set

double MCMCInterval::UpperLimitByHist(RooRealVar& param)
{
   if (fUseSparseHist) {
      return UpperLimitBySparseHist(param);
   } else {
      return UpperLimitByDataHist(param);
   }
}

////////////////////////////////////////////////////////////////////////////////
/// Determine the lower limit for param on this interval
/// using the binned data set

double MCMCInterval::LowerLimitBySparseHist(RooRealVar& param)
{
   if (fDimension != 1) {
      coutE(InputArguments) << "In MCMCInterval::LowerLimitBySparseHist: "
         << "Sorry, will not compute lower limit unless dimension == 1" << std::endl;
      return param.getMin();
   }
   if (fHistCutoff < 0)
      DetermineBySparseHist(); // this initializes fSparseHist

   if (fHistCutoff < 0) {
      // if fHistCutoff < 0 still, then determination of interval failed
      coutE(Eval) << "In MCMCInterval::LowerLimitBySparseHist: "
         << "couldn't determine cutoff.  Check that num burn in steps < num "
         << "steps in the Markov chain.  Returning param.getMin()." << std::endl;
      return param.getMin();
   }

   std::vector<Int_t> coord(fDimension);
   for (Int_t d = 0; d < fDimension; d++) {
      if (strcmp(fAxes[d]->GetName(), param.GetName()) == 0) {
         Long_t numBins = (Long_t)fSparseHist->GetNbins();
         double lowerLimit = param.getMax();
         double val;
         for (Long_t i = 0; i < numBins; i++) {
            if (fSparseHist->GetBinContent(i, &coord[0]) >= fHistCutoff) {
               val = fSparseHist->GetAxis(d)->GetBinCenter(coord[d]);
               if (val < lowerLimit)
                  lowerLimit = val;
            }
         }
         return lowerLimit;
      }
   }
   return param.getMin();
}

////////////////////////////////////////////////////////////////////////////////
/// Determine the lower limit for param on this interval
/// using the binned data set

double MCMCInterval::LowerLimitByDataHist(RooRealVar& param)
{
   if (fHistCutoff < 0)
      DetermineByDataHist(); // this initializes fDataHist

   if (fHistCutoff < 0) {
      // if fHistCutoff < 0 still, then determination of interval failed
      coutE(Eval) << "In MCMCInterval::LowerLimitByDataHist: "
         << "couldn't determine cutoff.  Check that num burn in steps < num "
         << "steps in the Markov chain.  Returning param.getMin()." << std::endl;
      return param.getMin();
   }

   for (Int_t d = 0; d < fDimension; d++) {
      if (strcmp(fAxes[d]->GetName(), param.GetName()) == 0) {
         Int_t numBins = fDataHist->numEntries();
         double lowerLimit = param.getMax();
         double val;
         for (Int_t i = 0; i < numBins; i++) {
            fDataHist->get(i);
            if (fDataHist->weight() >= fHistCutoff) {
               val = fDataHist->get()->getRealValue(param.GetName());
               if (val < lowerLimit)
                  lowerLimit = val;
            }
         }
         return lowerLimit;
      }
   }
   return param.getMin();
}

////////////////////////////////////////////////////////////////////////////////
/// Determine the upper limit for param on this interval
/// using the binned data set

double MCMCInterval::UpperLimitBySparseHist(RooRealVar& param)
{
   if (fDimension != 1) {
      coutE(InputArguments) << "In MCMCInterval::UpperLimitBySparseHist: "
         << "Sorry, will not compute upper limit unless dimension == 1" << std::endl;
      return param.getMax();
   }
   if (fHistCutoff < 0)
      DetermineBySparseHist(); // this initializes fSparseHist

   if (fHistCutoff < 0) {
      // if fHistCutoff < 0 still, then determination of interval failed
      coutE(Eval) << "In MCMCInterval::UpperLimitBySparseHist: "
         << "couldn't determine cutoff.  Check that num burn in steps < num "
         << "steps in the Markov chain.  Returning param.getMax()." << std::endl;
      return param.getMax();
   }

   std::vector<Int_t> coord(fDimension);
   for (Int_t d = 0; d < fDimension; d++) {
      if (strcmp(fAxes[d]->GetName(), param.GetName()) == 0) {
         Long_t numBins = (Long_t)fSparseHist->GetNbins();
         double upperLimit = param.getMin();
         double val;
         for (Long_t i = 0; i < numBins; i++) {
            if (fSparseHist->GetBinContent(i, &coord[0]) >= fHistCutoff) {
               val = fSparseHist->GetAxis(d)->GetBinCenter(coord[d]);
               if (val > upperLimit)
                  upperLimit = val;
            }
         }
         return upperLimit;
      }
   }
   return param.getMax();
}

////////////////////////////////////////////////////////////////////////////////
/// Determine the upper limit for param on this interval
/// using the binned data set

double MCMCInterval::UpperLimitByDataHist(RooRealVar& param)
{
   if (fHistCutoff < 0)
      DetermineByDataHist(); // this initializes fDataHist

   if (fHistCutoff < 0) {
      // if fHistCutoff < 0 still, then determination of interval failed
      coutE(Eval) << "In MCMCInterval::UpperLimitByDataHist: "
         << "couldn't determine cutoff.  Check that num burn in steps < num "
         << "steps in the Markov chain.  Returning param.getMax()." << std::endl;
      return param.getMax();
   }

   for (Int_t d = 0; d < fDimension; d++) {
      if (strcmp(fAxes[d]->GetName(), param.GetName()) == 0) {
         Int_t numBins = fDataHist->numEntries();
         double upperLimit = param.getMin();
         double val;
         for (Int_t i = 0; i < numBins; i++) {
            fDataHist->get(i);
            if (fDataHist->weight() >= fHistCutoff) {
               val = fDataHist->get()->getRealValue(param.GetName());
               if (val > upperLimit)
                  upperLimit = val;
            }
         }
         return upperLimit;
      }
   }
   return param.getMax();
}

////////////////////////////////////////////////////////////////////////////////
/// Determine the lower limit for param on this interval
/// using the keys pdf

double MCMCInterval::LowerLimitByKeys(RooRealVar& param)
{
   if (fKeysCutoff < 0)
      DetermineByKeys();

   if (fKeysDataHist == nullptr)
      CreateKeysDataHist();

   if (fKeysCutoff < 0 || fKeysDataHist == nullptr) {
      // failure in determination of cutoff and/or creation of histogram
      coutE(Eval) << "in MCMCInterval::LowerLimitByKeys(): "
         << "couldn't find lower limit, check that the number of burn in "
         << "steps < number of total steps in the Markov chain.  Returning "
         << "param.getMin()" << std::endl;
      return param.getMin();
   }

   for (Int_t d = 0; d < fDimension; d++) {
      if (strcmp(fAxes[d]->GetName(), param.GetName()) == 0) {
         Int_t numBins = fKeysDataHist->numEntries();
         double lowerLimit = param.getMax();
         double val;
         for (Int_t i = 0; i < numBins; i++) {
            fKeysDataHist->get(i);
            if (fKeysDataHist->weight() >= fKeysCutoff) {
               val = fKeysDataHist->get()->getRealValue(param.GetName());
               if (val < lowerLimit)
                  lowerLimit = val;
            }
         }
         return lowerLimit;
      }
   }
   return param.getMin();
}

////////////////////////////////////////////////////////////////////////////////
/// Determine the upper limit for param on this interval
/// using the keys pdf

double MCMCInterval::UpperLimitByKeys(RooRealVar& param)
{
   if (fKeysCutoff < 0)
      DetermineByKeys();

   if (fKeysDataHist == nullptr)
      CreateKeysDataHist();

   if (fKeysCutoff < 0 || fKeysDataHist == nullptr) {
      // failure in determination of cutoff and/or creation of histogram
      coutE(Eval) << "in MCMCInterval::UpperLimitByKeys(): "
         << "couldn't find upper limit, check that the number of burn in "
         << "steps < number of total steps in the Markov chain.  Returning "
         << "param.getMax()" << std::endl;
      return param.getMax();
   }

   for (Int_t d = 0; d < fDimension; d++) {
      if (strcmp(fAxes[d]->GetName(), param.GetName()) == 0) {
         Int_t numBins = fKeysDataHist->numEntries();
         double upperLimit = param.getMin();
         double val;
         for (Int_t i = 0; i < numBins; i++) {
            fKeysDataHist->get(i);
            if (fKeysDataHist->weight() >= fKeysCutoff) {
               val = fKeysDataHist->get()->getRealValue(param.GetName());
               if (val > upperLimit)
                  upperLimit = val;
            }
         }
         return upperLimit;
      }
   }
   return param.getMax();
}

////////////////////////////////////////////////////////////////////////////////
/// Determine the approximate maximum value of the Keys PDF

double MCMCInterval::GetKeysMax()
{
   if (fKeysCutoff < 0)
      DetermineByKeys();

   if (fKeysDataHist == nullptr)
      CreateKeysDataHist();

   if (fKeysDataHist == nullptr) {
      // failure in determination of cutoff and/or creation of histogram
      coutE(Eval) << "in MCMCInterval::KeysMax(): "
         << "couldn't find Keys max value, check that the number of burn in "
         << "steps < number of total steps in the Markov chain.  Returning 0"
         << std::endl;
      return 0;
   }

   Int_t numBins = fKeysDataHist->numEntries();
   double max = 0;
   double w;
   for (Int_t i = 0; i < numBins; i++) {
      fKeysDataHist->get(i);
      w = fKeysDataHist->weight();
      if (w > max)
         max = w;
   }

   return max;
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::GetHistCutoff()
{
   if (fHistCutoff < 0)
      DetermineByHist();

   return fHistCutoff;
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::GetKeysPdfCutoff()
{
   if (fKeysCutoff < 0)
      DetermineByKeys();

   // kbelasco: if fFull hasn't been set (because Keys creation failed because
   // fNumBurnInSteps >= fChain->Size()) then this will return infinity, which
   // seems ok to me since it will indicate error

   return fKeysCutoff / fFull;
}

////////////////////////////////////////////////////////////////////////////////

double MCMCInterval::CalcConfLevel(double cutoff, double full)
{
   fCutoffVar->setVal(cutoff);
   std::unique_ptr<RooAbsReal> integral{fProduct->createIntegral(fParameters, NormSet(fParameters))};
   double confLevel = integral->getVal(fParameters) / full;
   coutI(Eval) << "cutoff = " << cutoff << ", conf = " << confLevel << std::endl;
   return confLevel;
}

////////////////////////////////////////////////////////////////////////////////

TH1* MCMCInterval::GetPosteriorHist()
{
   if (fConfidenceLevel == 0) {
      coutE(InputArguments) << "Error in MCMCInterval::GetPosteriorHist: "
                            << "confidence level not set " << std::endl;
   }
  if (fHist == nullptr)
     CreateHist();

  if (fHist == nullptr) {
     // if fHist is still nullptr, then CreateHist failed
     return nullptr;
  }

  return static_cast<TH1*>(fHist->Clone("MCMCposterior_hist"));
}

////////////////////////////////////////////////////////////////////////////////

RooNDKeysPdf* MCMCInterval::GetPosteriorKeysPdf()
{
  if (fConfidenceLevel == 0) {
     coutE(InputArguments) << "Error in MCMCInterval::GetPosteriorKeysPdf: "
                           << "confidence level not set " << std::endl;
  }
   if (fKeysPdf == nullptr)
      CreateKeysPdf();

   if (fKeysPdf == nullptr) {
      // if fKeysPdf is still nullptr, then it means CreateKeysPdf failed
      return nullptr;
   }

   return static_cast<RooNDKeysPdf*>(fKeysPdf->Clone("MCMCPosterior_keys"));
}

////////////////////////////////////////////////////////////////////////////////

RooProduct* MCMCInterval::GetPosteriorKeysProduct()
{
   if (fConfidenceLevel == 0) {
      coutE(InputArguments) << "MCMCInterval::GetPosteriorKeysProduct: "
                            << "confidence level not set " << std::endl;
   }
   if (fProduct == nullptr) {
      CreateKeysPdf();
      DetermineByKeys();
   }

   if (fProduct == nullptr) {
      // if fProduct is still nullptr, then it means CreateKeysPdf failed
      return nullptr;
   }

   return static_cast<RooProduct*>(fProduct->Clone("MCMCPosterior_keysproduct"));
}

////////////////////////////////////////////////////////////////////////////////

RooArgSet* MCMCInterval::GetParameters() const
{
   // returns list of parameters
   return new RooArgSet(fParameters);
}

////////////////////////////////////////////////////////////////////////////////

bool MCMCInterval::AcceptableConfLevel(double confLevel)
{
   return (std::abs(confLevel - fConfidenceLevel) < fEpsilon);
}

////////////////////////////////////////////////////////////////////////////////

bool MCMCInterval::WithinDeltaFraction(double a, double b)
{
   return (std::abs(a - b) < std::abs(fDelta * (a + b)/2));
}

////////////////////////////////////////////////////////////////////////////////

void MCMCInterval::CreateKeysDataHist()
{
   if (fAxes.empty())
      return;
   if (fProduct == nullptr)
      DetermineByKeys();
   if (fProduct == nullptr) {
      // if fProduct still nullptr, then creation failed
      return;
   }

   //RooAbsBinning** savedBinning = new RooAbsBinning*[fDimension];
   std::vector<Int_t> savedBins(fDimension);
   Int_t i;
   double numBins;
   RooRealVar* var;

   // kbelasco: Note - the accuracy is only increased here if the binning for
   // each RooRealVar is uniform

   // kbelasco: look into why saving the binnings and replacing them doesn't
   // work (replaces with 1 bin always).
   // Note: this code modifies the binning for the parameters (if they are
   // uniform) and sets them back to what they were.  If the binnings are not
   // uniform, this code does nothing.

   // first scan through fAxes to make sure all binnings are uniform, or else
   // we can't change the number of bins because there seems to be an error
   // when setting the binning itself rather than just the number of bins
   bool tempChangeBinning = true;
   for (i = 0; i < fDimension; i++) {
      if (!fAxes[i]->getBinning(nullptr, false, false).isUniform()) {
         tempChangeBinning = false;
         break;
      }
   }

   // kbelasco: for 1 dimension this should be fine, but for more dimensions
   // the total number of bins in the histogram increases exponentially with
   // the dimension, so don't do this above 1-D for now.
   if (fDimension >= 2)
      tempChangeBinning = false;

   if (tempChangeBinning) {
      // set high number of bins for high accuracy on lower/upper limit by keys
      for (i = 0; i < fDimension; i++) {
         var = fAxes[i];
         //savedBinning[i] = &var->getBinning("__binning_clone", false, true);
         savedBins[i] = var->getBinning(nullptr, false, false).numBins();
         numBins = (var->getMax() - var->getMin()) / fEpsilon;
         var->setBins((Int_t)numBins);
      }
   }

   fKeysDataHist = std::make_unique<RooDataHist>("_productDataHist",
         "Keys PDF & Heaviside Product Data Hist", fParameters);
   fProduct->fillDataHist(fKeysDataHist.get(), &fParameters, 1.);

   if (tempChangeBinning) {
      // set the binning back to normal
      for (i = 0; i < fDimension; i++) {
         //fAxes[i]->setBinning(*savedBinning[i], nullptr);
         //fAxes[i]->setBins(savedBinning[i]->numBins(), nullptr);
         fAxes[i]->setBins(savedBins[i], nullptr);
      }
   }
}

////////////////////////////////////////////////////////////////////////////////

bool MCMCInterval::CheckParameters(const RooArgSet& parameterPoint) const
{
   // check that the parameters are correct

   if (parameterPoint.size() != fParameters.size() ) {
     coutE(Eval) << "MCMCInterval: size is wrong, parameters don't match" << std::endl;
     return false;
   }
   if ( ! parameterPoint.equals( fParameters ) ) {
     coutE(Eval) << "MCMCInterval: size is ok, but parameters don't match" << std::endl;
     return false;
   }
   return true;
}
