diff --git a/slothy/targets/aarch64/aarch64_neon.py b/slothy/targets/aarch64/aarch64_neon.py
index 2baf5df69..4a8db16d6 100644
--- a/slothy/targets/aarch64/aarch64_neon.py
+++ b/slothy/targets/aarch64/aarch64_neon.py
@@ -806,6 +806,33 @@ class AArch64Instruction(Instruction):
PARSERS = {}
+ @staticmethod
+ def _replace_duplicate_datatypes(src, mnemonic_key):
+ pattern = re.compile(rf"<{re.escape(mnemonic_key)}\d*>")
+
+ matches = list(pattern.finditer(src))
+
+ if len(matches) > 1:
+ for i, match in enumerate(reversed(matches)):
+ start, end = match.span()
+ src = src[:start] + f"<{mnemonic_key}{len(matches)-1-i}>" + src[end:]
+
+ return src
+
+ @staticmethod
+ def _enforce_datatype_matching(pattern, res):
+ datatypes = {}
+ for i, m in enumerate(re.finditer(r"
", pattern)):
+ dt = m.group(0)
+ val = res.get(f"datatype{i}", res.get("datatype"))
+ if dt in datatypes and datatypes[dt] != val:
+ raise FatalParsingException(
+ f"Inconsistent data type: {datatypes[dt]} vs {val}"
+ )
+ elif dt not in datatypes and val in datatypes.values():
+ raise FatalParsingException(f"Inconsistent dt: {dt}")
+ datatypes[dt] = val
+
@staticmethod
def _unfold_pattern(src):
@@ -883,6 +910,7 @@ def pattern_i(i):
barrel_pattern = "(?i:lsl|ror|lsr|asr)\\\\s*"
src = replace_placeholders(src, "imm", imm_pattern, "imm")
+ src = AArch64Instruction._replace_duplicate_datatypes(src, "dt")
src = replace_placeholders(src, "dt", dt_pattern, "datatype")
src = replace_placeholders(src, "index", index_pattern, "index")
src = replace_placeholders(src, "flag", flag_pattern, "flag")
@@ -1090,6 +1118,8 @@ def build(c, src):
assert isinstance(src, dict)
res = src
+ AArch64Instruction._enforce_datatype_matching(pattern, res)
+
obj = c(
pattern,
inputs=inputs,
@@ -1134,6 +1164,7 @@ def t_default(x):
return txt
out = replace_pattern(out, "immediate", "imm", lambda x: f"#{x}")
+ out = AArch64Instruction._replace_duplicate_datatypes(out, "dt")
out = replace_pattern(out, "datatype", "dt", lambda x: x.upper())
out = replace_pattern(out, "flag", "flag")
out = replace_pattern(out, "index", "index", str)
@@ -1216,13 +1247,13 @@ class nop(AArch64Instruction):
class vadd(AArch64Instruction):
- pattern = "add ., ., ."
+ pattern = "add ., ., ."
inputs = ["Vb", "Vc"]
outputs = ["Va"]
class vsub(AArch64Instruction):
- pattern = "sub ., ., ."
+ pattern = "sub ., ., ."
inputs = ["Vb", "Vc"]
outputs = ["Va"]
@@ -1355,7 +1386,7 @@ class Q_Ld2_Lane_Post_Inc(AArch64Instruction):
class q_ld2_lane_post_inc(Q_Ld2_Lane_Post_Inc):
- pattern = "ld2 { ., . }[], [], "
+ pattern = "ld2 { ., . }[], [], "
in_outs = ["Va", "Vb", "Xa"]
@classmethod
@@ -1372,7 +1403,7 @@ def write(self):
class q_ld2_lane_post_inc_force_output(Q_Ld2_Lane_Post_Inc):
- pattern = "ld2 { ., . }[], [], "
+ pattern = "ld2 { ., . }[], [], "
# TODO: Model sp dependency
in_outs = ["Xa"]
outputs = ["Va", "Vb"]
@@ -2942,25 +2973,25 @@ class Vzip(AArch64Instruction):
class vzip1(Vzip):
- pattern = "zip1 ., ., ."
+ pattern = "zip1 ., ., ."
inputs = ["Va", "Vb"]
outputs = ["Vd"]
class vzip2(Vzip):
- pattern = "zip2 ., ., ."
+ pattern = "zip2 ., ., ."
inputs = ["Va", "Vb"]
outputs = ["Vd"]
class vuzp1(Vzip):
- pattern = "uzp1 ., ., ."
+ pattern = "uzp1 ., ., ."
inputs = ["Va", "Vb"]
outputs = ["Vd"]
class vuzp2(Vzip):
- pattern = "uzp2 ., ., ."
+ pattern = "uzp2 ., ., ."
inputs = ["Va", "Vb"]
outputs = ["Vd"]
@@ -2970,13 +3001,13 @@ class Vqdmulh(AArch64Instruction):
class vqrdmulh(Vqdmulh):
- pattern = "sqrdmulh ., ., ."
+ pattern = "sqrdmulh ., ., ."
inputs = ["Va", "Vb"]
outputs = ["Vd"]
class vqrdmulh_lane(Vqdmulh):
- pattern = "sqrdmulh ., ., .[]"
+ pattern = "sqrdmulh ., ., .[]"
inputs = ["Va", "Vb"]
outputs = ["Vd"]
@@ -2992,7 +3023,7 @@ def make(cls, src):
class vqdmulh_lane(Vqdmulh):
- pattern = "sqdmulh ., ., .[]"
+ pattern = "sqdmulh ., ., .[]"
inputs = ["Va", "Vb"]
outputs = ["Vd"]
@@ -3067,13 +3098,13 @@ class AArch64NeonCount(AArch64Instruction):
class vcnt(AArch64NeonCount):
- pattern = "cnt ., ."
+ pattern = "cnt ., ."
inputs = ["Va"]
outputs = ["Vd"]
class vclz(AArch64NeonCount):
- pattern = "clz ., ."
+ pattern = "clz ., ."
inputs = ["Va"]
outputs = ["Vd"]
@@ -3112,25 +3143,25 @@ class SHA3Instruction(
class vrax1(SHA3Instruction): # pylint: disable=missing-docstring,invalid-name
- pattern = "rax1 ., ., ."
+ pattern = "rax1 ., ., ."
inputs = ["Va", "Vb"]
outputs = ["Vd"]
class veor3(SHA3Instruction): # pylint: disable=missing-docstring,invalid-name
- pattern = "eor3 ., ., ., ."
+ pattern = "eor3 ., ., ., ."
inputs = ["Va", "Vb", "Vc"]
outputs = ["Vd"]
class vbcax(SHA3Instruction): # pylint: disable=missing-docstring,invalid-name
- pattern = "bcax ., ., ., ."
+ pattern = "bcax ., ., ., ."
inputs = ["Va", "Vb", "Vc"]
outputs = ["Vd"]
class vxar(SHA3Instruction): # pylint: disable=missing-docstring,invalid-name
- pattern = "xar ., ., ., "
+ pattern = "xar ., ., ., "
inputs = ["Va", "Vb"]
outputs = ["Vd"]
@@ -3140,19 +3171,19 @@ class AArch64NeonLogical(AArch64Instruction):
class vtbl(AArch64Instruction):
- pattern = "tbl ., {.}, ."
+ pattern = "tbl ., {.}, ."
inputs = ["Va", "Vb"]
outputs = ["Vd"]
class vand(AArch64NeonLogical):
- pattern = "and ., ., ."
+ pattern = "and ., ., ."
inputs = ["Va", "Vb"]
outputs = ["Vd"]
class vbic(AArch64NeonLogical):
- pattern = "bic ., ., ."
+ pattern = "bic ., ., ."
inputs = ["Va", "Vb"]
outputs = ["Vd"]
@@ -3163,25 +3194,25 @@ class vbic_imm_shifted(AArch64NeonLogical):
class vmvn(AArch64NeonLogical):
- pattern = "mvn .,