#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Shift the detector by a given amount
#
# Copyright © 2016-2020 Deutsches Elektronen-Synchrotron DESY,
#                       a research centre of the Helmholtz Association.
#
# Author:
#    2016-2017 Thomas White <taw@physics.org>
#

import sys
import argparse
import os
import re

op = argparse.ArgumentParser()

op.add_argument('--output', dest='ofn', metavar='output.geom',
              help="Output geometry file")

op.add_argument('--pixels', action='store_true', dest='px',
              help="Indicates that the x and y shifts are in pixel units.  "
                   "The z shift is always in millimeters")

op.add_argument('--mm', action='store_false', dest='px',
              help="Indicates that the shifts are in millimeters (this is the default)")

op.add_argument('--beam', action='store_true', dest='beam',
              help="Shift the X-ray beam rather than the detector (reverses x and y shifts)")

op.add_argument('--detector', action='store_false', dest='beam',
              help="Shift the detector rather than the X-ray beam.  This is the default.")

op.add_argument('ifn', metavar='input.geom',
               help="The name of the geometry file to shift")

op.add_argument('xshift', type=float, metavar='XXX',
               help="The shift in the x direction")

op.add_argument('yshift', type=float, metavar='YYY',
               help="The shift in the y direction")

op.add_argument('zshift', type=float, metavar='ZZZ',
               help="The shift in the z direction")

args = op.parse_args()

if not args.ofn:
    out = os.path.splitext(args.ifn)[0]+'-shifted.geom'
else:
    out = args.ofn

if args.px:
    units = 'px'
else:
    units = 'mm'

if args.beam:
    bm = 'beam'
else:
    bm = 'detector'

print('Input filename {}\nOutput filename {}'.format(args.ifn,out))
print('Shifting {} by {},{},{} {}'.format(bm, args.xshift, args.yshift, args.zshift, units))

if args.beam:
    args.xshift = -args.xshift
    args.yshift = -args.yshift

f = open(args.ifn, 'r')
h = open(out, 'w')
panel_resolutions = {}
panels = []
panels_have_coffset = []

prog1 = re.compile("^\s*res\s+=\s+([0-9\.]+)\s")
prog2 = re.compile("^\s*(.*)\/res\s+=\s+([0-9\.]+)\s")
prog3 = re.compile("^\s*(.*)\/corner_x\s+=\s+([0-9\.\-]+)\s")
prog4 = re.compile("^\s*(.*)\/corner_y\s+=\s+([0-9\.\-]+)\s")
prog5 = re.compile("^\s*(.*)\/coffset\s+=\s+([0-9\.\-]+)\s")
prog6 = re.compile("^\s*coffset\s+=\s+([0-9\.]+)\s")
default_res = 0
default_coffset = None
while True:

    fline = f.readline()
    if not fline:
        break

    match = prog1.match(fline)
    if match:
        default_res = float(match.group(1))
        h.write(fline)
        continue

    match = prog6.match(fline)
    if match:
        default_coffset = float(match.group(1))
        h.write("coffset = %f\n" % (default_coffset+args.zshift))
        continue

    match = prog2.match(fline)
    if match:
        panel = match.group(1)
        panel_res = float(match.group(2))
        default_res =  panel_res
        panel_resolutions[panel] = panel_res
        h.write(fline)
        continue

    match = prog3.match(fline)
    if match:
        panel = match.group(1)
        if panel not in panels:
            panels.append(panel)

            # If we have a default coffset, this panel has a coffset
            # (which has already been adjusted)
            if default_coffset is not None:
                panels_have_coffset.append(panel)

        panel_cnx = float(match.group(2))
        if panel in panel_resolutions:
            res = panel_resolutions[panel]
        else:
            res = default_res
            if not args.px:
                print('Using default resolution ({} px/m) for panel {}'.format(res, panel))
        if args.px:
            h.write('%s/corner_x = %f\n' % (panel,panel_cnx+args.xshift))
        else:
            h.write('%s/corner_x = %f\n' % (panel,panel_cnx+(args.xshift*res*1e-3)))
        continue

    match = prog4.match(fline)
    if match:
        panel = match.group(1)
        panel_cny = float(match.group(2))
        if panel in panel_resolutions:
            res = panel_resolutions[panel]
        else:
            res = default_res
            if not args.px:
                print('Using default resolution ({} px/m) for panel {}'.format(res, panel))
        if args.px:
            h.write('%s/corner_y = %f\n' % (panel,panel_cny+args.yshift))
        else:
            h.write('%s/corner_y = %f\n' % (panel,panel_cny+(args.yshift*res*1e-3)))
        continue

    match = prog5.match(fline)
    if match:
        panel = match.group(1)
        panel_coffset = float(match.group(2))
        h.write('%s/coffset = %f\n' % (panel,panel_coffset+args.zshift*1e-3))
        panels_have_coffset.append(panel)
        continue

    h.write(fline)

h.write("\n")
for p in [n for n in panels if n not in panels_have_coffset]:
    h.write("%s/coffset = %f\n" % (p, args.zshift*1e-3))

f.close()
h.close()

