#!/usr/bin/env python
"""Publish comm node pose as PoseArray."""

import rospy

from core_msgs.msg import CommNodeStatus
from pose_graph_msgs.msg import PoseGraph
from geometry_msgs.msg import PoseArray, Pose, PoseStamped
from visualization_msgs.msg import Marker, MarkerArray

from tf import transformations as tfm
import tf2_ros
from tf2_geometry_msgs import *

from comm_node_manager.pose_graph import parse_key


class CommNodePosePublisher(object):
    def __init__(self):
        self.pose_graph = PoseGraph()
        self.comm_node_keys = set()
        self.relative_poses = {}

        self.robot_name = rospy.get_namespace().split('/')[1]
        self.target_frame = rospy.get_param('~target_frame', 'world')
        self.pose_pub = rospy.Publisher('comm_node_manager/poses', PoseArray, latch=True)
        self.marker_pub = rospy.Publisher('comm_node_manager/markers', MarkerArray, latch=True)
        self.status_sub = rospy.Subscriber('comm_node_manager/status',
                                           CommNodeStatus, self.status_cb)
        self.pg_sub = rospy.Subscriber('lamp/pose_graph', PoseGraph, self.pose_graph_cb)
        self.agg_status_sub = rospy.Subscriber('comm_node_manager/status_agg',
                                               CommNodeStatus, self.status_cb)
        self.tf_buf = tf2_ros.Buffer()
        self.tf_listener = tf2_ros.TransformListener(self.tf_buf)

    def status_cb(self, msg):
        prev_keys = self.comm_node_keys.copy()
        for n in msg.dropped:
            # Prefix 'z' do not have position
            prefix, index = parse_key(n.pose_graph_key)
            if prefix == 'z':
                continue

            self.comm_node_keys.add(n.pose_graph_key)
            self.relative_poses[n.pose_graph_key] = n.relative_pose

        if self.comm_node_keys != prev_keys:
            rospy.loginfo("New comm node dropped")
            self.publish_poses()

    def pose_graph_cb(self, msg):
        self.pose_graph = msg
        self.publish_poses()

    def publish_poses(self):
        if not self.comm_node_keys:
            return

        rospy.loginfo("Publishing poses for %d comm nodes", len(self.comm_node_keys))
        poses = []

        # Extract pose from merged pose graph
        for key in self.comm_node_keys:
            relative_pose = self.relative_poses[key]
            for node in self.pose_graph.nodes:
                if node.key == key:
                    pose = PoseStamped()
                    pose.header.stamp = rospy.Time(0)
                    pose.header.frame_id = 'world'  # TODO: Fix original frame ID
                    m_world_key = self.pose2matrix(node.pose)
                    m_key_node = self.pose2matrix(relative_pose)
                    m_world_node = tfm.concatenate_matrices(m_world_key, m_key_node)
                    pose.pose = self.matrix2pose(m_world_node)
                    poses.append(pose)
                    break

        rospy.loginfo("Found %d keys in merged pose graph", len(poses))

        # Transform poses to target frame
        msg = PoseArray()
        msg.header.stamp = rospy.Time.now()
        msg.header.frame_id = self.target_frame
        for pose in poses:
            try:
                pose_tfm = self.tf_buf.transform(pose, self.target_frame,
                                                 timeout=rospy.Duration(1))
            except Exception as e:
                rospy.logwarn("TF failure: %s", e)
                continue
            msg.poses.append(pose_tfm.pose)

        self.pose_pub.publish(msg)
        rospy.loginfo("Published %d valid poses", len(msg.poses))

        # Convert to markers for costmap layer
        marker_msg = MarkerArray()
        for i, pose in enumerate(msg.poses):
            m = Marker()
            m.id = i
            m.header = msg.header
            m.type = Marker.CYLINDER
            m.action = Marker.ADD
            m.pose = pose
            m.scale.x = 0.5
            m.scale.y = 0.5
            m.scale.z = 0.45
            m.color.r = 0.3
            m.color.g = 0.5
            m.color.b = 0.3
            m.color.a = 1.0
            marker_msg.markers.append(m)

        self.marker_pub.publish(marker_msg)
        rospy.loginfo("Published markers")

    def pose2matrix(self, pose):
        q = [pose.orientation.x, pose.orientation.y,
             pose.orientation.z, pose.orientation.w]
        t = [pose.position.x, pose.position.y, pose.position.z]
        m = tfm.quaternion_matrix(q)
        m[:3, 3] = t
        return m

    def matrix2pose(self, m):
        t = m[:3, 3]
        q = tfm.quaternion_from_matrix(m)
        pose = Pose()
        pose.position.x = t[0]
        pose.position.y = t[1]
        pose.position.z = t[2]
        pose.orientation.x = q[0]
        pose.orientation.y = q[1]
        pose.orientation.z = q[2]
        pose.orientation.w = q[3]
        return pose


def main():
    rospy.init_node('comm_node_pose')
    handler = CommNodePosePublisher()
    rospy.spin()


if __name__ == '__main__':
    main()