r# -*- coding: utf-8 -*-
r"""
  Author:   Peter N. Saeta
  Purpose:  perform nonlinear least-squares fits with appropriate
  			weighting, residuals, chisq, and probabilities, and
              generate plots with residuals and statistics
  Created:  19 October 2020

  Defines a class Fit:
    Base class that performs (non)linear least-squares fits and
    can make carefully formatted plots of the results.

    Minimal example:
    amp, center, width, background = 0.1, 14, 0.2, 1.2
    gus = Fit(
            x,
            y,
            yunc=noise,
            function_tex=r"$f(x)=y_0+A*\exp[-0.5((x-\mu)/\sigma)^2]$",
            tex=r"A;\mu;\sigma;y_0",
            function=lambda x, A, mu, sigma, back: back + A * np.exp(-0.5 * ((x - mu) / sigma) ** 2),
            p0=(amp, center, width, background)
        )
    gus.plot()

    The function to apply may either be specified by passing in
    appropriate keyword arguments or by subclassing this class and
    providing the requisite functions as methods of the subclass.

    By default, a Fit object immediately attempts to run the fit with
    the information passed to the constructor. If that fails, or if the
    information is insufficient, you can see what the data look like and
    the calculated curve with the passed values for p0 (or the values generated
    by the estimate_p0 function) by calling plot on the resulting object.
    If you want to start the process without attempting a fit, but just
    seeing the plot with the curve generated by values in p0, pass the
    key-value pair adjust=True to the constructor.

    These requisite functions/methods are
    - function(x:np.ndarray, *params, **kwargs):
    - tex_f: (optional) a string or callable that returns a string providing
      a representation of the function in LaTeX.
    - estimate_p0(**kwargs): (optional) a function that takes a dictionary
      of keyword arguments, including the instance of this Fit class with the
      name 'self' and returns a vector of initial values for the
      fit parameters. The function may avail itself of data fields in
      the Fit object, including x, y.

    Alternatively, you can subclass this Fit class and implement
    the following methods:

    - function(self, x:np.ndarray, *params), where self is an instance of a
        subclass of Fit
    - tex_f, a string or @property callable returning a string providing
        a representation of the function in LaTeX. If not defined,
        a crude approximation is generated by introspection of the function.
    - __str__(self), to produce a string representation of the fit
    - __init__(self, x:np.ndarray, y:np.ndarray, **kwargs), constructor
        that must call super().__init__(x, y, p0, tex, **kwargs), where
        p0 is an initial guess for the fitting parameters and
        tex is a list of strings with a LaTeX version of the names
        of the fit parameters

    If the fit is successful, the routine self.after_fit() is called; the
        default routine does nothing, but this would be an opportunity to
        adjust parameters, such as making sure that a parameter that
        enters quadratically has a positive value. In addition, the
        following fields are set:

    - self.valid is set to True
    - self.params is set to the list of optimized fitting parameters

    If yunc holds valid uncertainties, then
    - self.param_uncs is set to the list of parameter uncertainties
    - self.chisq is set
    - self.dof is set to the number of degrees of freedom of the fit
    - self.prob_greater indicates the probability that a greater value
        of chi-squared on repeating the experiment


    Optional keyword inputs:
    - yunc:np.ndarray, an array of uncertainties in the dependent variable
    - xunc:np.ndarray, an array of uncertainties in the independent variable
    - hold:str, a string of the form "01100", where "0" means the variable is
        optimized by the fitting routine and "1" means that it is held at the
        value given in p0
    - tex:[str], a list of strings providing LaTeX code to represent the
        fit parameters
    - mask:np.ndarray, an array to indicate a subset of data to exclude from
        the fitting procedure. If mask is an iterable of booleans, values that
        are True are used, while values that are False are removed.
        If mask is numeric, values that are nonzero are excluded while zero
        values are *not* masked out.

"""

import inspect
import re
import numpy as np
from scipy.optimize import curve_fit, OptimizeWarning
from scipy.odr import *
from scipy.stats import chi2
import matplotlib.pyplot as plt
from matplotlib import rcParams

CONSERVATIVE = True

traces = [
    (0, 0, 0.5),  # blue
    (0.5, 0, 0),  # red
    (0, 0.5, 0),  # green
    (0.5, 0.5, 0),  # yellow
    (0.4, 0.3, 0),  # brown ?
    (0.5, 0, 0.4),  # purple
    (0.5, 0.5, 0.5),  # gray
]


class Fit:
    r"""
    Base class that performs (non)linear least-squares fits and
    can make carefully formatted plots of the results.

    Minimal example:
    amp, center, width, background = 0.1, 14, 0.2, 1.2
    gus = Fit(
            x,
            y,
            yunc=noise,
            function_tex=r"$f(x)=y_0+A*\exp[-0.5((x-\mu)/\sigma)^2]$",
            tex=r"A;\mu;\sigma;y_0",
            function=lambda x, A, mu, sigma, back: back + A * np.exp(-0.5 * ((x - mu) / sigma) ** 2),
            p0=(amp, center, width, background)
        )
    gus.plot()

    The function to apply may either be specified by passing in
    appropriate keyword arguments or by subclassing this class and
    providing the requisite functions as methods of the subclass.

    By default, a Fit object immediately attempts to run the fit with
    the information passed to the constructor. If that fails, or if the
    information is insufficient, you can see what the data look like and
    the calculated curve with the passed values for p0 (or the values generated
    by the estimate_p0 function) by calling plot on the resulting object.
    If you want to start the process without attempting a fit, but just
    seeing the plot with the curve generated by values in p0, pass the
    key-value pair adjust=True to the constructor. You can

    These requisite functions/methods are
    - function(x:np.ndarray, *params, **kwargs):
    - tex_f: (optional) a string or callable that returns a string providing
      a representation of the function in LaTeX.
    - estimate_p0(**kwargs): (optional) a function that takes a dictionary
      of keyword arguments, including the instance of this Fit class with the
      name 'self' and returns a vector of initial values for the
      fit parameters. The function may avail itself of data fields in
      the Fit object, including x, y.


    should be subclassed to implement a particular
    (nonlinear) fitting function. This base class provides all of the
    smarts, leaving to the subclass just the essentials particular to
    the chosen function. The subclass should implement the following
    fields and methods:

    - function(self, x:np.ndarray, *params), where self is an instance of a
        subclass of Fit
    - tex_f, a string or @property callable returning a string providing
        a representation of the function in LaTeX. If not defined,
        a crude approximation is generated by introspection of the function.
    - __str__(self), to produce a string representation of the fit
    - __init__(self, x:np.ndarray, y:np.ndarray, **kwargs), constructor
        that must call super().__init__(x, y, p0, tex, **kwargs), where
        p0 is an initial guess for the fitting parameters and
        tex is a list of strings with a LaTeX version of the names
        of the fit parameters

    If the fit is successful, the routine self.after_fit() is called; the
        default routine does nothing, but this would be an opportunity to
        adjust parameters, such as making sure that a parameter that
        enters quadratically has a positive value. In addition, the
        following fields are set:

    - self.valid is set to True
    - self.params is set to the list of optimized fitting parameters

    If yunc holds valid uncertainties, then
    - self.param_uncs is set to the list of parameter uncertainties
    - self.chisq is set
    - self.dof is set to the number of degrees of freedom of the fit
    - self.prob_greater indicates the probability that a greater value
        of chi-squared on repeating the experiment


    Optional keyword inputs:
    - yunc:np.ndarray, an array of uncertainties in the dependent variable
    - xunc:np.ndarray, an array of uncertainties in the independent variable
    - hold:str, a string of the form "01100", where "0" means the variable is
        optimized by the fitting routine and "1" means that it is held at the
        value given in p0
    - tex:[str], a list of strings providing LaTeX code to represent the
        fit parameters
    - mask:np.ndarray, an array to indicate a subset of data to exclude from
        the fitting procedure. If mask is an iterable of booleans, values that
        are True are used, while values that are False are removed.
        If mask is numeric, values that are nonzero are excluded while zero
        values are *not* masked out.
    """

    def __init__(
        self,
        x: np.ndarray,
        y: np.ndarray,
        **kwargs,
    ):
        """
        kwargs may include hold="1011", where each 0 corresponds
        to a parameter allowed to vary and 1 to a value held fixed.
        If the function passed has a .tex attribute, it is assumed
        to be a string that can be passed to tex to describe the
        function in the fit annotation. Subclasses may define a
        string variable tex_f for this purpose.

        """
        if not hasattr(self, 'function'):
            self.function = kwargs.get('function')
        if not hasattr(self, 'estimate_p0'):
            self.est_p0 = kwargs.get('estimate_p0')
        self.x = np.array(x)
        self.y = np.array(y)
        # self.function = function
        if not hasattr(self, "function_tex"):
            if hasattr(self.function, "function_tex"):
                self.set_function_tex(self.function.function_tex)
            else:
                self.set_function_tex(kwargs.get('function_tex'))
        # tex representation of variables
        tex = kwargs.get('tex')
        if not tex and hasattr(self.function, "tex"):
            tex = self.function.tex
        self.tex_labels = self.get_tex_labels(tex)
        if 'p0' in kwargs:
            self.p0 = kwargs['p0']
        else:
            try:
                if hasattr(self, 'estimate_p0'):
                    self.p0 = self.estimate_p0(**kwargs)
                elif hasattr(self, 'est_p0'):
                    self.p0 = self.est_p0(self=self, **kwargs)
            except Exception as eeps:
                print(f"I could not find a way to estimate the initial parameters p0")
                raise eeps
        self.error = None  # reason for failure of fit
        self.xunc = None  # standard error for independent variable
        self.yunc = None  # standard error for dependent variable

        self.xerrors = None  # will be x errors if we consider x uncertainties
        self.params = []
        self.param_uncs = []
        self.weighting = ""
        self.hold = kwargs.get("hold")
        self.set_mask(kwargs.get("mask"))
        self.dof = self.set_dof()

        self.chisq = None
        self.prob_greater = None
        self.axdata = None  # axis for plotting the data
        self.axyresiduals = None  # axis for y residuals
        self.axynorm = None  # axis for normalized y residuals
        self.axxresiduals = None  # axis for x residuals
        self.axxnorm = None  # axis for normalized x residuals
        self.top_axis = None  # use this one for the figure title
        self._plot_params = dict()

        # Handle the hold parameter. If it is nontrivial,
        # we need to use a modified function that adjusts the
        # p0 vector and the function definition to account only for
        # the parameters allowed to vary.

        try:
            if len(x) != len(y):
                raise f"x [{len(x)}] and y [{len(y)}] must have the same length"

            # set the uncertainties
            for unc in ('yunc', 'xunc'):
                try:
                    uncs = kwargs[unc]
                except:
                    setattr(self, unc, None)
                    continue
                if isinstance(uncs, (float, int)):
                    uncs = np.ones(len(x)) * uncs
                else:
                    try:
                        uncs = np.array(uncs)
                    except:
                        uncs = None
                setattr(self, unc, uncs)
                assert len(uncs) == len(x)

            if isinstance(self.yunc, np.ndarray):
                self.weighting = "xy" if isinstance(self.xunc, np.ndarray) else "y"
            if 'adjust' in kwargs:
                self.plot()
            else:
                self.run_fit(self.p0)
                self.after_fit()
        except Exception as eeps:
            self.error = str(eeps)

    def set_mask(self, mask):
        """
        If mask is None, create a mask of all True
        If mask is an iterable of numeric type, treat
        non-zero values as False and zero values as True.
        Internally, the values from the x and y arrays that
        get used are those corresponding to True values in the mask,
        which is the reverse of 0 vs 1 values passed in (1 to mask OUT).
        """
        if mask is None:
            mask = np.zeros(len(self.x))
        else:
            assert len(mask) == len(self.x)
            if isinstance(mask[0], (bool, np.bool_)):
                self.mask = mask
                return
        self.mask = np.where(mask, False, True)

    def set_dof(self):
        lenx = len(self.data('x'))
        dof = lenx - len(self.p0)
        if self.hold is not None:
            dof += self.hold.count('1')
        return dof

    def data(self, key, complement=False):
        v = getattr(self, key)
        if v is None:
            return v
        mask = np.logical_not(self.mask) if complement else self.mask
        return v[mask]

    def fhold(self, x, *params):
        """ """
        p = self.parameters(params)
        return self.function(x, *p)

    def run_fit(self, p0):
        """Run or re-run the fitting procedure, starting from
        the passed parameter values.
        """
        self.p0 = list(p0)  # update the initial parameter values
        if hasattr(self, 'pre_fit'):
            self.pre_fit()

        # remove any remnants of a previous fitting procedure
        fields = (
            "error;xerrors;yerrors;chisq;prob_greater;norm_yresiduals;reduced_chisq"
        )
        fields += ";params;param_uncs"
        for f in fields.split(';'):
            setattr(self, f, None)

        if self.hold:
            if not isinstance(self.hold, str):
                raise "hold must be a string with 1 for fixed, 0 for variable"
            if len(self.hold) != len(self.p0):
                raise "The number of digits in hold must match the number of parameters"

            # We need to copy the parameters held fixed and respect the ones
            # being varied. The variable field holds the indices of the parameters
            # that are allowed to vary.

            self.variable = [x for x in range(len(p0)) if self.hold[x] == "0"]
            # self.dof = len(self.x) - len(self.variable)

            # def fhold(t, *params):
            # """This function is used when some parameters are held,
            # to reconstitute a complete parameter vector to use with
            # the fitting function."""
            # p = self.parameters(params)
            # return self.function(t, *p)

            f = lambda x, *p: self.fhold(x, *p)
            p0 = [self.p0[n] for n in self.variable]
        else:
            # f = self.function
            f = lambda x, *p: self.function(x, *p)

        try:
            if isinstance(self.xunc, np.ndarray):
                self.run_odr(f)
            else:
                sigma = self.data('yunc')
                self.params, self.covars = curve_fit(
                    f,
                    self.data('x'),
                    self.data('y'),
                    p0=p0,
                    sigma=self.data('yunc'),
                    absolute_sigma=True,
                )

        except OptimizeWarning:
            raise "Failed to converge"

        if np.any(np.isnan(self.covars)) or np.any(np.isinf(self.covars)):
            self.error = "Failed to converge"
        else:
            # adjust for held parameters, if necessary
            self.params = self.parameters(self.params)
            self.yresiduals = self.y - self.function(self.x, *self.params)
            if 'y' in self.weighting:
                self.norm_yresiduals = self.yresiduals / self.yunc
                self.chisq = np.sum(self.data('norm_yresiduals') ** 2)

                # x errors?
                if 'x' in self.weighting:
                    self.norm_xresiduals = self.xresiduals / self.xunc
                    # What do we do about chisq?????
                    # The following is wrong; it needs to account for the slope
                    m = self._slope
                    unc = np.sqrt(np.power(self.yunc, 2) + np.power(m * self.xunc, 2))
                    self.norm_residuals = self.yresiduals / unc
                    self.chisq = np.sum(self.data('norm_residuals') ** 2)

                self.reduced_chisq = self.chisq / self.dof

                self.prob_greater = 1 - chi2.cdf(self.chisq, self.dof)
            else:
                # we have no error estimates; set pearsonR2
                self.pearsonR2 = np.corrcoef(self.data('x'), self.data('y'))[0, 1]

            # calculate the errors in the parameter estimations
            errs = np.sqrt(np.diag(self.covars))

            if CONSERVATIVE and self.reduced_chisq and self.reduced_chisq > 1:
                errs *= np.sqrt(self.reduced_chisq)
            if self.hold:
                self.param_uncs = np.zeros(self.params.size)
                for n, pnum in enumerate(self.variable):
                    self.param_uncs[pnum] = errs[n]
            else:
                self.param_uncs = errs

    def after_fit(self):
        """Override to perform any post-fit alteration of parameters"""
        pass

    def run_odr(self, f):
        """Run an orthogonal distance regression to handle errors along x"""
        data = Data(
            self.x,
            self.y,
            wd=1.0 / np.power(self.xunc, 2),
            we=1.0 / np.power(self.yunc, 2),
        )

        # an odr.Model assumes a call signature f(beta, x) -> y
        # where beta is the fitting parameters. The call signature
        # of f, however, is f(x, *params), so we need to remap.
        f_odr = lambda beta, x: f(x, *beta)
        model = Model(f_odr)
        odr = ODR(data, model, self.p0)
        odr.run()
        self.odr = odr
        output = odr.output
        if output.info == 1:
            # fitting was successful
            self.params = output.beta
            self.covars = output.cov_beta
            self.yresiduals = output.eps
            self.xresiduals = output.delta

    def parameters(self, par):
        "Handle the subset of parameters that are being allowed to vary."
        if self.hold:
            p = self.p0
            for n, pnum in enumerate(self.variable):
                p[pnum] = par[n]
            return np.array(p)
        return par

    def __call__(self, t: np.ndarray):
        "Evaluate the fitting function using current parameter values."
        if self.valid:
            return self.function(t, *self.params)
        return self.function(t, *self.p0)

    @property
    def _slope(self, dx=1e-6):
        "Return a numerical approximation  to the slope at each data pt"
        dy = self(self.x + dx) - self(self.x - dx)
        return dy / (2 * dx)

    def __str__(self):
        name = self.function.__name__
        args = inspect.getfullargspec(self.function)
        astart = 2 if args[0][0] == 'self' else 1
        argnames = args.args[astart:]
        lines = []
        if name != "function":
            lines.append(name)

        has_unc = isinstance(self.yunc, np.ndarray)
        if self.valid:
            for n in range(len(self.params)):
                lines.append(f"{argnames[n]:>16s} = {self.params[n]:^8.4g}")
                if has_unc:
                    lines[-1] += f" ± {self.param_uncs[n]:.2g}"
                    if self.param_uncs[n] > 0:
                        lines[
                            -1
                        ] += f" ({abs(self.param_uncs[n]/self.params[n])*100:.2g}%)"
            if has_unc:
                lines.append(f"N_dof = {self.dof}, chisq = {self.chisq:.3g}")
                lines[-1] += f" ({self.reduced_chisq:.3g})"
                lines[-1] += f", P> = {100*self.prob_greater:.2g}%"
        else:
            lines.append(self.error)
        return "\n".join(lines)

    def texval(self, x):
        """
        Render a string representation of a value in TeX format
        """
        if "e" not in x:
            return x
        m = re.search(r'([-+0-9.]*)\s?e\s?([+-])\s?([0-9]*)', x)
        if m:
            s = m.group(1) + r" \times 10^{"
            if m.group(2) == '-':
                s += '-'
            s += str(int(m.group(3))) + "}"
            return s
        print("Ack! for " + x)
        return x

    def tex_val_unc(self, x, dx):
        """
        Produce a LaTeX representation of the value and its uncertainty
        """
        if dx == 0.0:
            return (str(x), "", "")
        try:
            assert dx != 0.0 and x != 0.0
            xdigits = int(np.log10(abs(x)))
            dxdigits = int(np.log10(dx))
            digits = 2 + xdigits - dxdigits
            round_spot = -xdigits + digits
            xround = round(x, round_spot)
            dxround = round(dx, round_spot)

            ratio = dx / abs(x)  # this will fail if dx is 0
            # digits = 2 + int(np.round(np.log10(ratio), 2))
            fmt = "{:0." + str(round_spot) + "f}"
            main = self.texval(fmt.format(xround))
            # To get the right number of digits, we need to
            # figure out the place of the LSD of x
            unc = self.texval(f"{dxround:.2g}")
            if ratio > 1:
                rel = f"{ratio:.2f}"
            elif ratio > 0.001:
                rel = f"{100*ratio:.1f}" + "\\%"
            else:
                rel = f"{int(1e6*ratio)}" + r"\;\rm ppm"

            return (main, unc, rel)
        except Exception as eeps:
            # print(f"tex_val_unc error {eeps} for {x}, {dx}")
            return (f"{x:.3g}", f"{dx:.2g}", "")

    def get_tex_labels(self, vals):
        if vals is None:
            args = inspect.getfullargspec(self.function)
            argnames = args.args
        else:
            if isinstance(vals, str):
                sep = ';' if ';' in vals else ','
                an = vals.split(sep)
                if not an[0].startswith('$'):
                    argnames = [r"$%s$" % x for x in an]
                else:
                    argnames = an
            else:
                argnames = vals
        return argnames

    def set_function_tex(self, txt):
        if txt is None:
            self.function_tex = ""
        else:
            if not "$" in txt:
                txt = f"${txt}$"
            self.function_tex = txt

    def legend_tex(self, use_table=True):
        """Generate an annotation showing the fit function
        and the fitting parameters.
        """
        name = (
            self.function_tex
            if hasattr(self, 'function_tex') and self.function_tex
            else self.function.__name__
        )

        if self.tex_labels:
            argnames = self.tex_labels
        else:
            args = inspect.getfullargspec(self.function)
            argnames = args.args[1:]
        lines = [
            name,
        ]

        has_unc = isinstance(self.yunc, np.ndarray)
        if self.valid:
            if has_unc:
                stats = [
                    r"$\chi_\nu^2 = "
                    + self.texval(f"{self.reduced_chisq:.3g}")
                    + r"\;\; \mathrm{for}\;\; \nu = %d$" % self.dof,
                    r"$P_> = " + f"{100*self.prob_greater:.1f}" + r"\%$",
                ]
            else:
                stats = [r"$R^2 = " + self.texval(f"{self.pearsonR2:.3f}$")]
        if self.valid:
            if use_table:
                lines.insert(0, r"\begin{tabular}{lccc}\multicolumn{4}{c}{")
                lines.append(r"}\\[0.05in]")
                lines.append(
                    r"\textbf{Param} & \textbf{Value} & \textbf{Unc.} & \textbf{Rel. Unc.}"
                )
            for n in range(len(self.params)):
                if use_table:
                    lines.append("\\\\ " + argnames[n] + " & ")
                    lines.append(
                        " & ".join(
                            [
                                f"~~${x}$~~"
                                for x in self.tex_val_unc(
                                    self.params[n], self.param_uncs[n]
                                )
                            ]
                        )
                    )

                else:
                    lines.append(f"{argnames[n]:>16s} = $")
                    v, u, r = self.tex_val_unc(self.params[n], self.param_uncs[n])
                    if r == "":
                        vals = f"{v}$ (fixed)"
                    else:
                        vals = "{0} \\pm {1}\\; ({2})$".format(v, u, r)
                    lines[-1] += vals
            if has_unc:
                if use_table:
                    lines.append(r"\\[0.05in]\multicolumn{4}{c}{")
                    lines.append(stats[0] + r" \qquad " + stats[1] + "}")
                    lines.append(r"\end{tabular}")
                else:
                    lines.append(stats[0])
                    lines.append(stats[1])
            else:
                if use_table:
                    lines.append(r"\\[0.05in]\multicolumn{4}{c}{")
                    lines.append(f"$R^2 = {self.pearsonR2:.2f}$" + "}")
                    lines.append(r"\end{tabular}")

        else:
            lines.append(self.error)

        res = "\n".join(lines)
        # print(res)
        return res

    @property
    def valid(self):
        "Did the fit exit successfully?"
        return self.error is None and hasattr(self, 'covars')

    @property
    def error_scale(self):
        """
        By what factor errors need to grow to yield chisq/DoF = 1
        """
        if not self.valid or not isinstance(self.param_uncs, np.ndarray):
            return None
        return np.sqrt(self.chisq / self.dof)

    @property
    def alpha(self):
        """
        Return a vector of alpha (transparency) values to points,
        to illustrate which points were (not) used in making the fit.
        """
        return np.where(self.mask, 0.6, 0.3)

    def _set_plot_params(self, dic: dict, reset=False):
        """
        To handle independent calls to plotting routines, store
        necessary parameters in an internal dictionary,
        _plot_params. To start over, set reset=True.

        """
        d = self._plot_params
        if reset:
            d = dic
        else:
            d.update(dic)

        if 'xmin' not in d:
            d['xmin'] = self.x.min()
        if 'xmax' not in d:
            d['xmax'] = self.x.max()
        if 'ymin' not in d:
            d['ymin'] = self.y.min()
        if 'ymax' not in d:
            d['ymax'] = self.y.max()

        if 'c' in d:
            d['color'] = d['c']
        elif 'color' not in d:
            d['color'] = 'b'

        self._plot_params = d
        return d

    def plot(self, **kwargs):
        """
        Generate a plot showing the data, fit, residuals, and
        normalized residuals.

        Optional kwargs:

        - residuals (boolean)
        - normalized_residuals (boolean)
        - title (str)
        - xlabel (str)
        - ylabel (str)
        - figsize (width, height)
        - msize (marker size, in pts)
        - logx (boolean)
        - logy (boolean)
        - xmin (float)
        - xmax (float)
        - ymin (float)
        - ymax (float)
        - color (or c)
        - alpha_data (alpha value for data used in fit; masked data shown with alpha/2)
        - alpha_fit (alpha value for fit curve)
        - xfit (array of positions at which to compute the fitted curve)
        - yfit (array of precomputed fit values). If yfit is not supplied,
          the fitted curve will be evaluated at xfit values.
        - npoints (int) defaults to 200 and specifies the number of points
          to use in computing the fitted curve
        - legend (xfrac, yfrac) specifies the anchor position of the
          annotation, expressed as a fraction of the width and height of the
          plot area. Defaults to (0.1, 0.1). Alternatively, can be a suitable
          combination of north, east, south, and west for position outside
          the plot area.
        """
        # before adjusting kwargs, record the honest passed
        # arguments
        xlim = (kwargs.get('xmin'), kwargs.get('xmax'))
        d = self._set_plot_params(kwargs, True)

        yresiduals, normalized_yresiduals = False, False
        axes_order = []
        if self.valid:
            # See if we should plot residuals
            normalized_yresiduals = d.get('normalized_residuals', "y" in self.weighting)
            if normalized_yresiduals:
                axes_order.append('axynorm')
            yresiduals = d.get('residuals', True)
            if yresiduals:
                axes_order.append('axyresiduals')
        axes_order.append('axdata')

        # if we have xerrs, we should add those axes
        # if isinstance(self.xerrs, np.ndarray):
        #    axes_order
        self.axes_order = axes_order

        self.make_gspec(**kwargs)
        self.top_axis.set_title(kwargs.get('title', ''))

        # Lay down the fit curve first, then plot the data
        self.plot_fit(self.axdata)
        self.plot_data(self.axdata)

        if yresiduals:
            self.plot_yresiduals(self.axyresiduals)

            # are there x residuals, too?
            if 'x' in self.weighting:
                self.plot_xresiduals(self.axxresiduals)

        if normalized_yresiduals:
            self.plot_ynormresiduals(self.axynorm)

        if False and self.valid:
            # add an annotation
            # Because the calculated curve can exceed the data, we need
            # to update the value of ymax
            ymax = max(ymax, np.max(yfit))
            self.fit_results(fig, (xmin, xmax), (ymin, ymax), **kwargs)

        fig = self.fig
        fig.align_ylabels(fig.axes)
        try:
            gs = self.gs  # fig._gridspecs[0]
            gs.tight_layout(fig)
            gs.update(hspace=0.06, wspace=0.04)
        except:
            pass

        # Handle any explicitly designated axis limits

        if xlim[0] != None or xlim[1] != None:
            self.axdata.set_xlim(*xlim)
        ylim = (kwargs.get('ymin'), kwargs.get('ymax'))
        if ylim[0] != None or ylim[1] != None:
            self.axdata.set_ylim(*ylim)

    def plot_fit(self, ax, **kwargs):
        d = self._set_plot_params(kwargs)

        xfit, yfit = d.get('xfit'), d.get('yfit')

        npoints = d.get('npoints', 200)
        if d.get('logx'):
            ax.set_xscale('log', nonpositive='clip')
            if not isinstance(xfit, np.ndarray):
                xfit = np.power(
                    10,
                    np.linspace(np.log10(d['xmin']), np.log10(d['xmax']), npoints),
                )
        elif not isinstance(xfit, np.ndarray):
            xfit = np.linspace(d['xmin'], d['xmax'], npoints)
        if not yfit:
            yfit = self.__call__(xfit)

        if d.get('logy'):
            ax.set_yscale('log')  # , nonposy='clip')

        # establish plot color and transparency
        alpha = d.get('fit_alpha', 1)
        color = d.get('fit_color', d['color'])

        ax.plot(xfit, yfit, color=color, alpha=alpha)

    def plot_data(self, ax, **kwargs):
        """
        Add plot of data to the given axes
        """
        d = self._set_plot_params(kwargs)
        msize = d.get('msize', 8 if len(self.x) < 24 else 6)
        alpha = d.get('data_alpha', 0.6)

        # To handle showing masked data lighter, we need
        # to make two passes
        for complement in (False, True):
            try:
                label = None if complement else d.get('label')
                if self.weighting:
                    ax.errorbar(
                        self.data('x', complement),
                        self.data('y', complement),
                        yerr=self.data('yunc', complement),
                        xerr=self.data('xunc', complement),
                        fmt='o',
                        c=d['color'],
                        alpha=alpha * (0.3 if complement else 1),
                        markersize=msize,
                        label=label,
                    )
                else:
                    ax.scatter(
                        self.data('x', complement),
                        self.data('y', complement),
                        alpha=alpha * (0.3 if complement else 1),
                        s=msize,
                        c=d['color'],
                        label=label,
                    )
            except Exception as eeps:
                print(eeps)
        if 'xlabel' in d:
            self.axdata.set_xlabel(d['xlabel'])
        if 'ylabel' in d:
            self.axdata.set_ylabel(d['ylabel'])

    def plot_yresiduals(self, ax, **kwargs):
        d = self._set_plot_params(kwargs)
        if d.get('logx'):
            ax.set_xscale('log', nonpositive='clip')
        # plot the zero line
        ax.plot([d['xmin'], d['xmax']], [0, 0], 'k-', alpha=0.5)

        if self.weighting == 'xy':
            colors = np.array(self.norm_res_colors('xy'))
            ax.scatter(
                self.data('x'),
                self.data('yresiduals'),
                c=colors[self.mask],
                marker='o',
                s=5**2,
            )
            ms = 1
        else:
            ms = 3

        # Unfortunately, the errorbars routine does not handle
        # colors-by-point
        ax.errorbar(
            self.data('x'),
            self.data('yresiduals'),
            yerr=self.data('yunc'),
            ls='None',
            marker='o',
            ms=ms,
        )
        ax.set_ylabel('Res.')

    def plot_xresiduals(self, ax, **kwargs):
        d = self._set_plot_params(kwargs)
        ms = 1 if self.weighting == 'xy' else 3
        if 'x' in self.weighting:
            if d.get('logy'):
                ax.set_yscale('log', nonpositive='clip')  # nonposy?
            # plot the zero line
            ax.plot([0, 0], [d['ymin'], d['ymax']], 'k-', alpha=0.5)
            ax.scatter(
                self.data('xresiduals'),
                self.data('y'),
                c=d['colors'][self.mask],
                s=5**2,
                marker='o',
            )
            ax.errorbar(
                self.data('xresiduals'),
                self.data('y'),
                xerr=self.data('xunc'),
                marker='o',
                ls='None',
                ms=ms,
            )
            ax.set_xlabel('Res.')

    def plot_xnormresiduals(self, ax, **kwargs):
        d = self._set_plot_params(kwargs)
        ax.plot([0, 0], [d['ymin'], d['ymax']], 'k-', alpha=0.5)
        colors = self.norm_res_colors('x')
        ax.scatter(
            self.norm_xresiduals,
            self.y,
            s=5**2,
            marker='o',
            c=colors,
        )
        ax.set_xlabel('N.R.')

    def plot_ynormresiduals(self, ax, **kwargs):
        d = self._set_plot_params(kwargs)
        if d.get('logx'):
            ax.set_xscale('log', nonpositive='clip')
        # Compute colors
        ax.plot([d['xmin'], d['xmax']], [0, 0], 'k-', alpha=0.5)
        colors = np.array(self.norm_res_colors('y'))
        ax.scatter(
            self.data('x'),
            self.data('norm_yresiduals'),
            s=5**2,
            marker='o',
            c=colors[self.mask],
        )
        ax.set_ylabel('N.R.')

        if 'x' in self.weighting:
            self.plot_xnormresiduals(self.axxnorm)

    def make_gspec(self, **kwargs):
        """
        Figure out how to layout the axes of the plot.
        Relevant issues:
          - do we have residuals and normalized residuals in y
          - in x?
          - is the legend outside the plot area?
        """
        # Will we be putting the legend inside the plot area or outside?
        legend = kwargs.get('legend', 'south')
        figsize = kwargs.get('figsize', rcParams["figure.figsize"])
        fontsize = rcParams['font.size']
        use_table = kwargs.get('table', True)

        rcParams["font.family"] = "serif"
        rcParams["font.size"] = 12.0
        rcParams["text.usetex"] = True
        rcParams["xtick.top"] = True
        rcParams["xtick.direction"] = "in"
        rcParams["ytick.right"] = True
        rcParams["ytick.direction"] = "in"
        rcParams["savefig.transparent"] = True

        axes_order = self.axes_order
        width, height = figsize
        # make a preliminary dictionary
        gspec = dict(
            width_ratios=[1],
            height_ratios=[1],
            hspace=0.025,
            wspace=0.025,
            left=0.125,
            right=0.975,
        )

        # If we don't have a valid fit, make a simple plot
        if not self.valid:
            self.fig, self.axdata = plt.subplots(gridspec_kw=gspec)
            self.top_axis = self.axdata
            return

        # we will have a legend describing fitting parameters
        msg = self.legend_tex(use_table)
        num_lines = 4.5 + len(self.params)  # len(msg.split('\n')) + 2
        msg_height = num_lines * fontsize * 1.5 / 72  # should be improved

        if "tabular" in msg:
            msg = msg.replace("\n", "")

        self._legend = dict(
            text=msg,
            num_lines=num_lines,
            msg_height=msg_height,
            position=legend,
            x=0.1,
            y=0.1,
            ha='center',
            va='middle',  # correct?
            primary=None,
            secondary=None,
        )
        leg = self._legend

        # The opts dictionary will get passed to a call to annotate
        # in the axes used for the legend
        opts = dict(xy=(0.5, 0.5), xycoords='axes fraction', va='center', ha='center')
        legend_padding = " \n \n"

        if isinstance(legend, tuple):
            # a legend position of the form (x,y) puts the legend at that fraction
            # of the plot area. This part probably needs work
            opts['xy'] = legend
            if len(axes_order) == 3:
                gspec['height_ratios'] = [1, 1, 4]
            elif len(axes_order) == 2:
                gspec['height_ratios'] = [1.5, 4]

            self.fig, self.axes = plt.subplots(
                nrows=len(axes_order),
                sharex=True,
                figsize=figsize,
                gridspec_kw=gspec,
            )
            for ax in ['axdata', 'axyresiduals', 'axynorm']:
                if ax in axes_order:
                    setattr(self, ax, self.axes[axes_order.index(ax)])

            self.top_axis = self.axes[0]
            # self.axdata.legend
            self.axdata.annotate(msg, **opts)

        if isinstance(legend, str):
            # the legend will be placed outside the plot area
            # using frameon=False
            fields = legend.lower().split(' ')
            leg['primary'] = fields[0]
            leg['secondary'] = None if len(fields) == 1 else fields[1]

            if leg['primary'] in ('north', 'south'):
                # extend the plot height to accommodate the
                # lower axes, which will hold the annotation
                self.fig = plt.figure(
                    figsize=(width, height + msg_height + 0.25),
                    # constrained_layout=True,
                )
                ncols = 1

                if leg['primary'] == 'north':
                    axes_order.insert(0, 'axlegend')
                    msg += legend_padding
                else:
                    axes_order.append('axlegend')
                    msg = legend_padding + msg

                if len(axes_order) == 3:
                    # no normalized y residuals
                    gspec['height_ratios'] = [
                        1,
                        4,
                        msg_height / 5.25 * height,
                    ]
                    if "x" in self.weighting:
                        ncols = 2
                        gspec['width_ratios'] = [4, 1]
                elif len(axes_order) == 4:
                    # normalized y residualss, too
                    gspec['height_ratios'] = [
                        1.5,
                        1.5,
                        4,
                        msg_height * 7 / height,
                    ]
                    if "x" in self.weighting:
                        ncols = 3
                        gspec['width_ratios'] = [4, 1, 1]
                else:
                    # no residuals
                    gspec['height_ratios'] = [4, msg_height / 4.25 * height]

                gs = self.fig.add_gridspec(ncols=ncols, nrows=len(axes_order), **gspec)
                self.gs = gs
                gs.update(hspace=0.02)
                ndata = axes_order.index('axdata')
                self.axdata = self.fig.add_subplot(gs[ndata, 0])

                for n, which in enumerate(axes_order):
                    if which in ('axdata', 'axlegend'):
                        continue
                    a = self.fig.add_subplot(gs[n, 0], sharex=self.axdata)
                    # suppress x ticks on residuals
                    plt.setp(a.get_xticklabels(), visible=False)
                    plt.subplots_adjust(wspace=0, hspace=0)
                    setattr(self, which, a)
                    if n == 0:
                        self.top_axis = a

                # if there are x errors, add those plots
                if ncols > 1:
                    for n, ax in zip([1, 2], ['axxresiduals', 'axxnorm']):
                        if n > ncols:
                            break
                        a = self.fig.add_subplot(gs[ndata, n], sharey=self.axdata)
                        plt.setp(a.get_yticklabels(), visible=False)
                        plt.subplots_adjust(wspace=0, hspace=0)
                        setattr(self, ax, a)

                # add the legend
                self.axlegend = self.fig.add_subplot(
                    gs[axes_order.index('axlegend'), 0]
                )
                self.axlegend.axis('off')
                self.axlegend.annotate(msg, **opts)
            else:
                self.fig = plt.figure(figsize=(figsize[0] + 2.25, figsize[1]))
                legwidth = 6
                gspec['width_ratios'] = (
                    [width, legwidth] if leg['primary'] == 'east' else [legwidth, width]
                )
                if len(axes_order) == 2:
                    gspec['height_ratios'] = [1, 4]
                elif len(axes_order) == 3:
                    gspec['height_ratios'] = [1, 1, 4]
                gs = self.fig.add_gridspec(ncols=2, nrows=len(axes_order), **gspec)

                col = 0 if leg['primary'] == 'west' else 1
                # set the legend
                self.axlegend = self.fig.add_subplot(gs[:, col])
                self.axlegend.annotate(msg, **opts)
                self.axlegend.set_axis_off()

                # To manage the axis sharing properly, we need to generate
                # the data plot first
                dindex = axes_order.index('axdata')
                self.axdata = self.fig.add_subplot(gs[dindex, 1 - col])

                # add axes to the object
                for n, which in enumerate(axes_order):
                    if which in ('axdata', 'axlegend'):
                        continue
                    a = self.fig.add_subplot(gs[n, 1 - col], sharex=self.axdata)
                    setattr(self, which, a)
                    plt.setp(a.get_xticklabels(), visible=False)
                    if n == 0:
                        self.top_axis = a

    def fit_results(self, fig, xbounds, ybounds, **kwargs):
        """ """
        legend = kwargs.get('legend', (0.1, 0.1))
        msg = self.legend_tex()
        num_lines = len(msg.split('\n'))
        if isinstance(legend, str):
            # the legend string should be a cardinal direction

            fields = legend.lower().split(" ")
            primary = fields[0]
            secondary = fields[1] if len(fields) > 1 else None

            legend_height = num_lines / 3  # (inches?)
            width, height = fig.get_size_inches()

            if primary == 'south':
                new_height = height + legend_height
                fig.set_figheight(new_height)
                rect = (0, legend_height / new_height, 1, 1)
                fig.tight_layout(rect=rect)
                # plt.subplots_adjust(
                #    bottom=0.05 + legend_height / new_height, top=0.9
                # )

                xpos, ypos, valign = 0.5, 0, 'bottom'

                if secondary == 'west':
                    xpos, halign = 0, 'left'
                elif secondary == 'east':
                    xpos, halign = 1, 'right'
                else:
                    halign = 'center'

            fig.text(
                xpos,
                ypos,
                msg,
                horizontalalignment=halign,
                verticalalignment=valign,
            )
        else:
            xmin, xmax = xbounds
            ymin, ymax = ybounds
            xpos = (1 - legend[0]) * xmin + legend[0] * xmax
            ypos = (1 - legend[1]) * ymin + legend[1] * ymax
            halign = (
                'left'
                if legend[0] < 0.3
                else ('right' if legend[0] > 0.7 else 'center')
            )
            valign = (
                'bottom'
                if legend[1] < 0.3
                else ('top' if legend[1] > 0.7 else 'center')
            )

            ax = fig.axes[-1]
            ax.text(
                xpos,
                ypos,
                msg,
                horizontalalignment=halign,
                verticalalignment=valign,
            )

    def norm_res_colors(self, kind='y'):
        """
        return a color map for the desired residuals
        """
        if kind == 'y':
            norm_res = self.norm_yresiduals
        elif kind == 'x':
            norm_res = self.norm_xresiduals
        else:
            norm_res = self.norm_residuals
        sigmas = np.asarray(np.abs(norm_res), dtype=np.uint16)
        return self._res_colors(sigmas)

    def _res_colors(self, sigmas):
        sigmas[sigmas > 3] = 3
        alpha = 0.75
        colormap = [
            (0, 0.7, 0, alpha),
            (0.9, 0.9, 0, alpha),
            (1.0, 0.7, 0, alpha),
            (1, 0, 0, alpha),
        ]
        return [colormap[x] for x in sigmas]


if __name__ == '__main__':
    from numpy.random import default_rng
    import matplotlib.pyplot as plt

    def mycosine(x, A, phi0):
        "Fit a cosine y = A * cos(x + phi0), with angles in degrees"
        return A * np.cos(np.radians(x + phi0))
    mycosine.tex = r"A;\phi_0"
    mycosine.function_tex = r"$A \cos(\theta + \phi_0)$"

    rng = default_rng()
    N = 20
    I0 = 75                                           # peak current, in mA
    phi0 = -37                                        # just cuz!
    theta = np.linspace(0, 60, N)                     # the "x" values
    noise_amp = rng.uniform(0.05, 0.5, size=N)        # pick random amplitudes for the noise at each point
    noise = rng.normal(size=N) * noise_amp            # noise is randomly selected from a normal distribution
    Isc = I0 * np.cos(np.radians(theta + phi0)) + noise # the actual fake data

    f = Fit(
        theta,                 # the independent variable
        Isc,                   # the dependent variable
        yunc=noise_amp,        # the uncertainties of the dependent variable
        function=mycosine,     # the fitting function
        p0=(45, -30)           # initial guesses for (A, phi0),
        )
    f.plot(xlabel=r"$\theta$", ylabel=r"$I_{\rm sc}$")
    plt.show()

    # Now check masking
    Isc[3] -= 5
    Isc[11] += 7
    mask = np.arange(len(Isc)) != 1
    mask[1:19:2] = False
    ffixed = Fit(
        theta,                 # the independent variable
        Isc,                   # the dependent variable
        yunc=noise_amp,        # the uncertainties of the dependent variable
        function=mycosine,     # the fitting function
        p0=(45, -30),          # initial guesses for (A, phi0),
        mask=mask
        )
    ffixed.plot(xlabel=r"$\theta$", ylabel=r"$I_{\rm sc}$")
    plt.show()
