[Torch Dialect] support aten.glu (#2531)

pull/2535/head snapshot-20231026.1003
Yuanqiang Liu 2023-10-26 10:36:18 +08:00 committed by GitHub
parent b0f39ac966
commit e7282487ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 136 additions and 0 deletions

View File

@ -963,6 +963,7 @@ TOSA_PASS_SET = {
"ElementwiseMaximumIntModule_basic",
"ElementwiseMaxOtherIntModule_basic",
"ElementwiseMaxOtherModule_basic",
"GluStaticModule_basic",
"ViewDoubleMergeStaticModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewFiveTestStaticModule_basic",

View File

@ -4189,6 +4189,30 @@ def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [
}];
}
def Torch_AtenGluOp : Torch_Op<"aten.glu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::glu : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenGluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -6382,6 +6382,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.glu\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: glu's dim size must be multiply of 2\"\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %13 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %14 = torch.aten.add.int %arg1, %13 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %14 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" }\n"
" %2 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.remainder.int %2, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.slice.t %arg0, %none, %1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %8 = torch.prim.ListConstruct %7 : (!torch.int) -> !torch.list<int>\n"
" %9 = torch.aten.add.t %5, %8 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" %10 = torch.aten.add.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %11 = torch.aten.slice.t %arg0, %10, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %12 = torch.aten.add.t %9, %11 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" return %12 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -8863,6 +8896,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.glu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter_reduce.two\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -361,6 +361,48 @@ public:
};
} // namespace
namespace {
class DecomposeAtenGluOp : public OpRewritePattern<AtenGluOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenGluOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value dim = op.getDim();
auto outputTy = op.getType().dyn_cast<Torch::ValueTensorType>();
if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "Expected output type having sizes and dtype");
}
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
Value two =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
Value remainder = rewriter.create<AtenRemainderIntOp>(loc, dimSize, two);
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, zero);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("AtenGluOp's dim size must be multiply of 2"));
Value splitLength = rewriter.create<AtenFloordivIntOp>(loc, dimSize, two);
Value a = rewriter.create<AtenNarrowOp>(loc, outputTy, self, dim, zero,
splitLength);
Value b = rewriter.create<AtenNarrowOp>(loc, outputTy, self, dim,
splitLength, splitLength);
// a⊗σ(b)
Value sigmoidB = rewriter.create<AtenSigmoidOp>(loc, outputTy, b);
Value result = rewriter.create<AtenMulTensorOp>(loc, outputTy, a, sigmoidB);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenZeroOp
: public OpRewritePattern<AtenZeroOp> {
@ -5289,6 +5331,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenGluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);

View File

@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenHardsigmoidOp>();
target.addIllegalOp<AtenRelu6Op>();
target.addIllegalOp<AtenEluOp>();
target.addIllegalOp<AtenGluOp>();
target.addIllegalOp<AtenHardswishOp>();
target.addIllegalOp<AtenSoftplusOp>();
target.addIllegalOp<AtenSiluOp>();

View File

@ -167,6 +167,12 @@ def atenrelu6〡shape(self: List[int]) -> List[int]:
def atenround〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenglu〡shape(self: List[int], dim: int = -1) -> List[int]:
if dim < 0:
dim += len(self)
assert self[dim] % 2 == 0, "glu's dim size must be multiply of 2"
return self[:dim] + [self[dim] // 2] + self[dim+1:]
def aten_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]:
return upstream_shape_functions.unary(self)
@ -1932,6 +1938,11 @@ def atenround〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(100,)], dim=0))
def atenglu〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype), "sum") for dtype in _SORTED_TORCH_TYPES])
def atenscatter_reducetwo〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], reduce: str, include_self: bool = True) -> int:

View File

@ -354,6 +354,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
emit("aten::view_as_real : (Tensor) -> (Tensor)")
emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)")
emit("aten::glu : (Tensor, int) -> (Tensor)")
# Ops with dynamic number of outputs
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")

View File

@ -3685,3 +3685,21 @@ class ElementwiseBitwiseAndScalarInt8Module(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt8Module())
def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int8))
# ==============================================================================
class GluStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 24, 5], torch.float32, True)
])
def forward(self, x):
return torch.ops.aten.glu(x, dim=1)
@register_test_case(module_factory=lambda: GluStaticModule())
def GluStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 24, 5))