# -*- coding: utf-8 -*-
# fitter.py

"""
File to handle linear fitting for Physics 50.

"""

from django import forms
import os
os.environ['MPLCONFIGDIR'] = '/tmp'

import matplotlib
matplotlib.use('Agg')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import urllib
from django.shortcuts import render
from django.utils.translation import gettext as _
from django.forms import ValidationError


def download(request):
    "Download a zip archive of the code and Jupyter notebook"
    from pathlib import Path
    from django.http import FileResponse
    par = Path(__file__).parent
    zippy = par / "Fitter.zip"
    response = FileResponse(open(zippy, "rb"))
    return response


function_choices = (
    ('data', 'plot data only'),
    ('constant', 'constant'),
    ('linear', 'linear'),
    ('quadratic', 'quadratic'),
    ('exponential', 'exponential'),
    #    ('sinusoid', 'sinusoid'),
    ('power', 'power')
)


def myconst(x, b):
    return x * 0 + b


myconst.tex = "$y = b$"
myconst.params = ("$b$",)
myconst.p0 = (1,)


def myline(x, m, b):
    return m * x + b


myline.tex = "$y = mx + b$"
myline.params = [f"${x}$" for x in "m;b".split(';')]
myline.p0 = (1, 1)


def myquad(x, a, b, c):
    return c + b * x + a * x * x


myquad.tex = "$y = a x^2 + b x + c$"
myquad.params = [f"${x}$" for x in "a;b;c".split(';')]
myquad.p0 = (1, 1, 1)


def sort_data(x, y):
    n = np.argsort(x)
    return x[n], y[n]

# Set up exponential fit to attempt a guess for initial values


def myexpop0(x, y):
    x, y = sort_data(x, y)
    xrange = x[-1] - x[0]
    yrange = np.log(y[0] / y[-1])
    d = xrange / yrange
    A = y[0] / np.exp(-x[0] / d)
    return (A, d)


def myexpo(x, A, d):
    return A * np.exp(-x / d)


myexpo.tex = r"$y = A e^{-x/d}$"
myexpo.params = [f"${x}$" for x in r"A;d".split(';')]
myexpo.p0 = myexpop0

# Set up power law fit to attempt a guess for initial values


def mypowerp0(x, y):
    x, y = sort_data(x, y)
    xrange = np.log(x[0] / x[-1])
    yrange = np.log(y[0] / y[-1])
    n = yrange / xrange
    A = y[0] / x[0] ** n
    return (A, n)


def mypower(x, A, n):
    return A * x ** n


mypower.tex = r"$y = A x^n$"
mypower.params = [f"${x}$" for x in r"A;n".split(';')]
mypower.p0 = mypowerp0


CAPSIZE = 3
DATACOLOR = 'r'
FITCOLOR = 'b'


class FreeFitForm(forms.Form):
    data = forms.CharField(
        widget=forms.Textarea,
        min_length=30, max_length=500,
        help_text='Paste data with columns (x, y, y_unc) separated by tabs or commas'
    )
    kind = forms.ChoiceField(choices=function_choices)
    options = forms.CharField(
        max_length=255,
        min_length=0,
        strip=True,
        required=False,
        help_text='optional key-value pairs to format the plot',
    )

    def dollar(self, x: str):
        """
        Screen the column heads for appropriate format
        """
        dollar_signs = x.count('$')
        if (dollar_signs % 2) != 0:
            raise ValidationError(
                _('Imbalanced $ in “' + x + '”'), code='dollars'
            )
        if x.count('{') != x.count('}'):
            raise ValidationError(
                _('Imbalanced braces in “' + x + '”'), code='braces'
            )
        try:
            v = float(x)
        except:
            return
        raise ValidationError(
            _('The first row of the CSV file must hold axes labels'),
            code='labels',
        )

    def clean(self):
        """I'm going to attempt to make it remember the data
        and not have a person reload every time.
        """
        cd = self.cleaned_data
        data = cd['data'].replace('\r\n', '\n')
        cd['text'] = data # store original text data
        source = io.StringIO(data)

        # set the separator
        sep = ' '
        if '\t' in data:
            sep = '\t'
        if ',' in data:
            sep = ','
        try:
            df = pd.read_csv(source, sep=sep)
        except:
            raise ValidationError(
                _(
                    'Could not load data file; confirm that it is a valid CSV file',
                    code='file',
                )
            )

        dtypes = df.dtypes
        for n, dt in enumerate(dtypes):
            if dt not in (np.dtype('float64'), np.dtype('int64')):
                raise ValidationError(
                    _('Column %d has one or more nonnumeric values' % n),
                    code='columns',
                )

        # The csv must have either two or three columns
        rows, cols = df.shape
        if cols < 2 or cols > 3:
            raise ValidationError(
                _('The data table must have 2 or 3 columns, not ' + str(cols)),
                code='columns',
            )

        # make sure the column heads make sense
        cd['header'] = heads = list(df.columns)
        for h in heads[:2]:
            self.dollar(h)
        cd['xlabel'] = heads[0]
        cd['ylabel'] = heads[1]
        cd['x'] = df.iloc[:, 0].to_numpy()
        cd['y'] = df.iloc[:, 1].to_numpy()
        try:
            cd['yunc'] = df.iloc[:, 2].to_numpy()
        except:
            cd['yunc'] = None
        data = df.to_numpy()
        cd['data'] = data
        cd['source'] = source


class FitForm(forms.Form):
    data = forms.FileField(
        help_text='Upload a CSV file with columns x,y,y_uncertainty',
        allow_empty_file=True
    )
    kind = forms.ChoiceField(choices=function_choices)
    options = forms.CharField(
        max_length=255,
        min_length=0,
        strip=True,
        required=False,
        help_text='optional key-value pairs to format the plot',
    )

    def dollar(self, x: str):
        """
        Screen the column heads for appropriate format
        """
        dollar_signs = x.count('$')
        if (dollar_signs % 2) != 0:
            raise ValidationError(
                _('Imbalanced $ in “' + x + '”'), code='dollars'
            )
        if x.count('{') != x.count('}'):
            raise ValidationError(
                _('Imbalanced braces in “' + x + '”'), code='braces'
            )
        try:
            v = float(x)
        except:
            return
        raise ValidationError(
            _('The first row of the CSV file must hold axes labels'),
            code='labels',
        )

    def clean(self):
        """I'm going to attempt to make it remember the data
        and not have a person reload every time.
        """
        cd = self.cleaned_data
        data = cd['data']
        cd['file'] = data # store file info
        source = cd.get('source', "")
        if source:
            setattr(data, 'file', io.StringIO(source))
        else:
            source = data.file.read().decode('utf-8')
            data.file.seek(0)
        # set the separator
        sep = '\t' if '\t' in source else ','
        try:
            df = pd.read_csv(data, sep=sep)
        except:
            raise ValidationError(
                _(
                    'Could not load data file; confirm that it is a valid CSV file',
                    code='file',
                )
            )

        dtypes = df.dtypes
        for n, dt in enumerate(dtypes):
            if dt not in (np.dtype('float64'), np.dtype('int64')):
                raise ValidationError(
                    _('Column %d has one or more nonnumeric values' % n),
                    code='columns',
                )

        # The csv must have either two or three columns
        rows, cols = df.shape
        if cols < 2 or cols > 3:
            raise ValidationError(
                _('The CSV file must have 2 or 3 columns, not ' + str(cols)),
                code='columns',
            )

        # make sure the column heads make sense
        cd['header'] = heads = list(df.columns)
        for h in heads[:2]:
            self.dollar(h)
        cd['xlabel'] = heads[0]
        cd['ylabel'] = heads[1]
        cd['x'] = df.iloc[:, 0].to_numpy()
        cd['y'] = df.iloc[:, 1].to_numpy()
        try:
            cd['yunc'] = df.iloc[:, 2].to_numpy()
        except:
            cd['yunc'] = None
        data = df.to_numpy()
        cd['data'] = data
        cd['source'] = source


def process_form(fitform):
    result = dict(fitform=fitform, picformat='png')

    if fitform.is_valid():
        cd = fitform.cleaned_data
        source = cd['source']
        if isinstance(fitform, FitForm):
            ff = FreeFitForm(
                {'data': source,
                 'kind': cd['kind'],
                 'options': cd['options']
                 })
            result['fitform'] = ff
        x, y, yunc = cd['x'], cd['y'], cd['yunc']
        xlabel, ylabel = cd['xlabel'], cd['ylabel']
        kind = cd['kind']
        options = cd['options']
        # process options
        opts = {}
        for fld in options.split(','):
            try:
                k, v = [x.strip() for x in fld.split('=')]
            except:
                continue
            # parse v if possible
            if ';' in v:
                val = [float(x) for x in v.split(';')]
            else:
                try:
                    val = int(v)
                except:
                    try:
                        val = float(v)
                    except:
                        val = v
            opts[k] = val

        mapper = dict(
            data=None, constant=myconst, linear=myline, quadratic=myquad,
            exponential=myexpo, power=mypower
        )
        func = mapper[kind]

        # the p0 attribute of a function can either be a tuple of
        # initial guesses or a function taking the (x,y) data that
        # provides initial guesses
        p0 = None
        if func and hasattr(func, "p0"):
            if isinstance(func.p0, tuple):
                p0 = func.p0
            else:
                p0 = func.p0(x, y)
        if func:
            tex = func.params
        else:
            tex = ""

        p0 = opts.get('p0', p0) # override guesses if supplied
        mask = str(opts.get("mask", ""))
        f = Fit(func, x, y, p0=p0, yunc=yunc, tex=tex,
                mask=mask, hold=opts.get("hold"))

        try:
            f.plot(xlabel=xlabel, ylabel=ylabel, tex=tex, **opts)

            # convert plot to a png
            buff = io.BytesIO()
            dpi = opts.get('dpi', 150)
            pad_inches = opts.get('pad_inches', 0.4)
            picformat = opts.get('format', 'png')
            f.fig.savefig(buff, format=picformat, dpi=dpi,
                          pad_inches=pad_inches)
            buff.seek(0)

            result['figure'] = urllib.parse.quote(
                base64.b64encode(buff.read()))
            result['header'] = cd['header']
            result['data'] = cd['data']
            result['picformat'] = picformat
            plt.close('all')
        except Exception as oops:
            err = str(oops)
            fitform.add_error('data', err)
            result['data'] = cd['file'] # put back the file information
    return result


def fit(request):
    figure, form, data, header, source = None, 'png', None, None, None
    if request.method == 'POST':
        fitform = FitForm(request.POST, request.FILES)
        result = process_form(fitform)
        result['usefile'] = False
    else:
        fitform = FitForm()  # create a blank form
        result = dict(fitform=fitform, usefile=True)

    args = {
        'fitform': fitform,
        'figure': figure,
        'format': "image/png" if form == 'png' else "application/pdf",
        'header': header,
        'data': data,
        'source': source,
    }
    return render(request, 'igor/fitform2.html', result)


def fit2(request):
    if request.method == 'POST':
        fitform = FreeFitForm(request.POST, request.FILES)
        result = process_form(fitform)
    else:
        fitform = FreeFitForm()  # create a blank form
        result = dict(fitform=fitform)

    result['usefile'] = False

    return render(request, 'igor/fitform2.html', result)
# -----------------------------------------------------------
#
# Fitting code
#
# ------------------------------------------------------------


import inspect
import re
from scipy.optimize import curve_fit, OptimizeWarning
from scipy.odr import *
from scipy.stats import chi2
from matplotlib import rcParams


class Fit:
    """
    Base class that performs (non)linear least-squares fits and
    can make carefully formatted plots of the results.

    The function to apply may either be specified by passing in
    appropriate keyword arguments or by subclassing this class and
    providing the requisite functions as methods of the subclass.

    By default, a Fit object immediately attempts to run the fit with
    the information passed to the constructor. If that fails, or if the
    information is insufficient, you can see what the data look like and
    the calculated curve with the passed values for p0 (or the values generated
    by the estimate_p0 function) by calling plot on the resulting object.
    If you want to start the process without attempting a fit, but just
    seeing the plot with the curve generated by values in p0, pass the
    key-value pair adjust=True to the constructor. You can

    These requisite functions/methods are
    - function(x:np.ndarray, *params, **kwargs):
    - tex_f: (optional) a string or callable that returns a string providing
    a representation of the function in LaTeX.
    - estimate_p0(**kwargs): (optional) a function that takes a dictionary
    of keyword arguments, including the instance of this Fit class with the
    name 'self' and returns a vector of initial values for the
    fit parameters. The function may avail itself of data fields in
    the Fit object, including x, y.


    should be subclassed to implement a particular
    (nonlinear) fitting function. This base class provides all of the
    smarts, leaving to the subclass just the essentials particular to
    the chosen function. The subclass should implement the following
    fields and methods:

    - function(self, x:np.ndarray, *params), where self is an instance of a
        subclass of Fit
    - tex_f, a string or @property callable returning a string providing
        a representation of the function in LaTeX. If not defined,
        a crude approximation is generated by introspection of the function.
    - __str__(self), to produce a string representation of the fit
    - __init__(self, x:np.ndarray, y:np.ndarray, **kwargs), constructor
        that must call super().__init__(x, y, p0, tex, **kwargs), where
        p0 is an initial guess for the fitting parameters and
        tex is a list of strings with a LaTeX version of the names
        of the fit parameters

    If the fit is successful, the routine self.after_fit() is called; the
        default routine does nothing, but this would be an opportunity to
        adjust parameters, such as making sure that a parameter that
        enters quadratically has a positive value. In addition, the
        following fields are set:

    - self.valid is set to True
    - self.params is set to the list of optimized fitting parameters

    If yunc holds valid uncertainties, then
    - self.param_uncs is set to the list of parameter uncertainties
    - self.chisq is set
    - self.dof is set to the number of degrees of freedom of the fit
    - self.prob_greater indicates the probability that a greater value
        of chi-squared on repeating the experiment


    Optional keyword inputs:
    - yunc:np.ndarray, an array of uncertainties in the dependent variable
    - xunc:np.ndarray, an array of uncertainties in the independent variable
    - hold:str, a string of the form "01100", where "0" means the variable is
        optimized by the fitting routine and "1" means that it is held at the
        value given in p0
    - tex:[str], a list of strings providing LaTeX code to represent the
        fit parameters"""

    def __init__(self, function, x, y, p0, **kwargs):
        """
        kwargs may include hold="1011", where each 0 corresponds
        to a parameter allowed to vary and 1 to a value held fixed.
        If the function passed has a .tex attribute, it is assumed
        to be a string that can be passed to tex to describe the
        function in the fit annotation. Subclasses may define a
        string variable tex_f for this purpose.

        """
        self.function = function
        self.x = np.array(x)
        self.y = np.array(y)
        assert len(self.x) == len(
            self.y
        ), "The x and y arrays must have the same length"
        self.p0 = p0

        if not hasattr(self, "function_tex"):
            if hasattr(self.function, "tex"):
                self.function_tex = self.function.tex
            else:
                self.function_tex = kwargs.get('function_tex')

        self.error = None  # reason for failure of fit
        self.xunc = None  # standard error for independent variable
        self.yunc = None  # standard error for dependent variable

        self.xerrors = (
            None  # will be x errors if we consider x uncertainties
        )
        self.params = []
        self.param_uncs = []
        self.weighting = ""
        self.hold = kwargs.get("hold")
        try:
            self.dof = len(x) - len(self.p0)  # assumes nothing held
        except:
            pass
        # tex representation of variables
        self.tex_labels = kwargs.get('tex')
        self.chisq = None
        self.prob_greater = None
        self.axdata = None  # axis for plotting the data
        self.axyresiduals = None  # axis for y residuals
        self.axynorm = None  # axis for normalized y residuals
        self.axxresiduals = None  # axis for x residuals
        self.axxnorm = None  # axis for normalized x residuals
        self.top_axis = None  # use this one for the figure title

        # Handle the hold parameter. If it is nontrivial,
        # we need to use a modified function that adjusts the
        # p0 vector and the function definition to account only for
        # the parameters allowed to vary.

        try:
            # set the uncertainties
            for unc in ('yunc', 'xunc'):
                try:
                    uncs = kwargs[unc]
                except:
                    setattr(self, unc, None)
                    continue
                if isinstance(uncs, (float, int)):
                    uncs = np.ones(len(x)) * uncs
                else:
                    try:
                        uncs = np.array(uncs)
                    except:
                        uncs = None
                setattr(self, unc, uncs)
                assert len(uncs) == len(x)

            if isinstance(self.yunc, np.ndarray):
                self.weighting = (
                    "xy" if isinstance(self.xunc, np.ndarray) else "y"
                )
            if 'adjust' in kwargs or self.function == None:
                self.plot()
            else:
                self.run_fit(self.p0)
                self.after_fit()
        except Exception as eeps:
            self.error = str(eeps)

    def fhold(self, x, *params):
        """ """
        p = self.parameters(params)
        return self.function(x, *p)

    def run_fit(self, p0):
        """Run or re-run the fitting procedure, starting from
        the passed parameter values.
        """
        self.p0 = list(p0)  # update the initial parameter values

        # remove any remnants of a previous fitting procedure
        fields = "error;xerrors;yerrors;chisq;prob_greater;norm_yresiduals;reduced_chisq"
        fields += ";params;param_uncs"
        for f in fields.split(';'):
            setattr(self, f, None)

        if self.hold:
            if not isinstance(self.hold, str):
                raise "hold must be a string with 1 for fixed, 0 for variable"
            if len(self.hold) != len(self.p0):
                raise "The number of digits in hold must match the number of parameters"

            # We need to copy the parameters held fixed and respect the ones
            # being varied. The variable field holds the indices of the parameters
            # that are allowed to vary.

            self.variable = [
                x for x in range(len(p0)) if self.hold[x] == "0"
            ]
            self.dof = len(self.x) - len(self.variable)

            # def fhold(t, *params):
            # """This function is used when some parameters are held,
            # to reconstitute a complete parameter vector to use with
            # the fitting function."""
            # p = self.parameters(params)
            # return self.function(t, *p)

            f = lambda x, *p: self.fhold(x, *p)
            p0 = [self.p0[n] for n in self.variable]
        else:
            # f = self.function
            f = lambda x, *p: self.function(x, *p)

        try:
            if isinstance(self.xunc, np.ndarray):
                self.run_odr(f)
            else:
                self.params, self.covars = curve_fit(
                    f,
                    self.x,
                    self.y,
                    p0=p0,
                    sigma=self.yunc,
                    absolute_sigma=True,
                )

        except OptimizeWarning:
            raise "Failed to converge"

        if np.inf in self.covars or np.nan in self.covars:
            self.error = "Failed to converge"
        else:
            # adjust for held parameters, if necessary
            self.params = self.parameters(self.params)
            self.yresiduals = self.y - self.function(self.x, *self.params)
            if 'y' in self.weighting:
                self.norm_yresiduals = self.yresiduals / self.yunc
                self.chisq = np.sum(self.norm_yresiduals ** 2)

                # x errors?
                if 'x' in self.weighting:
                    self.norm_xresiduals = self.xresiduals / self.xunc
                    # What do we do about chisq?????
                    # The following is wrong; it needs to account for the slope
                    m = self._slope
                    unc = np.sqrt(
                        np.power(self.yunc, 2) + np.power(m * self.xunc, 2)
                    )
                    self.norm_residuals = self.yresiduals / unc
                    self.chisq = np.sum(self.norm_residuals ** 2)

                self.reduced_chisq = self.chisq / self.dof
                self.prob_greater = 1 - chi2.cdf(self.chisq, self.dof)

            # calculate the errors in the parameter estimations
            errs = np.sqrt(np.diag(self.covars))
            if self.hold:
                self.param_uncs = np.zeros(self.params.size)
                for n, pnum in enumerate(self.variable):
                    self.param_uncs[pnum] = errs[n]
            else:
                self.param_uncs = errs

    def after_fit(self):
        """Override to perform any post-fit alteration of parameters"""
        pass

    def run_odr(self, f):
        """Run an orthogonal distance regression to handle errors along x"""
        data = Data(
            self.x,
            self.y,
            wd=1.0 / np.power(self.xunc, 2),
            we=1.0 / np.power(self.yunc, 2),
        )

        # an odr.Model assumes a call signature f(beta, x) -> y
        # where beta is the fitting parameters. The call signature
        # of f, however, is f(x, *params), so we need to remap.

        def f_odr(beta, x): return f(x, *beta)

        model = Model(f_odr)
        odr = ODR(data, model, self.p0)
        odr.run()
        self.odr = odr
        output = odr.output
        if output.info == 1:
            # fitting was successful
            self.params = output.beta
            self.covars = output.cov_beta
            self.yresiduals = output.eps
            self.xresiduals = output.delta

    def parameters(self, par):
        "Handle the subset of parameters that are being allowed to vary."
        if self.hold:
            p = self.p0
            for n, pnum in enumerate(self.variable):
                p[pnum] = par[n]
            return np.array(p)
        return par

    def __call__(self, t: np.ndarray):
        "Evaluate the fitting function using current parameter values."
        if self.function == None:
            return []
        if self.valid:
            return self.function(t, *self.params)
        return self.function(t, *self.p0)

    @property
    def _slope(self, dx=1e-6):
        "Return a numerical approximation  to the slope at each data pt"
        dy = self(self.x + dx) - self(self.x - dx)
        return dy / (2 * dx)

    def __str__(self):
        name = self.function.__name__
        args = inspect.getfullargspec(self.function)
        argnames = args.args[1:]
        lines = [
            name,
        ]
        has_unc = isinstance(self.yunc, np.ndarray)
        if self.valid:
            for n in range(len(self.params)):
                lines.append(f"{argnames[n]:>16s} = {self.params[n]:^8.4g}")
                if has_unc:
                    lines[-1] += f" ± {self.errors[n]:.2g}"
                    if self.param_uncs[n] > 0:
                        lines[
                            -1
                        ] += f" ({abs(self.errors[n] / self.params[n]) * 100:.2g}%)"
            if has_unc:
                lines.append(f"N_dof = {self.dof}, chisq = {self.chisq:.3g}")
                lines[-1] += f" ({self.reduced_chisq:.3g})"
                lines[-1] += f", P> = {100 * self.prob_greater:.2g}%"
        else:
            lines.append(self.error)
        return "\n".join(lines)

    def texval(self, x):
        """
        Render a string representation of a value in TeX format
        """
        if "e" not in x:
            return x
        m = re.search(r'([-+0-9.]*)\s?e\s?([+-])\s?([0-9]*)', x)
        if m:
            s = m.group(1) + r" \times 10^{"
            if m.group(2) == '-':
                s += '-'
            s += str(int(m.group(3))) + "}"
            return s
        print("Ack! for " + x)
        return x

    def tex_val_unc(self, x, dx):
        """
        Produce a LaTeX representation of the value and its uncertainty
        """
        if dx == 0.0:
            return (str(x), "", "")
        try:
            assert dx != 0.0 and x != 0.0
            xdigits = int(np.log10(abs(x)))
            dxdigits = int(np.log10(dx))
            digits = 2 + xdigits - dxdigits
            round_spot = -xdigits + digits
            xround = round(x, round_spot)
            dxround = round(dx, round_spot)

            ratio = dx / abs(x)  # this will fail if dx is 0
            # digits = 2 + int(np.round(np.log10(ratio), 2))
            fmt = "{:0." + str(round_spot) + "f}"
            main = self.texval(fmt.format(xround))
            # To get the right number of digits, we need to
            # figure out the place of the LSD of x
            unc = self.texval(fmt.format(dxround))
            if ratio > 1:
                rel = f"{ratio:.2f}"
            elif ratio > 0.001:
                rel = f"{100 * ratio:.1f}" + "\\%"
            else:
                rel = f"{int(1e6 * ratio)}" + r"\;\rm ppm"

            return (main, unc)  # , rel)
        except Exception as eeps:
            # print(f"tex_val_unc error {eeps} for {x}, {dx}")
            return (f"{x:.3g}", f"{dx:.2g}")

    def legend_tex(self, use_table=True):
        """Generate an annotation showing the fit function
        and the fitting parameters.
        """
        name = (
            self.function_tex
            if hasattr(self, 'function_tex') and self.function_tex
            else self.function.__name__
        )

        if self.tex_labels:
            argnames = self.tex_labels
        else:
            args = inspect.getfullargspec(self.function)
            argnames = args.args[1:]
        lines = [
            name,
        ]

        has_unc = isinstance(self.yunc, np.ndarray)
        if has_unc and self.valid:
            stats = [
                r"$\chi_\nu^2 = "
                + self.texval(f"{self.reduced_chisq:.3g}")
                + r"\;\; \mathrm{for}\;\; \nu = %d$" % self.dof,
                r"$P_> = " + f"{100 * self.prob_greater:.1f}" + r"\%$",
            ]
        if self.valid:
            if use_table:
                lines.insert(0, r"\begin{tabular}{lcc}\multicolumn{3}{c}{")
                lines.append(r"}\\[0.05in]")
                lines.append(
                    # & \textbf{Rel. Unc.}"
                    r"\textbf{Param} & \textbf{Value} & \textbf{Unc.}"
                )
            for n in range(len(self.params)):
                if use_table:
                    lines.append("\\\\ " + argnames[n] + " & ")
                    lines.append(
                        " & ".join(
                            [
                                f"${x}$"
                                for x in self.tex_val_unc(
                                    self.params[n], self.param_uncs[n]
                                )
                            ]
                        )
                    )

                else:
                    lines.append(f"{argnames[n]:>16s} = $")
                    v, u, r = self.tex_val_unc(
                        self.params[n], self.param_uncs[n]
                    )
                    if r == "":
                        vals = f"{v}$ (fixed)"
                    else:
                        vals = "{0} \\pm {1}\\; ({2})$".format(v, u, r)
                    lines[-1] += vals
            if has_unc:
                if use_table:
                    lines.append(r"\\[0.05in]\multicolumn{3}{c}{")
                    lines.append(stats[0] + r" \qquad " + stats[1] + "}")
                    lines.append(r"\end{tabular}")
                else:
                    lines.append(stats[0])
                    lines.append(stats[1])
            else:
                if use_table:
                    lines.append(r"\end{tabular}")

        else:
            lines.append(self.error)

        res = "\n".join(lines)
        # print(res)
        return res

    @property
    def valid(self):
        "Did the fit exit successfully?"
        return self.error is None and hasattr(self, 'covars')

    @property
    def error_scale(self):
        """
        By what factor errors need to grow to yield chisq/DoF = 1
        """
        if not self.valid or not isinstance(self.param_uncs, np.ndarray):
            return None
        return np.sqrt(self.chisq / self.dof)

    def plot(self, **kwargs):
        """
        Generate a plot showing the data, fit, residuals, and
        normalized residuals.

        Optional kwargs:

        - residuals (boolean)
        - normalized_residuals (boolean)
        - title (str)
        - xlabel (str)
        - ylabel (str)
        - figsize (width, height)
        - msize (marker size, in pts)
        - logx (boolean)
        - logy (boolean)
        - xmin (float)
        - xmax (float)
        - ymin (float)
        - ymax (float)
        - xfit (array of positions at which to compute the fitted curve)
        - yfit (array of precomputed fit values). If yfit is not supplied,
          the fitted curve will be evaluated at xfit values.
        - npoints (int) defaults to 200 and specifies the number of points
          to use in computing the fitted curve
        - legend (xfrac, yfrac) specifies the anchor position of the
          annotation, expressed as a fraction of the width and height of the
          plot area. Defaults to (0.1, 0.1). Alternatively, can be a suitable
          combination of north, east, south, and west for position outside
          the plot area.
        """
        yresiduals, normalized_yresiduals = False, False
        title = kwargs.get('title', '')
        NORM = False
        axes_order = []
        if self.valid:
            # See if we should plot residuals
            normalized_yresiduals = (
                kwargs.get('normalized_residuals', "y" in self.weighting)
                and NORM
            )
            if normalized_yresiduals:
                axes_order.append('axynorm')
            yresiduals = kwargs.get('residuals', True)
            if yresiduals:
                axes_order.append('axyresiduals')
        axes_order.append('axdata')

        # if we have xerrs, we should add those axes
        # if isinstance(self.xerrs, np.ndarray):
        #    axes_order
        self.axes_order = axes_order

        self.make_gspec(**kwargs)
        self.top_axis.set_title(title)

        # Lay down the fit curve first, then plot the data

        logx = kwargs.get('logx', False)
        logy = kwargs.get('logy', False)
        xmin = kwargs.get('xmin', np.min(self.x))
        xmax = kwargs.get('xmax', np.max(self.x))
        ymin = kwargs.get('ymin', np.min(self.y))
        ymax = kwargs.get('ymax', np.max(self.y))

        xfit, yfit = kwargs.get('xfit'), kwargs.get('yfit')
        npoints = kwargs.get('npoints', 200)
        if logx:
            self.axdata.set_xscale('log', nonposx='clip')
            if not isinstance(xfit, np.ndarray):
                xfit = np.power(
                    10, np.linspace(np.log10(xmin), np.log10(xmax), npoints)
                )
        elif not isinstance(xfit, np.ndarray):
            xfit = np.linspace(xmin, xmax, npoints)
        if not yfit:
            yfit = self.__call__(xfit)

        if logy:
            self.axdata.set_yscale('log', nonposy='clip')

        if len(xfit) == len(yfit):
            self.axdata.plot(xfit, yfit, c=FITCOLOR, alpha=0.6)

        msize = kwargs.get('msize', 8 if len(self.x) < 24 else 6)
        # msize *= msize # matplotlib uses the square of the size
        if self.weighting:
            self.axdata.errorbar(
                self.x,
                self.y,
                yerr=self.yunc,
                xerr=self.xunc,
                fmt='o',
                c=DATACOLOR,
                alpha=0.5,
                markersize=msize,
                capsize=3
            )
        else:
            self.axdata.scatter(self.x, self.y, alpha=0.5, s=msize,
                                c=DATACOLOR)

        if 'xlabel' in kwargs:
            self.axdata.set_xlabel(kwargs['xlabel'])
        if 'ylabel' in kwargs:
            self.axdata.set_ylabel(kwargs['ylabel'])

        # Now handle residuals

        if yresiduals:
            if logx:
                self.axyresiduals.set_xscale('log', nonposx='clip')
            # plot the zero line
            self.axyresiduals.plot([xmin, xmax], [0, 0], 'k-', alpha=0.5)

            if self.weighting == 'xy':
                colors = self.norm_res_colors('xy')
                self.axyresiduals.scatter(
                    self.x,
                    self.yresiduals,
                    c=colors,
                    marker='o',
                    s=5 ** 2,
                )
                ms = 1
            else:
                ms = 3

            # Unfortunately, the errorbars routine does not handle
            # colors-by-point
            self.axyresiduals.errorbar(
                self.x,
                self.yresiduals,
                yerr=self.yunc,
                ls='None',
                marker='o',
                ms=ms,
                c=DATACOLOR,
                capsize=3
            )
            self.axyresiduals.set_ylabel('Res.')

            # are there x residuals, too?
            if 'x' in self.weighting:
                if logy:
                    self.axxresiduals.set_yscale(
                        'log', nonposx='clip'
                    )  # nonposy?
                # plot the zero line
                self.axxresiduals.plot([0, 0], [ymin, ymax], 'k-', alpha=0.5)
                self.axxresiduals.scatter(
                    self.xresiduals, self.y, c=colors, s=5 ** 2, marker='o'
                )
                self.axxresiduals.errorbar(
                    self.xresiduals,
                    self.y,
                    xerr=self.xunc,
                    marker='o',
                    ls='None',
                    ms=ms,
                    capsize=3
                )
                self.axxresiduals.set_xlabel('Res.')

        if normalized_yresiduals:
            if logx:
                self.axynorm.set_xscale('log', nonposx='clip')
            # Compute colors
            self.axynorm.plot([xmin, xmax], [0, 0], 'k-', alpha=0.5)
            colors = self.norm_res_colors('y')
            self.axynorm.scatter(
                self.x, self.norm_yresiduals, s=5 ** 2, marker='o', c=colors
            )
            self.axynorm.set_ylabel('N.R.')

            if 'x' in self.weighting:
                ax = self.axxnorm
                ax.plot([0, 0], [ymin, ymax], 'k-', alpha=0.5)
                colors = self.norm_res_colors('x')
                ax.scatter(
                    self.norm_xresiduals,
                    self.y,
                    s=5 ** 2,
                    marker='o',
                    c=colors,
                )
                ax.set_xlabel('N.R.')

        if False and self.valid:
            # add an annotation
            # Because the calculated curve can exceed the data, we need
            # to update the value of ymax
            ymax = max(ymax, np.max(yfit))
            self.fit_results(fig, (xmin, xmax), (ymin, ymax), **kwargs)

        fig = self.fig
        fig.align_ylabels(fig.axes)
        #gs = fig._gridspecs[0]
        try:
            gs = self.gs
            gs.tight_layout(fig)
            #fig  # , pad=0.4, w_pad=0.5, h_pad=1.0  )  # plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
            gs.update(hspace=0.06, wspace=0.04)
        except:
            pass

        # Handle any explicitly designated axis limits
        xlim = (kwargs.get('xmin'), kwargs.get('xmax'))
        if xlim[0] != None or xlim[1] != None:
            self.axdata.set_xlim(*xlim)
        ylim = (kwargs.get('ymin'), kwargs.get('ymax'))
        if ylim[0] != None or ylim[1] != None:
            self.axdata.set_ylim(*ylim)

    def make_gspec(self, **kwargs):
        """
        Figure out how to layout the axes of the plot.
        Relevant issues:
          - do we have residuals and normalized residuals in y
          - in x?
          - is the legend outside the plot area?
        """
        # Will we be putting the legend inside the plot area or outside?
        legend = kwargs.get('legend', 'south')
        figsize = kwargs.get('figsize', rcParams["figure.figsize"])
        fontsize = rcParams['font.size']
        use_table = kwargs.get('table', True)

        # Adjust plot defaults
        rcParams["font.family"] = "serif"
        rcParams["font.size"] = 12.0
        # rcParams["font.serif"] = "Utopia, DejaVu Serif, Bitstream Vera Serif"
        rcParams["text.usetex"] = True
        rcParams["xtick.top"] = True
        rcParams["xtick.direction"] = "in"
        rcParams["ytick.right"] = True
        rcParams["ytick.direction"] = "in"
        rcParams["savefig.transparent"] = True

        axes_order = self.axes_order
        width, height = figsize
        # make a preliminary dictionary
        gspec = dict(
            width_ratios=[1],
            height_ratios=[1],
            hspace=0.025,
            wspace=0.025,
            left=0.125,
            right=0.975,
        )

        # If we don't have a valid fit, make a simple plot
        if not self.valid:
            self.fig, self.axdata = plt.subplots(gridspec_kw=gspec)
            self.top_axis = self.axdata
            return

        # we will have a legend describing fitting parameters
        msg = self.legend_tex(use_table)
        num_lines = 4.5 + len(self.params)  # len(msg.split('\n')) + 2
        msg_height = num_lines * fontsize * 1.5 / 72  # should be improved

        if "tabular" in msg:
            msg = msg.replace("\n", "")

        self._legend = dict(
            text=msg,
            num_lines=num_lines,
            msg_height=msg_height,
            position=legend,
            x=0.1,
            y=0.1,
            ha='center',
            va='middle',  # correct?
            primary=None,
            secondary=None,
        )
        leg = self._legend

        # The opts dictionary will get passed to a call to annotate
        # in the axes used for the legend
        opts = dict(
            xy=(0.5, 0.5), xycoords='axes fraction', va='center', ha='center'
        )
        legend_padding = " \n \n"

        if isinstance(legend, tuple):
            # a legend position of the form (x,y) puts the legend at that fraction
            # of the plot area. This part probably needs work
            opts['xy'] = legend
            if len(axes_order) == 3:
                gspec['height_ratios'] = [1, 1, 4]
            elif len(axes_order) == 2:
                gspec['height_ratios'] = [1.5, 4]

            self.fig, self.axes = plt.subplots(
                nrows=len(axes_order),
                sharex=True,
                figsize=figsize,
                gridspec_kw=gspec,
            )
            for ax in ['axdata', 'axyresiduals', 'axynorm']:
                if ax in axes_order:
                    setattr(self, ax, self.axes[axes_order.index(ax)])

            self.top_axis = self.axes[0]
            # self.axdata.legend
            self.axdata.annotate(msg, **opts)

        if isinstance(legend, str):
            # the legend will be placed outside the plot area
            # using frameon=False
            fields = [x.strip('"\'') for x in legend.lower().split(' ')]
            leg['primary'] = fields[0]
            leg['secondary'] = None if len(fields) == 1 else fields[1]

            if leg['primary'] in ('north', 'south'):
                # extend the plot height to accommodate the
                # lower axes, which will hold the annotation
                self.fig = plt.figure(
                    figsize=(width, height + msg_height + 0.25),
                    # constrained_layout=True,
                )
                ncols = 1

                if leg['primary'] == 'north':
                    axes_order.insert(0, 'axlegend')
                    msg += legend_padding
                else:
                    axes_order.append('axlegend')
                    msg = legend_padding + msg

                if len(axes_order) == 3:
                    # no normalized y residuals
                    gspec['height_ratios'] = [
                        1,
                        4,
                        msg_height / 4.5 * height,
                    ]
                    if "x" in self.weighting:
                        ncols = 2
                        gspec['width_ratios'] = [4, 1]
                elif len(axes_order) == 4:
                    # normalized y residualss, too
                    gspec['height_ratios'] = [
                        1.5,
                        1.5,
                        4,
                        msg_height * 7 / height,
                    ]
                    if "x" in self.weighting:
                        ncols = 3
                        gspec['width_ratios'] = [4, 1, 1]
                else:
                    # no residuals
                    gspec['height_ratios'] = [4, msg_height / 4.25 * height]

                # if the legend is in the north, we need to roll 1 position
                if leg['primary'] == 'north':
                    gspec['height_ratios'] = np.roll(
                        gspec['height_ratios'], 1
                    )
                gs = self.fig.add_gridspec(
                    ncols=ncols, nrows=len(axes_order), **gspec
                )
                self.gs = gs
                gs.update(hspace=0.02)
                ndata = axes_order.index('axdata')
                self.axdata = self.fig.add_subplot(gs[ndata, 0])

                for n, which in enumerate(axes_order):
                    if which in ('axdata', 'axlegend'):
                        continue
                    a = self.fig.add_subplot(gs[n, 0], sharex=self.axdata)
                    # suppress x ticks on residuals
                    plt.setp(a.get_xticklabels(), visible=False)
                    plt.subplots_adjust(wspace=0, hspace=0)
                    setattr(self, which, a)
                    if n == 0:
                        self.top_axis = a

                # if there are x errors, add those plots
                if ncols > 1:
                    for n, ax in zip([1, 2], ['axxresiduals', 'axxnorm']):
                        if n > ncols:
                            break
                        a = self.fig.add_subplot(
                            gs[ndata, n], sharey=self.axdata
                        )
                        plt.setp(a.get_yticklabels(), visible=False)
                        plt.subplots_adjust(wspace=0, hspace=0)
                        setattr(self, ax, a)

                # add the legend
                self.axlegend = self.fig.add_subplot(
                    gs[axes_order.index('axlegend'), 0]
                )
                self.axlegend.axis('off')
                self.axlegend.annotate(msg, **opts)
            else:
                self.fig = plt.figure(
                    figsize=(figsize[0] + 2.25, figsize[1])
                )
                legwidth = 6
                gspec['width_ratios'] = (
                    [width, legwidth]
                    if leg['primary'] == 'east'
                    else [legwidth, width]
                )
                if len(axes_order) == 2:
                    gspec['height_ratios'] = [1, 4]
                elif len(axes_order) == 3:
                    gspec['height_ratios'] = [1, 1, 4]
                gs = self.fig.add_gridspec(
                    ncols=2, nrows=len(axes_order), **gspec
                )

                col = 0 if leg['primary'] == 'west' else 1
                # set the legend
                self.axlegend = self.fig.add_subplot(gs[:, col])
                self.axlegend.annotate(msg, **opts)
                self.axlegend.set_axis_off()

                # To manage the axis sharing properly, we need to generate
                # the data plot first
                dindex = axes_order.index('axdata')
                self.axdata = self.fig.add_subplot(gs[dindex, 1 - col])

                # add axes to the object
                for n, which in enumerate(axes_order):
                    if which in ('axdata', 'axlegend'):
                        continue
                    a = self.fig.add_subplot(
                        gs[n, 1 - col], sharex=self.axdata
                    )
                    setattr(self, which, a)
                    plt.setp(a.get_xticklabels(), visible=False)
                    if n == 0:
                        self.top_axis = a
        # make sure the top axis is set
        topper = axes_order[0]
        a = getattr(self, topper)
        self.top_axis = a

    def fit_results(self, fig, xbounds, ybounds, **kwargs):
        """ """
        legend = kwargs.get('legend', (0.1, 0.1))
        msg = self.legend_tex()
        num_lines = len(msg.split('\n'))
        if isinstance(legend, str):
            # the legend string should be a cardinal direction

            fields = legend.lower().split(" ")
            primary = fields[0]
            secondary = fields[1] if len(fields) > 1 else None

            legend_height = num_lines / 3  # (inches?)
            width, height = fig.get_size_inches()

            if primary == 'south':
                new_height = height + legend_height
                fig.set_figheight(new_height)
                rect = (0, legend_height / new_height, 1, 1)
                fig.tight_layout(rect=rect)
                # plt.subplots_adjust(
                #    bottom=0.05 + legend_height / new_height, top=0.9
                # )

                xpos, ypos, valign = 0.5, 0, 'bottom'

                if secondary == 'west':
                    xpos, halign = 0, 'left'
                elif secondary == 'east':
                    xpos, halign = 1, 'right'
                else:
                    halign = 'center'

            fig.text(
                xpos,
                ypos,
                msg,
                horizontalalignment=halign,
                verticalalignment=valign,
            )
        else:
            xmin, xmax = xbounds
            ymin, ymax = ybounds
            xpos = (1 - legend[0]) * xmin + legend[0] * xmax
            ypos = (1 - legend[1]) * ymin + legend[1] * ymax
            halign = (
                'left'
                if legend[0] < 0.3
                else ('right' if legend[0] > 0.7 else 'center')
            )
            valign = (
                'bottom'
                if legend[1] < 0.3
                else ('top' if legend[1] > 0.7 else 'center')
            )

            ax = fig.axes[-1]
            ax.text(
                xpos,
                ypos,
                msg,
                horizontalalignment=halign,
                verticalalignment=valign,
            )

    def norm_res_colors(self, kind='y'):
        """
        return a color map for the desired residuals
        """
        if kind == 'y':
            norm_res = self.norm_yresiduals
        elif kind == 'x':
            norm_res = self.norm_xresiduals
        else:
            norm_res = self.norm_residuals
        sigmas = np.asarray(np.abs(norm_res), dtype=np.uint16)
        return self._res_colors(sigmas)

    def _res_colors(self, sigmas):
        sigmas[sigmas > 3] = 3
        alpha = 0.75
        colormap = [
            (0, 0.7, 0, alpha),
            (0.9, 0.9, 0, alpha),
            (1.0, 0.7, 0, alpha),
            (1, 0, 0, alpha),
        ]
        return [colormap[x] for x in sigmas]

