diff --git a/slothy/targets/aarch64/aarch64_neon.py b/slothy/targets/aarch64/aarch64_neon.py index 2baf5df69..4a8db16d6 100644 --- a/slothy/targets/aarch64/aarch64_neon.py +++ b/slothy/targets/aarch64/aarch64_neon.py @@ -806,6 +806,33 @@ class AArch64Instruction(Instruction): PARSERS = {} + @staticmethod + def _replace_duplicate_datatypes(src, mnemonic_key): + pattern = re.compile(rf"<{re.escape(mnemonic_key)}\d*>") + + matches = list(pattern.finditer(src)) + + if len(matches) > 1: + for i, match in enumerate(reversed(matches)): + start, end = match.span() + src = src[:start] + f"<{mnemonic_key}{len(matches)-1-i}>" + src[end:] + + return src + + @staticmethod + def _enforce_datatype_matching(pattern, res): + datatypes = {} + for i, m in enumerate(re.finditer(r"", pattern)): + dt = m.group(0) + val = res.get(f"datatype{i}", res.get("datatype")) + if dt in datatypes and datatypes[dt] != val: + raise FatalParsingException( + f"Inconsistent data type: {datatypes[dt]} vs {val}" + ) + elif dt not in datatypes and val in datatypes.values(): + raise FatalParsingException(f"Inconsistent dt: {dt}") + datatypes[dt] = val + @staticmethod def _unfold_pattern(src): @@ -883,6 +910,7 @@ def pattern_i(i): barrel_pattern = "(?i:lsl|ror|lsr|asr)\\\\s*" src = replace_placeholders(src, "imm", imm_pattern, "imm") + src = AArch64Instruction._replace_duplicate_datatypes(src, "dt") src = replace_placeholders(src, "dt", dt_pattern, "datatype") src = replace_placeholders(src, "index", index_pattern, "index") src = replace_placeholders(src, "flag", flag_pattern, "flag") @@ -1090,6 +1118,8 @@ def build(c, src): assert isinstance(src, dict) res = src + AArch64Instruction._enforce_datatype_matching(pattern, res) + obj = c( pattern, inputs=inputs, @@ -1134,6 +1164,7 @@ def t_default(x): return txt out = replace_pattern(out, "immediate", "imm", lambda x: f"#{x}") + out = AArch64Instruction._replace_duplicate_datatypes(out, "dt") out = replace_pattern(out, "datatype", "dt", lambda x: x.upper()) out = replace_pattern(out, "flag", "flag") out = replace_pattern(out, "index", "index", str) @@ -1216,13 +1247,13 @@ class nop(AArch64Instruction): class vadd(AArch64Instruction): - pattern = "add ., ., ." + pattern = "add .
, .
, .
" inputs = ["Vb", "Vc"] outputs = ["Va"] class vsub(AArch64Instruction): - pattern = "sub ., ., ." + pattern = "sub .
, .
, .
" inputs = ["Vb", "Vc"] outputs = ["Va"] @@ -1355,7 +1386,7 @@ class Q_Ld2_Lane_Post_Inc(AArch64Instruction): class q_ld2_lane_post_inc(Q_Ld2_Lane_Post_Inc): - pattern = "ld2 { ., . }[], [], " + pattern = "ld2 { .
, .
}[], [], " in_outs = ["Va", "Vb", "Xa"] @classmethod @@ -1372,7 +1403,7 @@ def write(self): class q_ld2_lane_post_inc_force_output(Q_Ld2_Lane_Post_Inc): - pattern = "ld2 { ., . }[], [], " + pattern = "ld2 { .
, .
}[], [], " # TODO: Model sp dependency in_outs = ["Xa"] outputs = ["Va", "Vb"] @@ -2942,25 +2973,25 @@ class Vzip(AArch64Instruction): class vzip1(Vzip): - pattern = "zip1 ., ., ." + pattern = "zip1 .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class vzip2(Vzip): - pattern = "zip2 ., ., ." + pattern = "zip2 .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class vuzp1(Vzip): - pattern = "uzp1 ., ., ." + pattern = "uzp1 .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class vuzp2(Vzip): - pattern = "uzp2 ., ., ." + pattern = "uzp2 .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -2970,13 +3001,13 @@ class Vqdmulh(AArch64Instruction): class vqrdmulh(Vqdmulh): - pattern = "sqrdmulh ., ., ." + pattern = "sqrdmulh .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class vqrdmulh_lane(Vqdmulh): - pattern = "sqrdmulh ., ., .[]" + pattern = "sqrdmulh ., ., .[]" inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -2992,7 +3023,7 @@ def make(cls, src): class vqdmulh_lane(Vqdmulh): - pattern = "sqdmulh ., ., .[]" + pattern = "sqdmulh ., ., .[]" inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -3067,13 +3098,13 @@ class AArch64NeonCount(AArch64Instruction): class vcnt(AArch64NeonCount): - pattern = "cnt ., ." + pattern = "cnt .
, .
" inputs = ["Va"] outputs = ["Vd"] class vclz(AArch64NeonCount): - pattern = "clz ., ." + pattern = "clz .
, .
" inputs = ["Va"] outputs = ["Vd"] @@ -3112,25 +3143,25 @@ class SHA3Instruction( class vrax1(SHA3Instruction): # pylint: disable=missing-docstring,invalid-name - pattern = "rax1 ., ., ." + pattern = "rax1 .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class veor3(SHA3Instruction): # pylint: disable=missing-docstring,invalid-name - pattern = "eor3 ., ., ., ." + pattern = "eor3 .
, .
, .
, .
" inputs = ["Va", "Vb", "Vc"] outputs = ["Vd"] class vbcax(SHA3Instruction): # pylint: disable=missing-docstring,invalid-name - pattern = "bcax ., ., ., ." + pattern = "bcax .
, .
, .
, .
" inputs = ["Va", "Vb", "Vc"] outputs = ["Vd"] class vxar(SHA3Instruction): # pylint: disable=missing-docstring,invalid-name - pattern = "xar ., ., ., " + pattern = "xar .
, .
, .
, " inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -3140,19 +3171,19 @@ class AArch64NeonLogical(AArch64Instruction): class vtbl(AArch64Instruction): - pattern = "tbl ., {.}, ." + pattern = "tbl .
, {.
}, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class vand(AArch64NeonLogical): - pattern = "and ., ., ." + pattern = "and .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class vbic(AArch64NeonLogical): - pattern = "bic ., ., ." + pattern = "bic .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -3163,25 +3194,25 @@ class vbic_imm_shifted(AArch64NeonLogical): class vmvn(AArch64NeonLogical): - pattern = "mvn ., ." + pattern = "mvn .
, .
" inputs = ["Va"] outputs = ["Vd"] class vorr(AArch64NeonLogical): - pattern = "orr ., ., ." + pattern = "orr .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class vorn(AArch64NeonLogical): - pattern = "orn ., ., ." + pattern = "orn .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class veor(AArch64NeonLogical): - pattern = "eor ., ., ." + pattern = "eor .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -3199,7 +3230,7 @@ class vmov_d(AArch64Instruction): class vext(AArch64Instruction): - pattern = "ext ., ., ., " + pattern = "ext .
, .
, .
, " inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -3209,13 +3240,13 @@ class Vmul(AArch64Instruction): class vmul(Vmul): - pattern = "mul ., ., ." + pattern = "mul .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class vmul_lane(Vmul): - pattern = "mul ., ., .[]" + pattern = "mul ., ., .[]" inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -3236,13 +3267,13 @@ class Vmla(AArch64Instruction): class vmla(Vmla): - pattern = "mla ., ., ." + pattern = "mla .
, .
, .
" inputs = ["Va", "Vb"] in_outs = ["Vd"] class vmla_lane(Vmla): - pattern = "mla ., ., .[]" + pattern = "mla ., ., .[]" inputs = ["Va", "Vb"] in_outs = ["Vd"] @@ -3258,13 +3289,13 @@ def make(cls, src): class vmls(Vmla): - pattern = "mls ., ., ." + pattern = "mls .
, .
, .
" inputs = ["Va", "Vb"] in_outs = ["Vd"] class vmls_lane(Vmla): - pattern = "mls ., ., .[]" + pattern = "mls ., ., .[]" inputs = ["Va", "Vb"] in_outs = ["Vd"] @@ -3290,25 +3321,25 @@ class Vmull(AArch64Instruction): class vumull(Vmull): - pattern = "umull ., ., ." + pattern = "umull ., ., ." inputs = ["Va", "Vb"] outputs = ["Vd"] class vumull2(Vmull): - pattern = "umull2 ., ., ." + pattern = "umull2 ., ., ." inputs = ["Va", "Vb"] outputs = ["Vd"] class vsmull(Vmull): - pattern = "smull ., ., ." + pattern = "smull ., ., ." inputs = ["Va", "Vb"] outputs = ["Vd"] class vsmull2(Vmull): - pattern = "smull2 ., ., ." + pattern = "smull2 ., ., ." inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -3382,25 +3413,25 @@ class Vmlal(AArch64Instruction): class vumlal(Vmlal): - pattern = "umlal ., ., ." + pattern = "umlal ., ., ." inputs = ["Va", "Vb"] in_outs = ["Vd"] class vumlal2(Vmlal): - pattern = "umlal2 ., ., ." + pattern = "umlal2 ., ., ." inputs = ["Va", "Vb"] in_outs = ["Vd"] class vsmlal(Vmlal): - pattern = "smlal ., ., ." + pattern = "smlal ., ., ." inputs = ["Va", "Vb"] in_outs = ["Vd"] class vsmlal2(Vmlal): - pattern = "smlal2 ., ., ." + pattern = "smlal2 ., ., ." inputs = ["Va", "Vb"] in_outs = ["Vd"] @@ -3470,25 +3501,25 @@ def make(cls, src): class vumlsl(Vmlal): - pattern = "umlsl ., ., ." + pattern = "umlsl ., ., ." inputs = ["Va", "Vb"] in_outs = ["Vd"] class vumlsl2(Vmlal): - pattern = "umlsl2 ., ., ." + pattern = "umlsl2 ., ., ." inputs = ["Va", "Vb"] in_outs = ["Vd"] class vsmlsl(Vmlal): - pattern = "smlsl ., ., ." + pattern = "smlsl ., ., ." inputs = ["Va", "Vb"] in_outs = ["Vd"] class vsmlsl2(Vmlal): - pattern = "smlsl2 ., ., ." + pattern = "smlsl2 ., ., ." inputs = ["Va", "Vb"] in_outs = ["Vd"] @@ -3562,7 +3593,7 @@ class VShiftImmediateBasic(AArch64Instruction): class vshl(VShiftImmediateBasic): - pattern = "shl ., ., " + pattern = "shl .
, .
, " inputs = ["Va"] outputs = ["Vd"] @@ -3580,13 +3611,13 @@ class vshrn(VShiftImmediateBasic): class vsshr(VShiftImmediateBasic): - pattern = "sshr ., ., " + pattern = "sshr .
, .
, " inputs = ["Va"] outputs = ["Vd"] class vushr(VShiftImmediateBasic): - pattern = "ushr ., ., " + pattern = "ushr .
, .
, " inputs = ["Va"] outputs = ["Vd"] @@ -3602,13 +3633,13 @@ class VShiftImmediateRounding(AArch64Instruction): class vsrshr(VShiftImmediateRounding): - pattern = "srshr ., ., " + pattern = "srshr .
, .
, " inputs = ["Va"] outputs = ["Vd"] class vurshr(VShiftImmediateRounding): - pattern = "urshr ., ., " + pattern = "urshr .
, .
, " inputs = ["Va"] outputs = ["Vd"] @@ -3618,13 +3649,13 @@ class AArch64NeonShiftInsert(AArch64Instruction): class vsli(AArch64NeonShiftInsert): - pattern = "sli ., ., " + pattern = "sli .
, .
, " inputs = ["Va"] in_outs = ["Vd"] class vsri(AArch64NeonShiftInsert): - pattern = "sri ., ., " + pattern = "sri .
, .
, " inputs = ["Va"] in_outs = ["Vd"] @@ -3702,13 +3733,13 @@ class Transpose(AArch64Instruction): class trn1(Transpose): - pattern = "trn1 ., ., ." + pattern = "trn1 .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class trn2(Transpose): - pattern = "trn2 ., ., ." + pattern = "trn2 .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] @@ -4126,7 +4157,7 @@ class St4(AArch64Instruction): class st4_base(St4): - pattern = "st4 {., ., ., .}, []" + pattern = "st4 {.
, .
, .
, .
}, []" inputs = ["Xc", "Va", "Vb", "Vc", "Vd"] @classmethod @@ -4144,7 +4175,7 @@ def make(cls, src): class st4_with_inc(St4): - pattern = "st4 {., ., ., .}, [], " + pattern = "st4 {.
, .
, .
, .
}, [], " inputs = ["Va", "Vb", "Vc", "Vd"] in_outs = ["Xc"] @@ -4168,7 +4199,7 @@ class St3(AArch64Instruction): class st3_base(St3): - pattern = "st3 {., ., .}, []" + pattern = "st3 {.
, .
, .
}, []" inputs = ["Xc", "Va", "Vb", "Vc"] @classmethod @@ -4183,7 +4214,7 @@ def make(cls, src): class st3_with_inc(St3): - pattern = "st3 {., ., .}, [], " + pattern = "st3 {.
, .
, .
}, [], " inputs = ["Va", "Vb", "Vc"] in_outs = ["Xc"] @@ -4204,7 +4235,7 @@ class St2(AArch64Instruction): class st2_base(St2): - pattern = "st2 {., .}, []" + pattern = "st2 {.
, .
}, []" inputs = ["Xc", "Va", "Vb"] @classmethod @@ -4219,7 +4250,7 @@ def make(cls, src): class st2_with_inc(St2): - pattern = "st2 {., .}, [], " + pattern = "st2 {.
, .
}, [], " inputs = ["Va", "Vb"] in_outs = ["Xc"] @@ -4240,7 +4271,7 @@ class Ld4(AArch64Instruction): class ld4_base(Ld4): - pattern = "ld4 {., ., ., .}, []" + pattern = "ld4 {.
, .
, .
, .
}, []" inputs = ["Xc"] outputs = ["Va", "Vb", "Vc", "Vd"] @@ -4259,7 +4290,7 @@ def make(cls, src): class ld4_with_inc(Ld4): - pattern = "ld4 {., ., ., .}, [], " + pattern = "ld4 {.
, .
, .
, .
}, [], " in_outs = ["Xc"] outputs = ["Va", "Vb", "Vc", "Vd"] @@ -4283,7 +4314,7 @@ class Ld3(AArch64Instruction): class ld3_base(Ld3): - pattern = "ld3 {., ., .}, []" + pattern = "ld3 {.
, .
, .
}, []" inputs = ["Xc"] outputs = ["Va", "Vb", "Vc"] @@ -4299,7 +4330,7 @@ def make(cls, src): class ld3_with_inc(Ld3): - pattern = "ld3 {., ., .}, [], " + pattern = "ld3 {.
, .
, .
}, [], " in_outs = ["Xc"] outputs = ["Va", "Vb", "Vc"] @@ -4320,7 +4351,7 @@ class Ld2(AArch64Instruction): class ld2_base(Ld2): - pattern = "ld2 {., .}, []" + pattern = "ld2 {.
, .
}, []" inputs = ["Xc"] outputs = ["Va", "Vb"] @@ -4336,7 +4367,7 @@ def make(cls, src): class ld2_with_inc(Ld2): - pattern = "ld2 {., .}, [], " + pattern = "ld2 {.
, .
}, [], " in_outs = ["Xc"] outputs = ["Va", "Vb"] @@ -4357,43 +4388,43 @@ class ASimdCompare(AArch64Instruction): class cmge(ASimdCompare): - pattern = "cmge ., ., ." + pattern = "cmge .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class cmhi(ASimdCompare): - pattern = "cmhi ., ., ." + pattern = "cmhi .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class cmeq(ASimdCompare): - pattern = "cmeq ., ., ." + pattern = "cmeq .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class cmgt(ASimdCompare): - pattern = "cmgt ., ., ." + pattern = "cmgt .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class cmhs(ASimdCompare): - pattern = "cmhs ., ., ." + pattern = "cmhs .
, .
, .
" inputs = ["Va", "Vb"] outputs = ["Vd"] class cmle(ASimdCompare): - pattern = "cmle ., ., " + pattern = "cmle .
, .
, " inputs = ["Va"] outputs = ["Vd"] class cmlt(ASimdCompare): - pattern = "cmlt ., ., " + pattern = "cmlt .
, .
, " inputs = ["Va"] outputs = ["Vd"]