blob: 7dd424798d693d68dc2a0e12fc55f7b37d430d08 [file] [log] [blame]
#include <asm/hwcap.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <sys/auxv.h>
#include <sys/prctl.h>
// Important details for this program:
// * Making a syscall will disable streaming mode if it is active.
// * Changing the vector length will make streaming mode and ZA inactive.
// * ZA can be active independent of streaming mode.
// * ZA's size is the streaming vector length squared.
#ifndef PR_SME_SET_VL
#define PR_SME_SET_VL 63
#endif
#ifndef PR_SME_GET_VL
#define PR_SME_GET_VL 64
#endif
#ifndef PR_SME_VL_LEN_MASK
#define PR_SME_VL_LEN_MASK 0xffff
#endif
#define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr")
#define SMSTART SM_INST(7)
#define SMSTART_SM SM_INST(3)
#define SMSTART_ZA SM_INST(5)
#define SMSTOP SM_INST(6)
#define SMSTOP_SM SM_INST(2)
#define SMSTOP_ZA SM_INST(4)
#ifndef HWCAP2_SME2
#define HWCAP2_SME2 (1UL << 37)
#endif
int start_vl = 0;
int other_vl = 0;
bool has_zt0 = false;
void write_sve_regs() {
// We assume the smefa64 feature is present, which allows ffr access
// in streaming mode.
asm volatile("setffr\n\t");
asm volatile("ptrue p0.b\n\t");
asm volatile("ptrue p1.h\n\t");
asm volatile("ptrue p2.s\n\t");
asm volatile("ptrue p3.d\n\t");
asm volatile("pfalse p4.b\n\t");
asm volatile("ptrue p5.b\n\t");
asm volatile("ptrue p6.h\n\t");
asm volatile("ptrue p7.s\n\t");
asm volatile("ptrue p8.d\n\t");
asm volatile("pfalse p9.b\n\t");
asm volatile("ptrue p10.b\n\t");
asm volatile("ptrue p11.h\n\t");
asm volatile("ptrue p12.s\n\t");
asm volatile("ptrue p13.d\n\t");
asm volatile("pfalse p14.b\n\t");
asm volatile("ptrue p15.b\n\t");
asm volatile("cpy z0.b, p0/z, #1\n\t");
asm volatile("cpy z1.b, p5/z, #2\n\t");
asm volatile("cpy z2.b, p10/z, #3\n\t");
asm volatile("cpy z3.b, p15/z, #4\n\t");
asm volatile("cpy z4.b, p0/z, #5\n\t");
asm volatile("cpy z5.b, p5/z, #6\n\t");
asm volatile("cpy z6.b, p10/z, #7\n\t");
asm volatile("cpy z7.b, p15/z, #8\n\t");
asm volatile("cpy z8.b, p0/z, #9\n\t");
asm volatile("cpy z9.b, p5/z, #10\n\t");
asm volatile("cpy z10.b, p10/z, #11\n\t");
asm volatile("cpy z11.b, p15/z, #12\n\t");
asm volatile("cpy z12.b, p0/z, #13\n\t");
asm volatile("cpy z13.b, p5/z, #14\n\t");
asm volatile("cpy z14.b, p10/z, #15\n\t");
asm volatile("cpy z15.b, p15/z, #16\n\t");
asm volatile("cpy z16.b, p0/z, #17\n\t");
asm volatile("cpy z17.b, p5/z, #18\n\t");
asm volatile("cpy z18.b, p10/z, #19\n\t");
asm volatile("cpy z19.b, p15/z, #20\n\t");
asm volatile("cpy z20.b, p0/z, #21\n\t");
asm volatile("cpy z21.b, p5/z, #22\n\t");
asm volatile("cpy z22.b, p10/z, #23\n\t");
asm volatile("cpy z23.b, p15/z, #24\n\t");
asm volatile("cpy z24.b, p0/z, #25\n\t");
asm volatile("cpy z25.b, p5/z, #26\n\t");
asm volatile("cpy z26.b, p10/z, #27\n\t");
asm volatile("cpy z27.b, p15/z, #28\n\t");
asm volatile("cpy z28.b, p0/z, #29\n\t");
asm volatile("cpy z29.b, p5/z, #30\n\t");
asm volatile("cpy z30.b, p10/z, #31\n\t");
asm volatile("cpy z31.b, p15/z, #32\n\t");
}
// Write something different so we will know if we didn't restore them
// correctly.
void write_sve_regs_expr() {
asm volatile("pfalse p0.b\n\t");
asm volatile("wrffr p0.b\n\t");
asm volatile("pfalse p1.b\n\t");
asm volatile("pfalse p2.b\n\t");
asm volatile("pfalse p3.b\n\t");
asm volatile("ptrue p4.b\n\t");
asm volatile("pfalse p5.b\n\t");
asm volatile("pfalse p6.b\n\t");
asm volatile("pfalse p7.b\n\t");
asm volatile("pfalse p8.b\n\t");
asm volatile("ptrue p9.b\n\t");
asm volatile("pfalse p10.b\n\t");
asm volatile("pfalse p11.b\n\t");
asm volatile("pfalse p12.b\n\t");
asm volatile("pfalse p13.b\n\t");
asm volatile("ptrue p14.b\n\t");
asm volatile("pfalse p15.b\n\t");
asm volatile("cpy z0.b, p0/z, #2\n\t");
asm volatile("cpy z1.b, p5/z, #3\n\t");
asm volatile("cpy z2.b, p10/z, #4\n\t");
asm volatile("cpy z3.b, p15/z, #5\n\t");
asm volatile("cpy z4.b, p0/z, #6\n\t");
asm volatile("cpy z5.b, p5/z, #7\n\t");
asm volatile("cpy z6.b, p10/z, #8\n\t");
asm volatile("cpy z7.b, p15/z, #9\n\t");
asm volatile("cpy z8.b, p0/z, #10\n\t");
asm volatile("cpy z9.b, p5/z, #11\n\t");
asm volatile("cpy z10.b, p10/z, #12\n\t");
asm volatile("cpy z11.b, p15/z, #13\n\t");
asm volatile("cpy z12.b, p0/z, #14\n\t");
asm volatile("cpy z13.b, p5/z, #15\n\t");
asm volatile("cpy z14.b, p10/z, #16\n\t");
asm volatile("cpy z15.b, p15/z, #17\n\t");
asm volatile("cpy z16.b, p0/z, #18\n\t");
asm volatile("cpy z17.b, p5/z, #19\n\t");
asm volatile("cpy z18.b, p10/z, #20\n\t");
asm volatile("cpy z19.b, p15/z, #21\n\t");
asm volatile("cpy z20.b, p0/z, #22\n\t");
asm volatile("cpy z21.b, p5/z, #23\n\t");
asm volatile("cpy z22.b, p10/z, #24\n\t");
asm volatile("cpy z23.b, p15/z, #25\n\t");
asm volatile("cpy z24.b, p0/z, #26\n\t");
asm volatile("cpy z25.b, p5/z, #27\n\t");
asm volatile("cpy z26.b, p10/z, #28\n\t");
asm volatile("cpy z27.b, p15/z, #29\n\t");
asm volatile("cpy z28.b, p0/z, #30\n\t");
asm volatile("cpy z29.b, p5/z, #31\n\t");
asm volatile("cpy z30.b, p10/z, #32\n\t");
asm volatile("cpy z31.b, p15/z, #33\n\t");
}
void set_sme_registers(int svl, uint8_t value_offset) {
#define MAX_VL_BYTES 256
uint8_t data[MAX_VL_BYTES];
// ldr za will actually wrap the selected vector row, by the number of rows
// you have. So setting one that didn't exist would actually set one that did.
// That's why we need the streaming vector length here.
for (int i = 0; i < svl; ++i) {
memset(data, i + value_offset, MAX_VL_BYTES);
// Each one of these loads a VL sized row of ZA.
asm volatile("mov w12, %w0\n\t"
"ldr za[w12, 0], [%1]\n\t" ::"r"(i),
"r"(&data)
: "w12");
}
#undef MAX_VL_BYTES
if (has_zt0) {
#define ZTO_LEN (512 / 8)
uint8_t data[ZTO_LEN];
for (unsigned i = 0; i < ZTO_LEN; ++i)
data[i] = i + value_offset;
asm volatile("ldr zt0, [%0]" ::"r"(&data));
#undef ZT0_LEN
}
}
void expr_disable_za() {
SMSTOP_ZA;
write_sve_regs_expr();
}
void expr_enable_za() {
SMSTART_ZA;
set_sme_registers(start_vl, 2);
write_sve_regs_expr();
}
void expr_start_vl() {
prctl(PR_SME_SET_VL, start_vl);
SMSTART_ZA;
set_sme_registers(start_vl, 4);
write_sve_regs_expr();
}
void expr_other_vl() {
prctl(PR_SME_SET_VL, other_vl);
SMSTART_ZA;
set_sme_registers(other_vl, 5);
write_sve_regs_expr();
}
void expr_enable_sm() {
SMSTART_SM;
write_sve_regs_expr();
}
void expr_disable_sm() {
SMSTOP_SM;
write_sve_regs_expr();
}
int main(int argc, char *argv[]) {
// We expect to get:
// * whether to enable streaming mode
// * whether to enable ZA
// * what the starting VL should be
// * what the other VL should be
if (argc != 5)
return 1;
bool ssve = argv[1][0] == '1';
bool za = argv[2][0] == '1';
start_vl = atoi(argv[3]);
other_vl = atoi(argv[4]);
if ((getauxval(AT_HWCAP2) & HWCAP2_SME2))
has_zt0 = true;
prctl(PR_SME_SET_VL, start_vl);
if (ssve)
SMSTART_SM;
if (za) {
SMSTART_ZA;
set_sme_registers(start_vl, 1);
}
write_sve_regs();
return 0; // Set a break point here.
}