#!/usr/bin/env python

#
# Authors:
#     Stef Walter <stefw@redhat.com>
#
# Copyright (C) 2014 Red Hat
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

#
# Some parser code from GLib
#
# Copyright (C) 2008-2011 Red Hat, Inc.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General
# Public License along with this library; if not, write to the
# Free Software Foundation, Inc., 59 Temple Place, Suite 330,
# Boston, MA 02111-1307, USA.
#
# Portions by: David Zeuthen <davidz@redhat.com>
#

#
# DBus interfaces are defined here:
#
# http://dbus.freedesktop.org/doc/dbus-specification.html#introspection-format
#
# The introspection data format has become the standard way to represent a
# DBus interface. For many examples see /usr/share/dbus-1/interfaces/ on a
# typical linux machine.
#
# A word about annotations. These are extra flags or values that can be
# assigned to anything. So far, the codegen supports this annotation:
#
# org.freedesktop.DBus.GLib.CSymbol
#  - An annotation specified in the specification that tells us what C symbol
#    to generate for a given interface or method. By default the codegen will
#    build up a symbol name from the DBus name.
#
from __future__ import print_function

import optparse
import os
import re
import sys
import xml.parsers.expat

if sys.version_info[0] > 2:
    import io as StringIO
else:
    import StringIO

# -----------------------------------------------------------------------------
# Objects

class DBusXmlException(Exception):
    line = 0
    file = None

    # Lets us print problems like a compiler would
    def __str__(self):
        message = Exception.__str__(self)
        if self.file and self.line:
            return "%s:%d: %s" % (self.file, self.line, message)
        elif self.file:
            return "%s: %s" % (self.file, message)
        else:
            return message

class Base(object):
    def __init__(self, name):
        if not name:
            raise DBusXmlException('No name on element')
        self.name = name
        self.annotations = { }
    def validate(self):
        pass
    def c_name(self):
        return self.annotations.get("org.freedesktop.DBus.GLib.CSymbol", self.name)

# The basic types that we support marshalling right now. These
# are the ones we can pass as basic arguments to libdbus directly.
# If the dbus and sssd types are identical we pass things directly.
# otherwise some copying is necessary.
BASIC_TYPES = {
 'y': ( "DBUS_TYPE_BYTE", "uint8_t", "uint8_t" ),
 'b': ( "DBUS_TYPE_BOOLEAN", "dbus_bool_t", "bool" ),
 'n': ( "DBUS_TYPE_INT16", "int16_t", "int16_t" ),
 'q': ( "DBUS_TYPE_UINT16", "uint16_t", "uint16_t" ),
 'i': ( "DBUS_TYPE_INT32", "int32_t", "int32_t" ),
 'u': ( "DBUS_TYPE_UINT32", "uint32_t", "uint32_t" ),
 'x': ( "DBUS_TYPE_INT64", "int64_t", "int64_t" ),
 't': ( "DBUS_TYPE_UINT64", "uint64_t", "uint64_t" ),
 'd': ( "DBUS_TYPE_DOUBLE", "double", "double" ),
 's': ( "DBUS_TYPE_STRING", "const char *", "const char *" ),
 'o': ( "DBUS_TYPE_OBJECT_PATH", "const char *", "const char *" ),
}

class Typed(Base):
    def __init__(self, name, type):
        Base.__init__(self, name)
        self.type = type
        self.is_basic = False
        self.is_array = False
        self.is_dictionary = False
        self.dbus_constant = None
        self.dbus_type = None
        self.sssd_type = None
        if type[0] == 'a':
            type = type[1:]
            self.is_array = True
        if "{" in type:
            self.is_dictionary = True
        if type in BASIC_TYPES:
            (self.dbus_constant, self.dbus_type, self.sssd_type) = BASIC_TYPES[type]
            # If types are not identical, we can't do array (yet)
            if self.is_array:
                self.is_basic = (self.dbus_type == self.sssd_type)
            else:
                self.is_basic = True

class Arg(Typed):
    def __init__(self, method, name, type):
        Typed.__init__(self, name, type)
        self.method = method

class Method(Base):
    def __init__(self, iface, name):
        Base.__init__(self, name)
        self.iface = iface
        self.in_args = []
        self.out_args = []
    def validate(self):
        if not self.only_basic_args() and not self.use_raw_handler():
            raise DBusXmlException("Method has complex arguments and requires " +
                                   "the 'org.freedesktop.sssd.RawHandler' annotation")
    def fq_c_name(self):
        return "%s_%s" % (self.iface.c_name(), self.c_name())
    def use_raw_handler(self):
        anno = 'org.freedesktop.sssd.RawHandler'
        return self.annotations.get(anno, self.iface.annotations.get(anno)) == 'true'
    def in_signature(self):
        return "".join([arg.type for arg in self.in_args])
    def only_basic_args(self):
        for arg in self.in_args + self.out_args:
            if not arg.is_basic:
                return False
        return True

class Signal(Base):
    def __init__(self, iface, name):
        Base.__init__(self, name)
        self.iface = iface
        self.args = []
    def fq_c_name(self):
        return "%s_%s" % (self.iface.c_name(), self.c_name())

class Property(Typed):
    def __init__(self, iface, name, type, access):
        Typed.__init__(self, name, type)
        self.iface = iface
        self.readable = False
        self.writable = False
        if access == 'readwrite':
            self.readable = True
            self.writable = True
        elif access == 'read':
            self.readable = True
        elif access == 'write':
            self.writable = True
        else:
            raise DBusXmlException('Invalid access type %s'%self.access)
    def fq_c_name(self):
        return "%s_%s" % (self.iface.c_name(), self.c_name())
    def get_invoker_name(self):
        type = self.type
        type = type.replace("{", "DO")
        type = type.replace("}", "DE")
        return type
    def get_invoker_signature(self, name):
        sig = "void (*%s)(struct sbus_request *, void *data, " % (name)
        if self.is_dictionary:
            sig += "hash_table_t **"
        elif self.is_array:
            sig += "%s**, int *" % (self.sssd_type)
        else:
            sig += "%s*" % (self.sssd_type)
        sig += ")"
        return sig


    def getter_name(self):
        return "get_%s" % self.c_name()
    def getter_invoker_name(self):
        return "sbus_invoke_get_%s" % self.get_invoker_name()
    def getter_signature(self):
        return self.get_invoker_signature(self.getter_name())

class Interface(Base):
    def __init__(self, name):
        Base.__init__(self, name)
        self.methods = []
        self.signals = []
        self.properties = []
    def c_name(self):
        return self.annotations.get("org.freedesktop.DBus.GLib.CSymbol",
                                    self.name.replace(".", "_"))

# -----------------------------------------------------------------------------
# Code Generation

def out(format, *args, **kwargs):
    str = format % args
    sys.stdout.write(str)
    # NOTE: Would like to use the following syntax for this function
    # but need to wait until python3 until it is supported:
    #  def out(format, *args, new_line=True)
    if kwargs.pop("new_line", True):
        sys.stdout.write("\n")
    assert not kwargs, "unknown keyword argument(s): %s" % str(kwargs)

def method_arg_types(args, with_names=False):
    str = ""
    for arg in args:
        str += ", "
        str += arg.sssd_type
        if with_names:
            if str[-1] != '*':
                str += " "
            str += "arg_"
            str += arg.c_name()
        if arg.is_array:
            str += "[], int"
            if with_names:
                str += " len_"
                str += arg.c_name()
    return str

def method_function_pointer(meth, name, with_names=False):
    if meth.use_raw_handler():
        return "sbus_msg_handler_fn " + name
    else:
        return "int (*%s)(struct sbus_request *%s, void *%s%s)" % \
            (name, with_names and "req" or "",
             with_names and "data" or "",
             method_arg_types(meth.in_args, with_names))

def property_handlers(prop):
    return prop.getter_signature()

def forward_method_invoker(signature, args):
    out("")
    out("/* invokes a handler with a '%s' DBus signature */", signature)
    out("static int invoke_%s_method(struct sbus_request *dbus_req, void *function_ptr);", signature)

def source_method_invoker(signature, args):
    out("")
    out("/* invokes a handler with a '%s' DBus signature */", signature)
    out("static int invoke_%s_method(struct sbus_request *dbus_req, void *function_ptr)", signature)
    out("{")
    for i in range(0, len(args)):
        arg = args[i]
        if arg.is_array:
            out("    %s *arg_%d;", arg.dbus_type, i)
            out("    int len_%d;", i)
        else:
            out("    %s arg_%d;", arg.dbus_type, i)
    out("    int (*handler)(struct sbus_request *, void *%s) = function_ptr;", method_arg_types(args))
    out("")
    out("    if (!sbus_request_parse_or_finish(dbus_req,")
    for i in range(0, len(args)):
        arg = args[i]
        if arg.is_array:
            out("                               DBUS_TYPE_ARRAY, %s, &arg_%d, &len_%d,",
                arg.dbus_constant, i, i)
        else:
            out("                               %s, &arg_%d,", arg.dbus_constant, i)
    out("                               DBUS_TYPE_INVALID)) {")
    out("         return EOK; /* request handled */")
    out("    }")
    out("")

    out("    return (handler)(dbus_req, dbus_req->intf->handler_data", new_line=False)
    for i in range(0, len(args)):
        arg = args[i]
        out(",\n                     arg_%d", i, new_line=False)
        if arg.is_array:
            out(",\n                     len_%d", i, new_line=False)
    out(");")
    out("}")

def source_prop_types(prop, type_prefix=False):
    prefix = "%s_" % prop.type if type_prefix else ""
    if prop.is_array:
        out("    %s *%sprop_val;", prop.sssd_type, prefix)
        out("    int %sprop_len;", prefix)
        out("    %s *%sout_val;", prop.dbus_type, prefix)
    else:
        out("    %s %sprop_val;", prop.sssd_type, prefix)
        out("    %s %sout_val;", prop.dbus_type, prefix)

def source_prop_handler(prop, type_prefix=False):
    prefix = "%s_" % prop.type if type_prefix else ""
    out("    %s", prop.getter_signature("%shandler" % prefix), new_line=False)
    out(";")

def forward_method_invokers(ifaces):
    invokers = { }
    for iface in ifaces:
        for meth in iface.methods:
            if meth.use_raw_handler() or not meth.in_args:
                continue
            signature = meth.in_signature()
            if signature in invokers:
                continue
            forward_method_invoker(signature, meth.in_args)
            invokers[signature] = meth
    return invokers

def source_method_invokers(invokers):
    for (signature, meth) in invokers.items():
        source_method_invoker(signature, meth.in_args)

def source_finisher(meth):
    out("")
    out("int %s_finish(struct sbus_request *req%s)",
        meth.fq_c_name(), method_arg_types(meth.out_args, with_names=True))
    out("{")

    for arg in meth.out_args:
        if arg.dbus_type != arg.sssd_type:
            out("    %s cast_%s = arg_%s;", arg.dbus_type, arg.c_name(), arg.c_name())

    out("   return sbus_request_return_and_finish(req,")
    for arg in meth.out_args:
        out("                                         ", new_line=False)
        if arg.is_array:
            out("DBUS_TYPE_ARRAY, %s, &arg_%s, len_%s,",
                arg.dbus_constant, arg.c_name(), arg.c_name())
        elif arg.dbus_type != arg.sssd_type:
            out("%s, &cast_%s,", arg.dbus_constant, arg.c_name())
        else:
            out("%s, &arg_%s,", arg.dbus_constant, arg.c_name())
    out("                                         DBUS_TYPE_INVALID);")
    out("}")

def header_reply(meth):
    for arg in meth.out_args:
        if arg.is_array:
            out("    %s *%s", arg.dbus_type, arg.c_name())
            out("    int %s__len", arg.c_name())
        else:
            out("    %s %s;", arg.dbus_type, arg.c_name())
    types = [arg.sssd_type for arg in meth.in_args]

def source_args(parent, args, suffix):
    out("")
    out("/* arguments for %s.%s */", parent.iface.name, parent.name)
    out("const struct sbus_arg_meta %s%s[] = {", parent.fq_c_name(), suffix)
    for arg in args:
        out("    { \"%s\", \"%s\" },", arg.name, arg.type)
    out("    { NULL, }")
    out("};")

def source_methods(iface, methods):
    for meth in methods:
        if meth.in_args:
            source_args(meth, meth.in_args, "__in")
        if meth.out_args:
            source_args(meth, meth.out_args, "__out")

        if not meth.use_raw_handler():
            source_finisher(meth)

    out("")
    out("/* methods for %s */", iface.name)
    out("const struct sbus_method_meta %s__methods[] = {", iface.c_name())
    for meth in methods:
        out("    {")
        out("        \"%s\", /* name */", meth.name)
        if meth.in_args:
            out("        %s__in,", meth.fq_c_name())
        else:
            out("        NULL, /* no in_args */")
        if meth.out_args:
            out("        %s__out,", meth.fq_c_name())
        else:
            out("        NULL, /* no out_args */")
        out("        offsetof(struct %s, %s),", iface.c_name(), meth.c_name())
        if meth.use_raw_handler() or not meth.in_args:
            out("        NULL, /* no invoker */")
        else:
            out("        invoke_%s_method,", meth.in_signature())
        out("    },")
    out("    { NULL, }")
    out("};")

def source_signals(iface, signals):
    for sig in iface.signals:
        if sig.args:
            source_args(sig, sig.args, "__args")

    out("")
    out("/* signals for %s */", iface.name)
    out("const struct sbus_signal_meta %s__signals[] = {", iface.c_name())
    for sig in signals:
        out("    {")
        out("        \"%s\", /* name */", sig.name)
        if sig.args:
            out("        %s__args", sig.fq_c_name())
        else:
            out("        NULL, /* no args */")
        out("    },")
    out("    { NULL, }")
    out("};")

def source_properties(iface, properties):
    out("")
    out("/* property info for %s */", iface.name)

    out("const struct sbus_property_meta %s__properties[] = {", iface.c_name())
    for prop in properties:
        out("    {")
        out("        \"%s\", /* name */", prop.name)
        out("        \"%s\", /* type */", prop.type)
        if prop.readable and prop.writable:
            out("        SBUS_PROPERTY_READABLE | SBUS_PROPERTY_WRITABLE,")
        elif prop.readable:
            out("        SBUS_PROPERTY_READABLE,")
        elif prop.writable:
            out("        SBUS_PROPERTY_WRITABLE,")
        else:
            assert False, "should not be reached"
        if prop.readable:
            out("        offsetof(struct %s, %s),", iface.c_name(), prop.getter_name())
            out("        %s,", prop.getter_invoker_name())
        else:
            out("        0, /* not readable */")
            out("        NULL, /* no invoker */")
        out("        0, /* not writable */")
        out("        NULL, /* no invoker */")
        out("    },")
    out("    { NULL, }")
    out("};")

def header_interface(iface):
    out("")
    out("/* interface info for %s */", iface.name)
    out("extern const struct sbus_interface_meta %s_meta;", iface.c_name())

def source_interface(iface):
    out("")
    out("/* interface info for %s */", iface.name)
    out("const struct sbus_interface_meta %s_meta = {", iface.c_name())
    out("    \"%s\", /* name */", iface.name)
    if iface.methods:
        out("    %s__methods,", iface.c_name())
    else:
        out("    NULL, /* no methods */")
    if iface.signals:
        out("    %s__signals,", iface.c_name())
    else:
        out("    NULL, /* no signals */")
    if iface.properties:
        out("    %s__properties,", iface.c_name())
    else:
        out("    NULL, /* no properties */")
    out("    sbus_invoke_get_all, /* GetAll invoker */")
    out("};")

def generate_source(ifaces, filename, include_header=None):
    basename = os.path.basename(filename)

    out("/* The following definitions are auto-generated from %s */", basename)
    out("")

    out("#include <stddef.h>")
    out("")
    out("#include \"dbus/dbus-protocol.h\"")
    out("#include \"util/util_errors.h\"")
    out("#include \"sbus/sssd_dbus.h\"")
    out("#include \"sbus/sssd_dbus_meta.h\"")
    out("#include \"sbus/sssd_dbus_invokers.h\"")
    if include_header:
        out("#include \"%s\"", os.path.basename(include_header))

    meth_invokers = forward_method_invokers(ifaces)

    for iface in ifaces:

        # The methods
        if iface.methods:
            source_methods(iface, iface.methods)

        # The signals array
        if iface.signals:
            source_signals(iface, iface.signals)

        # The properties array
        if iface.properties:
            source_properties(iface, iface.properties)

        # The sbus_interface structure
        source_interface(iface)

    source_method_invokers(meth_invokers)

def header_finisher(iface, meth):
    if meth.use_raw_handler():
        return
    out("")
    out("/* finish function for %s */", meth.name)
    out("int %s_finish(struct sbus_request *req%s);",
        meth.fq_c_name(), method_arg_types(meth.out_args, with_names=True))

def header_vtable(iface, methods):
    out("")
    out("/* vtable for %s */", iface.name)
    out("struct %s {", iface.c_name())
    out("    struct sbus_vtable vtable; /* derive from sbus_vtable */")

    # All methods
    for meth in iface.methods:
        out("    %s;", method_function_pointer(meth, meth.c_name(), with_names=True))
    for prop in iface.properties:
        out("    %s;", property_handlers(prop))

    out("};")

def header_constants(iface):
    out("")
    out("/* constants for %s */", iface.name)
    out("#define %s \"%s\"", iface.c_name().upper(), iface.name)
    for meth in iface.methods:
        out("#define %s \"%s\"", meth.fq_c_name().upper(), meth.name)
    for sig in iface.signals:
        out("#define %s \"%s\"", sig.fq_c_name().upper(), sig.name)
    for prop in iface.properties:
        out("#define %s \"%s\"", prop.fq_c_name().upper(), prop.name)

def generate_header(ifaces, filename):
    basename = os.path.basename(filename)
    guard = "__%s__" % re.sub(r'([^_A-Z0-9])', "_", basename.upper())

    out("/* The following declarations are auto-generated from %s */", basename)
    out("")
    out("#ifndef %s", guard)
    out("#define %s", guard)
    out("")
    out("#include \"sbus/sssd_dbus.h\"")
    out("#include \"sbus/sssd_dbus_meta.h\"")

    out("")
    out("/* ------------------------------------------------------------------------")
    out(" * DBus Constants")
    out(" *")
    out(" * Various constants of interface and method names mostly for use by clients")
    out(" */")

    for iface in ifaces:
        header_constants(iface)

    out("")
    out("/* ------------------------------------------------------------------------")
    out(" * DBus handlers")
    out(" *")
    out(" * These structures are filled in by implementors of the different")
    out(" * dbus interfaces to handle method calls.")
    out(" *")
    out(" * Handler functions of type sbus_msg_handler_fn accept raw messages,")
    out(" * other handlers are typed appropriately. If a handler that is")
    out(" * set to NULL is invoked it will result in a")
    out(" * org.freedesktop.DBus.Error.NotSupported error for the caller.")
    out(" *")
    out(" * Handlers have a matching xxx_finish() function (unless the method has")
    out(" * accepts raw messages). These finish functions the")
    out(" * sbus_request_return_and_finish() with the appropriate arguments to")
    out(" * construct a valid reply. Once a finish function has been called, the")
    out(" * @dbus_req it was called with is freed and no longer valid.")
    out(" */")

    for iface in ifaces:
        if iface.methods or iface.properties:
            header_vtable(iface, iface.methods)
            for meth in iface.methods:
                header_finisher(iface, meth)

    out("")
    out("/* ------------------------------------------------------------------------")
    out(" * DBus Interface Metadata")
    out(" *")
    out(" * These structure definitions are filled in with the information about")
    out(" * the interfaces, methods, properties and so on.")
    out(" *")
    out(" * The actual definitions are found in the accompanying C file next")
    out(" * to this header.")
    out(" */")

    for iface in ifaces:
        header_interface(iface)

    out("")
    out("#endif /* %s */", guard)

# -----------------------------------------------------------------------------
# XML Interface Parsing

STATE_TOP = 'top'
STATE_NODE = 'node'
STATE_INTERFACE = 'interface'
STATE_METHOD = 'method'
STATE_SIGNAL = 'signal'
STATE_PROPERTY = 'property'
STATE_ARG = 'arg'
STATE_ANNOTATION = 'annotation'
STATE_IGNORED = 'ignored'

def expect_attr(attrs, name):
    if name not in attrs:
        raise DBusXmlException("Missing attribute '%s'" % name)
    if attrs[name] == "":
        raise DBusXmlException("Empty attribute '%s'" % name)
    return attrs[name]

class DBusXMLParser(object):
    def __init__(self, filename):
        parser = xml.parsers.expat.ParserCreate()
        parser.CommentHandler = self.handle_comment
        parser.CharacterDataHandler = self.handle_char_data
        parser.StartElementHandler = self.handle_start_element
        parser.EndElementHandler = self.handle_end_element

        self.parsed_interfaces = []
        self.cur_object = None

        self.state = STATE_TOP
        self.state_stack = []
        self.cur_object = None
        self.cur_object_stack = []
        self.arg_count = 0

        try:
            with open(filename, "rb") as f:
                parser.ParseFile(f)
        except DBusXmlException as ex:
            ex.line = parser.CurrentLineNumber
            ex.file = filename
            raise
        except xml.parsers.expat.ExpatError as ex:
            exc = DBusXmlException(str(ex))
            exc.line = ex.lineno
            exc.file = filename
            raise exc

    def handle_comment(self, data):
        pass

    def handle_char_data(self, data):
        pass

    def handle_start_element(self, name, attrs):
        old_state = self.state
        old_cur_object = self.cur_object
        if self.state == STATE_IGNORED:
            self.state = STATE_IGNORED
        elif self.cur_object and name == STATE_ANNOTATION:
                val = attrs.get('value', '')
                self.cur_object.annotations[expect_attr(attrs, 'name')] = val
                self.state = STATE_IGNORED
        elif self.state == STATE_TOP:
            if name == STATE_NODE:
                self.state = STATE_NODE
            else:
                self.state = STATE_IGNORED
        elif self.state == STATE_NODE:
            if name == STATE_INTERFACE:
                self.state = STATE_INTERFACE
                iface = Interface(expect_attr(attrs, 'name'))
                self.cur_object = iface
                self.parsed_interfaces.append(iface)
            else:
                self.state = STATE_IGNORED

        elif self.state == STATE_INTERFACE:
            if name == STATE_METHOD:
                self.state = STATE_METHOD
                method = Method(self.cur_object, expect_attr(attrs, 'name'))
                self.cur_object.methods.append(method)
                self.cur_object = method
                self.arg_count = 0
            elif name == STATE_SIGNAL:
                self.state = STATE_SIGNAL
                signal = Signal(self.cur_object, expect_attr(attrs, 'name'))
                self.cur_object.signals.append(signal)
                self.cur_object = signal
                self.arg_count = 0
            elif name == STATE_PROPERTY:
                self.state = STATE_PROPERTY
                prop = Property(self.cur_object,
                                expect_attr(attrs, 'name'),
                                expect_attr(attrs, 'type'),
                                expect_attr(attrs, 'access'))
                self.cur_object.properties.append(prop)
                self.cur_object = prop
            else:
                self.state = STATE_IGNORED

        elif self.state == STATE_METHOD:
            if name == STATE_ARG:
                self.state = STATE_ARG
                arg = Arg(self.cur_object,
                          expect_attr(attrs, 'name'),
                          expect_attr(attrs, 'type'))
                direction = attrs.get('direction', 'in')
                if direction == 'in':
                    self.cur_object.in_args.append(arg)
                elif direction == 'out':
                    self.cur_object.out_args.append(arg)
                else:
                    raise DBusXmlException('Invalid direction "%s"' % direction)
                self.cur_object = arg
            else:
                self.state = STATE_IGNORED

        elif self.state == STATE_SIGNAL:
            if name == STATE_ARG:
                self.state = STATE_ARG
                arg = Arg(self.cur_object,
                          expect_attr(attrs, 'name'),
                          expect_attr(attrs, 'type'))
                self.cur_object.args.append(arg)
                self.cur_object = arg
            else:
                self.state = STATE_IGNORED

        elif self.state == STATE_PROPERTY:
            self.state = STATE_IGNORED

        elif self.state == STATE_ARG:
            self.state = STATE_IGNORED

        else:
            assert False, 'Unhandled state "%s" while entering element with name "%s"' % (self.state, name)

        self.state_stack.append(old_state)
        self.cur_object_stack.append(old_cur_object)

    def handle_end_element(self, name):
        if self.cur_object:
            self.cur_object.validate()
        self.state = self.state_stack.pop()
        self.cur_object = self.cur_object_stack.pop()

def parse_options():
    parser = optparse.OptionParser("usage: %prog [options] introspect.xml ...")
    parser.set_description("sbus_codegen generates sbus interface structures \
                            from standard XML Introspect data.")
    parser.add_option("--mode",
                      dest="mode", default="header",
                      help="'header' or 'source' (default: header)",
                      metavar="MODE")
    parser.add_option("--output",
                      dest="output", default=None,
                      help="Set output file name (default: stdout)",
                      metavar="FILE")
    parser.add_option("--include",
                      dest="include", default=None,
                      help="name of a header to #include",
                      metavar="HEADER")
    (options, args) = parser.parse_args()

    if not args:
        print("sbus_codegen: no input file specified", file=sys.stderr)
        sys.exit(2)

    if options.mode not in ["header", "source"]:
        print("sbus_codegen: specify --mode=header or --mode=source", file=sys.stderr)

    return options, args

def main():
    options, args = parse_options()

    if options.output:
        sys.stdout = buf = StringIO.StringIO()

    for filename in args:
        parser = DBusXMLParser(filename)

        if options.mode == "header":
            generate_header(parser.parsed_interfaces, filename)
        elif options.mode == "source":
            generate_source(parser.parsed_interfaces, filename, options.include)
        else:
            assert False, "should not be reached"

    # Write output at end to be nice to 'make'
    if options.output:
        output = open(options.output, "w")
        output.write(buf.getvalue())
        output.close()

if __name__ == "__main__":
    try:
        main()
    except DBusXmlException as ex:
        print(str(ex), file=sys.stderr)
        sys.exit(1)
