PyPTO Pass
算子代码
def add_mul_kernel(
a: pypto.Tensor([], pypto.DT_FP32),
b: pypto.Tensor([], pypto.DT_FP32),
d: pypto.Tensor([], pypto.DT_FP32),
c: pypto.Tensor([], pypto.DT_FP32),
e: pypto.Tensor([], pypto.DT_FP32),
):
# Vector tile 16x16 covers element-wise add and mul on 32x32 tensors
pypto.set_vec_tile_shapes(16, 16)
# Cube tile 16x16 (M=16, K=16, N=16) — declared per spec
pypto.set_cube_tile_shapes([16, 16], [16, 16], [16, 16])
c[:] = pypto.add(a, b) # c = a + b
e[:] = pypto.mul(d, c) # e = d * c
return add_mul_kernel
张量计算图
Loading computation graph...
Tile切分
Loading computation graph...
优化与染色
Loading computation graph...