#!/usr/bin/env python3
# pylint: disable=C0114,C0115,C0116,C0209,C0302,R0902,R0911,R0912,R0914,R0915,E1101
#
# Copyright 2022-2024 by Wilson Snyder. Verilator is free software; you
# can redistribute it and/or modify it under the terms of either the GNU Lesser
# General Public License Version 3 or the Apache License 2.0.
# SPDX-License-Identifier: LGPL-3.0-only OR Apache-2.0

import argparse
import os
import sys
import shlex
from typing import Callable, Iterable, Optional, Union, TYPE_CHECKING
import dataclasses
from dataclasses import dataclass
import enum
from enum import Enum
import multiprocessing
import re
import tempfile

import clang.cindex
from clang.cindex import (
    Index,
    TranslationUnitSaveError,
    TranslationUnitLoadError,
    CompilationDatabase,
)

if not TYPE_CHECKING:
    from clang.cindex import CursorKind
else:
    # Workaround for missing support for members defined out-of-class in Pylance:
    # https://github.com/microsoft/pylance-release/issues/2365#issuecomment-1035803067

    class CursorKindMeta(type):

        def __getattr__(cls, name: str) -> clang.cindex.CursorKind:
            return getattr(clang.cindex.CursorKind, name)

    class CursorKind(clang.cindex.CursorKind, metaclass=CursorKindMeta):
        pass


def fully_qualified_name(node):
    if node is None:
        return []
    if node.kind == CursorKind.TRANSLATION_UNIT:
        return []
    res = fully_qualified_name(node.semantic_parent)
    if res:
        return res + ([node.displayname] if node.displayname else [])
    return [node.displayname] if node.displayname else []


# Returns True, if `class_node` contains node
# that matches `member` spelling
def check_class_member_exists(class_node, member):
    for child in class_node.get_children():
        if member.spelling == child.spelling:
            return True
    return False


# Returns Base class (if found) of `class_node`
# that is of type `base_type`
def get_base_class(class_node, base_type):
    for child in class_node.get_children():
        if child.kind is CursorKind.CXX_BASE_SPECIFIER:
            base_class = child.type
            if base_type.spelling == base_class.spelling:
                return base_class
    return None


@dataclass
class VlAnnotations:
    mt_start: bool = False
    mt_safe: bool = False
    stable_tree: bool = False
    mt_safe_postinit: bool = False
    mt_unsafe: bool = False
    mt_disabled: bool = False
    mt_unsafe_one: bool = False
    pure: bool = False
    guarded: bool = False
    requires: bool = False
    excludes: bool = False
    acquire: bool = False
    release: bool = False

    def is_mt_safe_context(self):
        return self.mt_safe and not (self.mt_unsafe or self.mt_unsafe_one)

    def is_pure_context(self):
        return self.pure

    def is_stabe_tree_context(self):
        # stable tree context requires calls to be marked
        # as MT_SAFE or MT_STABLE
        # Functions in MT_START needs to be MT_SAFE or MT_STABLE
        return self.stable_tree or self.mt_start

    def is_mt_unsafe_call(self):
        return self.mt_unsafe or self.mt_unsafe_one or self.mt_disabled

    def is_mt_safe_call(self):
        return (not self.is_mt_unsafe_call()
                and (self.mt_safe or self.mt_safe_postinit or self.pure
                     or self.requires or self.excludes or self.acquire
                     or self.release))

    def is_pure_call(self):
        return self.pure

    def is_stabe_tree_call(self):
        return self.stable_tree

    def __or__(self, other: "VlAnnotations"):
        result = VlAnnotations()
        for key, value in dataclasses.asdict(self).items():
            setattr(result, key, value | getattr(other, key))
        return result

    def is_empty(self):
        for value in dataclasses.asdict(self).values():
            if value:
                return False
        return True

    def __str__(self):
        result = []
        for field, value in dataclasses.asdict(self).items():
            if value:
                result.append(field)
        return ", ".join(result)

    @staticmethod
    def from_nodes_list(nodes: Iterable):
        result = VlAnnotations()
        for node in nodes:
            if node.kind == CursorKind.ANNOTATE_ATTR:
                if node.displayname == "MT_START":
                    result.mt_start = True
                elif node.displayname == "MT_SAFE":
                    result.mt_safe = True
                elif node.displayname == "MT_STABLE":
                    result.stable_tree = True
                elif node.displayname == "MT_SAFE_POSTINIT":
                    result.mt_safe_postinit = True
                elif node.displayname == "MT_UNSAFE":
                    result.mt_unsafe = True
                elif node.displayname == "MT_UNSAFE_ONE":
                    result.mt_unsafe_one = True
                elif node.displayname == "MT_DISABLED":
                    result.mt_disabled = True
                elif node.displayname == "PURE":
                    result.pure = True
                elif node.displayname in ["ACQUIRE", "ACQUIRE_SHARED"]:
                    result.acquire = True
                elif node.displayname in ["RELEASE", "RELEASE_SHARED"]:
                    result.release = True
                elif node.displayname == "REQUIRES":
                    result.requires = True
                elif node.displayname in ["EXCLUDES", "MT_SAFE_EXCLUDES"]:
                    result.excludes = True
                elif node.displayname == "GUARDED_BY":
                    result.guarded = True
            # Attributes are always at the beginning
            elif not node.kind.is_attribute():
                break
        return result


class FunctionType(Enum):
    UNKNOWN = enum.auto()
    FUNCTION = enum.auto()
    METHOD = enum.auto()
    STATIC_METHOD = enum.auto()
    CONSTRUCTOR = enum.auto()

    @staticmethod
    def from_node(node: clang.cindex.Cursor):
        if node is None:
            return FunctionType.UNKNOWN
        if node.kind == CursorKind.FUNCTION_DECL:
            return FunctionType.FUNCTION
        if node.kind == CursorKind.CXX_METHOD and node.is_static_method():
            return FunctionType.STATIC_METHOD
        if node.kind == CursorKind.CXX_METHOD:
            return FunctionType.METHOD
        if node.kind == CursorKind.CONSTRUCTOR:
            return FunctionType.CONSTRUCTOR
        return FunctionType.UNKNOWN


@dataclass(eq=False)
class FunctionInfo:
    name_parts: list[str]
    usr: str
    file: str
    line: int
    annotations: VlAnnotations
    ftype: FunctionType

    _hash: Optional[int] = dataclasses.field(default=None,
                                             init=False,
                                             repr=False)

    @property
    def name(self):
        return "::".join(self.name_parts)

    def __str__(self):
        return f"[{self.name}@{self.file}:{self.line}]"

    def __hash__(self):
        if not self._hash:
            self._hash = hash(f"{self.usr}:{self.file}:{self.line}")
        return self._hash

    def __eq__(self, other):
        return (self.usr == other.usr and self.file == other.file
                and self.line == other.line)

    def copy(self, /, **changes):
        return dataclasses.replace(self, **changes)

    @staticmethod
    def from_decl_file_line_and_refd_node(file: str, line: int,
                                          refd: clang.cindex.Cursor,
                                          annotations: VlAnnotations):
        file = os.path.abspath(file)
        refd = refd.canonical
        assert refd is not None
        name_parts = fully_qualified_name(refd)
        usr = refd.get_usr()
        ftype = FunctionType.from_node(refd)

        return FunctionInfo(name_parts, usr, file, line, annotations, ftype)

    @staticmethod
    def from_node(node: clang.cindex.Cursor,
                  refd: Optional[clang.cindex.Cursor] = None,
                  annotations: Optional[VlAnnotations] = None):
        file = os.path.abspath(node.location.file.name)
        line = node.location.line
        if annotations is None:
            annotations = VlAnnotations.from_nodes_list(node.get_children())
        if refd is None:
            refd = node.referenced
        if refd is not None:
            refd = refd.canonical
        assert refd is not None
        name_parts = fully_qualified_name(refd)
        usr = refd.get_usr()
        ftype = FunctionType.from_node(refd)

        return FunctionInfo(name_parts, usr, file, line, annotations, ftype)


class DiagnosticKind(Enum):
    ANNOTATIONS_DEF_DECL_MISMATCH = enum.auto()
    NON_PURE_CALL_IN_PURE_CTX = enum.auto()
    NON_MT_SAFE_CALL_IN_MT_SAFE_CTX = enum.auto()
    NON_STABLE_TREE_CALL_IN_STABLE_TREE_CTX = enum.auto()
    MISSING_MT_DISABLED_ANNOTATION = enum.auto()

    def __lt__(self, other):
        return self.value < other.value


@dataclass
class Diagnostic:
    target: FunctionInfo
    source: FunctionInfo
    source_ctx: FunctionInfo
    kind: DiagnosticKind

    _hash: Optional[int] = dataclasses.field(default=None,
                                             init=False,
                                             repr=False)

    def __hash__(self):
        if not self._hash:
            self._hash = hash(
                hash(self.target) ^ hash(self.source_ctx) ^ hash(self.kind))
        return self._hash


class CallAnnotationsValidator:

    def __init__(self, diagnostic_cb: Callable[[Diagnostic], None],
                 is_ignored_top_level: Callable[[clang.cindex.Cursor], bool],
                 is_ignored_def: Callable[
                     [clang.cindex.Cursor, clang.cindex.Cursor], bool],
                 is_ignored_call: Callable[[clang.cindex.Cursor], bool]):
        self._diagnostic_cb = diagnostic_cb
        self._is_ignored_top_level = is_ignored_top_level
        self._is_ignored_call = is_ignored_call
        self._is_ignored_def = is_ignored_def

        self._index = Index.create()

        # Map key represents translation unit initial defines
        # (from command line and source's lines before any include)
        self._processed_headers: dict[str, set[str]] = {}
        self._external_decls: dict[str, set[tuple[str, int]]] = {}

        # Current context
        self._main_source_file: str = ""
        self._defines: dict[str, str] = {}
        self._call_location: Optional[FunctionInfo] = None
        self._caller: Optional[FunctionInfo] = None
        self._constructor_context: list[clang.cindex.Cursor] = []
        self._level: int = 0

    def is_mt_disabled_code_unit(self):
        return "VL_MT_DISABLED_CODE_UNIT" in self._defines

    def is_constructor_context(self):
        return len(self._constructor_context) > 0

    # Parses all lines in a form: `#define KEY VALUE` located before any `#include` line.
    # The parsing is very simple, there is no support for line breaks, etc.
    @staticmethod
    def parse_initial_defines(source_file: str) -> dict[str, str]:
        defs: dict[str, str] = {}
        with open(source_file, "r", encoding="utf-8") as file:
            for line in file:
                line = line.strip()
                match = re.fullmatch(
                    r"^#\s*(define\s+(\w+)(?:\s+(.*))?|include\s+.*)$", line)
                if match:
                    if match.group(1).startswith("define"):
                        key = match.group(2)
                        value = match.groups("1")[2]
                        defs[key] = value
                    elif match.group(1).startswith("include"):
                        break
        return defs

    @staticmethod
    def filter_out_unsupported_compiler_args(
            args: list[str]) -> tuple[list[str], dict[str, str]]:
        filtered_args = []
        defines = {}
        args_iter = iter(args)
        try:
            while arg := next(args_iter):
                # Skip positional arguments (input file name).
                if not arg.startswith("-") and (arg.endswith(".cpp")
                                                or arg.endswith(".c")
                                                or arg.endswith(".h")):
                    continue

                # Skipped options with separate value argument.
                if arg in ["-o", "-T", "-MT", "-MQ", "-MF"
                           "-L"]:
                    next(args_iter)
                    continue

                # Skipped options without separate value argument.
                if arg == "-c" or arg.startswith("-W") or arg.startswith("-L"):
                    continue

                # Preserved options with separate value argument.
                if arg in [
                        "-x"
                        "-Xclang", "-I", "-isystem", "-iquote", "-include",
                        "-include-pch"
                ]:
                    filtered_args += [arg, next(args_iter)]
                    continue

                kv_str = None
                d_or_u = None
                # Preserve define/undefine with separate value argument.
                if arg in ["-D", "-U"]:
                    filtered_args.append(arg)
                    d_or_u = arg[1]
                    kv_str = next(args_iter)
                    filtered_args.append(kv_str)
                # Preserve define/undefine without separate value argument.
                elif arg[0:2] in ["-D", "-U"]:
                    filtered_args.append(arg)
                    kv_str = arg[2:]
                    d_or_u = arg[1]
                # Preserve everything else.
                else:
                    filtered_args.append(arg)
                    continue

                # Keep track of defines for class' internal purposes.
                key_value = kv_str.split("=", 1)
                key = key_value[0]
                val = "1" if len(key_value) == 1 else key_value[1]

                if d_or_u == "D":
                    defines[key] = val
                elif d_or_u == "U" and key in defines:
                    del defines[key]

        except StopIteration:
            pass

        return (filtered_args, defines)

    def compile_and_analyze_file(self, source_file: str,
                                 compiler_args: list[str],
                                 build_dir: Optional[str]):
        filename = os.path.abspath(source_file)
        initial_cwd = "."

        filtered_args, defines = self.filter_out_unsupported_compiler_args(
            compiler_args)
        defines.update(self.parse_initial_defines(source_file))

        if build_dir:
            initial_cwd = os.getcwd()
            os.chdir(build_dir)
        try:
            translation_unit = self._index.parse(filename, filtered_args)
        except TranslationUnitLoadError:
            translation_unit = None
        errors = []
        if translation_unit:
            for diag in translation_unit.diagnostics:
                if diag.severity >= clang.cindex.Diagnostic.Error:
                    errors.append(str(diag))
        if translation_unit and len(errors) == 0:
            self._defines = defines
            self._main_source_file = filename
            self.process_translation_unit(translation_unit)
            self._main_source_file = ""
            self._defines = {}
        elif len(errors) != 0:
            print(f"%Error: parsing failed: {filename}", file=sys.stderr)
            for error in errors:
                print(f"        {error}", file=sys.stderr)
        if build_dir:
            os.chdir(initial_cwd)

    def emit_diagnostic(self, target: Union[FunctionInfo, clang.cindex.Cursor],
                        kind: DiagnosticKind):
        assert self._caller is not None
        assert self._call_location is not None
        source = self._caller
        source_ctx = self._call_location
        if isinstance(target, FunctionInfo):
            self._diagnostic_cb(Diagnostic(target, source, source_ctx, kind))
        else:
            self._diagnostic_cb(
                Diagnostic(FunctionInfo.from_node(target), source, source_ctx,
                           kind))

    def iterate_children(self, children: Iterable[clang.cindex.Cursor],
                         handler: Callable[[clang.cindex.Cursor], None]):
        if children:
            self._level += 1
            for child in children:
                handler(child)
            self._level -= 1

    @staticmethod
    def get_referenced_node_info(
        node: clang.cindex.Cursor
    ) -> tuple[bool, Optional[clang.cindex.Cursor], VlAnnotations,
               Iterable[clang.cindex.Cursor]]:
        if not node.spelling and not node.displayname:
            return (False, None, VlAnnotations(), [])

        refd = node.referenced
        if refd is None:
            raise ValueError("The node does not specify referenced node.")

        refd = refd.canonical
        children = list(refd.get_children())

        annotations = VlAnnotations.from_nodes_list(children)
        return (True, refd, annotations, children)

    def check_mt_safe_call(self, node: clang.cindex.Cursor,
                           refd: clang.cindex.Cursor,
                           annotations: VlAnnotations):
        is_mt_safe = False

        if annotations.is_mt_safe_call():
            is_mt_safe = True
        elif not annotations.is_mt_unsafe_call():
            # Check whether the object the method is called on is mt-safe
            def find_object_ref(node):
                try:
                    node = next(node.get_children())
                    if node.kind == CursorKind.DECL_REF_EXPR:
                        # Operator on an argument or local object
                        return node
                    if node.kind != CursorKind.MEMBER_REF_EXPR:
                        return None
                    if node.referenced and node.referenced.kind == CursorKind.FIELD_DECL:
                        # Operator on a member object
                        return node
                    node = next(node.get_children())
                    if node.kind == CursorKind.UNEXPOSED_EXPR:
                        node = next(node.get_children())
                    return node
                except StopIteration:
                    return None

            refn = find_object_ref(node)
            if self.is_constructor_context() and not refn:
                # we are in constructor and no object reference means
                # we are calling local method. It is MT safe
                # only if this method is also only calling local methods or
                # MT-safe methods
                self.iterate_children(refd.get_children(),
                                      self.dispatch_node_inside_definition)
                is_mt_safe = True
            # class/struct member
            elif refn and refn.kind == CursorKind.MEMBER_REF_EXPR and refn.referenced:
                refn = refn.referenced
                refna = VlAnnotations.from_nodes_list(refn.get_children())
                if refna.guarded:
                    is_mt_safe = True
                if self.is_constructor_context() and refn.semantic_parent:
                    # we are in constructor, so calling local members is MT_SAFE,
                    # make sure object that we are calling is local to the constructor
                    constructor_class = self._constructor_context[
                        -1].semantic_parent
                    if refn.semantic_parent.spelling == constructor_class.spelling:
                        if check_class_member_exists(constructor_class, refn):
                            is_mt_safe = True
                    else:
                        # check if this class inherits from some base class
                        base_class = get_base_class(constructor_class,
                                                    refn.semantic_parent)
                        if base_class:
                            if check_class_member_exists(
                                    base_class.get_declaration(), refn):
                                is_mt_safe = True
            # variable
            elif refn and refn.kind == CursorKind.DECL_REF_EXPR and refn.referenced:
                if refn.get_definition():
                    if refn.referenced.semantic_parent:
                        if refn.referenced.semantic_parent.kind in [
                                CursorKind.FUNCTION_DECL, CursorKind.CXX_METHOD
                        ]:
                            # This is a local or an argument.
                            # Calling methods on local pointers or references is MT-safe,
                            # but on argument pointers or references is not.
                            if "*" not in refn.type.spelling and "&" not in refn.type.spelling:
                                is_mt_safe = True
                            # local variable
                            if refn.referenced.kind == CursorKind.VAR_DECL:
                                is_mt_safe = True
                else:
                    # Global variable in different translation unit, unsafe
                    pass
            elif refn and refn.kind == CursorKind.CALL_EXPR:
                if self.is_constructor_context():
                    # call to local function from constructor context
                    # safe if this function also calling local methods or
                    # MT-safe methods
                    self.dispatch_call_node(refn)
                    is_mt_safe = True
        return is_mt_safe

    # Call handling

    def process_method_call(self, node: clang.cindex.Cursor,
                            refd: clang.cindex.Cursor,
                            annotations: VlAnnotations):
        assert self._call_location
        ctx = self._call_location.annotations

        # MT-safe context
        if ctx.is_mt_safe_context():
            if not self.check_mt_safe_call(node, refd, annotations):
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_MT_SAFE_CALL_IN_MT_SAFE_CTX)

        # stable tree context
        if ctx.is_stabe_tree_context():
            if annotations.is_mt_unsafe_call() or not (
                    annotations.is_stabe_tree_call()
                    or annotations.is_pure_call()
                    or self.check_mt_safe_call(node, refd, annotations)):
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_STABLE_TREE_CALL_IN_STABLE_TREE_CTX)

        # pure context
        if ctx.is_pure_context():
            if not annotations.is_pure_call():
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_PURE_CALL_IN_PURE_CTX)

    def process_function_call(self, refd: clang.cindex.Cursor,
                              annotations: VlAnnotations):
        assert self._call_location
        ctx = self._call_location.annotations

        # MT-safe context
        if ctx.is_mt_safe_context():
            if not annotations.is_mt_safe_call():
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_MT_SAFE_CALL_IN_MT_SAFE_CTX)

        # stable tree context
        if ctx.is_stabe_tree_context():
            if annotations.is_mt_unsafe_call() or not (
                    annotations.is_pure_call()
                    or annotations.is_mt_safe_call()
                    or annotations.is_stabe_tree_call()):
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_STABLE_TREE_CALL_IN_STABLE_TREE_CTX)

        # pure context
        if ctx.is_pure_context():
            if not annotations.is_pure_call():
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_PURE_CALL_IN_PURE_CTX)

    def process_constructor_call(self, refd: clang.cindex.Cursor,
                                 annotations: VlAnnotations):
        assert self._call_location
        ctx = self._call_location.annotations

        # Constructors are OK in MT-safe context
        # only if they call local methods or MT-safe functions.
        if ctx.is_mt_safe_context() or self.is_constructor_context():
            self._constructor_context.append(refd)
            self.iterate_children(refd.get_children(),
                                  self.dispatch_node_inside_definition)
            self._constructor_context.pop()

        # stable tree context
        if ctx.is_stabe_tree_context():
            self._constructor_context.append(refd)
            self.iterate_children(refd.get_children(),
                                  self.dispatch_node_inside_definition)
            self._constructor_context.pop()

        # pure context
        if ctx.is_pure_context():
            if not annotations.is_pure_call(
            ) and not refd.is_default_constructor():
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_PURE_CALL_IN_PURE_CTX)

    def dispatch_call_node(self, node: clang.cindex.Cursor):
        [supported, refd, annotations, _] = self.get_referenced_node_info(node)

        if not supported:
            self.iterate_children(node.get_children(),
                                  self.dispatch_node_inside_definition)
            return True

        assert refd is not None
        if self._is_ignored_call(refd):
            return True

        if "std::function" in refd.displayname:
            # Workaroud for missing support for lambda annotations
            # in c++11.
            # If function takes std::function as argument,
            # assume, that this std::function will be called inside it.
            self.process_function_definition(node)
            return False

        assert self._call_location is not None
        node_file = os.path.abspath(node.location.file.name)
        self._call_location = self._call_location.copy(file=node_file,
                                                       line=node.location.line)

        # Standalone functions and static class methods
        if (refd.kind == CursorKind.FUNCTION_DECL
                or refd.kind == CursorKind.CXX_METHOD
                and refd.is_static_method()):
            self.process_function_call(refd, annotations)
        # Function pointer
        elif refd.kind in [
                CursorKind.VAR_DECL, CursorKind.FIELD_DECL,
                CursorKind.PARM_DECL
        ]:
            self.process_function_call(refd, annotations)
        # Non-static class methods
        elif refd.kind == CursorKind.CXX_METHOD:
            self.process_method_call(node, refd, annotations)
        # Conversion method (e.g. `operator int()`)
        elif refd.kind == CursorKind.CONVERSION_FUNCTION:
            self.process_method_call(node, refd, annotations)
        # Constructors
        elif refd.kind == CursorKind.CONSTRUCTOR:
            self.process_constructor_call(refd, annotations)
        else:
            # Ignore other callables, but report them
            print("Unknown callable: "
                  f"{refd.location.file.name}:{refd.location.line}: "
                  f"{refd.displayname}    {refd.kind}\n"
                  f"    from: {node.location.file.name}:{node.location.line}")
        return True

    def process_function_declaration(self, node: clang.cindex.Cursor):
        # Ignore declarations in main .cpp file
        if node.location.file.name != self._main_source_file:
            children = list(node.get_children())
            annotations = VlAnnotations.from_nodes_list(children)
            if not annotations.mt_disabled:
                self._external_decls.setdefault(node.get_usr(), set()).add(
                    (str(node.location.file.name), int(node.location.line)))
            return self.iterate_children(children, self.dispatch_node)

        return self.iterate_children(node.get_children(), self.dispatch_node)

    # Definition handling

    def dispatch_node_inside_definition(self, node: clang.cindex.Cursor):
        if node.kind == CursorKind.CALL_EXPR:
            if self.dispatch_call_node(node) is False:
                return None
        elif node.is_definition() and node.kind in [
                CursorKind.CXX_METHOD, CursorKind.FUNCTION_DECL,
                CursorKind.CONSTRUCTOR, CursorKind.CONVERSION_FUNCTION
        ]:
            self.process_function_definition(node)
            return None

        return self.iterate_children(node.get_children(),
                                     self.dispatch_node_inside_definition)

    def process_function_definition(self, node: clang.cindex.Cursor):
        [supported, refd, annotations, _] = self.get_referenced_node_info(node)

        if refd and self._is_ignored_def(node, refd):
            return None

        node_children = list(node.get_children())

        if not supported:
            return self.iterate_children(node_children, self.dispatch_node)

        assert refd is not None

        def_annotations = VlAnnotations.from_nodes_list(node_children)
        # Implicitly mark definitions in VL_MT_DISABLED_CODE_UNIT .cpp files as
        # VL_MT_DISABLED. Existence of the annotation on declarations in .h
        # files is verified below.
        # Also sets VL_REQUIRES, as this annotation is added together with
        # explicit VL_MT_DISABLED.
        if self.is_mt_disabled_code_unit():
            if node.location.file.name == self._main_source_file:
                annotations.mt_disabled = True
                annotations.requires = True
            if refd.location.file.name == self._main_source_file:
                def_annotations.mt_disabled = True
                def_annotations.requires = True

        if not (def_annotations.is_empty() or def_annotations == annotations):
            # Use definition's annotations for the diagnostic
            # source (i.e. the definition)
            self._caller = FunctionInfo.from_node(node, refd, def_annotations)
            self._call_location = self._caller

            self.emit_diagnostic(
                FunctionInfo.from_node(refd, refd, annotations),
                DiagnosticKind.ANNOTATIONS_DEF_DECL_MISMATCH)

        # Use concatenation of definition and declaration annotations
        # for calls validation.
        self._caller = FunctionInfo.from_node(node, refd,
                                              def_annotations | annotations)
        prev_call_location = self._call_location
        self._call_location = self._caller

        if self.is_mt_disabled_code_unit():
            # Report declarations of this functions that don't have MT_DISABLED annotation
            # and are located in headers.
            if node.location.file.name == self._main_source_file:
                usr = node.get_usr()
                declarations = self._external_decls.get(usr, set())
                for file, line in declarations:
                    self.emit_diagnostic(
                        FunctionInfo.from_decl_file_line_and_refd_node(
                            file, line, refd, def_annotations),
                        DiagnosticKind.MISSING_MT_DISABLED_ANNOTATION)
                if declarations:
                    del self._external_decls[usr]

        self.iterate_children(node_children,
                              self.dispatch_node_inside_definition)

        self._call_location = prev_call_location
        self._caller = prev_call_location

        return None

    # Nodes not located inside definition

    def dispatch_node(self, node: clang.cindex.Cursor):
        if node.kind in [
                CursorKind.CXX_METHOD, CursorKind.FUNCTION_DECL,
                CursorKind.CONSTRUCTOR, CursorKind.CONVERSION_FUNCTION
        ]:
            if node.is_definition():
                return self.process_function_definition(node)
            # else:
            return self.process_function_declaration(node)

        return self.iterate_children(node.get_children(), self.dispatch_node)

    def process_translation_unit(
            self, translation_unit: clang.cindex.TranslationUnit):
        self._level += 1
        kv_defines = sorted([f"{k}={v}" for k, v in self._defines.items()])
        concat_defines = '\n'.join(kv_defines)
        # List of headers already processed in a TU with specified set of defines.
        tu_processed_headers = self._processed_headers.setdefault(
            concat_defines, set())
        for child in translation_unit.cursor.get_children():
            if self._is_ignored_top_level(child):
                continue
            if tu_processed_headers:
                filename = os.path.abspath(child.location.file.name)
                if filename in tu_processed_headers:
                    continue
            self.dispatch_node(child)
        self._level -= 1

        tu_processed_headers.update([
            os.path.abspath(str(hdr.source))
            for hdr in translation_unit.get_includes()
        ])


@dataclass
class CompileCommand:
    refid: int
    filename: str
    args: list[str]
    directory: str = dataclasses.field(default_factory=os.getcwd)


def get_filter_funcs(verilator_root: str):
    verilator_root = os.path.abspath(verilator_root) + "/"

    def is_ignored_top_level(node: clang.cindex.Cursor) -> bool:
        # Anything defined in a header outside Verilator root
        if not node.location.file:
            return True
        filename = os.path.abspath(node.location.file.name)
        return not filename.startswith(verilator_root)

    def is_ignored_def(node: clang.cindex.Cursor,
                       refd: clang.cindex.Cursor) -> bool:
        # __*
        if str(refd.spelling).startswith("__"):
            return True

        # Anything defined in a header outside Verilator root
        if not node.location.file:
            return True
        filename = os.path.abspath(node.location.file.name)
        if not filename.startswith(verilator_root):
            return True

        return False

    def is_ignored_call(refd: clang.cindex.Cursor) -> bool:
        # __*
        if str(refd.spelling).startswith("__"):
            return True

        # std::*
        fqn = fully_qualified_name(refd)
        if fqn and fqn[0] == "std":
            return True

        # Anything declared in a header outside Verilator root
        if not refd.location.file:
            return True
        filename = os.path.abspath(refd.location.file.name)
        if not filename.startswith(verilator_root):
            return True

        return False

    return (is_ignored_top_level, is_ignored_def, is_ignored_call)


def precompile_header(compile_command: CompileCommand, tmp_dir: str) -> str:
    initial_cwd = os.getcwd()
    errors = []
    try:
        os.chdir(compile_command.directory)

        index = Index.create()
        translation_unit = index.parse(compile_command.filename,
                                       compile_command.args)
        for diag in translation_unit.diagnostics:
            if diag.severity >= clang.cindex.Diagnostic.Error:
                errors.append(str(diag))

        if len(errors) == 0:
            pch_file = os.path.join(
                tmp_dir,
                f"{compile_command.refid:02}_{os.path.basename(compile_command.filename)}.pch"
            )
            translation_unit.save(pch_file)

            if pch_file:
                return pch_file

    except (TranslationUnitSaveError, TranslationUnitLoadError,
            OSError) as exception:
        print(f"%Warning: {exception}", file=sys.stderr)

    finally:
        os.chdir(initial_cwd)

    print(
        f"%Warning: Precompilation failed, skipping: {compile_command.filename}",
        file=sys.stderr)
    for error in errors:
        print(f"          {error}", file=sys.stderr)
    return ""


# Compile and analyze inputs in a single process.
def run_analysis(ccl: Iterable[CompileCommand], pccl: Iterable[CompileCommand],
                 diagnostic_cb: Callable[[Diagnostic],
                                         None], verilator_root: str):
    (is_ignored_top_level, is_ignored_def,
     is_ignored_call) = get_filter_funcs(verilator_root)

    prefix = "verilator_clang_check_attributes_"
    with tempfile.TemporaryDirectory(prefix=prefix) as tmp_dir:
        extra_args = []
        for pcc in pccl:
            pch_file = precompile_header(pcc, tmp_dir)
            if pch_file:
                extra_args += ["-include-pch", pch_file]

        cav = CallAnnotationsValidator(diagnostic_cb, is_ignored_top_level,
                                       is_ignored_def, is_ignored_call)
        for compile_command in ccl:
            cav.compile_and_analyze_file(compile_command.filename,
                                         extra_args + compile_command.args,
                                         compile_command.directory)


@dataclass
class ParallelAnalysisProcess:
    cav: Optional[CallAnnotationsValidator] = None
    diags: set[Diagnostic] = dataclasses.field(default_factory=set)
    tmp_dir: str = ""

    @staticmethod
    def init_data(verilator_root: str, tmp_dir: str):
        (is_ignored_top_level, is_ignored_def,
         is_ignored_call) = get_filter_funcs(verilator_root)

        ParallelAnalysisProcess.cav = CallAnnotationsValidator(
            ParallelAnalysisProcess._diagnostic_handler, is_ignored_top_level,
            is_ignored_def, is_ignored_call)
        ParallelAnalysisProcess.tmp_dir = tmp_dir

    @staticmethod
    def _diagnostic_handler(diag: Diagnostic):
        ParallelAnalysisProcess.diags.add(diag)

    @staticmethod
    def analyze_cpp_file(compile_command: CompileCommand) -> set[Diagnostic]:
        ParallelAnalysisProcess.diags = set()
        assert ParallelAnalysisProcess.cav is not None
        ParallelAnalysisProcess.cav.compile_and_analyze_file(
            compile_command.filename, compile_command.args,
            compile_command.directory)
        return ParallelAnalysisProcess.diags

    @staticmethod
    def precompile_header(compile_command: CompileCommand) -> str:
        return precompile_header(compile_command,
                                 ParallelAnalysisProcess.tmp_dir)


# Compile and analyze inputs in multiple processes.
def run_parallel_analysis(ccl: Iterable[CompileCommand],
                          pccl: Iterable[CompileCommand],
                          diagnostic_cb: Callable[[Diagnostic], None],
                          jobs_count: int, verilator_root: str):
    prefix = "verilator_clang_check_attributes_"
    with tempfile.TemporaryDirectory(prefix=prefix) as tmp_dir:
        with multiprocessing.Pool(
                processes=jobs_count,
                initializer=ParallelAnalysisProcess.init_data,
                initargs=[verilator_root, tmp_dir]) as pool:
            extra_args = []
            for pch_file in pool.imap_unordered(
                    ParallelAnalysisProcess.precompile_header, pccl):
                if pch_file:
                    extra_args += ["-include-pch", pch_file]

            if extra_args:
                for compile_command in ccl:
                    compile_command.args = compile_command.args + extra_args

            for diags in pool.imap_unordered(
                    ParallelAnalysisProcess.analyze_cpp_file, ccl, 1):
                for diag in diags:
                    diagnostic_cb(diag)


class TopDownSummaryPrinter():

    @dataclass
    class FunctionCallees:
        info: FunctionInfo
        calees: set[FunctionInfo]
        mismatch: Optional[FunctionInfo] = None
        reason: Optional[DiagnosticKind] = None

    def __init__(self):
        self._is_first_group = True

        self._funcs: dict[str, TopDownSummaryPrinter.FunctionCallees] = {}
        self._unsafe_in_safe: set[str] = set()

    def begin_group(self, label):
        if not self._is_first_group:
            print()

        print(f"%Error: {label}")

        self._is_first_group = False

    def handle_diagnostic(self, diag: Diagnostic):
        usr = diag.source.usr
        func = self._funcs.get(usr, None)
        if func is None:
            func = TopDownSummaryPrinter.FunctionCallees(diag.source, set())
            self._funcs[usr] = func
        func.reason = diag.kind
        if diag.kind == DiagnosticKind.ANNOTATIONS_DEF_DECL_MISMATCH:
            func.mismatch = diag.target
        else:
            func.calees.add(diag.target)
            self._unsafe_in_safe.add(diag.target.usr)

    def print_summary(self, root_dir: str):
        row_groups: dict[str, list[list[str]]] = {}
        column_widths = [0, 0]
        for func in sorted(self._funcs.values(),
                           key=lambda func:
                           (func.info.file, func.info.line, func.info.usr)):
            func_info = func.info
            relfile = os.path.relpath(func_info.file, root_dir)

            row_group = []
            name = f"\"{func_info.name}\" "
            if func.reason == DiagnosticKind.ANNOTATIONS_DEF_DECL_MISMATCH:
                name += "declaration does not match definition"
            elif func.reason == DiagnosticKind.NON_MT_SAFE_CALL_IN_MT_SAFE_CTX:
                name += "is mtsafe but calls non-mtsafe function(s)"
            elif func.reason == DiagnosticKind.NON_PURE_CALL_IN_PURE_CTX:
                name += "is pure but calls non-pure function(s)"
            elif func.reason == DiagnosticKind.NON_STABLE_TREE_CALL_IN_STABLE_TREE_CTX:
                name += "is stable_tree but calls non-stable_tree or non-mtsafe"
            elif func.reason == DiagnosticKind.MISSING_MT_DISABLED_ANNOTATION:
                name += ("defined in a file marked as " +
                         "VL_MT_DISABLED_CODE_UNIT has declaration(s) " +
                         "without VL_MT_DISABLED annotation")
            else:
                name += "for unknown reason (please add description)"

            if func.mismatch:
                mrelfile = os.path.relpath(func.mismatch.file, root_dir)
                row_group.append([
                    f"{mrelfile}:{func.mismatch.line}:",
                    f"[{func.mismatch.annotations}]",
                    func.mismatch.name + " [declaration]"
                ])

            row_group.append([
                f"{relfile}:{func_info.line}:", f"[{func_info.annotations}]",
                func_info.name
            ])

            for callee in sorted(func.calees,
                                 key=lambda func:
                                 (func.file, func.line, func.usr)):
                crelfile = os.path.relpath(callee.file, root_dir)
                row_group.append([
                    f"{crelfile}:{callee.line}:", f"[{callee.annotations}]",
                    "  " + callee.name
                ])

            row_groups[name] = row_group

            for row in row_group:
                for row_id, value in enumerate(row[0:-1]):
                    column_widths[row_id] = max(column_widths[row_id],
                                                len(value))

        for label, rows in sorted(row_groups.items(), key=lambda kv: kv[0]):
            self.begin_group(label)
            for row in rows:
                print(f"{row[0]:<{column_widths[0]}}  "
                      f"{row[1]:<{column_widths[1]}}    "
                      f"{row[2]}")
        print(
            f"Number of functions reported unsafe: {len(self._unsafe_in_safe)}"
        )


def main():
    default_verilator_root = os.path.abspath(
        os.path.join(os.path.dirname(__file__), ".."))

    parser = argparse.ArgumentParser(
        allow_abbrev=False,
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""Check function annotations for correctness""",
        epilog=
        """Copyright 2022-2024 by Wilson Snyder. Verilator is free software;
    you can redistribute it and/or modify it under the terms of either the GNU
    Lesser General Public License Version 3 or the Apache License 2.0.
    SPDX-License-Identifier: LGPL-3.0-only OR Apache-2.0""")

    parser.add_argument("--verilator-root",
                        type=str,
                        default=default_verilator_root,
                        help="Path to Verilator sources root directory.")
    parser.add_argument("--jobs",
                        "-j",
                        type=int,
                        default=0,
                        help="Number of parallel jobs to use.")
    parser.add_argument(
        "--compile-commands-dir",
        type=str,
        default=None,
        help="Path to directory containing compile_commands.json.")
    parser.add_argument("--cxxflags",
                        type=str,
                        default=None,
                        help="Extra flags passed to clang++.")
    parser.add_argument(
        "--compilation-root",
        type=str,
        default=os.getcwd(),
        help="Directory used as CWD when compiling source files.")
    parser.add_argument(
        "-c",
        "--precompile",
        action="append",
        help="Header file to be precompiled and cached at the start.")
    parser.add_argument("file",
                        type=str,
                        nargs="+",
                        help="Source file to analyze.")

    cmdline = parser.parse_args()

    if cmdline.jobs == 0:
        cmdline.jobs = max(1, len(os.sched_getaffinity(0)))

    if not cmdline.compilation_root:
        cmdline.compilation_root = cmdline.verilator_root

    verilator_root = os.path.abspath(cmdline.verilator_root)
    default_compilation_root = os.path.abspath(cmdline.compilation_root)

    compdb: Optional[CompilationDatabase] = None
    if cmdline.compile_commands_dir:
        compdb = CompilationDatabase.fromDirectory(
            cmdline.compile_commands_dir)

    if cmdline.cxxflags is not None:
        common_cxxflags = shlex.split(cmdline.cxxflags)
    else:
        common_cxxflags = []

    precompile_commands_list = []

    if cmdline.precompile:
        hdr_cxxflags = ['-xc++-header'] + common_cxxflags
        for refid, file in enumerate(cmdline.precompile):
            filename = os.path.abspath(file)
            compile_command = CompileCommand(refid, filename, hdr_cxxflags,
                                             default_compilation_root)
            precompile_commands_list.append(compile_command)

    compile_commands_list = []
    for refid, file in enumerate(cmdline.file):
        filename = os.path.abspath(file)
        root = default_compilation_root
        cxxflags = []
        if compdb:
            entry = compdb.getCompileCommands(filename)
            entry_list = list(entry)
            # Compilation database can contain multiple entries for single file,
            # e.g. when it has been updated by appending new entries.
            # Use last entry for the file, if it exists, as it is the newest one.
            if len(entry_list) > 0:
                last_entry = entry_list[-1]
                root = last_entry.directory
                entry_args = list(last_entry.arguments)
                # First argument in compile_commands.json arguments list is
                # compiler executable name/path. CIndex (libclang) always
                # implicitly prepends executable name, so it shouldn't be passed
                # here.
                cxxflags = common_cxxflags + entry_args[1:]
        else:
            cxxflags = common_cxxflags[:]

        compile_command = CompileCommand(refid, filename, cxxflags, root)
        compile_commands_list.append(compile_command)

    summary_printer = TopDownSummaryPrinter()

    if cmdline.jobs == 1:
        run_analysis(compile_commands_list, precompile_commands_list,
                     summary_printer.handle_diagnostic, verilator_root)
    else:
        run_parallel_analysis(compile_commands_list, precompile_commands_list,
                              summary_printer.handle_diagnostic, cmdline.jobs,
                              verilator_root)

    summary_printer.print_summary(verilator_root)


if __name__ == '__main__':
    main()
