#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Plot post-refinement results
#
# Copyright © 2017-2020 Deutsches Elektronen-Synchrotron DESY,
#                       a research centre of the Helmholtz Association.
#
# Author:
#    2017-2018 Thomas White <taw@physics.org>
#

import sys
import matplotlib
import numpy as np
from matplotlib import pyplot as plt
from operator import truediv, itemgetter
from matplotlib.widgets import RadioButtons,SpanSelector,Button

crystal=0
resmin=0
resmax=100e9
inum="0"
first=True

def update_graph():

    global inum, resmin, resmax, first

    # Update pobs/pcalc scatter plot
    xlim = pgraphE.get_xlim()
    ylim = pgraphE.get_ylim()
    pgraphE.cla()
    pgraphE.scatter([x[0] for x in pgraph if x[2]==inum if x[3]<resmax if x[3]>resmin],
                 [x[1] for x in pgraph if x[2]==inum if x[3]<resmax if x[3]>resmin], marker="+")
    pgraphE.grid(which='major', color='k', linestyle='--')
    if not first:
        pgraphE.set_xlim(xlim)
        pgraphE.set_ylim(ylim)
    pgraphE.set_xlabel("pcalc")
    pgraphE.set_ylabel("pobs")
    pgraphE.set_title("Observed and calculated partialities")

    # Update spectrum plot
    pobsE.set_data([x[0] for x in specgraph if x[3]==inum if x[4]<resmax if x[4]>resmin],
                   [x[2] for x in specgraph if x[3]==inum if x[4]<resmax if x[4]>resmin])
    pcalcE.set_data([x[0] for x in specgraph if x[3]==inum if x[4]<resmax if x[4]>resmin],
                    [x[1] for x in specgraph if x[3]==inum if x[4]<resmax if x[4]>resmin])
    specgraphE.set_title("Spectrum plot for crystal %i" % crystal)

    first = False
    plt.draw()


def iteration_click(label):
    global inum
    inum = label.split(" ")[1].split("\n")[0]
    update_graph()


def resolutionchange(rmin,rmax):
    global resmin,resmax
    resmin = rmin
    resmax = rmax
    update_graph()


def read_spectrum(filename):

    specgraph = []

    try:
        fh = open(filename, 'r')
    except IOError:
        print("Failed to read "+filename)
        raise IOError

    fline = fh.readline()
    fline = fh.readline()  # Ignore header (two lines)

    while True:
        fline = fh.readline()
        if not fline:
            break
        kval = float(fline.split()[0])
        res = float(fline.split()[1])
        pcalc = float(fline.split()[2])
        pobs = float(fline.split()[3])
        inum = fline.split()[4]
        specgraph.append([kval,pcalc,pobs,inum,res])

    return sorted(specgraph, key=itemgetter(0))


def next_click(w):
    global specgraph, crystal
    crystal += 20
    specgraph = read_spectrum("%s/specgraph-crystal%i.dat" % (folder,crystal))
    update_graph()


def prev_click(w):
    global specgraph, crystal
    crystal -= 20
    specgraph = read_spectrum("%s/specgraph-crystal%i.dat" % (folder,crystal))
    update_graph()


def set_sensible_axes(w):
    pgraphE.set_xlim([-0.1,1.1])
    pgraphE.set_ylim([-0.2,2.0])
    specgraphE.set_ylim([-5.0,5.0])
    plt.draw()


# Read the pgraph file
pgraph = []

if len(sys.argv) == 2:
    folder = sys.argv[1]
else:
    folder = "pr-logs"

print("Using folder '%s'" % folder)

try:
    fh = open("%s/pgraph.dat" % folder, 'r')
except IOError:
    print("Failed to read "+filename)
    raise IOError

fline = fh.readline()  # Ignore header

while True:
    fline = fh.readline()
    if not fline:
        break
    res = float(fline.split()[4])
    Ipart = float(fline.split()[5])
    pc = float(fline.split()[6])
    po = float(fline.split()[7])
    inum = fline.split()[8]
    pgraph.append([pc,po,inum,res,Ipart])


# Read the spectrum file
specgraph = read_spectrum("%s/specgraph-crystal%i.dat" % (folder,crystal))

fig = plt.figure(figsize=(15,6))
fig.subplots_adjust(left=0.05, right=0.65)

# pobs/pcalc
pgraphE = fig.add_subplot(121)

# Spectrum
specgraphE = fig.add_subplot(122)
pobsE = specgraphE.plot([0,1],[0,1],label="pobs")[0]
pcalcE = specgraphE.plot([0,1],[0,1],label="pcalc")[0]
specgraphE.set_xlabel("khalf / m^-1")
specgraphE.set_ylabel("Partiality")
specgraphE.grid(which='major', color='k', linestyle='--')
specgraphE.legend()

# Iteration selector
ax = plt.axes([0.75, 0.50, 0.2, 0.40], facecolor="lightgoldenrodyellow")
iterations = sorted(set(["Iteration "+str(x[2]) for x in pgraph]))
if iterations[0] != "Iteration 0":
    print("Whoops, couldn't find iteration 0!")
    sys.exit(1)
iterations[0] = "Iteration 0\n(initial scaling only)"
if iterations[len(iterations)-1] != "Iteration F":
    print("Whoops, couldn't find final iteration!")
else:
    iterations[len(iterations)-1] = "Iteration F\n(after final merge)"
iteration = RadioButtons(ax, iterations)
iteration.on_clicked(iteration_click)
iteration.set_active(len(iterations)-1)

# Resolution selector
ax = plt.axes([0.75, 0.25, 0.2, 0.15], facecolor="lightgoldenrodyellow")
ax.hist([x[3] for x in pgraph],weights=[x[4] for x in pgraph],bins=100)
ax.set_xlabel("Resolution / m^-1")
ax.set_ylabel("Frequency")
ax.set_title("Resolution selector")
resolution = SpanSelector(ax, resolutionchange, 'horizontal', interactive=True, useblit=True)

# Crystal switcher
ax = plt.axes([0.75, 0.10, 0.09, 0.05])
crystal_prev = Button(ax, "Previous crystal")
crystal_prev.on_clicked(prev_click)
ax = plt.axes([0.86, 0.10, 0.09, 0.05])
crystal_next = Button(ax, "Next crystal")
crystal_next.on_clicked(next_click)

# Set sensible axes
ax = plt.axes([0.75, 0.01, 0.09, 0.05])
sensible_axes = Button(ax, "Set sensible axes")
sensible_axes.on_clicked(set_sensible_axes)

specgraphE.relim()
specgraphE.autoscale_view(True,True,True)

plt.show()
