#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Check a stream for saturation
#
# Copyright © 2016-2020 Deutsches Elektronen-Synchrotron DESY,
#                       a research centre of the Helmholtz Association.
# Copyright © 2016      The Research Foundation for SUNY
#
# Authors:
#   2016-2017 Thomas White <taw@physics.org>
#   2014-2016 Thomas Grant <tgrant@hwi.buffalo.edu>
#
# This file is part of CrystFEL.
#
# CrystFEL is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# CrystFEL is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with CrystFEL.  If not, see <http://www.gnu.org/licenses/>.

import sys
import argparse
import math as m
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

def c2(a):
    return m.cos(a) * m.cos(a)

def s2(a):
    return m.sin(a) * m.sin(a)

# Return 1/d for hkl in cell, in 1/Angstroms
def resolution(scell, shkl):

    a = float(scell[0])*10.0
    b = float(scell[1])*10.0
    c = float(scell[2])*10.0  # nm -> Angstroms

    al = m.radians(float(scell[3]))
    be = m.radians(float(scell[4]))
    ga = m.radians(float(scell[5])) # in degrees

    h = int(shkl[0])
    k = int(shkl[1])
    l = int(shkl[2])

    pf = 1.0 - c2(al) - c2(be) - c2(ga) + 2.0*m.cos(al)*m.cos(be)*m.cos(ga)
    n1 = h*h*s2(al)/(a*a) + k*k*s2(be)/(b*b) + l*l*s2(ga)/(c*c)
    n2a = 2.0*k*l*(m.cos(be)*m.cos(ga) - m.cos(al))/(b*c)
    n2b = 2.0*l*h*(m.cos(ga)*m.cos(al) - m.cos(be))/(c*a)
    n2c = 2.0*h*k*(m.cos(al)*m.cos(be) - m.cos(ga))/(a*b)

    return m.sqrt((n1 + n2a + n2b + n2c) / pf)


parser = argparse.ArgumentParser()
parser.add_argument("-i", action="append", required=True, help="stream filename")
parser.add_argument("-l", action="store_true", help="log scale y-axis")
parser.add_argument("--rmin", type=float, help="minimum resolution cutoff (1/d in Angstroms^-1)")
parser.add_argument("--rmax", type=float, help="maximum resolution cutoff (1/d in Angstroms^-1)")
parser.add_argument("--imin", type=float, help="minimum peak intensity cutoff")
parser.add_argument("--imax", type=float, help="maximum peak intensity cutoff")
parser.add_argument("--nmax", default=np.inf, type=int, help="maximum number of peaks to read")
parser.add_argument("-o", default="peakogram", help="output file prefix")
args = parser.parse_args()

data = []
n=0
in_list = 0
cell = []

for file in args.i:

    if file == "-":
        f = sys.stdin
    else:
        f = open(file)

    for line in f:

        if line.find("Cell parameters") != -1:
            cell[0:3] = line.split()[2:5]
            cell[3:6] = line.split()[6:9]
            continue
        if line.find("Reflections measured after indexing") != -1:
            in_list = 1
            continue
        if line.find("End of reflections") != -1:
            in_list = 0
        if in_list == 1:
            in_list = 2
            continue
        elif in_list != 2:
            continue

        # From here, we are definitely handling a reflection line

        # Add reflection to list
        columns = line.split()
        n += 1
        try:
            data.append([resolution(cell, columns[0:3]),columns[5]])
        except:
            print("Error with line: "+line.rstrip("\r\n"))
            print("Cell: "+str(cell))

        if n%1000==0:
            sys.stdout.write("\r%i predicted reflections found" % n)
            sys.stdout.flush()

        if n >= args.nmax:
            break

    f.close()



data = np.asarray(data,dtype=float)

sys.stdout.write("\r%i predicted reflections found" % n)
sys.stdout.flush()

print("")

x = data[:,0]
y = data[:,1]

xmin = np.min(x[x>0])
xmax = np.max(x)
ymin = np.min(y[y>0])
ymax = np.max(y)

if args.rmin is not None:
    xmin = args.rmin
if args.rmax is not None:
    xmax = args.rmax
if args.imin is not None:
    ymin = args.imin
if args.imax is not None:
    ymax = args.imax

keepers = np.where((x>=xmin) & (x<=xmax) & (y>=ymin) & (y<=ymax))

x = x[keepers]
y = y[keepers]

if args.l:
    y = np.log10(y)
    ymin = np.log10(ymin)
    ymax = np.log10(ymax)

bins=300
H,xedges,yedges = np.histogram2d(y,x,bins=bins)

fig = plt.figure()
ax1 = plt.subplot(111)
plot = ax1.pcolormesh(yedges,xedges,H, norm=LogNorm())
cbar = plt.colorbar(plot)
plt.xlim([xmin,xmax])
plt.ylim([ymin,ymax])
plt.xlabel("1/d (A^-1)")
if args.l:
    plt.ylabel("Log(Reflection max intensity)")
else:
    plt.ylabel("Reflection max intensity")
plt.title(args.i)
plt.show()

